1#include <ATen/core/jit_type.h>
2#include <c10/util/irange.h>
3
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/ir/subgraph_matcher.h>
6#include <torch/csrc/jit/passes/constant_pooling.h>
7#include <torch/csrc/jit/passes/fold_conv_bn.h>
8#include <torch/csrc/jit/passes/freeze_module.h>
9#include <torch/csrc/jit/passes/fuse_linear.h>
10#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11#include <torch/csrc/jit/passes/metal_rewrite.h>
12#include <torch/csrc/jit/passes/prepack_folding.h>
13#include <torch/csrc/jit/passes/remove_dropout.h>
14#include <torch/csrc/jit/passes/remove_mutation.h>
15#include <torch/csrc/jit/passes/subgraph_rewrite.h>
16#include <torch/csrc/jit/runtime/graph_executor_impl.h>
17
18namespace torch {
19namespace jit {
20
21namespace {
22
23void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
24 // fuse decomposed linear into aten::linear
25 FuseLinear(graph);
26
27 std::string linear_pattern = R"(
28 graph(%input, %weight, %bias):
29 %r = aten::linear(%input, %weight, %bias)
30 return (%r))";
31 std::string prepacked_ops_pattern = R"(
32 graph(%input, %weight, %bias):
33 %output_min_max : None = prim::Constant()
34 %packed_weight_bias = metal_prepack::linear_prepack(
35 %weight, %bias, %output_min_max, %output_min_max)
36 %res = metal_prepack::linear_run(%input, %packed_weight_bias)
37 return (%res))";
38
39 SubgraphRewriter linear_rewriter;
40 linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
41 linear_rewriter.runOnGraph(graph);
42}
43
44void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
45 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
46
47 std::string conv_2d_pattern = R"(
48 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
49 %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
50 return (%r) )";
51
52 std::string prepacked_ops_conv2d_pattern = R"(
53 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
54 %dilation:int[], %groups:int):
55 %output_min_max : None = prim::Constant()
56 %packed_weight_bias = metal_prepack::conv2d_prepack(
57 %weight, %bias, %stride, %padding, %dilation, %groups,
58 %output_min_max, %output_min_max)
59 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
60 return (%r) )";
61
62 SubgraphRewriter rewriter;
63 rewriter.RegisterRewritePattern(
64 conv_2d_pattern, prepacked_ops_conv2d_pattern);
65 rewriter.runOnGraph(graph);
66}
67
68void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
69 SubgraphRewriter rewriter;
70
71 std::string linear_prepack_run_relu_fused = R"(
72 graph(%input, %weight, %bias, %dummy_min_max):
73 %output_min: float = prim::Constant[value=0.0]()
74 %output_max: None = prim::Constant()
75 %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack(
76 %weight, %bias, %output_min, %output_max)
77 %res = metal_prepack::linear_run(%input, %packed_weight_bias)
78 return (%res))";
79
80 std::string linear_prepack_run_relu = R"(
81 graph(%input, %weight, %bias, %dummy_min_max):
82 %packed_weight_bias = metal_prepack::linear_prepack(
83 %weight, %bias, %dummy_min_max, %dummy_min_max)
84 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
85 %res = aten::relu(%linear_res)
86 return (%res))";
87
88 rewriter.RegisterRewritePattern(
89 linear_prepack_run_relu, linear_prepack_run_relu_fused);
90
91 std::string conv2d_prepack_run_relu = R"(
92 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
93 %dilation:int[], %groups:int, %dummy_min_max):
94 %packed_weight_bias = metal_prepack::conv2d_prepack(
95 %weight, %bias, %stride, %padding, %dilation, %groups,
96 %dummy_min_max, %dummy_min_max)
97 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
98 %r = aten::relu(%r)
99 return (%r) )";
100
101 std::string conv2d_prepack_run_relu_fused = R"(
102 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
103 %dilation:int[], %groups:int, %dummy_min_max):
104 %output_min: float = prim::Constant[value=0.0]()
105 %output_max: None = prim::Constant()
106 %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
107 %weight, %bias, %stride, %padding, %dilation, %groups,
108 %output_min, %output_max)
109 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
110 return (%r) )";
111
112 rewriter.RegisterRewritePattern(
113 conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
114
115 std::string linear_prepack_run_relu_inplace = R"(
116 graph(%input, %weight, %bias, %dummy_min_max):
117 %packed_weight_bias = metal_prepack::linear_prepack(
118 %weight, %bias, %dummy_min_max, %dummy_min_max)
119 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
120 %res = aten::relu_(%linear_res)
121 return (%res))";
122
123 std::string conv2d_prepack_run_relu_inplace = R"(
124 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
125 %dilation:int[], %groups:int, %dummy_min_max):
126 %packed_weight_bias = metal_prepack::conv2d_prepack(
127 %weight, %bias, %stride, %padding, %dilation, %groups,
128 %dummy_min_max, %dummy_min_max)
129 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
130 %r = aten::relu_(%r)
131 return (%r) )";
132
133 rewriter.RegisterRewritePattern(
134 linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
135 rewriter.RegisterRewritePattern(
136 conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
137
138 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
139}
140
141void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
142 SubgraphRewriter rewriter;
143
144 std::string linear_prepack_run_hardtanh_fused = R"(
145 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
146 %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack(%weight, %bias, %output_min, %output_max)
147 %res = metal_prepack::linear_run(%input, %packed_weight_bias)
148 return (%res))";
149
150 std::string linear_prepack_run_hardtanh = R"(
151 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
152 %packed_weight_bias = metal_prepack::linear_prepack(
153 %weight, %bias, %dummy_min_max, %dummy_min_max)
154 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
155 %res = aten::hardtanh(%linear_res, %output_min, %output_max)
156 return (%res))";
157
158 rewriter.RegisterRewritePattern(
159 linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused);
160
161 std::string conv2d_prepack_run_hardtanh_fused = R"(
162 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
163 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
164 %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
165 %weight, %bias, %stride, %padding, %dilation, %groups,
166 %output_min, %output_max)
167 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
168 return (%r) )";
169
170 std::string conv2d_prepack_run_hardtanh = R"(
171 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
172 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
173 %packed_weight_bias = metal_prepack::conv2d_prepack(
174 %weight, %bias, %stride, %padding, %dilation, %groups,
175 %dummy_min_max, %dummy_min_max)
176 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
177 %r = aten::hardtanh(%r, %output_min, %output_max)
178 return (%r) )";
179
180 rewriter.RegisterRewritePattern(
181 conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
182
183 std::string conv2d_prepack_run_hardtanh_inplace = R"(
184 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
185 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
186 %packed_weight_bias = metal_prepack::conv2d_prepack(
187 %weight, %bias, %stride, %padding, %dilation, %groups,
188 %dummy_min_max, %dummy_min_max)
189 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
190 %r = aten::hardtanh_(%r, %output_min, %output_max)
191 return (%r) )";
192
193 std::string linear_prepack_run_hardtanh_inplace = R"(
194 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
195 %packed_weight_bias = metal_prepack::linear_prepack(
196 %weight, %bias, %dummy_min_max, %dummy_min_max)
197 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
198 %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
199 return (%res))";
200
201 rewriter.RegisterRewritePattern(
202 linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused);
203
204 rewriter.RegisterRewritePattern(
205 conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
206
207 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
208}
209
210} // namespace
211
212void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
213 insertPrePackedLinearOp(graph);
214 insertPrePackedConv2dOp(graph);
215}
216
217void metalInsertPrePackedOps(script::Module& module) {
218 for (auto& method : module.get_methods()) {
219 auto graph = method.graph();
220 metalInsertPrePackedOps(graph);
221 }
222 for (script::Module m : module.children()) {
223 metalInsertPrePackedOps(m);
224 }
225}
226
227void metalFoldPrePackingOps(script::Module& m) {
228 PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
229 return (
230 (n->kind() ==
231 Symbol::fromQualString("metal_prepack::conv2d_prepack")) ||
232 (n->kind() == Symbol::fromQualString("metal_prepack::linear_prepack")));
233 };
234 PrePackingOpsFolder(m, filter_fn, "prepack_folding");
235}
236
237void metalFusePrePackedConvWithClamp(script::Module& module) {
238 auto graph = module.get_method("forward").graph();
239 fuseReluWithPackedOps(graph);
240 fuseHardtanhWithPackedOps(graph);
241}
242
243void metalInsertCopyOps(script::Module& module) {
244 auto graph = module.get_method("forward").graph();
245 auto&& outputs = graph->outputs();
246 for (const auto i : c10::irange(outputs.size())) {
247 Value* output = outputs[i];
248 auto namedValue = NamedValue("", output);
249 if (namedValue.type()->kind() == TypeKind::TensorType) {
250 // find the insertion point
251 WithInsertPoint ip(output->node()->next());
252 Value* replaced_output = graph->insert(
253 Symbol::fromQualString("metal::copy_to_host"), {namedValue});
254 // replaced the output
255 graph->block()->replaceOutput(i, replaced_output);
256 }
257 }
258 SubgraphRewriter rewriter;
259 rewriter.runOnGraph(graph);
260}
261
262void metalRemoveMutation(script::Module& module) {
263 auto graph = module.get_method("forward").graph();
264 RemoveTensorMutation(graph);
265}
266
267void metalRunCanonicalOptimizations(script::Module& module) {
268 auto graph = module.get_method("forward").graph();
269 runOptimization(graph, false /* no loop unrolling */);
270}
271
272script::Module metalOptimizeForMobile(
273 const script::Module& m,
274 const std::vector<std::string>& preserved_methods) {
275 auto cloned_module = m.clone();
276 cloned_module.eval();
277 cloned_module = FoldConvBatchNorm(cloned_module);
278 metalInsertPrePackedOps(cloned_module);
279 cloned_module = freeze_module(cloned_module, preserved_methods);
280 metalFusePrePackedConvWithClamp(cloned_module);
281 metalFoldPrePackingOps(cloned_module);
282 removeDropout(cloned_module);
283 metalRemoveMutation(cloned_module);
284 // remove duplicated constants
285 metalRunCanonicalOptimizations(cloned_module);
286 cloned_module.register_attribute(
287 "optimized_for_metal", BoolType::get(), true);
288 return cloned_module;
289}
290
291} // namespace jit
292} // namespace torch
293