1#include <ATen/core/jit_type.h>
2#include <torch/csrc/jit/ir/ir.h>
3#include <torch/csrc/jit/ir/subgraph_matcher.h>
4#include <torch/csrc/jit/passes/constant_pooling.h>
5#include <torch/csrc/jit/passes/dead_code_elimination.h>
6#include <torch/csrc/jit/passes/fold_conv_bn.h>
7#include <torch/csrc/jit/passes/freeze_module.h>
8#include <torch/csrc/jit/passes/fuse_linear.h>
9#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
10#include <torch/csrc/jit/passes/prepack_folding.h>
11#include <torch/csrc/jit/passes/remove_dropout.h>
12#include <torch/csrc/jit/passes/remove_mutation.h>
13#include <torch/csrc/jit/passes/subgraph_rewrite.h>
14#include <torch/csrc/jit/passes/vulkan_rewrite.h>
15#include <torch/csrc/jit/runtime/graph_executor_impl.h>
16
17namespace torch {
18namespace jit {
19
20namespace {
21
22void insertPrePackedBatchNormOp(std::shared_ptr<Graph>& graph) {
23 std::string batchnorm_pattern = R"(
24 graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
25 %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
26 return (%r))";
27 std::string prepacked_ops_pattern = R"(
28 graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
29 %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context(
30 %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
31 %res = vulkan_prepack::run_batchnorm_context(%input, %op_context)
32 return (%res))";
33
34 SubgraphRewriter batchnorm_rewriter;
35 batchnorm_rewriter.RegisterRewritePattern(
36 batchnorm_pattern, prepacked_ops_pattern);
37 batchnorm_rewriter.runOnGraph(graph);
38}
39
40void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
41 // fuse decomposed linear into aten::linear
42 FuseLinear(graph);
43
44 std::string linear_pattern = R"(
45 graph(%input, %weight, %bias):
46 %r = aten::linear(%input, %weight, %bias)
47 return (%r))";
48 std::string prepacked_ops_pattern = R"(
49 graph(%input, %weight, %bias):
50 %weight_t = aten::t(%weight)
51 %packed_weight_bias = vulkan_prepack::create_linear_context(
52 %weight_t, %bias)
53 %res = vulkan_prepack::run_linear_context(%input, %packed_weight_bias)
54 return (%res))";
55
56 SubgraphRewriter linear_rewriter;
57 linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
58 linear_rewriter.runOnGraph(graph);
59}
60
61void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
62 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
63
64 std::string conv_2d_pattern = R"(
65 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
66 %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
67 return (%r) )";
68
69 std::string prepacked_ops_conv2d_pattern = R"(
70 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
71 %output_min_max : None = prim::Constant()
72 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
73 %weight, %bias, %stride, %padding, %dilation, %groups,
74 %output_min_max, %output_min_max)
75 %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
76 return (%r) )";
77
78 SubgraphRewriter rewriter;
79 rewriter.RegisterRewritePattern(
80 conv_2d_pattern, prepacked_ops_conv2d_pattern);
81 rewriter.runOnGraph(graph);
82
83 std::string conv_2d_transpose_pattern = R"(
84 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
85 %output_padding:int[], %groups:int):
86 %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
87 return (%res) )";
88
89 std::string prepacked_ops_conv2d_transpose_pattern = R"(
90 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
91 %output_min_max : None = prim::Constant()
92 %packed_weight_bias = vulkan_prepack::create_tconv2d_context(
93 %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
94 %output_min_max, %output_min_max)
95 %res = vulkan_prepack::run_tconv2d_context(%input, %packed_weight_bias)
96 return (%res) )";
97
98 SubgraphRewriter transpose_rewriter;
99 transpose_rewriter.RegisterRewritePattern(
100 conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern);
101 transpose_rewriter.runOnGraph(graph);
102}
103
104void transferInputOutputBackends(std::shared_ptr<Graph>& graph) {
105 // Move inputs to Vulkan backend
106 for (Value* input : graph->inputs()) {
107 NamedValue named_input = NamedValue("", input);
108 if (named_input.type()->kind() == TypeKind::TensorType &&
109 !input->uses().empty()) {
110 // find the insertion point
111 WithInsertPoint ip(input->uses()[0].user->prev());
112 Value* replaced_input = graph->insert(
113 Symbol::fromQualString("aten::to"), {named_input, "vulkan"});
114 // replace the input
115 input->replaceAllUsesAfterNodeWith(
116 replaced_input->node(), replaced_input);
117 }
118 }
119
120 // Move outputs to CPU backend
121 at::ArrayRef<Value*>&& outputs = graph->outputs();
122 for (size_t i = 0; i < outputs.size(); i++) {
123 Value* output = outputs[i];
124 NamedValue named_output = NamedValue("", output);
125 if (named_output.type()->kind() == TypeKind::TensorType) {
126 // find the insertion point
127 WithInsertPoint ip(output->node()->next());
128 Value* replaced_output = graph->insert(
129 Symbol::fromQualString("aten::to"), {named_output, "cpu"});
130 // replace the output
131 graph->block()->replaceOutput(i, replaced_output);
132 }
133 }
134
135 SubgraphRewriter rewriter;
136 rewriter.runOnGraph(graph);
137}
138
139void transferInputOutputBackends(script::Module& module) {
140 std::shared_ptr<Graph> graph = module.get_methods()[0].graph();
141 transferInputOutputBackends(graph);
142}
143
144void eliminateDeadCode(script::Module& module) {
145 for (auto& method : module.get_methods()) {
146 EliminateDeadCode(method.graph());
147 }
148}
149
150void insertPrePackedGruOp(std::shared_ptr<Graph>& graph) {
151 std::string gru_pattern = R"(
152 graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
153 %y.1 : Tensor, %hn.1 : Tensor = aten::gru(%input.1, %hx.1, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
154 return (%y.1, %hn.1) )";
155 std::string prepacked_ops_pattern = R"(
156 graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
157 %packed_weights_biases = vulkan_prepack::create_gru_context(
158 %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
159 %y.1 : Tensor, %hn.1 : Tensor = vulkan_prepack::run_gru_context(%input.1, %hx.1, %packed_weights_biases)
160 return (%y.1, %hn.1) )";
161
162 auto filter = [&](const Match& match,
163 const std::unordered_map<std::string, Value*>& vmap) {
164 auto node = match.values_map.at(vmap.at("params_cpu"))->node();
165 return node->output()->type()->str() == "Tensor[]";
166 };
167
168 SubgraphRewriter gru_rewriter;
169 gru_rewriter.RegisterRewritePattern(gru_pattern, prepacked_ops_pattern);
170 gru_rewriter.runOnGraph(graph, filter);
171}
172
173void insertPrePackedLstmOp(std::shared_ptr<Graph>& graph) {
174 std::string lstm_pattern = R"(
175 graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
176 %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = aten::lstm(%input.1, %hx, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
177 return (%y.1, %hn.1, %cn.1) )";
178 std::string prepacked_ops_pattern = R"(
179 graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
180 %packed_weights_biases = vulkan_prepack::create_lstm_context(
181 %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
182 %hx.1 : Tensor, %cx.1 : Tensor = prim::ListUnpack(%hx)
183 %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = vulkan_prepack::run_lstm_context(%input.1, %hx.1, %cx.1, %packed_weights_biases)
184 return (%y.1, %hn.1, %cn.1) )";
185
186 auto filter = [&](const Match& match,
187 const std::unordered_map<std::string, Value*>& vmap) {
188 auto node = match.values_map.at(vmap.at("hx"))->node();
189 return node->output()->type()->str() == "Tensor[]";
190 };
191
192 SubgraphRewriter lstm_rewriter;
193 lstm_rewriter.RegisterRewritePattern(lstm_pattern, prepacked_ops_pattern);
194 lstm_rewriter.runOnGraph(graph, filter);
195}
196
197void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
198 SubgraphRewriter rewriter;
199
200 std::string conv2d_prepack_run_hardtanh_fused = R"(
201 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
202 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
203 %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
204 %weight, %bias, %stride, %padding, %dilation, %groups,
205 %output_min, %output_max)
206 %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
207 return (%r) )";
208
209 std::string conv2d_prepack_run_hardtanh = R"(
210 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
211 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
212 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
213 %weight, %bias, %stride, %padding, %dilation, %groups,
214 %dummy_min_max, %dummy_min_max)
215 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
216 %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
217 return (%r) )";
218
219 rewriter.RegisterRewritePattern(
220 conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
221
222 std::string conv2d_prepack_run_hardtanh_inplace = R"(
223 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
224 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
225 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
226 %weight, %bias, %stride, %padding, %dilation, %groups,
227 %dummy_min_max, %dummy_min_max)
228 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
229 %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
230 return (%r) )";
231
232 rewriter.RegisterRewritePattern(
233 conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
234
235 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
236}
237
238void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
239 SubgraphRewriter rewriter;
240
241 std::string conv2d_prepack_run_relu_fused = R"(
242 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
243 %dilation:int[], %groups:int, %dummy_min_max):
244 %output_min: float = prim::Constant[value=0.0]()
245 %output_max: None = prim::Constant()
246 %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
247 %weight, %bias, %stride, %padding, %dilation, %groups,
248 %output_min, %output_max)
249 %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
250 return (%r) )";
251
252 std::string conv2d_prepack_run_relu = R"(
253 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
254 %dilation:int[], %groups:int, %dummy_min_max):
255 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
256 %weight, %bias, %stride, %padding, %dilation, %groups,
257 %dummy_min_max, %dummy_min_max)
258 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
259 %r = aten::relu(%conv2d_res)
260 return (%r) )";
261
262 rewriter.RegisterRewritePattern(
263 conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
264
265 std::string conv2d_prepack_run_relu_inplace = R"(
266 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
267 %dilation:int[], %groups:int, %dummy_min_max):
268 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
269 %weight, %bias, %stride, %padding, %dilation, %groups,
270 %dummy_min_max, %dummy_min_max)
271 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
272 %r = aten::relu_(%conv2d_res)
273 return (%r) )";
274
275 rewriter.RegisterRewritePattern(
276 conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
277 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
278}
279
280} // namespace
281
282void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
283 insertPrePackedLinearOp(graph);
284 insertPrePackedConv2dOp(graph);
285 insertPrePackedGruOp(graph);
286 insertPrePackedLstmOp(graph);
287 insertPrePackedBatchNormOp(graph);
288}
289
290void vulkanInsertPrePackedOps(script::Module& module) {
291 for (auto& method : module.get_methods()) {
292 auto graph = method.graph();
293 vulkanInsertPrePackedOps(graph);
294 }
295 for (script::Module m : module.children()) {
296 vulkanInsertPrePackedOps(m);
297 }
298}
299
300void vulkanFusePrePackedConvWithClamp(script::Module& module) {
301 auto graph = module.get_method("forward").graph();
302 fuseReluWithPackedOps(graph);
303 fuseHardtanhWithPackedOps(graph);
304}
305
306void vulkanFoldPrePackingOps(script::Module& m) {
307 PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
308 return (
309 (n->kind() ==
310 Symbol::fromQualString("vulkan_prepack::create_conv2d_context")) ||
311 (n->kind() ==
312 Symbol::fromQualString("vulkan_prepack::create_tconv2d_context")) ||
313 (n->kind() ==
314 Symbol::fromQualString("vulkan_prepack::create_linear_context")) ||
315 (n->kind() ==
316 Symbol::fromQualString("vulkan_prepack::create_gru_context")) ||
317 (n->kind() ==
318 Symbol::fromQualString("vulkan_prepack::create_lstm_context")) ||
319 (n->kind() ==
320 Symbol::fromQualString("vulkan_prepack::create_batchnorm_context")));
321 };
322 PrePackingOpsFolder(m, filter_fn, "prepack_folding");
323}
324
325void vulkanRemoveMutation(script::Module& module) {
326 auto graph = module.get_method("forward").graph();
327 RemoveTensorMutation(graph);
328}
329
330void vulkanRunCanonicalOptimizations(script::Module& module) {
331 auto graph = module.get_method("forward").graph();
332 for (const auto& method : module.get_methods()) {
333 auto graph = method.graph();
334 runOptimization(graph, false /* no loop unrolling */);
335 }
336}
337
338script::Module vulkanOptimizeForMobile(
339 const script::Module& m,
340 const std::set<MobileOptimizerType>& optimization_blocklist,
341 const std::vector<std::string>& preserved_methods) {
342 auto cloned_module = m.clone();
343 cloned_module.eval();
344 cloned_module = FoldConvBatchNorm(cloned_module);
345 cloned_module = freeze_module(cloned_module, preserved_methods);
346 vulkanInsertPrePackedOps(cloned_module);
347 vulkanFusePrePackedConvWithClamp(cloned_module);
348 vulkanFoldPrePackingOps(cloned_module);
349 removeDropout(cloned_module);
350 vulkanRemoveMutation(cloned_module);
351
352 if (!optimization_blocklist.count(
353 MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) {
354 transferInputOutputBackends(cloned_module);
355 cloned_module.register_attribute(
356 "requires_backend_transfers", BoolType::get(), false);
357 }
358
359 // remove duplicated constants
360 vulkanRunCanonicalOptimizations(cloned_module);
361 eliminateDeadCode(cloned_module);
362
363 cloned_module.register_attribute(
364 "optimized_for_vulkan", BoolType::get(), true);
365 return cloned_module;
366}
367
368} // namespace jit
369} // namespace torch
370