1 | #include <ATen/core/jit_type.h> |
2 | #include <torch/csrc/jit/ir/ir.h> |
3 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
4 | #include <torch/csrc/jit/passes/constant_pooling.h> |
5 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
6 | #include <torch/csrc/jit/passes/fold_conv_bn.h> |
7 | #include <torch/csrc/jit/passes/freeze_module.h> |
8 | #include <torch/csrc/jit/passes/fuse_linear.h> |
9 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
10 | #include <torch/csrc/jit/passes/prepack_folding.h> |
11 | #include <torch/csrc/jit/passes/remove_dropout.h> |
12 | #include <torch/csrc/jit/passes/remove_mutation.h> |
13 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
14 | #include <torch/csrc/jit/passes/vulkan_rewrite.h> |
15 | #include <torch/csrc/jit/runtime/graph_executor_impl.h> |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | namespace { |
21 | |
22 | void insertPrePackedBatchNormOp(std::shared_ptr<Graph>& graph) { |
23 | std::string batchnorm_pattern = R"( |
24 | graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable): |
25 | %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable) |
26 | return (%r))" ; |
27 | std::string prepacked_ops_pattern = R"( |
28 | graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable): |
29 | %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context( |
30 | %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable) |
31 | %res = vulkan_prepack::run_batchnorm_context(%input, %op_context) |
32 | return (%res))" ; |
33 | |
34 | SubgraphRewriter batchnorm_rewriter; |
35 | batchnorm_rewriter.RegisterRewritePattern( |
36 | batchnorm_pattern, prepacked_ops_pattern); |
37 | batchnorm_rewriter.runOnGraph(graph); |
38 | } |
39 | |
40 | void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) { |
41 | // fuse decomposed linear into aten::linear |
42 | FuseLinear(graph); |
43 | |
44 | std::string linear_pattern = R"( |
45 | graph(%input, %weight, %bias): |
46 | %r = aten::linear(%input, %weight, %bias) |
47 | return (%r))" ; |
48 | std::string prepacked_ops_pattern = R"( |
49 | graph(%input, %weight, %bias): |
50 | %weight_t = aten::t(%weight) |
51 | %packed_weight_bias = vulkan_prepack::create_linear_context( |
52 | %weight_t, %bias) |
53 | %res = vulkan_prepack::run_linear_context(%input, %packed_weight_bias) |
54 | return (%res))" ; |
55 | |
56 | SubgraphRewriter linear_rewriter; |
57 | linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); |
58 | linear_rewriter.runOnGraph(graph); |
59 | } |
60 | |
61 | void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) { |
62 | graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); |
63 | |
64 | std::string conv_2d_pattern = R"( |
65 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
66 | %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
67 | return (%r) )" ; |
68 | |
69 | std::string prepacked_ops_conv2d_pattern = R"( |
70 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
71 | %output_min_max : None = prim::Constant() |
72 | %packed_weight_bias = vulkan_prepack::create_conv2d_context( |
73 | %weight, %bias, %stride, %padding, %dilation, %groups, |
74 | %output_min_max, %output_min_max) |
75 | %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) |
76 | return (%r) )" ; |
77 | |
78 | SubgraphRewriter rewriter; |
79 | rewriter.RegisterRewritePattern( |
80 | conv_2d_pattern, prepacked_ops_conv2d_pattern); |
81 | rewriter.runOnGraph(graph); |
82 | |
83 | std::string conv_2d_transpose_pattern = R"( |
84 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], |
85 | %output_padding:int[], %groups:int): |
86 | %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) |
87 | return (%res) )" ; |
88 | |
89 | std::string prepacked_ops_conv2d_transpose_pattern = R"( |
90 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): |
91 | %output_min_max : None = prim::Constant() |
92 | %packed_weight_bias = vulkan_prepack::create_tconv2d_context( |
93 | %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, |
94 | %output_min_max, %output_min_max) |
95 | %res = vulkan_prepack::run_tconv2d_context(%input, %packed_weight_bias) |
96 | return (%res) )" ; |
97 | |
98 | SubgraphRewriter transpose_rewriter; |
99 | transpose_rewriter.RegisterRewritePattern( |
100 | conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern); |
101 | transpose_rewriter.runOnGraph(graph); |
102 | } |
103 | |
104 | void transferInputOutputBackends(std::shared_ptr<Graph>& graph) { |
105 | // Move inputs to Vulkan backend |
106 | for (Value* input : graph->inputs()) { |
107 | NamedValue named_input = NamedValue("" , input); |
108 | if (named_input.type()->kind() == TypeKind::TensorType && |
109 | !input->uses().empty()) { |
110 | // find the insertion point |
111 | WithInsertPoint ip(input->uses()[0].user->prev()); |
112 | Value* replaced_input = graph->insert( |
113 | Symbol::fromQualString("aten::to" ), {named_input, "vulkan" }); |
114 | // replace the input |
115 | input->replaceAllUsesAfterNodeWith( |
116 | replaced_input->node(), replaced_input); |
117 | } |
118 | } |
119 | |
120 | // Move outputs to CPU backend |
121 | at::ArrayRef<Value*>&& outputs = graph->outputs(); |
122 | for (size_t i = 0; i < outputs.size(); i++) { |
123 | Value* output = outputs[i]; |
124 | NamedValue named_output = NamedValue("" , output); |
125 | if (named_output.type()->kind() == TypeKind::TensorType) { |
126 | // find the insertion point |
127 | WithInsertPoint ip(output->node()->next()); |
128 | Value* replaced_output = graph->insert( |
129 | Symbol::fromQualString("aten::to" ), {named_output, "cpu" }); |
130 | // replace the output |
131 | graph->block()->replaceOutput(i, replaced_output); |
132 | } |
133 | } |
134 | |
135 | SubgraphRewriter rewriter; |
136 | rewriter.runOnGraph(graph); |
137 | } |
138 | |
139 | void transferInputOutputBackends(script::Module& module) { |
140 | std::shared_ptr<Graph> graph = module.get_methods()[0].graph(); |
141 | transferInputOutputBackends(graph); |
142 | } |
143 | |
144 | void eliminateDeadCode(script::Module& module) { |
145 | for (auto& method : module.get_methods()) { |
146 | EliminateDeadCode(method.graph()); |
147 | } |
148 | } |
149 | |
150 | void insertPrePackedGruOp(std::shared_ptr<Graph>& graph) { |
151 | std::string gru_pattern = R"( |
152 | graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): |
153 | %y.1 : Tensor, %hn.1 : Tensor = aten::gru(%input.1, %hx.1, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) |
154 | return (%y.1, %hn.1) )" ; |
155 | std::string prepacked_ops_pattern = R"( |
156 | graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): |
157 | %packed_weights_biases = vulkan_prepack::create_gru_context( |
158 | %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) |
159 | %y.1 : Tensor, %hn.1 : Tensor = vulkan_prepack::run_gru_context(%input.1, %hx.1, %packed_weights_biases) |
160 | return (%y.1, %hn.1) )" ; |
161 | |
162 | auto filter = [&](const Match& match, |
163 | const std::unordered_map<std::string, Value*>& vmap) { |
164 | auto node = match.values_map.at(vmap.at("params_cpu" ))->node(); |
165 | return node->output()->type()->str() == "Tensor[]" ; |
166 | }; |
167 | |
168 | SubgraphRewriter gru_rewriter; |
169 | gru_rewriter.RegisterRewritePattern(gru_pattern, prepacked_ops_pattern); |
170 | gru_rewriter.runOnGraph(graph, filter); |
171 | } |
172 | |
173 | void insertPrePackedLstmOp(std::shared_ptr<Graph>& graph) { |
174 | std::string lstm_pattern = R"( |
175 | graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): |
176 | %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = aten::lstm(%input.1, %hx, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) |
177 | return (%y.1, %hn.1, %cn.1) )" ; |
178 | std::string prepacked_ops_pattern = R"( |
179 | graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): |
180 | %packed_weights_biases = vulkan_prepack::create_lstm_context( |
181 | %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) |
182 | %hx.1 : Tensor, %cx.1 : Tensor = prim::ListUnpack(%hx) |
183 | %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = vulkan_prepack::run_lstm_context(%input.1, %hx.1, %cx.1, %packed_weights_biases) |
184 | return (%y.1, %hn.1, %cn.1) )" ; |
185 | |
186 | auto filter = [&](const Match& match, |
187 | const std::unordered_map<std::string, Value*>& vmap) { |
188 | auto node = match.values_map.at(vmap.at("hx" ))->node(); |
189 | return node->output()->type()->str() == "Tensor[]" ; |
190 | }; |
191 | |
192 | SubgraphRewriter lstm_rewriter; |
193 | lstm_rewriter.RegisterRewritePattern(lstm_pattern, prepacked_ops_pattern); |
194 | lstm_rewriter.runOnGraph(graph, filter); |
195 | } |
196 | |
197 | void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) { |
198 | SubgraphRewriter rewriter; |
199 | |
200 | std::string conv2d_prepack_run_hardtanh_fused = R"( |
201 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
202 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
203 | %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context( |
204 | %weight, %bias, %stride, %padding, %dilation, %groups, |
205 | %output_min, %output_max) |
206 | %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) |
207 | return (%r) )" ; |
208 | |
209 | std::string conv2d_prepack_run_hardtanh = R"( |
210 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
211 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
212 | %packed_weight_bias = vulkan_prepack::create_conv2d_context( |
213 | %weight, %bias, %stride, %padding, %dilation, %groups, |
214 | %dummy_min_max, %dummy_min_max) |
215 | %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) |
216 | %r = aten::hardtanh(%conv2d_res, %output_min, %output_max) |
217 | return (%r) )" ; |
218 | |
219 | rewriter.RegisterRewritePattern( |
220 | conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused); |
221 | |
222 | std::string conv2d_prepack_run_hardtanh_inplace = R"( |
223 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
224 | %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): |
225 | %packed_weight_bias = vulkan_prepack::create_conv2d_context( |
226 | %weight, %bias, %stride, %padding, %dilation, %groups, |
227 | %dummy_min_max, %dummy_min_max) |
228 | %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) |
229 | %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max) |
230 | return (%r) )" ; |
231 | |
232 | rewriter.RegisterRewritePattern( |
233 | conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused); |
234 | |
235 | rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); |
236 | } |
237 | |
238 | void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) { |
239 | SubgraphRewriter rewriter; |
240 | |
241 | std::string conv2d_prepack_run_relu_fused = R"( |
242 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
243 | %dilation:int[], %groups:int, %dummy_min_max): |
244 | %output_min: float = prim::Constant[value=0.0]() |
245 | %output_max: None = prim::Constant() |
246 | %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context( |
247 | %weight, %bias, %stride, %padding, %dilation, %groups, |
248 | %output_min, %output_max) |
249 | %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) |
250 | return (%r) )" ; |
251 | |
252 | std::string conv2d_prepack_run_relu = R"( |
253 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
254 | %dilation:int[], %groups:int, %dummy_min_max): |
255 | %packed_weight_bias = vulkan_prepack::create_conv2d_context( |
256 | %weight, %bias, %stride, %padding, %dilation, %groups, |
257 | %dummy_min_max, %dummy_min_max) |
258 | %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) |
259 | %r = aten::relu(%conv2d_res) |
260 | return (%r) )" ; |
261 | |
262 | rewriter.RegisterRewritePattern( |
263 | conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused); |
264 | |
265 | std::string conv2d_prepack_run_relu_inplace = R"( |
266 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
267 | %dilation:int[], %groups:int, %dummy_min_max): |
268 | %packed_weight_bias = vulkan_prepack::create_conv2d_context( |
269 | %weight, %bias, %stride, %padding, %dilation, %groups, |
270 | %dummy_min_max, %dummy_min_max) |
271 | %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias) |
272 | %r = aten::relu_(%conv2d_res) |
273 | return (%r) )" ; |
274 | |
275 | rewriter.RegisterRewritePattern( |
276 | conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused); |
277 | rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); |
278 | } |
279 | |
280 | } // namespace |
281 | |
282 | void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) { |
283 | insertPrePackedLinearOp(graph); |
284 | insertPrePackedConv2dOp(graph); |
285 | insertPrePackedGruOp(graph); |
286 | insertPrePackedLstmOp(graph); |
287 | insertPrePackedBatchNormOp(graph); |
288 | } |
289 | |
290 | void vulkanInsertPrePackedOps(script::Module& module) { |
291 | for (auto& method : module.get_methods()) { |
292 | auto graph = method.graph(); |
293 | vulkanInsertPrePackedOps(graph); |
294 | } |
295 | for (script::Module m : module.children()) { |
296 | vulkanInsertPrePackedOps(m); |
297 | } |
298 | } |
299 | |
300 | void vulkanFusePrePackedConvWithClamp(script::Module& module) { |
301 | auto graph = module.get_method("forward" ).graph(); |
302 | fuseReluWithPackedOps(graph); |
303 | fuseHardtanhWithPackedOps(graph); |
304 | } |
305 | |
306 | void vulkanFoldPrePackingOps(script::Module& m) { |
307 | PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { |
308 | return ( |
309 | (n->kind() == |
310 | Symbol::fromQualString("vulkan_prepack::create_conv2d_context" )) || |
311 | (n->kind() == |
312 | Symbol::fromQualString("vulkan_prepack::create_tconv2d_context" )) || |
313 | (n->kind() == |
314 | Symbol::fromQualString("vulkan_prepack::create_linear_context" )) || |
315 | (n->kind() == |
316 | Symbol::fromQualString("vulkan_prepack::create_gru_context" )) || |
317 | (n->kind() == |
318 | Symbol::fromQualString("vulkan_prepack::create_lstm_context" )) || |
319 | (n->kind() == |
320 | Symbol::fromQualString("vulkan_prepack::create_batchnorm_context" ))); |
321 | }; |
322 | PrePackingOpsFolder(m, filter_fn, "prepack_folding" ); |
323 | } |
324 | |
325 | void vulkanRemoveMutation(script::Module& module) { |
326 | auto graph = module.get_method("forward" ).graph(); |
327 | RemoveTensorMutation(graph); |
328 | } |
329 | |
330 | void vulkanRunCanonicalOptimizations(script::Module& module) { |
331 | auto graph = module.get_method("forward" ).graph(); |
332 | for (const auto& method : module.get_methods()) { |
333 | auto graph = method.graph(); |
334 | runOptimization(graph, false /* no loop unrolling */); |
335 | } |
336 | } |
337 | |
338 | script::Module vulkanOptimizeForMobile( |
339 | const script::Module& m, |
340 | const std::set<MobileOptimizerType>& optimization_blocklist, |
341 | const std::vector<std::string>& preserved_methods) { |
342 | auto cloned_module = m.clone(); |
343 | cloned_module.eval(); |
344 | cloned_module = FoldConvBatchNorm(cloned_module); |
345 | cloned_module = freeze_module(cloned_module, preserved_methods); |
346 | vulkanInsertPrePackedOps(cloned_module); |
347 | vulkanFusePrePackedConvWithClamp(cloned_module); |
348 | vulkanFoldPrePackingOps(cloned_module); |
349 | removeDropout(cloned_module); |
350 | vulkanRemoveMutation(cloned_module); |
351 | |
352 | if (!optimization_blocklist.count( |
353 | MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) { |
354 | transferInputOutputBackends(cloned_module); |
355 | cloned_module.register_attribute( |
356 | "requires_backend_transfers" , BoolType::get(), false); |
357 | } |
358 | |
359 | // remove duplicated constants |
360 | vulkanRunCanonicalOptimizations(cloned_module); |
361 | eliminateDeadCode(cloned_module); |
362 | |
363 | cloned_module.register_attribute( |
364 | "optimized_for_vulkan" , BoolType::get(), true); |
365 | return cloned_module; |
366 | } |
367 | |
368 | } // namespace jit |
369 | } // namespace torch |
370 | |