1#include <ATen/core/jit_type.h>
2#include <ATen/native/xnnpack/OpContext.h>
3
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/ir/subgraph_matcher.h>
6#include <torch/csrc/jit/passes/constant_pooling.h>
7#include <torch/csrc/jit/passes/constant_propagation.h>
8#include <torch/csrc/jit/passes/fold_conv_bn.h>
9#include <torch/csrc/jit/passes/freeze_module.h>
10#include <torch/csrc/jit/passes/fuse_linear.h>
11#include <torch/csrc/jit/passes/fuse_relu.h>
12#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
13#include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
14#include <torch/csrc/jit/passes/inliner.h>
15#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
16#include <torch/csrc/jit/passes/prepack_folding.h>
17#include <torch/csrc/jit/passes/remove_dropout.h>
18#include <torch/csrc/jit/passes/subgraph_rewrite.h>
19#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
20#include <torch/csrc/jit/runtime/graph_executor_impl.h>
21
22namespace torch {
23namespace jit {
24
25namespace {
26
27void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
28 std::string conv_1d_pattern = R"(
29 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
30 %res = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
31 return (%res) )";
32
33 std::string conv_2d_pattern = R"(
34 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
35 %zero : int = prim::Constant[value=0]()
36 %one : int = prim::Constant[value=1]()
37 %stride_w : int = prim::ListUnpack(%stride)
38 %stride_2d : int[] = prim::ListConstruct(%one, %stride_w)
39 %padding_w : int = prim::ListUnpack(%padding)
40 %padding_2d : int[] = prim::ListConstruct(%zero, %padding_w)
41 %dilation_w : int = prim::ListUnpack(%dilation)
42 %dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w)
43 %two : int = prim::Constant[value=2]()
44 %input_2d : Tensor = aten::unsqueeze(%input, %two)
45 %weight_2d : Tensor = aten::unsqueeze(%weight, %two)
46 %output_2d = aten::conv2d(
47 %input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups)
48 %output : Tensor = aten::squeeze(%output_2d, %two)
49 return (%output) )";
50
51 std::vector<std::pair<std::string, std::string>> value_mappings(
52 {{"zero", "res"},
53 {"one", "res"},
54 {"stride_w", "res"},
55 {"stride_2d", "res"},
56 {"padding_w", "res"},
57 {"padding_2d", "res"},
58 {"dilation_w", "res"},
59 {"dilation_2d", "res"},
60 {"two", "res"},
61 {"input_2d", "res"},
62 {"weight_2d", "res"},
63 {"output_2d", "res"},
64 {"output", "res"}});
65
66 SubgraphRewriter rewriter;
67 rewriter.RegisterRewritePattern(
68 conv_1d_pattern, conv_2d_pattern, value_mappings);
69 rewriter.runOnGraph(graph);
70}
71
72} // namespace
73
74void transformConv1dToConv2d(std::shared_ptr<Graph>& graph) {
75 // Replace _convolution with conv1d and conv2d
76 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
77 replaceConv1dWithConv2d(graph);
78}
79
80void transformConv1dToConv2d(script::Module& module) {
81 for (auto& method : module.get_methods()) {
82 auto graph = method.graph();
83 transformConv1dToConv2d(graph);
84 }
85 for (script::Module m : module.children()) {
86 transformConv1dToConv2d(m);
87 }
88}
89
90#ifdef USE_XNNPACK
91
92namespace {
93
94void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
95 // fuse decomposed linear into aten::linear
96 FuseLinear(graph);
97
98 std::string linear_pattern = R"(
99 graph(%input, %weight, %bias):
100 %res = aten::linear(%input, %weight, %bias)
101 return (%res))";
102 std::string prepacked_ops_pattern = R"(
103 graph(%input, %weight, %bias):
104 %output_min_max : None = prim::Constant()
105 %packed_weight_bias = prepacked::linear_clamp_prepack(
106 %weight, %bias, %output_min_max, %output_min_max)
107 %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
108 return (%res))";
109
110 std::vector<std::pair<std::string, std::string>> value_mappings(
111 {{"output_min_max", "res"},
112 {"packed_weight_bias", "res"},
113 {"res", "res"}});
114
115 SubgraphRewriter linear_rewriter;
116 linear_rewriter.RegisterRewritePattern(
117 linear_pattern, prepacked_ops_pattern, value_mappings);
118 linear_rewriter.runOnGraph(graph);
119}
120
121void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
122 // Replace _convolution with conv2d
123 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
124
125 std::string conv_2d_pattern = R"(
126 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
127 %res = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
128 return (%res) )";
129
130 std::string prepacked_ops_conv2d_pattern = R"(
131 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
132 %output_min_max : None = prim::Constant()
133 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
134 %weight, %bias, %stride, %padding, %dilation, %groups,
135 %output_min_max, %output_min_max)
136 %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
137 return (%res) )";
138
139 std::vector<std::pair<std::string, std::string>> value_mappings(
140 {{"output_min_max", "res"},
141 {"packed_weight_bias", "res"},
142 {"res", "res"}});
143
144 SubgraphRewriter rewriter;
145 rewriter.RegisterRewritePattern(
146 conv_2d_pattern, prepacked_ops_conv2d_pattern, value_mappings);
147 rewriter.runOnGraph(graph);
148
149 std::string conv_2d_transpose_pattern = R"(
150 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
151 %output_padding:int[], %groups:int):
152 %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
153 return (%res) )";
154
155 std::string prepacked_ops_conv2d_transpose_pattern = R"(
156 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
157 %output_min_max : None = prim::Constant()
158 %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack(
159 %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
160 %output_min_max, %output_min_max)
161 %res = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias)
162 return (%res) )";
163
164 value_mappings = {
165 {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}};
166
167 SubgraphRewriter transpose_rewriter;
168 transpose_rewriter.RegisterRewritePattern(
169 conv_2d_transpose_pattern,
170 prepacked_ops_conv2d_transpose_pattern,
171 value_mappings);
172 transpose_rewriter.runOnGraph(graph);
173}
174
175void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
176 SubgraphRewriter rewriter;
177
178 std::string linear_prepack_run_hardtanh_fused = R"(
179 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
180 %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
181 %weight, %bias, %output_min, %output_max)
182 %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
183 return (%res))";
184
185 std::string conv2d_prepack_run_hardtanh_fused = R"(
186 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
187 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
188 %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
189 %weight, %bias, %stride, %padding, %dilation, %groups,
190 %output_min, %output_max)
191 %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
192 return (%res) )";
193
194 std::string linear_prepack_run_hardtanh = R"(
195 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
196 %packed_weight_bias = prepacked::linear_clamp_prepack(
197 %weight, %bias, %dummy_min_max, %dummy_min_max)
198 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
199 %res = aten::hardtanh(%linear_res, %output_min, %output_max)
200 return (%res))";
201
202 std::vector<std::pair<std::string, std::string>> value_mappings(
203 {{"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}});
204
205 rewriter.RegisterRewritePattern(
206 linear_prepack_run_hardtanh,
207 linear_prepack_run_hardtanh_fused,
208 value_mappings);
209
210 std::string conv2d_prepack_run_hardtanh = R"(
211 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
212 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
213 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
214 %weight, %bias, %stride, %padding, %dilation, %groups,
215 %dummy_min_max, %dummy_min_max)
216 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
217 %res = aten::hardtanh(%conv2d_res, %output_min, %output_max)
218 return (%res) )";
219
220 value_mappings = {
221 {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
222
223 rewriter.RegisterRewritePattern(
224 conv2d_prepack_run_hardtanh,
225 conv2d_prepack_run_hardtanh_fused,
226 value_mappings);
227
228 std::string linear_prepack_run_hardtanh_inplace = R"(
229 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
230 %packed_weight_bias = prepacked::linear_clamp_prepack(
231 %weight, %bias, %dummy_min_max, %dummy_min_max)
232 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
233 %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
234 return (%res))";
235
236 std::string conv2d_prepack_run_hardtanh_inplace = R"(
237 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
238 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
239 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
240 %weight, %bias, %stride, %padding, %dilation, %groups,
241 %dummy_min_max, %dummy_min_max)
242 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
243 %res = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
244 return (%res) )";
245
246 value_mappings = {
247 {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
248
249 rewriter.RegisterRewritePattern(
250 linear_prepack_run_hardtanh_inplace,
251 linear_prepack_run_hardtanh_fused,
252 value_mappings);
253
254 value_mappings = {
255 {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
256
257 rewriter.RegisterRewritePattern(
258 conv2d_prepack_run_hardtanh_inplace,
259 conv2d_prepack_run_hardtanh_fused,
260 value_mappings);
261
262 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
263}
264
265void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
266 SubgraphRewriter rewriter;
267
268 std::string linear_prepack_run_relu_fused = R"(
269 graph(%input, %weight, %bias, %dummy_min_max):
270 %output_min: float = prim::Constant[value=0.0]()
271 %output_max: None = prim::Constant()
272 %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
273 %weight, %bias, %output_min, %output_max)
274 %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
275 return (%res))";
276
277 std::string conv2d_prepack_run_relu_fused = R"(
278 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
279 %dilation:int[], %groups:int, %dummy_min_max):
280 %output_min: float = prim::Constant[value=0.0]()
281 %output_max: None = prim::Constant()
282 %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
283 %weight, %bias, %stride, %padding, %dilation, %groups,
284 %output_min, %output_max)
285 %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
286 return (%res) )";
287
288 std::string linear_prepack_run_relu = R"(
289 graph(%input, %weight, %bias, %dummy_min_max):
290 %packed_weight_bias = prepacked::linear_clamp_prepack(
291 %weight, %bias, %dummy_min_max, %dummy_min_max)
292 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
293 %res = aten::relu(%linear_res)
294 return (%res))";
295
296 std::vector<std::pair<std::string, std::string>> value_mappings(
297 {{"output_min", "packed_weight_bias"},
298 {"output_max", "packed_weight_bias"},
299 {"packed_weight_bias", "packed_weight_bias"},
300 {"res", "res"}});
301
302 rewriter.RegisterRewritePattern(
303 linear_prepack_run_relu, linear_prepack_run_relu_fused, value_mappings);
304
305 std::string conv2d_prepack_run_relu = R"(
306 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
307 %dilation:int[], %groups:int, %dummy_min_max):
308 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
309 %weight, %bias, %stride, %padding, %dilation, %groups,
310 %dummy_min_max, %dummy_min_max)
311 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
312 %res = aten::relu(%conv2d_res)
313 return (%res) )";
314
315 value_mappings = {
316 {"output_min", "packed_weight_bias"},
317 {"output_max", "packed_weight_bias"},
318 {"packed_weight_bias", "packed_weight_bias"},
319 {"res", "res"}};
320
321 rewriter.RegisterRewritePattern(
322 conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused, value_mappings);
323
324 std::string linear_prepack_run_relu_inplace = R"(
325 graph(%input, %weight, %bias, %dummy_min_max):
326 %packed_weight_bias = prepacked::linear_clamp_prepack(
327 %weight, %bias, %dummy_min_max, %dummy_min_max)
328 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
329 %res = aten::relu_(%linear_res)
330 return (%res))";
331
332 std::string conv2d_prepack_run_relu_inplace = R"(
333 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
334 %dilation:int[], %groups:int, %dummy_min_max):
335 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
336 %weight, %bias, %stride, %padding, %dilation, %groups,
337 %dummy_min_max, %dummy_min_max)
338 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
339 %res = aten::relu_(%conv2d_res)
340 return (%res) )";
341
342 value_mappings = {
343 {"output_min", "packed_weight_bias"},
344 {"output_max", "packed_weight_bias"},
345 {"packed_weight_bias", "packed_weight_bias"},
346 {"res", "res"}};
347
348 rewriter.RegisterRewritePattern(
349 linear_prepack_run_relu_inplace,
350 linear_prepack_run_relu_fused,
351 value_mappings);
352
353 value_mappings = {
354 {"output_min", "packed_weight_bias"},
355 {"output_max", "packed_weight_bias"},
356 {"packed_weight_bias", "packed_weight_bias"},
357 {"res", "res"}};
358
359 rewriter.RegisterRewritePattern(
360 conv2d_prepack_run_relu_inplace,
361 conv2d_prepack_run_relu_fused,
362 value_mappings);
363 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
364}
365
366void runCanonicalOptimizations(script::Module& module) {
367 for (const auto& method : module.get_methods()) {
368 auto graph = method.graph();
369 // Not sure if we have models running on mobile that require loop unrolling.
370 // Perhaps language/speech models? Conservatively setting that to false.
371 runOptimization(graph, false /* no loop unrolling */);
372 }
373}
374
375} // namespace
376
377void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
378 insertPrePackedLinearOp(graph);
379 insertPrePackedConv2dOp(graph);
380}
381
382void insertPrePackedOps(script::Module& module) {
383 for (auto& method : module.get_methods()) {
384 auto graph = method.graph();
385 insertPrePackedOps(graph);
386 }
387 for (script::Module m : module.children()) {
388 insertPrePackedOps(m);
389 }
390}
391
392void fusePrePackedLinearConvWithClamp(script::Module& module) {
393 for (auto& method : module.get_methods()) {
394 auto graph = method.graph();
395 fuseReluWithPackedOps(graph);
396 fuseHardtanhWithPackedOps(graph);
397
398 // Ignore user defined classes for later passes
399 ConstantPropagation(graph, true);
400 }
401}
402
403void FoldPrePackingOps(script::Module& m) {
404 PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
405 return (
406 (n->kind() ==
407 Symbol::fromQualString("prepacked::linear_clamp_prepack")) ||
408 n->kind() ==
409 Symbol::fromQualString("prepacked::conv2d_clamp_prepack") ||
410 n->kind() ==
411 Symbol::fromQualString(
412 "prepacked::conv2d_transpose_clamp_prepack"));
413 };
414 PrePackingOpsFolder(m, filter_fn, "prepack_folding");
415 for (auto& method : m.get_methods()) {
416 auto graph = method.graph();
417 // Folding requires a const propagation through user defined classes
418 ConstantPropagation(graph, false);
419 }
420}
421
422script::Module optimizeForMobile(
423 const script::Module& m,
424 const std::set<MobileOptimizerType>& optimization_blocklist,
425 const std::vector<std::string>& preserved_methods) {
426 auto cloned_module = m.clone();
427 cloned_module.eval();
428
429 if (!optimization_blocklist.count(MobileOptimizerType::CONV_1D_TO_2D)) {
430 transformConv1dToConv2d(cloned_module);
431 }
432
433 if (!optimization_blocklist.count(MobileOptimizerType::CONV_BN_FUSION)) {
434 cloned_module = FoldConvBatchNorm(cloned_module);
435 }
436
437 // Many optimizations require a frozen module, but ConvBatchNorm requires
438 // an unfrozen module
439 cloned_module = freeze_module(cloned_module, preserved_methods);
440
441 if (!optimization_blocklist.count(
442 MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)) {
443 // TODO fix duplication caused by referencing same op across multiple
444 // functions
445 insertPrePackedOps(cloned_module);
446 cloned_module = freeze_module(cloned_module, preserved_methods);
447 fusePrePackedLinearConvWithClamp(cloned_module);
448 FoldPrePackingOps(cloned_module);
449 }
450
451 if (!optimization_blocklist.count(
452 MobileOptimizerType::HOIST_CONV_PACKED_PARAMS) &&
453 cloned_module.find_method("forward")) {
454 // freeze again in case it was not done in previous optional passes
455 cloned_module = freeze_module(cloned_module, preserved_methods);
456 HoistConvPackedParams(cloned_module);
457 // and freeze yet again to remove the empty QuantizedConv modules
458 cloned_module = freeze_module(cloned_module, preserved_methods);
459 }
460
461 // Run canonical optimizations post freezing
462 // since freezing inlines the graph. Otherwise we
463 // will have to explicitly call Inlining pass.
464 runCanonicalOptimizations(cloned_module);
465
466 if (!optimization_blocklist.count(MobileOptimizerType::REMOVE_DROPOUT)) {
467 for (const auto& method : cloned_module.get_methods()) {
468 auto graph = method.graph();
469 // Module must be not be in training mode but optimize calls eval()
470 removeDropout(graph);
471 }
472 }
473
474 if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
475 for (const auto& method : cloned_module.get_methods()) {
476 auto graph = method.graph();
477 FuseAddRelu(graph);
478 }
479 }
480 cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
481 return cloned_module;
482}
483
484#else
485
486void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
487 TORCH_INTERNAL_ASSERT(
488 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
489}
490
491void insertPrePackedOps(script::Module& module) {
492 TORCH_INTERNAL_ASSERT(
493 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
494}
495
496void fusePrePackedLinearConvWithClamp(script::Module& module) {
497 TORCH_INTERNAL_ASSERT(
498 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
499}
500
501void FoldPrePackingOps(script::Module& m) {
502 TORCH_INTERNAL_ASSERT(
503 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
504}
505
506script::Module optimizeForMobile(
507 const script::Module& module,
508 const std::set<MobileOptimizerType>& blocklist,
509 const std::vector<std::string>& preserved_methods) {
510 TORCH_INTERNAL_ASSERT(
511 false,
512 "Mobile optimization only available with XNNPACK at the moment. "
513 "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
514 return module;
515}
516
517#endif
518} // namespace jit
519} // namespace torch
520