1 | #include <ATen/Utils.h> |
2 | |
3 | #include <ATen/code_template.h> |
4 | #include <ATen/cuda/CUDAConfig.h> |
5 | #include <torch/csrc/jit/ir/constants.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
8 | #include <torch/csrc/jit/jit_log.h> |
9 | #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h> |
10 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
11 | #include <torch/csrc/jit/passes/remove_mutation.h> |
12 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | namespace { |
18 | void fuseFrozenConvAddReluImpl(std::shared_ptr<Graph>& graph) { |
19 | #if AT_CUDNN_ENABLED() || AT_ROCM_ENABLED() |
20 | GRAPH_DEBUG("Before fuseFrozenConvAddReluImpl: " , *graph); |
21 | SubgraphRewriter rewriter; |
22 | |
23 | // CUDNN does not support conv1d |
24 | std::array<std::string, 2> conv_operators = {"conv2d" , "conv3d" }; |
25 | std::array<std::string, 2> add_operators = {"add" , "add_" }; |
26 | std::array<std::string, 2> relu_operators = {"relu" , "relu_" }; |
27 | |
28 | auto conv_relu_rstring = at::jit::CodeTemplate(R"( |
29 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
30 | %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
31 | %res = aten::${relu}(%x) |
32 | return (%res))" ); |
33 | |
34 | #ifdef USE_ROCM |
35 | std::string conv_relu_fused = R"( |
36 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
37 | %res = aten::miopen_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
38 | return (%res))" ; |
39 | #else |
40 | std::string conv_relu_fused = R"( |
41 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
42 | %res = aten::cudnn_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
43 | return (%res))" ; |
44 | #endif |
45 | |
46 | auto conv_add_relu_rstring = at::jit::CodeTemplate(R"( |
47 | graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
48 | %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups) |
49 | %y = aten::${add}(%x, %z, %alpha) |
50 | %res = aten::${relu}(%y) |
51 | return (%res))" ); |
52 | |
53 | #ifdef USE_ROCM |
54 | std::string conv_add_relu_fused = R"( |
55 | graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
56 | %res = aten::miopen_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups) |
57 | return (%res))" ; |
58 | #else |
59 | std::string conv_add_relu_fused = R"( |
60 | graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int): |
61 | %res = aten::cudnn_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups) |
62 | return (%res))" ; |
63 | #endif |
64 | |
65 | for (const auto& conv : conv_operators) { |
66 | for (const auto& relu : relu_operators) { |
67 | at::jit::TemplateEnv env; |
68 | env.s("conv" , conv); |
69 | env.s("relu" , relu); |
70 | rewriter.RegisterRewritePattern( |
71 | conv_relu_rstring.format(env), conv_relu_fused); |
72 | for (const auto& add : add_operators) { |
73 | env.s("add" , add); |
74 | rewriter.RegisterRewritePattern( |
75 | conv_add_relu_rstring.format(env), conv_add_relu_fused); |
76 | } |
77 | } |
78 | } |
79 | |
80 | auto filter = [](const Match& match, |
81 | const std::unordered_map<std::string, Value*>& vmap) { |
82 | auto weight = toIValue(match.values_map.at(vmap.at("weight" ))); |
83 | if (!weight.has_value() || !weight.value().isTensor()) { |
84 | return false; |
85 | } |
86 | const at::Tensor& weight_t = weight.value().toTensor(); |
87 | if (!weight_t.device().is_cuda() || !weight_t.is_contiguous()) { |
88 | return false; |
89 | } |
90 | |
91 | // bias is optional |
92 | if (vmap.find("bias" ) != vmap.end()) { |
93 | auto bias = toIValue(match.values_map.at(vmap.at("bias" ))); |
94 | if (bias.has_value() && bias.value().isTensor()) { |
95 | const at::Tensor& bias_t = bias.value().toTensor(); |
96 | if (bias_t.dtype() != weight_t.dtype() || bias_t.ndimension() != 1 || |
97 | bias_t.size(0) != weight_t.size(0) || !bias_t.device().is_cuda()) { |
98 | return false; |
99 | } |
100 | } |
101 | } |
102 | |
103 | // z is optional |
104 | if (vmap.find("z" ) != vmap.end()) { |
105 | auto z = toIValue(match.values_map.at(vmap.at("z" ))); |
106 | if (z.has_value() && z.value().isTensor()) { |
107 | const at::Tensor& z_t = z.value().toTensor(); |
108 | if (z_t.dtype() != weight_t.dtype() || |
109 | z_t.size(0) != weight_t.size(0) || !z_t.is_contiguous() || |
110 | !z_t.device().is_cuda()) { |
111 | return false; |
112 | } |
113 | } |
114 | } |
115 | return true; |
116 | }; |
117 | |
118 | // Convert _convolution and in-place operators for simpler replacement pattern |
119 | // matching |
120 | graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); |
121 | |
122 | rewriter.runOnGraph(graph, filter); |
123 | GRAPH_DEBUG("After fuseFrozenConvAddReluImpl: " , *graph); |
124 | #endif |
125 | } |
126 | |
127 | auto dummyInitializer = []() { |
128 | getFuseFrozenConvAddReluImpl() = fuseFrozenConvAddReluImpl; |
129 | return true; |
130 | }(); |
131 | |
132 | } // namespace |
133 | |
134 | } // namespace jit |
135 | } // namespace torch |
136 | |