1 | #include <ATen/core/jit_type.h> |
2 | #include <c10/util/irange.h> |
3 | |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
6 | #include <torch/csrc/jit/passes/constant_pooling.h> |
7 | #include <torch/csrc/jit/passes/fold_conv_bn.h> |
8 | #include <torch/csrc/jit/passes/freeze_module.h> |
9 | #include <torch/csrc/jit/passes/fuse_linear.h> |
10 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
11 | #include <torch/csrc/jit/passes/metal_rewrite.h> |
12 | #include <torch/csrc/jit/passes/prepack_folding.h> |
13 | #include <torch/csrc/jit/passes/remove_dropout.h> |
14 | #include <torch/csrc/jit/passes/remove_mutation.h> |
15 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
16 | #include <torch/csrc/jit/runtime/graph_executor_impl.h> |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | |
21 | namespace { |
22 | |
23 | void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) { |
24 | // fuse decomposed linear into aten::linear |
25 | FuseLinear(graph); |
26 | |
27 | std::string linear_pattern = R"( |
28 | graph(%input, %weight, %bias): |
29 | %r = aten::linear(%input, %weight, %bias) |
30 | return (%r))" ; |
31 | std::string prepacked_ops_pattern = R"( |
32 | graph(%input, %weight, %bias): |
33 | %output_min_max : None = prim::Constant() |
34 | %packed_weight_bias = metal_prepack::linear_prepack( |
35 | %weight, %bias, %output_min_max, %output_min_max) |
36 | %res = metal_prepack::linear_run(%input, %packed_weight_bias) |
37 | return (%res))" ; |
38 | |
39 | SubgraphRewriter linear_rewriter; |
40 | linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); |
41 | linear_rewriter.runOnGraph(graph); |
42 | } |
43 | |
44 | void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) { |
45 | graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); |
46 | |
47 | std::string conv_2d_pattern = R"( |
48 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
49 | %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
50 | return (%r) )" ; |
51 | |
52 | std::string prepacked_ops_conv2d_pattern = R"( |
53 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
54 | %dilation:int[], %groups:int): |
55 | %output_min_max : None = prim::Constant() |
56 | %packed_weight_bias = metal_prepack::conv2d_prepack( |
57 | %weight, %bias, %stride, %padding, %dilation, %groups, |
58 | %output_min_max, %output_min_max) |
59 | %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) |
60 | return (%r) )" ; |
61 | |
62 | SubgraphRewriter rewriter; |
63 | rewriter.RegisterRewritePattern( |
64 | conv_2d_pattern, prepacked_ops_conv2d_pattern); |
65 | rewriter.runOnGraph(graph); |
66 | } |
67 | |
68 | void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) { |
69 | SubgraphRewriter rewriter; |
70 | |
71 | std::string linear_prepack_run_relu_fused = R"( |
72 | graph(%input, %weight, %bias, %dummy_min_max): |
73 | %output_min: float = prim::Constant[value=0.0]() |
74 | %output_max: None = prim::Constant() |
75 | %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack( |
76 | %weight, %bias, %output_min, %output_max) |
77 | %res = metal_prepack::linear_run(%input, %packed_weight_bias) |
78 | return (%res))" ; |
79 | |
80 | std::string linear_prepack_run_relu = R"( |
81 | graph(%input, %weight, %bias, %dummy_min_max): |
82 | %packed_weight_bias = metal_prepack::linear_prepack( |
83 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
84 | %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias) |
85 | %res = aten::relu(%linear_res) |
86 | return (%res))" ; |
87 | |
88 | rewriter.RegisterRewritePattern( |
89 | linear_prepack_run_relu, linear_prepack_run_relu_fused); |
90 | |
91 | std::string conv2d_prepack_run_relu = R"( |
92 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
93 | %dilation:int[], %groups:int, %dummy_min_max): |
94 | %packed_weight_bias = metal_prepack::conv2d_prepack( |
95 | %weight, %bias, %stride, %padding, %dilation, %groups, |
96 | %dummy_min_max, %dummy_min_max) |
97 | %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) |
98 | %r = aten::relu(%r) |
99 | return (%r) )" ; |
100 | |
101 | std::string conv2d_prepack_run_relu_fused = R"( |
102 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
103 | %dilation:int[], %groups:int, %dummy_min_max): |
104 | %output_min: float = prim::Constant[value=0.0]() |
105 | %output_max: None = prim::Constant() |
106 | %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack( |
107 | %weight, %bias, %stride, %padding, %dilation, %groups, |
108 | %output_min, %output_max) |
109 | %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) |
110 | return (%r) )" ; |
111 | |
112 | rewriter.RegisterRewritePattern( |
113 | conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused); |
114 | |
115 | std::string linear_prepack_run_relu_inplace = R"( |
116 | graph(%input, %weight, %bias, %dummy_min_max): |
117 | %packed_weight_bias = metal_prepack::linear_prepack( |
118 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
119 | %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias) |
120 | %res = aten::relu_(%linear_res) |
121 | return (%res))" ; |
122 | |
123 | std::string conv2d_prepack_run_relu_inplace = R"( |
124 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
125 | %dilation:int[], %groups:int, %dummy_min_max): |
126 | %packed_weight_bias = metal_prepack::conv2d_prepack( |
127 | %weight, %bias, %stride, %padding, %dilation, %groups, |
128 | %dummy_min_max, %dummy_min_max) |
129 | %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) |
130 | %r = aten::relu_(%r) |
131 | return (%r) )" ; |
132 | |
133 | rewriter.RegisterRewritePattern( |
134 | linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused); |
135 | rewriter.RegisterRewritePattern( |
136 | conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused); |
137 | |
138 | rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); |
139 | } |
140 | |
141 | void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) { |
142 | SubgraphRewriter rewriter; |
143 | |
144 | std::string linear_prepack_run_hardtanh_fused = R"( |
145 | graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): |
146 | %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack(%weight, %bias, %output_min, %output_max) |
147 | %res = metal_prepack::linear_run(%input, %packed_weight_bias) |
148 | return (%res))" ; |
149 | |
150 | std::string linear_prepack_run_hardtanh = R"( |
151 | graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): |
152 | %packed_weight_bias = metal_prepack::linear_prepack( |
153 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
154 | %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias) |
155 | %res = aten::hardtanh(%linear_res, %output_min, %output_max) |
156 | return (%res))" ; |
157 | |
158 | rewriter.RegisterRewritePattern( |
159 | linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused); |
160 | |
161 | std::string conv2d_prepack_run_hardtanh_fused = R"( |
162 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
163 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
164 | %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack( |
165 | %weight, %bias, %stride, %padding, %dilation, %groups, |
166 | %output_min, %output_max) |
167 | %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) |
168 | return (%r) )" ; |
169 | |
170 | std::string conv2d_prepack_run_hardtanh = R"( |
171 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
172 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
173 | %packed_weight_bias = metal_prepack::conv2d_prepack( |
174 | %weight, %bias, %stride, %padding, %dilation, %groups, |
175 | %dummy_min_max, %dummy_min_max) |
176 | %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) |
177 | %r = aten::hardtanh(%r, %output_min, %output_max) |
178 | return (%r) )" ; |
179 | |
180 | rewriter.RegisterRewritePattern( |
181 | conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused); |
182 | |
183 | std::string conv2d_prepack_run_hardtanh_inplace = R"( |
184 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
185 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
186 | %packed_weight_bias = metal_prepack::conv2d_prepack( |
187 | %weight, %bias, %stride, %padding, %dilation, %groups, |
188 | %dummy_min_max, %dummy_min_max) |
189 | %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) |
190 | %r = aten::hardtanh_(%r, %output_min, %output_max) |
191 | return (%r) )" ; |
192 | |
193 | std::string linear_prepack_run_hardtanh_inplace = R"( |
194 | graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): |
195 | %packed_weight_bias = metal_prepack::linear_prepack( |
196 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
197 | %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias) |
198 | %res = aten::hardtanh_(%linear_res, %output_min, %output_max) |
199 | return (%res))" ; |
200 | |
201 | rewriter.RegisterRewritePattern( |
202 | linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused); |
203 | |
204 | rewriter.RegisterRewritePattern( |
205 | conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused); |
206 | |
207 | rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); |
208 | } |
209 | |
210 | } // namespace |
211 | |
212 | void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph) { |
213 | insertPrePackedLinearOp(graph); |
214 | insertPrePackedConv2dOp(graph); |
215 | } |
216 | |
217 | void metalInsertPrePackedOps(script::Module& module) { |
218 | for (auto& method : module.get_methods()) { |
219 | auto graph = method.graph(); |
220 | metalInsertPrePackedOps(graph); |
221 | } |
222 | for (script::Module m : module.children()) { |
223 | metalInsertPrePackedOps(m); |
224 | } |
225 | } |
226 | |
227 | void metalFoldPrePackingOps(script::Module& m) { |
228 | PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { |
229 | return ( |
230 | (n->kind() == |
231 | Symbol::fromQualString("metal_prepack::conv2d_prepack" )) || |
232 | (n->kind() == Symbol::fromQualString("metal_prepack::linear_prepack" ))); |
233 | }; |
234 | PrePackingOpsFolder(m, filter_fn, "prepack_folding" ); |
235 | } |
236 | |
237 | void metalFusePrePackedConvWithClamp(script::Module& module) { |
238 | auto graph = module.get_method("forward" ).graph(); |
239 | fuseReluWithPackedOps(graph); |
240 | fuseHardtanhWithPackedOps(graph); |
241 | } |
242 | |
243 | void metalInsertCopyOps(script::Module& module) { |
244 | auto graph = module.get_method("forward" ).graph(); |
245 | auto&& outputs = graph->outputs(); |
246 | for (const auto i : c10::irange(outputs.size())) { |
247 | Value* output = outputs[i]; |
248 | auto namedValue = NamedValue("" , output); |
249 | if (namedValue.type()->kind() == TypeKind::TensorType) { |
250 | // find the insertion point |
251 | WithInsertPoint ip(output->node()->next()); |
252 | Value* replaced_output = graph->insert( |
253 | Symbol::fromQualString("metal::copy_to_host" ), {namedValue}); |
254 | // replaced the output |
255 | graph->block()->replaceOutput(i, replaced_output); |
256 | } |
257 | } |
258 | SubgraphRewriter rewriter; |
259 | rewriter.runOnGraph(graph); |
260 | } |
261 | |
262 | void metalRemoveMutation(script::Module& module) { |
263 | auto graph = module.get_method("forward" ).graph(); |
264 | RemoveTensorMutation(graph); |
265 | } |
266 | |
267 | void metalRunCanonicalOptimizations(script::Module& module) { |
268 | auto graph = module.get_method("forward" ).graph(); |
269 | runOptimization(graph, false /* no loop unrolling */); |
270 | } |
271 | |
272 | script::Module metalOptimizeForMobile( |
273 | const script::Module& m, |
274 | const std::vector<std::string>& preserved_methods) { |
275 | auto cloned_module = m.clone(); |
276 | cloned_module.eval(); |
277 | cloned_module = FoldConvBatchNorm(cloned_module); |
278 | metalInsertPrePackedOps(cloned_module); |
279 | cloned_module = freeze_module(cloned_module, preserved_methods); |
280 | metalFusePrePackedConvWithClamp(cloned_module); |
281 | metalFoldPrePackingOps(cloned_module); |
282 | removeDropout(cloned_module); |
283 | metalRemoveMutation(cloned_module); |
284 | // remove duplicated constants |
285 | metalRunCanonicalOptimizations(cloned_module); |
286 | cloned_module.register_attribute( |
287 | "optimized_for_metal" , BoolType::get(), true); |
288 | return cloned_module; |
289 | } |
290 | |
291 | } // namespace jit |
292 | } // namespace torch |
293 | |