1#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
2
3#include <torch/csrc/jit/ir/subgraph_matcher.h>
4#include <torch/csrc/jit/passes/constant_propagation.h>
5#include <torch/csrc/jit/passes/subgraph_rewrite.h>
6
7namespace torch {
8namespace jit {
9namespace graph_rewrite_helper {
10
11std::string getFuncName(Value* func_value) {
12 auto func = func_value->type()->expectRef<FunctionType>().function();
13 const auto& qname = func->qualname();
14 const auto& name = qname.qualifiedName();
15 auto rdot_idx = name.rfind('.');
16 if (rdot_idx != std::string::npos) {
17 return name.substr(rdot_idx + 1, name.length());
18 } else {
19 return name;
20 }
21}
22
23Value* getValue(
24 const std::string& name,
25 const std::unordered_map<const Value*, Value*>& match_vmap,
26 const std::unordered_map<std::string, Value*>& vmap) {
27 return match_vmap.at(vmap.at(name));
28}
29
30c10::optional<IValue> getIValue(
31 const std::string& name,
32 const std::unordered_map<const Value*, Value*>& match_vmap,
33 const std::unordered_map<std::string, Value*>& vmap) {
34 return toIValue(getValue(name, match_vmap, vmap));
35}
36
37std::unordered_map<std::string, c10::IValue> getConvParams(
38 const Match& match,
39 const std::unordered_map<std::string, Value*>& vmap) {
40 std::unordered_map<std::string, c10::IValue> calc_values;
41 const auto& match_vmap = match.values_map;
42 auto transposed_value = getIValue("transposed", match_vmap, vmap).value();
43 calc_values["transposed"] = transposed_value;
44 auto output_padding_value =
45 getIValue("output_padding", match_vmap, vmap).value();
46 calc_values["output_padding"] = output_padding_value;
47 auto stride_value = getIValue("stride", match_vmap, vmap).value();
48 calc_values["stride"] = stride_value;
49 auto padding_value = getIValue("padding", match_vmap, vmap).value();
50 calc_values["padding"] = padding_value;
51 auto dilation_value = getIValue("dilation", match_vmap, vmap).value();
52 calc_values["dilation"] = dilation_value;
53 return calc_values;
54}
55
56void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
57 // TODO: remove constant prop in the pass
58 ConstantPropagation(graph);
59 std::string convolution_deprecated = R"(
60 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
61 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
62 %deterministic:bool, %cudnn_enabled:bool):
63 %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
64 %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
65 return (%r) )";
66
67 std::string convolution = R"(
68 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
69 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
70 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
71 %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
72 %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
73 return (%r) )";
74
75 std::string conv2d_for_deprecated_conv = R"(
76 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
77 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
78 %deterministic:bool, %cudnn_enabled:bool):
79 %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
80 return (%r) )";
81 std::string conv2d = R"(
82 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
83 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
84 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
85 %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
86 return (%r) )";
87
88 std::string conv1d_for_deprecated_conv = R"(
89 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
90 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
91 %deterministic:bool, %cudnn_enabled:bool):
92 %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
93 return (%r) )";
94 std::string conv1d = R"(
95 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
96 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
97 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
98 %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
99 return (%r) )";
100
101 std::string conv3d_for_deprecated_conv = R"(
102 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
103 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
104 %deterministic:bool, %cudnn_enabled:bool):
105 %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
106 return (%r) )";
107 std::string conv3d = R"(
108 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
109 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
110 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
111 %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
112 return (%r) )";
113
114 std::string conv_transpose1d_for_deprecated_conv = R"(
115 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
116 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
117 %deterministic:bool, %cudnn_enabled:bool):
118 %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
119 return (%r) )";
120
121 std::string conv_transpose1d = R"(
122 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
123 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
124 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
125 %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
126 return (%r) )";
127
128 std::string conv_transpose2d_for_deprecated_conv = R"(
129 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
130 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
131 %deterministic:bool, %cudnn_enabled:bool):
132 %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
133 return (%r) )";
134
135 std::string conv_transpose2d = R"(
136 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
137 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
138 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
139 %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
140 return (%r) )";
141
142 std::string conv_transpose3d_for_deprecated_conv = R"(
143 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
144 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
145 %deterministic:bool, %cudnn_enabled:bool):
146 %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
147 return (%r) )";
148
149 std::string conv_transpose3d = R"(
150 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
151 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
152 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
153 %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
154 return (%r) )";
155
156 // Filter the unsupported case
157 auto filter_conv1d = [](const Match& match,
158 const std::unordered_map<std::string, Value*>& vmap) {
159 auto calc_value_map = getConvParams(match, vmap);
160 if (calc_value_map["output_padding"].toIntList().size() != 1 ||
161 calc_value_map["stride"].toIntList().size() != 1 ||
162 calc_value_map["padding"].toIntList().size() != 1 ||
163 calc_value_map["dilation"].toIntList().size() != 1) {
164 return false;
165 }
166 return !calc_value_map["transposed"].toBool();
167 };
168 auto filter_conv2d = [](const Match& match,
169 const std::unordered_map<std::string, Value*>& vmap) {
170 auto calc_value_map = getConvParams(match, vmap);
171 if (calc_value_map["output_padding"].toIntList().size() != 2 ||
172 calc_value_map["stride"].toIntList().size() != 2 ||
173 calc_value_map["padding"].toIntList().size() != 2 ||
174 calc_value_map["dilation"].toIntList().size() != 2) {
175 return false;
176 }
177 return !calc_value_map["transposed"].toBool();
178 };
179 auto filter_conv3d = [](const Match& match,
180 const std::unordered_map<std::string, Value*>& vmap) {
181 auto calc_value_map = getConvParams(match, vmap);
182 if (calc_value_map["output_padding"].toIntList().size() != 3 ||
183 calc_value_map["stride"].toIntList().size() != 3 ||
184 calc_value_map["padding"].toIntList().size() != 3 ||
185 calc_value_map["dilation"].toIntList().size() != 3) {
186 return false;
187 }
188 return !calc_value_map["transposed"].toBool();
189 };
190 auto filter_conv_transpose1d =
191 [](const Match& match,
192 const std::unordered_map<std::string, Value*>& vmap) {
193 auto calc_value_map = getConvParams(match, vmap);
194 if (calc_value_map["output_padding"].toIntList().size() != 1 ||
195 calc_value_map["stride"].toIntList().size() != 1 ||
196 calc_value_map["padding"].toIntList().size() != 1 ||
197 calc_value_map["dilation"].toIntList().size() != 1) {
198 return false;
199 }
200 return calc_value_map["transposed"].toBool();
201 };
202 auto filter_conv_transpose2d =
203 [](const Match& match,
204 const std::unordered_map<std::string, Value*>& vmap) {
205 auto calc_value_map = getConvParams(match, vmap);
206 if (calc_value_map["output_padding"].toIntList().size() != 2 ||
207 calc_value_map["stride"].toIntList().size() != 2 ||
208 calc_value_map["padding"].toIntList().size() != 2 ||
209 calc_value_map["dilation"].toIntList().size() != 2) {
210 return false;
211 }
212 return calc_value_map["transposed"].toBool();
213 };
214 auto filter_conv_transpose3d =
215 [](const Match& match,
216 const std::unordered_map<std::string, Value*>& vmap) {
217 auto calc_value_map = getConvParams(match, vmap);
218 if (calc_value_map["output_padding"].toIntList().size() != 3 ||
219 calc_value_map["stride"].toIntList().size() != 3 ||
220 calc_value_map["padding"].toIntList().size() != 3 ||
221 calc_value_map["dilation"].toIntList().size() != 3) {
222 return false;
223 }
224 return calc_value_map["transposed"].toBool();
225 };
226
227 SubgraphRewriter rewriter_conv1d;
228 rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
229 rewriter_conv1d.RegisterRewritePattern(
230 convolution_deprecated, conv1d_for_deprecated_conv);
231 rewriter_conv1d.runOnGraph(graph, filter_conv1d);
232
233 SubgraphRewriter rewriter_conv2d;
234 rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
235 rewriter_conv2d.RegisterRewritePattern(
236 convolution_deprecated, conv2d_for_deprecated_conv);
237 rewriter_conv2d.runOnGraph(graph, filter_conv2d);
238
239 SubgraphRewriter rewriter_conv3d;
240 rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
241 rewriter_conv3d.RegisterRewritePattern(
242 convolution_deprecated, conv3d_for_deprecated_conv);
243 rewriter_conv3d.runOnGraph(graph, filter_conv3d);
244
245 SubgraphRewriter rewriter_conv_transpose1d;
246 rewriter_conv_transpose1d.RegisterRewritePattern(
247 convolution, conv_transpose1d);
248 rewriter_conv_transpose1d.RegisterRewritePattern(
249 convolution_deprecated, conv_transpose1d_for_deprecated_conv);
250 rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d);
251
252 SubgraphRewriter rewriter_conv_transpose2d;
253 rewriter_conv_transpose2d.RegisterRewritePattern(
254 convolution, conv_transpose2d);
255 rewriter_conv_transpose2d.RegisterRewritePattern(
256 convolution_deprecated, conv_transpose2d_for_deprecated_conv);
257 rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d);
258
259 SubgraphRewriter rewriter_conv_transpose3d;
260 rewriter_conv_transpose3d.RegisterRewritePattern(
261 convolution, conv_transpose3d);
262 rewriter_conv_transpose3d.RegisterRewritePattern(
263 convolution_deprecated, conv_transpose3d_for_deprecated_conv);
264 rewriter_conv_transpose3d.runOnGraph(graph, filter_conv_transpose3d);
265}
266
267bool isClampFusable(
268 const Match& match,
269 const std::unordered_map<std::string, Value*>& vmap) {
270 const auto& match_vmap = match.values_map;
271 TORCH_CHECK(
272 vmap.find("dummy_min_max") != vmap.end(),
273 "Expected to find dummy_min_max Value in the subgraph to be replaced.");
274 auto dummy_min_max =
275 graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
276
277 auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
278
279 // Also check if the output_min and output_max values are actually constant.
280 // If hardtanh's min/max Value's are not actually constants, we will end up
281 // rerouting those values to prepack op. And if they are not constants
282 // we will not be able to remove prepacking ops.
283 if (vmap.find("output_min") != vmap.end()) {
284 // aten::relu pattern does not have output_min/output_max.
285 // aten::hardtanh/_ does.
286 TORCH_CHECK(
287 vmap.find("output_max") != vmap.end(),
288 "Expected to find output_max as well given "
289 "output_min exist in pattern graph.");
290 // If output_min/max are not constant, we get c10::nullopt.
291 auto output_min =
292 graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
293 auto output_max =
294 graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
295 is_fusable =
296 is_fusable && (output_min.has_value() && output_max.has_value());
297 }
298
299 return is_fusable;
300}
301
302} // namespace graph_rewrite_helper
303} // namespace jit
304} // namespace torch
305