1 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
2 | |
3 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
4 | #include <torch/csrc/jit/passes/constant_propagation.h> |
5 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace graph_rewrite_helper { |
10 | |
11 | std::string getFuncName(Value* func_value) { |
12 | auto func = func_value->type()->expectRef<FunctionType>().function(); |
13 | const auto& qname = func->qualname(); |
14 | const auto& name = qname.qualifiedName(); |
15 | auto rdot_idx = name.rfind('.'); |
16 | if (rdot_idx != std::string::npos) { |
17 | return name.substr(rdot_idx + 1, name.length()); |
18 | } else { |
19 | return name; |
20 | } |
21 | } |
22 | |
23 | Value* getValue( |
24 | const std::string& name, |
25 | const std::unordered_map<const Value*, Value*>& match_vmap, |
26 | const std::unordered_map<std::string, Value*>& vmap) { |
27 | return match_vmap.at(vmap.at(name)); |
28 | } |
29 | |
30 | c10::optional<IValue> getIValue( |
31 | const std::string& name, |
32 | const std::unordered_map<const Value*, Value*>& match_vmap, |
33 | const std::unordered_map<std::string, Value*>& vmap) { |
34 | return toIValue(getValue(name, match_vmap, vmap)); |
35 | } |
36 | |
37 | std::unordered_map<std::string, c10::IValue> getConvParams( |
38 | const Match& match, |
39 | const std::unordered_map<std::string, Value*>& vmap) { |
40 | std::unordered_map<std::string, c10::IValue> calc_values; |
41 | const auto& match_vmap = match.values_map; |
42 | auto transposed_value = getIValue("transposed" , match_vmap, vmap).value(); |
43 | calc_values["transposed" ] = transposed_value; |
44 | auto output_padding_value = |
45 | getIValue("output_padding" , match_vmap, vmap).value(); |
46 | calc_values["output_padding" ] = output_padding_value; |
47 | auto stride_value = getIValue("stride" , match_vmap, vmap).value(); |
48 | calc_values["stride" ] = stride_value; |
49 | auto padding_value = getIValue("padding" , match_vmap, vmap).value(); |
50 | calc_values["padding" ] = padding_value; |
51 | auto dilation_value = getIValue("dilation" , match_vmap, vmap).value(); |
52 | calc_values["dilation" ] = dilation_value; |
53 | return calc_values; |
54 | } |
55 | |
56 | void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) { |
57 | // TODO: remove constant prop in the pass |
58 | ConstantPropagation(graph); |
59 | std::string convolution_deprecated = R"( |
60 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
61 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
62 | %deterministic:bool, %cudnn_enabled:bool): |
63 | %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, |
64 | %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled) |
65 | return (%r) )" ; |
66 | |
67 | std::string convolution = R"( |
68 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
69 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
70 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
71 | %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, |
72 | %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32) |
73 | return (%r) )" ; |
74 | |
75 | std::string conv2d_for_deprecated_conv = R"( |
76 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
77 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
78 | %deterministic:bool, %cudnn_enabled:bool): |
79 | %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) |
80 | return (%r) )" ; |
81 | std::string conv2d = R"( |
82 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
83 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
84 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
85 | %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) |
86 | return (%r) )" ; |
87 | |
88 | std::string conv1d_for_deprecated_conv = R"( |
89 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
90 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
91 | %deterministic:bool, %cudnn_enabled:bool): |
92 | %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) |
93 | return (%r) )" ; |
94 | std::string conv1d = R"( |
95 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
96 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
97 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
98 | %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) |
99 | return (%r) )" ; |
100 | |
101 | std::string conv3d_for_deprecated_conv = R"( |
102 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
103 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
104 | %deterministic:bool, %cudnn_enabled:bool): |
105 | %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) |
106 | return (%r) )" ; |
107 | std::string conv3d = R"( |
108 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
109 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
110 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
111 | %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) |
112 | return (%r) )" ; |
113 | |
114 | std::string conv_transpose1d_for_deprecated_conv = R"( |
115 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
116 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
117 | %deterministic:bool, %cudnn_enabled:bool): |
118 | %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) |
119 | return (%r) )" ; |
120 | |
121 | std::string conv_transpose1d = R"( |
122 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
123 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
124 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
125 | %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) |
126 | return (%r) )" ; |
127 | |
128 | std::string conv_transpose2d_for_deprecated_conv = R"( |
129 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
130 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
131 | %deterministic:bool, %cudnn_enabled:bool): |
132 | %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) |
133 | return (%r) )" ; |
134 | |
135 | std::string conv_transpose2d = R"( |
136 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
137 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
138 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
139 | %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) |
140 | return (%r) )" ; |
141 | |
142 | std::string conv_transpose3d_for_deprecated_conv = R"( |
143 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
144 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
145 | %deterministic:bool, %cudnn_enabled:bool): |
146 | %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) |
147 | return (%r) )" ; |
148 | |
149 | std::string conv_transpose3d = R"( |
150 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
151 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
152 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
153 | %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) |
154 | return (%r) )" ; |
155 | |
156 | // Filter the unsupported case |
157 | auto filter_conv1d = [](const Match& match, |
158 | const std::unordered_map<std::string, Value*>& vmap) { |
159 | auto calc_value_map = getConvParams(match, vmap); |
160 | if (calc_value_map["output_padding" ].toIntList().size() != 1 || |
161 | calc_value_map["stride" ].toIntList().size() != 1 || |
162 | calc_value_map["padding" ].toIntList().size() != 1 || |
163 | calc_value_map["dilation" ].toIntList().size() != 1) { |
164 | return false; |
165 | } |
166 | return !calc_value_map["transposed" ].toBool(); |
167 | }; |
168 | auto filter_conv2d = [](const Match& match, |
169 | const std::unordered_map<std::string, Value*>& vmap) { |
170 | auto calc_value_map = getConvParams(match, vmap); |
171 | if (calc_value_map["output_padding" ].toIntList().size() != 2 || |
172 | calc_value_map["stride" ].toIntList().size() != 2 || |
173 | calc_value_map["padding" ].toIntList().size() != 2 || |
174 | calc_value_map["dilation" ].toIntList().size() != 2) { |
175 | return false; |
176 | } |
177 | return !calc_value_map["transposed" ].toBool(); |
178 | }; |
179 | auto filter_conv3d = [](const Match& match, |
180 | const std::unordered_map<std::string, Value*>& vmap) { |
181 | auto calc_value_map = getConvParams(match, vmap); |
182 | if (calc_value_map["output_padding" ].toIntList().size() != 3 || |
183 | calc_value_map["stride" ].toIntList().size() != 3 || |
184 | calc_value_map["padding" ].toIntList().size() != 3 || |
185 | calc_value_map["dilation" ].toIntList().size() != 3) { |
186 | return false; |
187 | } |
188 | return !calc_value_map["transposed" ].toBool(); |
189 | }; |
190 | auto filter_conv_transpose1d = |
191 | [](const Match& match, |
192 | const std::unordered_map<std::string, Value*>& vmap) { |
193 | auto calc_value_map = getConvParams(match, vmap); |
194 | if (calc_value_map["output_padding" ].toIntList().size() != 1 || |
195 | calc_value_map["stride" ].toIntList().size() != 1 || |
196 | calc_value_map["padding" ].toIntList().size() != 1 || |
197 | calc_value_map["dilation" ].toIntList().size() != 1) { |
198 | return false; |
199 | } |
200 | return calc_value_map["transposed" ].toBool(); |
201 | }; |
202 | auto filter_conv_transpose2d = |
203 | [](const Match& match, |
204 | const std::unordered_map<std::string, Value*>& vmap) { |
205 | auto calc_value_map = getConvParams(match, vmap); |
206 | if (calc_value_map["output_padding" ].toIntList().size() != 2 || |
207 | calc_value_map["stride" ].toIntList().size() != 2 || |
208 | calc_value_map["padding" ].toIntList().size() != 2 || |
209 | calc_value_map["dilation" ].toIntList().size() != 2) { |
210 | return false; |
211 | } |
212 | return calc_value_map["transposed" ].toBool(); |
213 | }; |
214 | auto filter_conv_transpose3d = |
215 | [](const Match& match, |
216 | const std::unordered_map<std::string, Value*>& vmap) { |
217 | auto calc_value_map = getConvParams(match, vmap); |
218 | if (calc_value_map["output_padding" ].toIntList().size() != 3 || |
219 | calc_value_map["stride" ].toIntList().size() != 3 || |
220 | calc_value_map["padding" ].toIntList().size() != 3 || |
221 | calc_value_map["dilation" ].toIntList().size() != 3) { |
222 | return false; |
223 | } |
224 | return calc_value_map["transposed" ].toBool(); |
225 | }; |
226 | |
227 | SubgraphRewriter rewriter_conv1d; |
228 | rewriter_conv1d.RegisterRewritePattern(convolution, conv1d); |
229 | rewriter_conv1d.RegisterRewritePattern( |
230 | convolution_deprecated, conv1d_for_deprecated_conv); |
231 | rewriter_conv1d.runOnGraph(graph, filter_conv1d); |
232 | |
233 | SubgraphRewriter rewriter_conv2d; |
234 | rewriter_conv2d.RegisterRewritePattern(convolution, conv2d); |
235 | rewriter_conv2d.RegisterRewritePattern( |
236 | convolution_deprecated, conv2d_for_deprecated_conv); |
237 | rewriter_conv2d.runOnGraph(graph, filter_conv2d); |
238 | |
239 | SubgraphRewriter rewriter_conv3d; |
240 | rewriter_conv3d.RegisterRewritePattern(convolution, conv3d); |
241 | rewriter_conv3d.RegisterRewritePattern( |
242 | convolution_deprecated, conv3d_for_deprecated_conv); |
243 | rewriter_conv3d.runOnGraph(graph, filter_conv3d); |
244 | |
245 | SubgraphRewriter rewriter_conv_transpose1d; |
246 | rewriter_conv_transpose1d.RegisterRewritePattern( |
247 | convolution, conv_transpose1d); |
248 | rewriter_conv_transpose1d.RegisterRewritePattern( |
249 | convolution_deprecated, conv_transpose1d_for_deprecated_conv); |
250 | rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d); |
251 | |
252 | SubgraphRewriter rewriter_conv_transpose2d; |
253 | rewriter_conv_transpose2d.RegisterRewritePattern( |
254 | convolution, conv_transpose2d); |
255 | rewriter_conv_transpose2d.RegisterRewritePattern( |
256 | convolution_deprecated, conv_transpose2d_for_deprecated_conv); |
257 | rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d); |
258 | |
259 | SubgraphRewriter rewriter_conv_transpose3d; |
260 | rewriter_conv_transpose3d.RegisterRewritePattern( |
261 | convolution, conv_transpose3d); |
262 | rewriter_conv_transpose3d.RegisterRewritePattern( |
263 | convolution_deprecated, conv_transpose3d_for_deprecated_conv); |
264 | rewriter_conv_transpose3d.runOnGraph(graph, filter_conv_transpose3d); |
265 | } |
266 | |
267 | bool isClampFusable( |
268 | const Match& match, |
269 | const std::unordered_map<std::string, Value*>& vmap) { |
270 | const auto& match_vmap = match.values_map; |
271 | TORCH_CHECK( |
272 | vmap.find("dummy_min_max" ) != vmap.end(), |
273 | "Expected to find dummy_min_max Value in the subgraph to be replaced." ); |
274 | auto dummy_min_max = |
275 | graph_rewrite_helper::getIValue("dummy_min_max" , match_vmap, vmap); |
276 | |
277 | auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone(); |
278 | |
279 | // Also check if the output_min and output_max values are actually constant. |
280 | // If hardtanh's min/max Value's are not actually constants, we will end up |
281 | // rerouting those values to prepack op. And if they are not constants |
282 | // we will not be able to remove prepacking ops. |
283 | if (vmap.find("output_min" ) != vmap.end()) { |
284 | // aten::relu pattern does not have output_min/output_max. |
285 | // aten::hardtanh/_ does. |
286 | TORCH_CHECK( |
287 | vmap.find("output_max" ) != vmap.end(), |
288 | "Expected to find output_max as well given " |
289 | "output_min exist in pattern graph." ); |
290 | // If output_min/max are not constant, we get c10::nullopt. |
291 | auto output_min = |
292 | graph_rewrite_helper::getIValue("output_min" , match_vmap, vmap); |
293 | auto output_max = |
294 | graph_rewrite_helper::getIValue("output_max" , match_vmap, vmap); |
295 | is_fusable = |
296 | is_fusable && (output_min.has_value() && output_max.has_value()); |
297 | } |
298 | |
299 | return is_fusable; |
300 | } |
301 | |
302 | } // namespace graph_rewrite_helper |
303 | } // namespace jit |
304 | } // namespace torch |
305 | |