1 | #include <ATen/core/jit_type.h> |
2 | #include <ATen/native/xnnpack/OpContext.h> |
3 | |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
6 | #include <torch/csrc/jit/passes/constant_pooling.h> |
7 | #include <torch/csrc/jit/passes/constant_propagation.h> |
8 | #include <torch/csrc/jit/passes/fold_conv_bn.h> |
9 | #include <torch/csrc/jit/passes/freeze_module.h> |
10 | #include <torch/csrc/jit/passes/fuse_linear.h> |
11 | #include <torch/csrc/jit/passes/fuse_relu.h> |
12 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
13 | #include <torch/csrc/jit/passes/hoist_conv_packed_params.h> |
14 | #include <torch/csrc/jit/passes/inliner.h> |
15 | #include <torch/csrc/jit/passes/mobile_optimizer_type.h> |
16 | #include <torch/csrc/jit/passes/prepack_folding.h> |
17 | #include <torch/csrc/jit/passes/remove_dropout.h> |
18 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
19 | #include <torch/csrc/jit/passes/xnnpack_rewrite.h> |
20 | #include <torch/csrc/jit/runtime/graph_executor_impl.h> |
21 | |
22 | namespace torch { |
23 | namespace jit { |
24 | |
25 | namespace { |
26 | |
27 | void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) { |
28 | std::string conv_1d_pattern = R"( |
29 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
30 | %res = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
31 | return (%res) )" ; |
32 | |
33 | std::string conv_2d_pattern = R"( |
34 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
35 | %zero : int = prim::Constant[value=0]() |
36 | %one : int = prim::Constant[value=1]() |
37 | %stride_w : int = prim::ListUnpack(%stride) |
38 | %stride_2d : int[] = prim::ListConstruct(%one, %stride_w) |
39 | %padding_w : int = prim::ListUnpack(%padding) |
40 | %padding_2d : int[] = prim::ListConstruct(%zero, %padding_w) |
41 | %dilation_w : int = prim::ListUnpack(%dilation) |
42 | %dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w) |
43 | %two : int = prim::Constant[value=2]() |
44 | %input_2d : Tensor = aten::unsqueeze(%input, %two) |
45 | %weight_2d : Tensor = aten::unsqueeze(%weight, %two) |
46 | %output_2d = aten::conv2d( |
47 | %input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups) |
48 | %output : Tensor = aten::squeeze(%output_2d, %two) |
49 | return (%output) )" ; |
50 | |
51 | std::vector<std::pair<std::string, std::string>> value_mappings( |
52 | {{"zero" , "res" }, |
53 | {"one" , "res" }, |
54 | {"stride_w" , "res" }, |
55 | {"stride_2d" , "res" }, |
56 | {"padding_w" , "res" }, |
57 | {"padding_2d" , "res" }, |
58 | {"dilation_w" , "res" }, |
59 | {"dilation_2d" , "res" }, |
60 | {"two" , "res" }, |
61 | {"input_2d" , "res" }, |
62 | {"weight_2d" , "res" }, |
63 | {"output_2d" , "res" }, |
64 | {"output" , "res" }}); |
65 | |
66 | SubgraphRewriter rewriter; |
67 | rewriter.RegisterRewritePattern( |
68 | conv_1d_pattern, conv_2d_pattern, value_mappings); |
69 | rewriter.runOnGraph(graph); |
70 | } |
71 | |
72 | } // namespace |
73 | |
74 | void transformConv1dToConv2d(std::shared_ptr<Graph>& graph) { |
75 | // Replace _convolution with conv1d and conv2d |
76 | graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); |
77 | replaceConv1dWithConv2d(graph); |
78 | } |
79 | |
80 | void transformConv1dToConv2d(script::Module& module) { |
81 | for (auto& method : module.get_methods()) { |
82 | auto graph = method.graph(); |
83 | transformConv1dToConv2d(graph); |
84 | } |
85 | for (script::Module m : module.children()) { |
86 | transformConv1dToConv2d(m); |
87 | } |
88 | } |
89 | |
90 | #ifdef USE_XNNPACK |
91 | |
92 | namespace { |
93 | |
94 | void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) { |
95 | // fuse decomposed linear into aten::linear |
96 | FuseLinear(graph); |
97 | |
98 | std::string linear_pattern = R"( |
99 | graph(%input, %weight, %bias): |
100 | %res = aten::linear(%input, %weight, %bias) |
101 | return (%res))" ; |
102 | std::string prepacked_ops_pattern = R"( |
103 | graph(%input, %weight, %bias): |
104 | %output_min_max : None = prim::Constant() |
105 | %packed_weight_bias = prepacked::linear_clamp_prepack( |
106 | %weight, %bias, %output_min_max, %output_min_max) |
107 | %res = prepacked::linear_clamp_run(%input, %packed_weight_bias) |
108 | return (%res))" ; |
109 | |
110 | std::vector<std::pair<std::string, std::string>> value_mappings( |
111 | {{"output_min_max" , "res" }, |
112 | {"packed_weight_bias" , "res" }, |
113 | {"res" , "res" }}); |
114 | |
115 | SubgraphRewriter linear_rewriter; |
116 | linear_rewriter.RegisterRewritePattern( |
117 | linear_pattern, prepacked_ops_pattern, value_mappings); |
118 | linear_rewriter.runOnGraph(graph); |
119 | } |
120 | |
121 | void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) { |
122 | // Replace _convolution with conv2d |
123 | graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); |
124 | |
125 | std::string conv_2d_pattern = R"( |
126 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
127 | %res = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
128 | return (%res) )" ; |
129 | |
130 | std::string prepacked_ops_conv2d_pattern = R"( |
131 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
132 | %output_min_max : None = prim::Constant() |
133 | %packed_weight_bias = prepacked::conv2d_clamp_prepack( |
134 | %weight, %bias, %stride, %padding, %dilation, %groups, |
135 | %output_min_max, %output_min_max) |
136 | %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) |
137 | return (%res) )" ; |
138 | |
139 | std::vector<std::pair<std::string, std::string>> value_mappings( |
140 | {{"output_min_max" , "res" }, |
141 | {"packed_weight_bias" , "res" }, |
142 | {"res" , "res" }}); |
143 | |
144 | SubgraphRewriter rewriter; |
145 | rewriter.RegisterRewritePattern( |
146 | conv_2d_pattern, prepacked_ops_conv2d_pattern, value_mappings); |
147 | rewriter.runOnGraph(graph); |
148 | |
149 | std::string conv_2d_transpose_pattern = R"( |
150 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], |
151 | %output_padding:int[], %groups:int): |
152 | %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) |
153 | return (%res) )" ; |
154 | |
155 | std::string prepacked_ops_conv2d_transpose_pattern = R"( |
156 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): |
157 | %output_min_max : None = prim::Constant() |
158 | %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack( |
159 | %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, |
160 | %output_min_max, %output_min_max) |
161 | %res = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias) |
162 | return (%res) )" ; |
163 | |
164 | value_mappings = { |
165 | {"output_min_max" , "res" }, {"packed_weight_bias" , "res" }, {"res" , "res" }}; |
166 | |
167 | SubgraphRewriter transpose_rewriter; |
168 | transpose_rewriter.RegisterRewritePattern( |
169 | conv_2d_transpose_pattern, |
170 | prepacked_ops_conv2d_transpose_pattern, |
171 | value_mappings); |
172 | transpose_rewriter.runOnGraph(graph); |
173 | } |
174 | |
175 | void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) { |
176 | SubgraphRewriter rewriter; |
177 | |
178 | std::string linear_prepack_run_hardtanh_fused = R"( |
179 | graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): |
180 | %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack( |
181 | %weight, %bias, %output_min, %output_max) |
182 | %res = prepacked::linear_clamp_run(%input, %packed_weight_bias) |
183 | return (%res))" ; |
184 | |
185 | std::string conv2d_prepack_run_hardtanh_fused = R"( |
186 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
187 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
188 | %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack( |
189 | %weight, %bias, %stride, %padding, %dilation, %groups, |
190 | %output_min, %output_max) |
191 | %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) |
192 | return (%res) )" ; |
193 | |
194 | std::string linear_prepack_run_hardtanh = R"( |
195 | graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): |
196 | %packed_weight_bias = prepacked::linear_clamp_prepack( |
197 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
198 | %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias) |
199 | %res = aten::hardtanh(%linear_res, %output_min, %output_max) |
200 | return (%res))" ; |
201 | |
202 | std::vector<std::pair<std::string, std::string>> value_mappings( |
203 | {{"packed_weight_bias" , "packed_weight_bias" }, {"res" , "res" }}); |
204 | |
205 | rewriter.RegisterRewritePattern( |
206 | linear_prepack_run_hardtanh, |
207 | linear_prepack_run_hardtanh_fused, |
208 | value_mappings); |
209 | |
210 | std::string conv2d_prepack_run_hardtanh = R"( |
211 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
212 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
213 | %packed_weight_bias = prepacked::conv2d_clamp_prepack( |
214 | %weight, %bias, %stride, %padding, %dilation, %groups, |
215 | %dummy_min_max, %dummy_min_max) |
216 | %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) |
217 | %res = aten::hardtanh(%conv2d_res, %output_min, %output_max) |
218 | return (%res) )" ; |
219 | |
220 | value_mappings = { |
221 | {"packed_weight_bias" , "packed_weight_bias" }, {"res" , "res" }}; |
222 | |
223 | rewriter.RegisterRewritePattern( |
224 | conv2d_prepack_run_hardtanh, |
225 | conv2d_prepack_run_hardtanh_fused, |
226 | value_mappings); |
227 | |
228 | std::string linear_prepack_run_hardtanh_inplace = R"( |
229 | graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): |
230 | %packed_weight_bias = prepacked::linear_clamp_prepack( |
231 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
232 | %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias) |
233 | %res = aten::hardtanh_(%linear_res, %output_min, %output_max) |
234 | return (%res))" ; |
235 | |
236 | std::string conv2d_prepack_run_hardtanh_inplace = R"( |
237 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
238 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
239 | %packed_weight_bias = prepacked::conv2d_clamp_prepack( |
240 | %weight, %bias, %stride, %padding, %dilation, %groups, |
241 | %dummy_min_max, %dummy_min_max) |
242 | %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) |
243 | %res = aten::hardtanh_(%conv2d_res, %output_min, %output_max) |
244 | return (%res) )" ; |
245 | |
246 | value_mappings = { |
247 | {"packed_weight_bias" , "packed_weight_bias" }, {"res" , "res" }}; |
248 | |
249 | rewriter.RegisterRewritePattern( |
250 | linear_prepack_run_hardtanh_inplace, |
251 | linear_prepack_run_hardtanh_fused, |
252 | value_mappings); |
253 | |
254 | value_mappings = { |
255 | {"packed_weight_bias" , "packed_weight_bias" }, {"res" , "res" }}; |
256 | |
257 | rewriter.RegisterRewritePattern( |
258 | conv2d_prepack_run_hardtanh_inplace, |
259 | conv2d_prepack_run_hardtanh_fused, |
260 | value_mappings); |
261 | |
262 | rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); |
263 | } |
264 | |
265 | void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) { |
266 | SubgraphRewriter rewriter; |
267 | |
268 | std::string linear_prepack_run_relu_fused = R"( |
269 | graph(%input, %weight, %bias, %dummy_min_max): |
270 | %output_min: float = prim::Constant[value=0.0]() |
271 | %output_max: None = prim::Constant() |
272 | %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack( |
273 | %weight, %bias, %output_min, %output_max) |
274 | %res = prepacked::linear_clamp_run(%input, %packed_weight_bias) |
275 | return (%res))" ; |
276 | |
277 | std::string conv2d_prepack_run_relu_fused = R"( |
278 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
279 | %dilation:int[], %groups:int, %dummy_min_max): |
280 | %output_min: float = prim::Constant[value=0.0]() |
281 | %output_max: None = prim::Constant() |
282 | %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack( |
283 | %weight, %bias, %stride, %padding, %dilation, %groups, |
284 | %output_min, %output_max) |
285 | %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) |
286 | return (%res) )" ; |
287 | |
288 | std::string linear_prepack_run_relu = R"( |
289 | graph(%input, %weight, %bias, %dummy_min_max): |
290 | %packed_weight_bias = prepacked::linear_clamp_prepack( |
291 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
292 | %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias) |
293 | %res = aten::relu(%linear_res) |
294 | return (%res))" ; |
295 | |
296 | std::vector<std::pair<std::string, std::string>> value_mappings( |
297 | {{"output_min" , "packed_weight_bias" }, |
298 | {"output_max" , "packed_weight_bias" }, |
299 | {"packed_weight_bias" , "packed_weight_bias" }, |
300 | {"res" , "res" }}); |
301 | |
302 | rewriter.RegisterRewritePattern( |
303 | linear_prepack_run_relu, linear_prepack_run_relu_fused, value_mappings); |
304 | |
305 | std::string conv2d_prepack_run_relu = R"( |
306 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
307 | %dilation:int[], %groups:int, %dummy_min_max): |
308 | %packed_weight_bias = prepacked::conv2d_clamp_prepack( |
309 | %weight, %bias, %stride, %padding, %dilation, %groups, |
310 | %dummy_min_max, %dummy_min_max) |
311 | %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) |
312 | %res = aten::relu(%conv2d_res) |
313 | return (%res) )" ; |
314 | |
315 | value_mappings = { |
316 | {"output_min" , "packed_weight_bias" }, |
317 | {"output_max" , "packed_weight_bias" }, |
318 | {"packed_weight_bias" , "packed_weight_bias" }, |
319 | {"res" , "res" }}; |
320 | |
321 | rewriter.RegisterRewritePattern( |
322 | conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused, value_mappings); |
323 | |
324 | std::string linear_prepack_run_relu_inplace = R"( |
325 | graph(%input, %weight, %bias, %dummy_min_max): |
326 | %packed_weight_bias = prepacked::linear_clamp_prepack( |
327 | %weight, %bias, %dummy_min_max, %dummy_min_max) |
328 | %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias) |
329 | %res = aten::relu_(%linear_res) |
330 | return (%res))" ; |
331 | |
332 | std::string conv2d_prepack_run_relu_inplace = R"( |
333 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
334 | %dilation:int[], %groups:int, %dummy_min_max): |
335 | %packed_weight_bias = prepacked::conv2d_clamp_prepack( |
336 | %weight, %bias, %stride, %padding, %dilation, %groups, |
337 | %dummy_min_max, %dummy_min_max) |
338 | %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) |
339 | %res = aten::relu_(%conv2d_res) |
340 | return (%res) )" ; |
341 | |
342 | value_mappings = { |
343 | {"output_min" , "packed_weight_bias" }, |
344 | {"output_max" , "packed_weight_bias" }, |
345 | {"packed_weight_bias" , "packed_weight_bias" }, |
346 | {"res" , "res" }}; |
347 | |
348 | rewriter.RegisterRewritePattern( |
349 | linear_prepack_run_relu_inplace, |
350 | linear_prepack_run_relu_fused, |
351 | value_mappings); |
352 | |
353 | value_mappings = { |
354 | {"output_min" , "packed_weight_bias" }, |
355 | {"output_max" , "packed_weight_bias" }, |
356 | {"packed_weight_bias" , "packed_weight_bias" }, |
357 | {"res" , "res" }}; |
358 | |
359 | rewriter.RegisterRewritePattern( |
360 | conv2d_prepack_run_relu_inplace, |
361 | conv2d_prepack_run_relu_fused, |
362 | value_mappings); |
363 | rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); |
364 | } |
365 | |
366 | void runCanonicalOptimizations(script::Module& module) { |
367 | for (const auto& method : module.get_methods()) { |
368 | auto graph = method.graph(); |
369 | // Not sure if we have models running on mobile that require loop unrolling. |
370 | // Perhaps language/speech models? Conservatively setting that to false. |
371 | runOptimization(graph, false /* no loop unrolling */); |
372 | } |
373 | } |
374 | |
375 | } // namespace |
376 | |
377 | void insertPrePackedOps(std::shared_ptr<Graph>& graph) { |
378 | insertPrePackedLinearOp(graph); |
379 | insertPrePackedConv2dOp(graph); |
380 | } |
381 | |
382 | void insertPrePackedOps(script::Module& module) { |
383 | for (auto& method : module.get_methods()) { |
384 | auto graph = method.graph(); |
385 | insertPrePackedOps(graph); |
386 | } |
387 | for (script::Module m : module.children()) { |
388 | insertPrePackedOps(m); |
389 | } |
390 | } |
391 | |
392 | void fusePrePackedLinearConvWithClamp(script::Module& module) { |
393 | for (auto& method : module.get_methods()) { |
394 | auto graph = method.graph(); |
395 | fuseReluWithPackedOps(graph); |
396 | fuseHardtanhWithPackedOps(graph); |
397 | |
398 | // Ignore user defined classes for later passes |
399 | ConstantPropagation(graph, true); |
400 | } |
401 | } |
402 | |
403 | void FoldPrePackingOps(script::Module& m) { |
404 | PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { |
405 | return ( |
406 | (n->kind() == |
407 | Symbol::fromQualString("prepacked::linear_clamp_prepack" )) || |
408 | n->kind() == |
409 | Symbol::fromQualString("prepacked::conv2d_clamp_prepack" ) || |
410 | n->kind() == |
411 | Symbol::fromQualString( |
412 | "prepacked::conv2d_transpose_clamp_prepack" )); |
413 | }; |
414 | PrePackingOpsFolder(m, filter_fn, "prepack_folding" ); |
415 | for (auto& method : m.get_methods()) { |
416 | auto graph = method.graph(); |
417 | // Folding requires a const propagation through user defined classes |
418 | ConstantPropagation(graph, false); |
419 | } |
420 | } |
421 | |
422 | script::Module optimizeForMobile( |
423 | const script::Module& m, |
424 | const std::set<MobileOptimizerType>& optimization_blocklist, |
425 | const std::vector<std::string>& preserved_methods) { |
426 | auto cloned_module = m.clone(); |
427 | cloned_module.eval(); |
428 | |
429 | if (!optimization_blocklist.count(MobileOptimizerType::CONV_1D_TO_2D)) { |
430 | transformConv1dToConv2d(cloned_module); |
431 | } |
432 | |
433 | if (!optimization_blocklist.count(MobileOptimizerType::CONV_BN_FUSION)) { |
434 | cloned_module = FoldConvBatchNorm(cloned_module); |
435 | } |
436 | |
437 | // Many optimizations require a frozen module, but ConvBatchNorm requires |
438 | // an unfrozen module |
439 | cloned_module = freeze_module(cloned_module, preserved_methods); |
440 | |
441 | if (!optimization_blocklist.count( |
442 | MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)) { |
443 | // TODO fix duplication caused by referencing same op across multiple |
444 | // functions |
445 | insertPrePackedOps(cloned_module); |
446 | cloned_module = freeze_module(cloned_module, preserved_methods); |
447 | fusePrePackedLinearConvWithClamp(cloned_module); |
448 | FoldPrePackingOps(cloned_module); |
449 | } |
450 | |
451 | if (!optimization_blocklist.count( |
452 | MobileOptimizerType::HOIST_CONV_PACKED_PARAMS) && |
453 | cloned_module.find_method("forward" )) { |
454 | // freeze again in case it was not done in previous optional passes |
455 | cloned_module = freeze_module(cloned_module, preserved_methods); |
456 | HoistConvPackedParams(cloned_module); |
457 | // and freeze yet again to remove the empty QuantizedConv modules |
458 | cloned_module = freeze_module(cloned_module, preserved_methods); |
459 | } |
460 | |
461 | // Run canonical optimizations post freezing |
462 | // since freezing inlines the graph. Otherwise we |
463 | // will have to explicitly call Inlining pass. |
464 | runCanonicalOptimizations(cloned_module); |
465 | |
466 | if (!optimization_blocklist.count(MobileOptimizerType::REMOVE_DROPOUT)) { |
467 | for (const auto& method : cloned_module.get_methods()) { |
468 | auto graph = method.graph(); |
469 | // Module must be not be in training mode but optimize calls eval() |
470 | removeDropout(graph); |
471 | } |
472 | } |
473 | |
474 | if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) { |
475 | for (const auto& method : cloned_module.get_methods()) { |
476 | auto graph = method.graph(); |
477 | FuseAddRelu(graph); |
478 | } |
479 | } |
480 | cloned_module.register_attribute("mobile_optimized" , BoolType::get(), true); |
481 | return cloned_module; |
482 | } |
483 | |
484 | #else |
485 | |
486 | void insertPrePackedOps(std::shared_ptr<Graph>& graph) { |
487 | TORCH_INTERNAL_ASSERT( |
488 | false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1" ); |
489 | } |
490 | |
491 | void insertPrePackedOps(script::Module& module) { |
492 | TORCH_INTERNAL_ASSERT( |
493 | false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1" ); |
494 | } |
495 | |
496 | void fusePrePackedLinearConvWithClamp(script::Module& module) { |
497 | TORCH_INTERNAL_ASSERT( |
498 | false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1" ); |
499 | } |
500 | |
501 | void FoldPrePackingOps(script::Module& m) { |
502 | TORCH_INTERNAL_ASSERT( |
503 | false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1" ); |
504 | } |
505 | |
506 | script::Module optimizeForMobile( |
507 | const script::Module& module, |
508 | const std::set<MobileOptimizerType>& blocklist, |
509 | const std::vector<std::string>& preserved_methods) { |
510 | TORCH_INTERNAL_ASSERT( |
511 | false, |
512 | "Mobile optimization only available with XNNPACK at the moment. " |
513 | "XNNPACK is not enabled. Please build with USE_XNNPACK=1" ); |
514 | return module; |
515 | } |
516 | |
517 | #endif |
518 | } // namespace jit |
519 | } // namespace torch |
520 | |