1 | #pragma once |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
6 | #include <torch/csrc/jit/jit_log.h> |
7 | #include <torch/csrc/jit/passes/quantization/helper.h> |
8 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
9 | #include <string> |
10 | #include <unordered_map> |
11 | #include <utility> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | struct QuantFusionInfo { |
17 | std::string quantized_op_name; |
18 | std::string pattern; |
19 | std::string replacement; |
20 | std::vector<MatchFilter> filters = {}; |
21 | }; |
22 | |
23 | namespace { |
24 | std::string (std::vector<std::string> ) { |
25 | return std::accumulate( |
26 | extra_args.begin(), |
27 | extra_args.end(), |
28 | std::string(), |
29 | [](std::string acc, const std::string& arg) { return acc + ", " + arg; }); |
30 | } |
31 | |
32 | // Get the pattern we want to replace the match with |
33 | std::string getAtenOpPattern( |
34 | const std::string& , |
35 | const std::string& op_name, |
36 | const std::vector<std::string>& , |
37 | bool scalar_args = false) { |
38 | std::vector<std::string> = extra_op_args; |
39 | std::string aten_op_pattern = graph_header; |
40 | if (scalar_args) { |
41 | for (const auto& : _extra_op_args) { |
42 | aten_op_pattern |
43 | .append(R"( |
44 | )" ) |
45 | .append(extra_arg) |
46 | .append("_scalar = aten::item(" ) |
47 | .append(extra_arg) |
48 | .append(")" ); |
49 | } |
50 | |
51 | for (auto& : _extra_op_args) { |
52 | _extra_op_arg.append("_scalar" ); |
53 | } |
54 | } |
55 | const auto& = getExtraArgList(std::move(_extra_op_args)); |
56 | aten_op_pattern += R"( |
57 | %r = )" ; |
58 | aten_op_pattern += op_name + "(" + "%a_quant" + extra_op_arg_list + ")" ; |
59 | aten_op_pattern += R"( |
60 | return (%r) )" ; |
61 | return aten_op_pattern; |
62 | } |
63 | |
64 | // generate ops for quantize pattern for a scalar value |
65 | std::string getQuantizeForScalar(const std::string& value) { |
66 | // 6 is `torch.float` ScalarType, we are creating a float scalar |
67 | // tensor from a scalar value |
68 | std::string quantize_pattern = R"( |
69 | )" + |
70 | value + "_float_scalar_type : int = prim::Constant[value=6]()" ; |
71 | quantize_pattern += R"( |
72 | )" + |
73 | value + "_none : None = prim::Constant()" ; |
74 | quantize_pattern += R"( |
75 | )" + |
76 | value + "_tensor : Tensor = aten::scalar_tensor(" + value + ", " + value + |
77 | "_float_scalar_type" ; |
78 | for (const auto i : c10::irange(3)) { |
79 | (void)i; // Suppress unused variable warning |
80 | quantize_pattern += ", " + value + "_none" ; |
81 | } |
82 | quantize_pattern += ")" ; |
83 | quantize_pattern += |
84 | R"( |
85 | )" + |
86 | value + "_quant = aten::quantize_per_tensor(" + value + "_tensor" + |
87 | getExtraArgList( |
88 | {value + "_scale" , value + "_zero_point" , value + "_dtype" }) + |
89 | ")" ; |
90 | return quantize_pattern; |
91 | } |
92 | |
93 | std::string getDequantize(const std::string& value) { |
94 | return R"( |
95 | )" + |
96 | value + "_dequant = aten::dequantize(" + value + "_quant)" ; |
97 | } |
98 | |
99 | std::string getItem(const std::string& value) { |
100 | return R"( |
101 | )" + |
102 | value + "_scalar : float = aten::item(" + value + "_dequant)" ; |
103 | } |
104 | |
105 | // Patterns for the ops that inherit parameters from input |
106 | std::string getInputTensorQParamOpPattern( |
107 | const std::string& op_name, |
108 | const std::vector<std::string>& ) { |
109 | const auto& = getExtraArgList(extra_op_args); |
110 | std::string op_pattern = "graph(%a_quant" + extra_op_arg_list + "):" + R"( |
111 | %a_dequant = aten::dequantize(%a_quant) |
112 | %r = )" + |
113 | op_name + "(" + "%a_dequant" + extra_op_arg_list + ")" + R"( |
114 | %r_scale : float = aten::q_scale(%a_quant) |
115 | %r_zero_point : int = aten::q_zero_point(%a_quant) |
116 | %r_dtype : int = prim::dtype(%a_quant) |
117 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
118 | return (%r_quant) )" ; |
119 | return op_pattern; |
120 | } |
121 | |
122 | // QuantFusionInfo for the ops that inherit parameters from input |
123 | QuantFusionInfo getInputTensorQParamOpFusionInfo( |
124 | const std::string& op_name, |
125 | const std::vector<std::string>& ) { |
126 | std::string op_pattern = |
127 | getInputTensorQParamOpPattern(op_name, extra_op_args); |
128 | const auto& = getExtraArgList(extra_op_args); |
129 | std::string = "graph(%a_quant" + extra_op_arg_list + "):" ; |
130 | std::string op_replacement = |
131 | getAtenOpPattern(graph_header, op_name, extra_op_args); |
132 | |
133 | return {op_name, std::move(op_pattern), std::move(op_replacement)}; |
134 | } |
135 | |
136 | // quant fusion for ops like `quantized::add_scalar`, `quantized::mul_scalar` |
137 | QuantFusionInfo getBinaryOpScalarFusionInfo( |
138 | const std::string& op_name, |
139 | const std::vector<std::string>& , |
140 | const std::string& quantized_op_name, |
141 | const std::vector<std::string>& , |
142 | const std::vector<MatchFilter>& filters = {}) { |
143 | std::string op_pattern = |
144 | getInputTensorQParamOpPattern(op_name, extra_op_args); |
145 | |
146 | const auto& = getExtraArgList(extra_op_args); |
147 | std::string = "graph(%a_quant" + extra_op_arg_list + "):" ; |
148 | std::string op_replacement = getAtenOpPattern( |
149 | graph_header, quantized_op_name, extra_quantized_op_args); |
150 | |
151 | return {op_name, std::move(op_pattern), std::move(op_replacement), filters}; |
152 | } |
153 | |
154 | QuantFusionInfo getClampOpFusionInfo( |
155 | const std::string& op_name, |
156 | const std::vector<std::string>& ) { |
157 | std::vector<std::string> = extra_op_args; |
158 | std::vector<std::string> input_qparams = {"_scale" , "_zero_point" , "_dtype" }; |
159 | for (const auto& arg : extra_op_args) { |
160 | for (const auto& qparam : input_qparams) { |
161 | header_args.push_back(arg + qparam); |
162 | } |
163 | } |
164 | for (const auto& qparam : input_qparams) { |
165 | header_args.push_back("%r" + qparam); |
166 | } |
167 | const auto& = getExtraArgList(std::move(header_args)); |
168 | std::string = "graph(%a_quant" + extra_header_arg_list + "):" ; |
169 | std::string op_pattern = graph_header; |
170 | for (const auto& arg : extra_op_args) { |
171 | op_pattern += getQuantizeForScalar(arg); |
172 | op_pattern += getDequantize(arg); |
173 | op_pattern += getItem(arg); |
174 | } |
175 | op_pattern += getDequantize("%a" ); |
176 | op_pattern += R"( |
177 | %r = )" ; |
178 | std::vector<std::string> ; |
179 | scalar_extra_args.reserve(extra_op_args.size()); |
180 | for (const auto& arg : extra_op_args) { |
181 | scalar_extra_args.push_back(arg + "_scalar" ); |
182 | } |
183 | op_pattern += op_name + "(" + "%a_dequant" + |
184 | getExtraArgList(std::move(scalar_extra_args)) + ")" ; |
185 | // IR pattern common to all ops that inherit qparam from input |
186 | op_pattern += R"( |
187 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
188 | return (%r_quant) )" ; |
189 | |
190 | std::string aten_op_pattern = |
191 | getAtenOpPattern(graph_header, op_name, extra_op_args); |
192 | |
193 | return {op_name, std::move(op_pattern), std::move(aten_op_pattern)}; |
194 | } |
195 | |
196 | // Patterns for the ops that has fixed quantization parameters |
197 | QuantFusionInfo getFixedQParamOpFusionInfo( |
198 | const std::string& op_name, |
199 | const std::vector<std::string>& , |
200 | bool is_symmetric) { |
201 | const auto& = getExtraArgList(extra_op_args); |
202 | std::string = "graph(%a_quant" + extra_op_arg_list + "):" ; |
203 | std::string op_pattern = graph_header; |
204 | op_pattern += R"( |
205 | %a_dequant = aten::dequantize(%a_quant) |
206 | %r = )" ; |
207 | op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")" ; |
208 | // IR pattern common to all ops with fixed quantization parameters for |
209 | // asymetric quantization |
210 | std::string asym_fixed_qparam_op_suffix = R"( |
211 | %r_scale : float = prim::Constant[value=0.00390625]() |
212 | %r_zero_point : int = prim::Constant[value=0]() |
213 | %r_dtype : int = prim::Constant[value=13]() |
214 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
215 | return (%r_quant) )" ; |
216 | |
217 | std::string sym_fixed_qparam_op_suffix = R"( |
218 | %r_scale : float = prim::Constant[value=0.0078125]() |
219 | %r_zero_point : int = prim::Constant[value=128]() |
220 | %r_dtype : int = prim::Constant[value=13]() |
221 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
222 | return (%r_quant) )" ; |
223 | op_pattern += |
224 | is_symmetric ? sym_fixed_qparam_op_suffix : asym_fixed_qparam_op_suffix; |
225 | |
226 | std::string aten_op_pattern = |
227 | getAtenOpPattern(graph_header, op_name, extra_op_args); |
228 | |
229 | return {op_name, std::move(op_pattern), std::move(aten_op_pattern)}; |
230 | } |
231 | |
232 | // filter that checks %b_scalar is a scalar |
233 | bool input_b_is_scalar( |
234 | const Match& match, |
235 | const std::unordered_map<std::string, Value*>& vmap) { |
236 | const auto& match_vmap = match.values_map; |
237 | auto b_scalar = match_vmap.at(vmap.at("b_scalar" )); |
238 | return isScalar(b_scalar); |
239 | } |
240 | |
241 | // Patterns for ops that require observation for output quantization parameters |
242 | // Example: |
243 | // |
244 | // before fusion: |
245 | // |
246 | // graph(%a_quant, %r_scale, %r_zero_point, %r_dtype): |
247 | // %a_dequant = aten::dequantize(%a_quant) |
248 | // %r = {op_name}(%a_dequant, {extra_args}) |
249 | // %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, |
250 | // %r_dtype) return (%r_quant) |
251 | // |
252 | // after fusion: |
253 | // |
254 | // graph(%a_quant, %r_scale, %r_zero_point, %r_dtype): |
255 | // %r_quant = {quantized_op_name}(%a_quant, {extra_args}, %r_scale, |
256 | // %r_zero_point) return (%r_quant) |
257 | QuantFusionInfo getObservedQParamOpFusionInfo( |
258 | const std::string& fp_op_name, |
259 | const std::string& q_op_name, |
260 | const std::vector<std::string>& , |
261 | const std::vector<std::string>& ) { |
262 | const auto& = getExtraArgList(fp_extra_args); |
263 | const auto& = getExtraArgList(q_extra_args); |
264 | |
265 | std::string op_pattern = "graph(%a_quant" + fp_extra_arg_list + |
266 | ", %r_scale, %r_zero_point, %r_dtype):" + R"( |
267 | %a_dequant = aten::dequantize(%a_quant) |
268 | %r = )" + |
269 | fp_op_name + "(" + "%a_dequant" + fp_extra_arg_list + ")" + R"( |
270 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
271 | return (%r_quant) )" ; |
272 | |
273 | std::string aten_op_pattern = "graph(%a_quant" + fp_extra_arg_list + |
274 | ", %r_scale, %r_zero_point, %r_dtype):" + R"( |
275 | %r_quant = )" + |
276 | q_op_name + "(%a_quant" + q_extra_arg_list + |
277 | ", %r_scale, %r_zero_point)" + R"( |
278 | return (%r_quant) )" ; |
279 | |
280 | return {q_op_name, std::move(op_pattern), std::move(aten_op_pattern)}; |
281 | } |
282 | |
283 | } // namespace |
284 | |
285 | std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() { |
286 | // aten::conv1d |
287 | std::string conv1d = R"( |
288 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
289 | %a_dequant = aten::dequantize(%a_quant) |
290 | %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params) |
291 | %w_dequant = aten::dequantize(%w_quant) |
292 | %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
293 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
294 | return (%r_quant) )" ; |
295 | |
296 | // aten::conv1d - aten::relu |
297 | std::string conv1d_relu = R"( |
298 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
299 | %a_dequant = aten::dequantize(%a_quant) |
300 | %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params) |
301 | %w_dequant = aten::dequantize(%w_quant) |
302 | %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
303 | %r = aten::relu(%conv_out) |
304 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
305 | return (%r_quant) )" ; |
306 | |
307 | // aten::conv1d - aten::relu_ |
308 | std::string conv1d_inplace_relu = R"( |
309 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
310 | %a_dequant = aten::dequantize(%a_quant) |
311 | %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params) |
312 | %w_dequant = aten::dequantize(%w_quant) |
313 | %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
314 | %r = aten::relu_(%conv_out) |
315 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
316 | return (%r_quant) )" ; |
317 | |
318 | // quantized::conv1d |
319 | std::string quantized_conv1d = R"( |
320 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
321 | %r_quant = quantized::conv1d(%a_quant, %packed_params, %r_scale, %r_zero_point) |
322 | return (%r_quant) )" ; |
323 | |
324 | // quantized::conv1d_relu |
325 | std::string quantized_conv1d_relu = R"( |
326 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
327 | %r_quant = quantized::conv1d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point) |
328 | return (%r_quant) )" ; |
329 | |
330 | // aten::conv2d |
331 | std::string conv2d = R"( |
332 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
333 | %a_dequant = aten::dequantize(%a_quant) |
334 | %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params) |
335 | %w_dequant = aten::dequantize(%w_quant) |
336 | %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
337 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
338 | return (%r_quant) )" ; |
339 | |
340 | // aten::conv2d - aten::relu |
341 | std::string conv2d_relu = R"( |
342 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
343 | %a_dequant = aten::dequantize(%a_quant) |
344 | %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params) |
345 | %w_dequant = aten::dequantize(%w_quant) |
346 | %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
347 | %r = aten::relu(%conv_out) |
348 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
349 | return (%r_quant) )" ; |
350 | |
351 | // aten::conv2d - aten::relu_ |
352 | std::string conv2d_inplace_relu = R"( |
353 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
354 | %a_dequant = aten::dequantize(%a_quant) |
355 | %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params) |
356 | %w_dequant = aten::dequantize(%w_quant) |
357 | %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
358 | %r = aten::relu_(%conv_out) |
359 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
360 | return (%r_quant) )" ; |
361 | |
362 | // quantized::conv2d |
363 | std::string quantized_conv2d = R"( |
364 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
365 | %r_quant = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point) |
366 | return (%r_quant) )" ; |
367 | |
368 | // quantized::conv2d_relu |
369 | std::string quantized_conv2d_relu = R"( |
370 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
371 | %r_quant = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point) |
372 | return (%r_quant) )" ; |
373 | |
374 | // aten::conv3d |
375 | std::string conv3d = R"( |
376 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
377 | %a_dequant = aten::dequantize(%a_quant) |
378 | %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params) |
379 | %w_dequant = aten::dequantize(%w_quant) |
380 | %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
381 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
382 | return (%r_quant) )" ; |
383 | |
384 | // aten::conv3d - aten::relu |
385 | std::string conv3d_relu = R"( |
386 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
387 | %a_dequant = aten::dequantize(%a_quant) |
388 | %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params) |
389 | %w_dequant = aten::dequantize(%w_quant) |
390 | %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
391 | %r = aten::relu(%conv_out) |
392 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
393 | return (%r_quant) )" ; |
394 | |
395 | // aten::conv3d - aten::relu_ |
396 | std::string conv3d_inplace_relu = R"( |
397 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
398 | %a_dequant = aten::dequantize(%a_quant) |
399 | %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params) |
400 | %w_dequant = aten::dequantize(%w_quant) |
401 | %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
402 | %r = aten::relu_(%conv_out) |
403 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
404 | return (%r_quant) )" ; |
405 | |
406 | // quantized::conv3d |
407 | std::string quantized_conv3d = R"( |
408 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
409 | %r_quant = quantized::conv3d(%a_quant, %packed_params, %r_scale, %r_zero_point) |
410 | return (%r_quant) )" ; |
411 | |
412 | // quantized::conv3d_relu |
413 | std::string quantized_conv3d_relu = R"( |
414 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): |
415 | %r_quant = quantized::conv3d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point) |
416 | return (%r_quant) )" ; |
417 | |
418 | // aten::conv_transpose1d |
419 | std::string conv_transpose1d = R"( |
420 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): |
421 | %a_dequant = aten::dequantize(%a_quant) |
422 | %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose1d_unpack(%packed_params) |
423 | %w_dequant = aten::dequantize(%w_quant) |
424 | %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) |
425 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
426 | return (%r_quant) )" ; |
427 | |
428 | // quantized::conv_transpose1d |
429 | std::string quantized_conv_transpose1d = R"( |
430 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): |
431 | %r_quant = quantized::conv_transpose1d(%a_quant, %packed_params, %r_scale, %r_zero_point) |
432 | return (%r_quant) )" ; |
433 | |
434 | // aten::conv_transpose2d |
435 | std::string conv_transpose2d = R"( |
436 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): |
437 | %a_dequant = aten::dequantize(%a_quant) |
438 | %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose2d_unpack(%packed_params) |
439 | %w_dequant = aten::dequantize(%w_quant) |
440 | %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) |
441 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
442 | return (%r_quant) )" ; |
443 | |
444 | // quantized::conv_transpose1d |
445 | std::string quantized_conv_transpose2d = R"( |
446 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): |
447 | %r_quant = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point) |
448 | return (%r_quant) )" ; |
449 | |
450 | std::string add_relu = R"( |
451 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
452 | %a_dequant = aten::dequantize(%a_quant) |
453 | %b_dequant = aten::dequantize(%b_quant) |
454 | %r_add = aten::add(%a_dequant, %b_dequant, %alpha) |
455 | %r_relu = aten::relu(%r_add) |
456 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
457 | return (%r) )" ; |
458 | |
459 | std::string add_inplace_relu = R"( |
460 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
461 | %a_dequant = aten::dequantize(%a_quant) |
462 | %b_dequant = aten::dequantize(%b_quant) |
463 | %r_add = aten::add(%a_dequant, %b_dequant, %alpha) |
464 | %r_relu = aten::relu_(%r_add) |
465 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
466 | return (%r) )" ; |
467 | |
468 | std::string inplace_add_relu = R"( |
469 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
470 | %a_dequant = aten::dequantize(%a_quant) |
471 | %b_dequant = aten::dequantize(%b_quant) |
472 | %r_add = aten::add_(%a_dequant, %b_dequant, %alpha) |
473 | %r_relu = aten::relu(%r_add) |
474 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
475 | return (%r) )" ; |
476 | |
477 | std::string inplace_add_inplace_relu = R"( |
478 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
479 | %a_dequant = aten::dequantize(%a_quant) |
480 | %b_dequant = aten::dequantize(%b_quant) |
481 | %r_add = aten::add_(%a_dequant, %b_dequant, %alpha) |
482 | %r_relu = aten::relu_(%r_add) |
483 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
484 | return (%r) )" ; |
485 | |
486 | std::string quantized_add_relu = R"( |
487 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
488 | %r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point) |
489 | return (%r) )" ; |
490 | |
491 | // aten::linear |
492 | std::string linear = R"( |
493 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype): |
494 | %a_dequant = aten::dequantize(%a_quant) |
495 | %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params) |
496 | %w_dequant = aten::dequantize(%w_quant) |
497 | %r = aten::linear(%a_dequant, %w_dequant, %b) |
498 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
499 | return (%r_quant) )" ; |
500 | |
501 | std::string linear_relu = R"( |
502 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype): |
503 | %a_dequant = aten::dequantize(%a_quant) |
504 | %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params) |
505 | %w_dequant = aten::dequantize(%w_quant) |
506 | %linear_out = aten::linear(%a_dequant, %w_dequant, %b) |
507 | %r = aten::relu(%linear_out) |
508 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
509 | return (%r_quant) )" ; |
510 | |
511 | std::string linear_inplace_relu = R"( |
512 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype): |
513 | %a_dequant = aten::dequantize(%a_quant) |
514 | %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params) |
515 | %w_dequant = aten::dequantize(%w_quant) |
516 | %linear_out = aten::linear(%a_dequant, %w_dequant, %b) |
517 | %r = aten::relu_(%linear_out) |
518 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
519 | return (%r_quant) )" ; |
520 | |
521 | // quantized::linear |
522 | std::string quantized_linear = R"( |
523 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype): |
524 | %r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point) |
525 | return (%r) )" ; |
526 | |
527 | std::string quantized_linear_relu = R"( |
528 | graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype): |
529 | %r = quantized::linear_relu(%a_quant, %packed_params, %r_scale, %r_zero_point) |
530 | return (%r) )" ; |
531 | |
532 | std::string cat = R"( |
533 | graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype): |
534 | %input_dequant = aten::dequantize(%input_quant) |
535 | %r = aten::cat(%input_dequant, %dim) |
536 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
537 | return (%r_quant) )" ; |
538 | |
539 | std::string quantized_cat = R"( |
540 | graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype): |
541 | %r_quant = quantized::cat(%input_quant, %dim, %r_scale, %r_zero_point) |
542 | return (%r_quant) )" ; |
543 | |
544 | // aten::add |
545 | std::string add = R"( |
546 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
547 | %a_dequant = aten::dequantize(%a_quant) |
548 | %b_dequant = aten::dequantize(%b_quant) |
549 | %r_add = aten::add(%a_dequant, %b_dequant, %alpha) |
550 | %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype) |
551 | return (%r) )" ; |
552 | |
553 | // TODO: add %dtype after when https://github.com/pytorch/pytorch/issues/34351 |
554 | // is fixed |
555 | // quantized::add |
556 | std::string quantized_add = R"( |
557 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
558 | %r = quantized::add(%a_quant, %b_quant, %scale, %zero_point) |
559 | return (%r) )" ; |
560 | |
561 | // aten::add_ |
562 | std::string inplace_add = R"( |
563 | graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): |
564 | %a_dequant = aten::dequantize(%a_quant) |
565 | %b_dequant = aten::dequantize(%b_quant) |
566 | %r_add = aten::add_(%a_dequant, %b_dequant, %alpha) |
567 | %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype) |
568 | return (%r) )" ; |
569 | |
570 | auto add_scalar = getBinaryOpScalarFusionInfo( |
571 | "aten::add" , |
572 | {"%b_scalar" , "%alpha" }, |
573 | "quantized::add_scalar" , |
574 | {"%b_scalar" }, |
575 | {aten_add_alpha_is_one, input_b_is_scalar}); |
576 | |
577 | auto add_scalar_out = getBinaryOpScalarFusionInfo( |
578 | "aten::add_" , |
579 | {"%b_scalar" , "%alpha" }, |
580 | "quantized::add_scalar_out" , |
581 | {"%b_scalar" , "%a_quant" }, |
582 | {aten_add_alpha_is_one, input_b_is_scalar}); |
583 | |
584 | // quantized::add_scalar_relu -- fusing quantized::add_scalar |
585 | // and aten::relu |
586 | auto quantized_add_scalar_relu_pattern = R"( |
587 | graph(%a_quant, %b_scalar): |
588 | %r_add = quantized::add_scalar(%a_quant, %b_scalar) |
589 | %r = aten::relu(%r_add) |
590 | return (%r) )" ; |
591 | |
592 | auto quantized_add_scalar_inplace_relu_pattern = R"( |
593 | graph(%a_quant, %b_scalar): |
594 | %r_add = quantized::add_scalar(%a_quant, %b_scalar) |
595 | %r = aten::relu_(%r_add) |
596 | return (%r) )" ; |
597 | |
598 | auto quantized_add_scalar_relu_replacement = R"( |
599 | graph(%a_quant, %b_scalar): |
600 | %r = quantized::add_scalar_relu(%a_quant, %b_scalar) |
601 | return (%r) )" ; |
602 | |
603 | // quantized::add_scalar_relu_out -- fusing quantized::add_scalarOut |
604 | // and aten::relu |
605 | auto quantized_add_scalar_relu_out_pattern = R"( |
606 | graph(%a_quant, %b_scalar): |
607 | %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant) |
608 | %r = aten::relu(%r_add) |
609 | return (%r) )" ; |
610 | |
611 | auto quantized_add_scalar_inplace_relu_out_pattern = R"( |
612 | graph(%a_quant, %b_scalar): |
613 | %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant) |
614 | %r = aten::relu_(%r_add) |
615 | return (%r) )" ; |
616 | |
617 | auto quantized_add_scalar_relu_out_replacement = R"( |
618 | graph(%a_quant, %b_scalar): |
619 | %r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant) |
620 | return (%r) )" ; |
621 | |
622 | // quantized::batch_norm |
623 | std::string batch_norm = R"( |
624 | graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type): |
625 | %a_dequant = aten::dequantize(%a_quant) |
626 | %r_bn = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7) |
627 | %r = aten::quantize_per_tensor(%r_bn, %scale, %zero_point, %scalar_type) |
628 | return (%r) )" ; |
629 | std::string quantized_batch_norm = R"( |
630 | graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type): |
631 | %r = quantized::batch_norm(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point) |
632 | return (%r) )" ; |
633 | |
634 | std::string batch_norm_relu = R"( |
635 | graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type): |
636 | %a_dequant = aten::dequantize(%a_quant) |
637 | %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7) |
638 | %relu = aten::relu(%bn_out) |
639 | %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type) |
640 | return (%r) )" ; |
641 | std::string batch_norm_inplace_relu = R"( |
642 | graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type): |
643 | %a_dequant = aten::dequantize(%a_quant) |
644 | %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7) |
645 | %relu = aten::relu_(%bn_out) |
646 | %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type) |
647 | return (%r) )" ; |
648 | |
649 | std::string quantized_batch_norm_relu = R"( |
650 | graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type): |
651 | %r = quantized::batch_norm_relu(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point) |
652 | return (%r) )" ; |
653 | |
654 | // aten::mul |
655 | std::string mul = R"( |
656 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
657 | %a_dequant = aten::dequantize(%a_quant) |
658 | %b_dequant = aten::dequantize(%b_quant) |
659 | %r_mul = aten::mul(%a_dequant, %b_dequant) |
660 | %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype) |
661 | return (%r) )" ; |
662 | |
663 | // aten::mul_ |
664 | std::string inplace_mul = R"( |
665 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
666 | %a_dequant = aten::dequantize(%a_quant) |
667 | %b_dequant = aten::dequantize(%b_quant) |
668 | %r_mul = aten::mul_(%a_dequant, %b_dequant) |
669 | %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype) |
670 | return (%r) )" ; |
671 | |
672 | // quantized::mul |
673 | std::string quantized_mul = R"( |
674 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
675 | %r = quantized::mul(%a_quant, %b_quant, %scale, %zero_point) |
676 | return (%r) )" ; |
677 | |
678 | auto mul_scalar = getBinaryOpScalarFusionInfo( |
679 | "aten::mul" , |
680 | {"%b_scalar" }, |
681 | "quantized::mul_scalar" , |
682 | {"%b_scalar" }, |
683 | {input_b_is_scalar}); |
684 | |
685 | auto mul_scalar_out = getBinaryOpScalarFusionInfo( |
686 | "aten::mul_" , |
687 | {"%b_scalar" }, |
688 | "quantized::mul_scalar_out" , |
689 | {"%b_scalar" , "%a_quant" }, |
690 | {input_b_is_scalar}); |
691 | |
692 | // quantized::mul_relu |
693 | std::string mul_relu = R"( |
694 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
695 | %a_dequant = aten::dequantize(%a_quant) |
696 | %b_dequant = aten::dequantize(%b_quant) |
697 | %r_mul = aten::mul(%a_dequant, %b_dequant) |
698 | %r_relu = aten::relu(%r_mul) |
699 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
700 | return (%r) )" ; |
701 | |
702 | std::string mul_inplace_relu = R"( |
703 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
704 | %a_dequant = aten::dequantize(%a_quant) |
705 | %b_dequant = aten::dequantize(%b_quant) |
706 | %r_mul = aten::mul(%a_dequant, %b_dequant) |
707 | %r_relu = aten::relu_(%r_mul) |
708 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
709 | return (%r) )" ; |
710 | |
711 | std::string inplace_mul_relu = R"( |
712 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
713 | %a_dequant = aten::dequantize(%a_quant) |
714 | %b_dequant = aten::dequantize(%b_quant) |
715 | %r_mul = aten::mul_(%a_dequant, %b_dequant) |
716 | %r_relu = aten::relu(%r_mul) |
717 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
718 | return (%r) )" ; |
719 | |
720 | std::string inplace_mul_inplace_relu = R"( |
721 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
722 | %a_dequant = aten::dequantize(%a_quant) |
723 | %b_dequant = aten::dequantize(%b_quant) |
724 | %r_mul = aten::mul_(%a_dequant, %b_dequant) |
725 | %r_relu = aten::relu_(%r_mul) |
726 | %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype) |
727 | return (%r) )" ; |
728 | |
729 | std::string quantized_mul_relu = R"( |
730 | graph(%a_quant, %b_quant, %scale, %zero_point, %dtype): |
731 | %r = quantized::mul_relu(%a_quant, %b_quant, %scale, %zero_point) |
732 | return (%r) )" ; |
733 | |
734 | // quantized::mul_scalar_relu -- fusing quantized::mul_scalar |
735 | // and aten::relu |
736 | auto quantized_mul_scalar_relu_pattern = R"( |
737 | graph(%a_quant, %b_scalar): |
738 | %r_mul = quantized::mul_scalar(%a_quant, %b_scalar) |
739 | %r = aten::relu(%r_mul) |
740 | return (%r) )" ; |
741 | |
742 | auto quantized_mul_scalar_inplace_relu_pattern = R"( |
743 | graph(%a_quant, %b_scalar): |
744 | %r_mul = quantized::mul_scalar(%a_quant, %b_scalar) |
745 | %r = aten::relu_(%r_mul) |
746 | return (%r) )" ; |
747 | |
748 | auto quantized_mul_scalar_relu_replacement = R"( |
749 | graph(%a_quant, %b_scalar): |
750 | %r = quantized::mul_scalar_relu(%a_quant, %b_scalar) |
751 | return (%r) )" ; |
752 | |
753 | // quantized::mul_scalar_relu_out -- fusing quantized::mul_scalarOut |
754 | // and aten::relu |
755 | auto quantized_mul_scalar_relu_out_pattern = R"( |
756 | graph(%a_quant, %b_scalar): |
757 | %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant) |
758 | %r = aten::relu(%r_mul) |
759 | return (%r) )" ; |
760 | |
761 | auto quantized_mul_scalar_inplace_relu_out_pattern = R"( |
762 | graph(%a_quant, %b_scalar): |
763 | %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant) |
764 | %r = aten::relu_(%r_mul) |
765 | return (%r) )" ; |
766 | |
767 | auto quantized_mul_scalar_relu_out_replacement = R"( |
768 | graph(%a_quant, %b_scalar): |
769 | %r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant) |
770 | return (%r) )" ; |
771 | |
772 | // quantized::elu |
773 | std::string elu = R"( |
774 | graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype): |
775 | %a_dequant = aten::dequantize(%a_quant) |
776 | %r = aten::elu(%a_dequant, %alpha, %scale, %input_scale) |
777 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
778 | return (%r_quant) )" ; |
779 | |
780 | std::string quantized_elu = R"( |
781 | graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype): |
782 | %r_quant = quantized::elu(%a_quant, %r_scale, %r_zero_point, %alpha, %scale, %input_scale) |
783 | return (%r_quant) )" ; |
784 | |
785 | std::string elu_ = R"( |
786 | graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype): |
787 | %a_dequant = aten::dequantize(%a_quant) |
788 | %r = aten::elu_(%a_dequant, %alpha, %scale, %input_scale) |
789 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
790 | return (%r_quant) )" ; |
791 | |
792 | // ============= General Ops that inherit quantization paramters from input |
793 | // tensor ============= |
794 | auto avg_pool1d = getInputTensorQParamOpFusionInfo( |
795 | "aten::avg_pool1d" , |
796 | {"%kernel_size" , |
797 | "%stride" , |
798 | "%padding" , |
799 | "%ceil_mode" , |
800 | "%count_include_pad" }); |
801 | |
802 | auto avg_pool2d = getInputTensorQParamOpFusionInfo( |
803 | "aten::avg_pool2d" , |
804 | {"%kernel_size" , |
805 | "%stride" , |
806 | "%padding" , |
807 | "%ceil_mode" , |
808 | "%count_include_pad" , |
809 | "%divisor_override" }); |
810 | |
811 | std::string common_general_value_op = R"( |
812 | %r_scale : float = aten::q_scale(%a_quant) |
813 | %r_zero_point : int = aten::q_zero_point(%a_quant) |
814 | %r_dtype : int = prim::dtype(%a_quant) |
815 | %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) |
816 | return (%r_quant) )" ; |
817 | |
818 | auto avg_pool3d = getInputTensorQParamOpFusionInfo( |
819 | "aten::avg_pool3d" , |
820 | {"%kernel_size" , |
821 | "%stride" , |
822 | "%padding" , |
823 | "%ceil_mode" , |
824 | "%count_include_pad" , |
825 | "%divisor_override" }); |
826 | |
827 | auto adaptive_avg_pool1d = getInputTensorQParamOpFusionInfo( |
828 | "aten::adaptive_avg_pool1d" , {"%output_size" }); |
829 | |
830 | auto adaptive_avg_pool2d = getInputTensorQParamOpFusionInfo( |
831 | "aten::adaptive_avg_pool2d" , {"%output_size" }); |
832 | |
833 | auto adaptive_avg_pool3d = getInputTensorQParamOpFusionInfo( |
834 | "aten::adaptive_avg_pool3d" , {"%output_size" }); |
835 | |
836 | auto mean1 = getInputTensorQParamOpFusionInfo("aten::mean" , {"%dim" }); |
837 | |
838 | auto mean2 = getInputTensorQParamOpFusionInfo( |
839 | "aten::mean" , {"%dim" , "%keepdim" , "%out" }); |
840 | |
841 | auto upsample_nearest1d_vec = getInputTensorQParamOpFusionInfo( |
842 | "aten::upsample_nearest1d" , {"%output_size" , "%scale_factors" }); |
843 | |
844 | auto upsample_nearest2d_vec = getInputTensorQParamOpFusionInfo( |
845 | "aten::upsample_nearest2d" , {"%output_size" , "%scale_factors" }); |
846 | |
847 | auto upsample_nearest3d_vec = getInputTensorQParamOpFusionInfo( |
848 | "aten::upsample_nearest3d" , {"%output_size" , "%scale_factors" }); |
849 | |
850 | auto upsample_linear1d_vec = getInputTensorQParamOpFusionInfo( |
851 | "aten::upsample_linear1d" , |
852 | {"%output_size" , "%align_corners" , "%scale_factors" }); |
853 | |
854 | auto upsample_bilinear2d_vec = getInputTensorQParamOpFusionInfo( |
855 | "aten::upsample_bilinear2d" , |
856 | {"%output_size" , "%align_corners" , "%scale_factors" }); |
857 | |
858 | auto upsample_trilinear3d_vec = getInputTensorQParamOpFusionInfo( |
859 | "aten::upsample_trilinear3d" , |
860 | {"%output_size" , "%align_corners" , "%scale_factors" }); |
861 | |
862 | auto upsample_nearest1d = getInputTensorQParamOpFusionInfo( |
863 | "aten::upsample_nearest1d" , {"%output_size" , "%scales" }); |
864 | |
865 | auto upsample_nearest2d = getInputTensorQParamOpFusionInfo( |
866 | "aten::upsample_nearest2d" , {"%output_size" , "%scale_h" , "%scale_w" }); |
867 | |
868 | auto upsample_nearest3d = getInputTensorQParamOpFusionInfo( |
869 | "aten::upsample_nearest3d" , |
870 | {"%output_size" , "%scale_d" , "%scale_h" , "%scale_w" }); |
871 | |
872 | auto upsample_linear1d = getInputTensorQParamOpFusionInfo( |
873 | "aten::upsample_linear1d" , {"%output_size" , "%align_corners" , "%scales" }); |
874 | |
875 | auto upsample_bilinear2d = getInputTensorQParamOpFusionInfo( |
876 | "aten::upsample_bilinear2d" , |
877 | {"%output_size" , "%align_corners" , "%scale_h" , "%scale_w" }); |
878 | |
879 | auto upsample_trilinear3d = getInputTensorQParamOpFusionInfo( |
880 | "aten::upsample_trilinear3d" , |
881 | {"%output_size" , "%align_corners" , "%scale_d" , "%scale_h" , "%scale_w" }); |
882 | |
883 | auto clamp = getClampOpFusionInfo("aten::clamp" , {"%min" , "%max" }); |
884 | |
885 | auto hardtanh = getClampOpFusionInfo("aten::hardtanh" , {"%min" , "%max" }); |
886 | |
887 | auto hardtanh_ = getClampOpFusionInfo("aten::hardtanh_" , {"%min" , "%max" }); |
888 | |
889 | auto leaky_relu = |
890 | getInputTensorQParamOpFusionInfo("aten::leaky_relu" , {"%negative_slope" }); |
891 | |
892 | auto leaky_relu_ = getInputTensorQParamOpFusionInfo( |
893 | "aten::leaky_relu_" , {"%negative_slope" }); |
894 | |
895 | // Ops with fixed quantization parameters |
896 | auto hardsigmoid = getFixedQParamOpFusionInfo("aten::hardsigmoid" , {}, false); |
897 | |
898 | auto hardsigmoid_ = |
899 | getFixedQParamOpFusionInfo("aten::hardsigmoid_" , {}, false); |
900 | |
901 | auto sigmoid = getFixedQParamOpFusionInfo("aten::sigmoid" , {}, false); |
902 | |
903 | auto sigmoid_ = getFixedQParamOpFusionInfo("aten::sigmoid_" , {}, false); |
904 | |
905 | auto tanh = getFixedQParamOpFusionInfo("aten::tanh" , {}, true); |
906 | |
907 | auto tanh_ = getFixedQParamOpFusionInfo("aten::tanh_" , {}, true); |
908 | |
909 | auto hardswish = getObservedQParamOpFusionInfo( |
910 | "aten::hardswish" , "quantized::hardswish" , {}, {}); |
911 | |
912 | auto hardswish_ = getObservedQParamOpFusionInfo( |
913 | "aten::hardswish_" , "quantized::hardswish" , {}, {}); |
914 | |
915 | auto layer_norm = getObservedQParamOpFusionInfo( |
916 | "aten::layer_norm" , |
917 | "quantized::layer_norm" , |
918 | {"%normalized_shape" , "%weight" , "%bias" , "%eps" , "%cudnn_enabled" }, |
919 | {"%normalized_shape" , "%weight" , "%bias" , "%eps" }); |
920 | |
921 | auto group_norm = getObservedQParamOpFusionInfo( |
922 | "aten::group_norm" , |
923 | "quantized::group_norm" , |
924 | {"%num_groups" , "%weight" , "%bias" , "%eps" , "%cudnn_enabled" }, |
925 | {"%num_groups" , "%weight" , "%bias" , "%eps" }); |
926 | |
927 | auto instance_norm = getObservedQParamOpFusionInfo( |
928 | "aten::instance_norm" , |
929 | "quantized::instance_norm" , |
930 | {"%weight" , |
931 | "%bias" , |
932 | "%running_mean" , |
933 | "%running_var" , |
934 | "%use_input_stats" , |
935 | "%momentum" , |
936 | "%eps" , |
937 | "%cudnn_enabled" }, |
938 | {"%weight" , "%bias" , "%eps" }); |
939 | |
940 | return { |
941 | {"quantized::conv1d" , std::move(conv1d), std::move(quantized_conv1d)}, |
942 | {"quantized::conv1d_relu" , std::move(conv1d_relu), quantized_conv1d_relu}, |
943 | {"quantized::conv1d_relu" , |
944 | std::move(conv1d_inplace_relu), |
945 | std::move(quantized_conv1d_relu)}, |
946 | {"quantized::conv2d" , std::move(conv2d), std::move(quantized_conv2d)}, |
947 | {"quantized::conv2d_relu" , std::move(conv2d_relu), quantized_conv2d_relu}, |
948 | {"quantized::conv2d_relu" , |
949 | std::move(conv2d_inplace_relu), |
950 | std::move(quantized_conv2d_relu)}, |
951 | {"quantized::conv3d" , std::move(conv3d), std::move(quantized_conv3d)}, |
952 | {"quantized::conv3d_relu" , std::move(conv3d_relu), quantized_conv3d_relu}, |
953 | {"quantized::conv3d_relu" , |
954 | std::move(conv3d_inplace_relu), |
955 | std::move(quantized_conv3d_relu)}, |
956 | {"quantized::conv_transpose1d" , |
957 | std::move(conv_transpose1d), |
958 | std::move(quantized_conv_transpose1d)}, |
959 | {"quantized::conv_transpose2d" , |
960 | std::move(conv_transpose2d), |
961 | std::move(quantized_conv_transpose2d)}, |
962 | {"quantized::linear" , std::move(linear), std::move(quantized_linear)}, |
963 | {"quantized::linear_relu" , std::move(linear_relu), quantized_linear_relu}, |
964 | {"quantized::linear_relu" , |
965 | std::move(linear_inplace_relu), |
966 | std::move(quantized_linear_relu)}, |
967 | {"quantized::add_relu" , |
968 | std::move(add_relu), |
969 | quantized_add_relu, |
970 | {aten_add_alpha_is_one}}, |
971 | {"quantized::add_relu" , |
972 | std::move(add_inplace_relu), |
973 | quantized_add_relu, |
974 | {aten_add_alpha_is_one}}, |
975 | {"quantized::add_relu" , |
976 | std::move(inplace_add_relu), |
977 | quantized_add_relu, |
978 | {aten_add_alpha_is_one}}, |
979 | {"quantized::add_relu" , |
980 | std::move(inplace_add_inplace_relu), |
981 | std::move(quantized_add_relu), |
982 | {aten_add_alpha_is_one}}, |
983 | std::move(add_scalar), |
984 | std::move(add_scalar_out), |
985 | // note that these must come after quantized::add_scalar and |
986 | // quantized::add_scalar_out patterns |
987 | {"quantized::add_scalar_relu" , |
988 | quantized_add_scalar_relu_pattern, |
989 | quantized_add_scalar_relu_replacement}, |
990 | {"quantized::add_scalar_relu" , |
991 | quantized_add_scalar_inplace_relu_pattern, |
992 | quantized_add_scalar_relu_replacement}, |
993 | {"quantized::add_scalar_relu_out" , |
994 | quantized_add_scalar_relu_out_pattern, |
995 | quantized_add_scalar_relu_out_replacement}, |
996 | {"quantized::add_scalar_relu_out" , |
997 | quantized_add_scalar_inplace_relu_out_pattern, |
998 | quantized_add_scalar_relu_out_replacement}, |
999 | {"quantized::add" , |
1000 | std::move(add), |
1001 | quantized_add, |
1002 | {aten_add_alpha_is_one}}, |
1003 | {"quantized::add" , |
1004 | std::move(inplace_add), |
1005 | std::move(quantized_add), |
1006 | {aten_add_alpha_is_one}}, |
1007 | {"quantized::cat" , std::move(cat), std::move(quantized_cat)}, |
1008 | {"quantized::batch_norm" , |
1009 | std::move(batch_norm), |
1010 | std::move(quantized_batch_norm)}, |
1011 | {"quantized::batch_norm_relu" , |
1012 | std::move(batch_norm_relu), |
1013 | quantized_batch_norm_relu}, |
1014 | {"quantized::batch_norm_relu" , |
1015 | std::move(batch_norm_inplace_relu), |
1016 | std::move(quantized_batch_norm_relu)}, |
1017 | std::move(mul_scalar), |
1018 | std::move(mul_scalar_out), |
1019 | // note that these must come after quantized::mul_scalar and |
1020 | // quantized::mul_scalar_out patterns |
1021 | {"quantized::mul_scalar_relu" , |
1022 | quantized_mul_scalar_relu_pattern, |
1023 | quantized_mul_scalar_relu_replacement}, |
1024 | {"quantized::mul_scalar_relu" , |
1025 | quantized_mul_scalar_inplace_relu_pattern, |
1026 | quantized_mul_scalar_relu_replacement}, |
1027 | {"quantized::mul_scalar_relu_out" , |
1028 | quantized_mul_scalar_relu_out_pattern, |
1029 | quantized_mul_scalar_relu_out_replacement}, |
1030 | {"quantized::mul_scalar_relu_out" , |
1031 | quantized_mul_scalar_inplace_relu_out_pattern, |
1032 | quantized_mul_scalar_relu_out_replacement}, |
1033 | {"quantized::mul_relu" , std::move(mul_relu), quantized_mul_relu}, |
1034 | {"quantized::mul_relu" , std::move(mul_inplace_relu), quantized_mul_relu}, |
1035 | {"quantized::mul_relu" , std::move(inplace_mul_relu), quantized_mul_relu}, |
1036 | {"quantized::mul_relu" , |
1037 | std::move(inplace_mul_inplace_relu), |
1038 | std::move(quantized_mul_relu)}, |
1039 | {"quantized::mul" , std::move(mul), quantized_mul}, |
1040 | {"quantized::mul" , std::move(inplace_mul), std::move(quantized_mul)}, |
1041 | std::move(hardswish), |
1042 | std::move(hardswish_), |
1043 | std::move(layer_norm), |
1044 | std::move(group_norm), |
1045 | std::move(instance_norm), |
1046 | {"quantized::elu" , std::move(elu), quantized_elu}, |
1047 | {"quantized::elu_" , std::move(elu_), std::move(quantized_elu)}, |
1048 | std::move(avg_pool1d), |
1049 | std::move(avg_pool2d), |
1050 | std::move(avg_pool3d), |
1051 | std::move(adaptive_avg_pool1d), |
1052 | std::move(adaptive_avg_pool2d), |
1053 | std::move(adaptive_avg_pool3d), |
1054 | std::move(mean1), |
1055 | std::move(mean2), |
1056 | std::move(upsample_nearest1d), |
1057 | std::move(upsample_nearest2d), |
1058 | std::move(upsample_nearest3d), |
1059 | std::move(upsample_linear1d), |
1060 | std::move(upsample_bilinear2d), |
1061 | std::move(upsample_trilinear3d), |
1062 | std::move(upsample_nearest1d_vec), |
1063 | std::move(upsample_nearest2d_vec), |
1064 | std::move(upsample_nearest3d_vec), |
1065 | std::move(upsample_linear1d_vec), |
1066 | std::move(upsample_bilinear2d_vec), |
1067 | std::move(upsample_trilinear3d_vec), |
1068 | std::move(clamp), |
1069 | std::move(hardtanh), |
1070 | std::move(hardtanh_), |
1071 | std::move(leaky_relu), |
1072 | std::move(leaky_relu_), |
1073 | // fixed qparam ops |
1074 | std::move(hardsigmoid), |
1075 | std::move(hardsigmoid_), |
1076 | std::move(sigmoid), |
1077 | std::move(sigmoid_), |
1078 | std::move(tanh), |
1079 | std::move(tanh_), |
1080 | }; |
1081 | } |
1082 | |
1083 | inline std::vector<QuantFusionInfo> |
1084 | dynamic_quantized_linear_pattern_and_replacements() { |
1085 | std::string linear_dynamic = R"( |
1086 | graph(%packed_params, %a): |
1087 | %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params) |
1088 | %w_dequant = aten::dequantize(%w_quant) |
1089 | %r = aten::linear(%a, %w_dequant, %b) |
1090 | return (%r) )" ; |
1091 | |
1092 | // This pattern ignores reduce range |
1093 | // Set the reduce range to default to true, since qnnpack backend ignores this |
1094 | // argument. |
1095 | std::string quantized_linear_dynamic = R"( |
1096 | graph(%packed_params, %a): |
1097 | %reduce_range : bool = prim::Constant[value=1]() |
1098 | %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range) |
1099 | return (%r) )" ; |
1100 | |
1101 | return { |
1102 | {"quantized::linear_dynamic" , |
1103 | std::move(linear_dynamic), |
1104 | std::move(quantized_linear_dynamic)}, |
1105 | }; |
1106 | } |
1107 | |
1108 | std::vector<QuantFusionInfo> dynamic_quant_fusion_pattern_and_replacements() { |
1109 | std::string linear_dynamic = R"( |
1110 | graph(%packed_params, %a, %reduce_range, %a_dtype): |
1111 | %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range) |
1112 | %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype) |
1113 | %a_dequant = aten::dequantize(%a_quant) |
1114 | %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params) |
1115 | %w_dequant = aten::dequantize(%w_quant) |
1116 | %r = aten::linear(%a_dequant, %w_dequant, %b) |
1117 | return (%r) )" ; |
1118 | |
1119 | std::string quantized_linear_dynamic = R"( |
1120 | graph(%packed_params, %a, %reduce_range, %a_dtype): |
1121 | %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range) |
1122 | return (%r) )" ; |
1123 | |
1124 | std::string linear_dynamic_fp16 = R"( |
1125 | graph(%packed_params, %a): |
1126 | %w_unpacked : Tensor, %b : Tensor? = quantized::linear_unpack_fp16(%packed_params) |
1127 | %r = aten::linear(%a, %w_unpacked, %b) |
1128 | return (%r) )" ; |
1129 | |
1130 | std::string quantized_linear_dynamic_fp16 = R"( |
1131 | graph(%packed_params, %a): |
1132 | %r = quantized::linear_dynamic_fp16(%a, %packed_params) |
1133 | return (%r) )" ; |
1134 | |
1135 | return { |
1136 | {"quantized::linear_dynamic" , |
1137 | std::move(linear_dynamic), |
1138 | std::move(quantized_linear_dynamic)}, |
1139 | {"quantized::linear_dynamic_fp16" , |
1140 | std::move(linear_dynamic_fp16), |
1141 | std::move(quantized_linear_dynamic_fp16)}, |
1142 | }; |
1143 | } |
1144 | |
1145 | std::vector<QuantFusionInfo> linear_prepack_unpack_patterns() { |
1146 | std::string linear_with_quant = R"( |
1147 | graph(%a_dequant, %w_quant, %b): |
1148 | %w_dequant = aten::dequantize(%w_quant) |
1149 | %r = aten::linear(%a_dequant, %w_dequant, %b) |
1150 | return (%r) )" ; |
1151 | |
1152 | std::string linear_with_quant_prepack = R"( |
1153 | graph(%a_dequant, %w_quant, %b): |
1154 | %packed_params = quantized::linear_prepack(%w_quant, %b) |
1155 | %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack(%packed_params) |
1156 | %w_dequant = aten::dequantize(%w_quant_unpacked) |
1157 | %r = aten::linear(%a_dequant, %w_dequant, %b_unpacked) |
1158 | return (%r) )" ; |
1159 | std::string linear_fp16_with_cast = R"( |
1160 | graph(%w, %a_dq, %b): |
1161 | %fp16_tensor = aten::_saturate_weight_to_fp16(%w) |
1162 | %r = aten::linear(%a_dq, %fp16_tensor, %b) |
1163 | return (%r) )" ; |
1164 | std::string linear_fp16_with_prepack = R"( |
1165 | graph(%w, %a_dq, %b): |
1166 | %packed_params = quantized::linear_prepack_fp16(%w, %b) |
1167 | %w_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack_fp16(%packed_params) |
1168 | %r = aten::linear(%a_dq, %w_unpacked, %b_unpacked) |
1169 | return (%r) )" ; |
1170 | |
1171 | return { |
1172 | {"linear_prepack_unpack" , |
1173 | std::move(linear_with_quant), |
1174 | std::move(linear_with_quant_prepack)}, |
1175 | {"linear_fp16_prepack_unpack" , |
1176 | std::move(linear_fp16_with_cast), |
1177 | std::move(linear_fp16_with_prepack)}, |
1178 | }; |
1179 | } |
1180 | |
1181 | std::vector<QuantFusionInfo> conv_prepack_unpack_patterns() { |
1182 | std::string conv1d_with_quant = R"( |
1183 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups): |
1184 | %w_dequant = aten::dequantize(%w_quant) |
1185 | %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
1186 | return (%r) )" ; |
1187 | |
1188 | std::string conv1d_with_quant_prepack = R"( |
1189 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups): |
1190 | %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv1d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups) |
1191 | %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv1d_unpack(%packed_params) |
1192 | %w_dequant = aten::dequantize(%w_quant_unpacked) |
1193 | %r = aten::conv1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups) |
1194 | return (%r) )" ; |
1195 | |
1196 | std::string conv2d_with_quant = R"( |
1197 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups): |
1198 | %w_dequant = aten::dequantize(%w_quant) |
1199 | %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
1200 | return (%r) )" ; |
1201 | |
1202 | std::string conv2d_with_quant_prepack = R"( |
1203 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups): |
1204 | %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups) |
1205 | %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv2d_unpack(%packed_params) |
1206 | %w_dequant = aten::dequantize(%w_quant_unpacked) |
1207 | %r = aten::conv2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups) |
1208 | return (%r) )" ; |
1209 | |
1210 | std::string conv3d_with_quant = R"( |
1211 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups): |
1212 | %w_dequant = aten::dequantize(%w_quant) |
1213 | %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) |
1214 | return (%r) )" ; |
1215 | |
1216 | std::string conv3d_with_quant_prepack = R"( |
1217 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups): |
1218 | %packed_params : __torch__.torch.classes.quantized.Conv3dPackedParamsBase = quantized::conv3d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups) |
1219 | %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv3d_unpack(%packed_params) |
1220 | %w_dequant = aten::dequantize(%w_quant_unpacked) |
1221 | %r = aten::conv3d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups) |
1222 | return (%r) )" ; |
1223 | |
1224 | std::string conv_transpose1d_with_quant = R"( |
1225 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): |
1226 | %w_dequant = aten::dequantize(%w_quant) |
1227 | %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) |
1228 | return (%r) )" ; |
1229 | |
1230 | std::string conv_transpose1d_with_quant_prepack = R"( |
1231 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): |
1232 | %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose1d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups) |
1233 | %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose1d_unpack(%packed_params) |
1234 | %w_dequant = aten::dequantize(%w_quant_unpacked) |
1235 | %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation) |
1236 | return (%r) )" ; |
1237 | |
1238 | std::string conv_transpose2d_with_quant = R"( |
1239 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): |
1240 | %w_dequant = aten::dequantize(%w_quant) |
1241 | %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) |
1242 | return (%r) )" ; |
1243 | |
1244 | std::string conv_transpose2d_with_quant_prepack = R"( |
1245 | graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): |
1246 | %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose2d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups) |
1247 | %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose2d_unpack(%packed_params) |
1248 | %w_dequant = aten::dequantize(%w_quant_unpacked) |
1249 | %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation) |
1250 | return (%r) )" ; |
1251 | |
1252 | return { |
1253 | {"conv1d_prepack_unpack" , |
1254 | std::move(conv1d_with_quant), |
1255 | std::move(conv1d_with_quant_prepack)}, |
1256 | {"conv2d_prepack_unpack" , |
1257 | std::move(conv2d_with_quant), |
1258 | std::move(conv2d_with_quant_prepack)}, |
1259 | {"conv3d_prepack_unpack" , |
1260 | std::move(conv3d_with_quant), |
1261 | std::move(conv3d_with_quant_prepack)}, |
1262 | {"conv_transpose1d_prepack_unpack" , |
1263 | std::move(conv_transpose1d_with_quant), |
1264 | std::move(conv_transpose1d_with_quant_prepack)}, |
1265 | {"conv_transpose2d_prepack_unpack" , |
1266 | std::move(conv_transpose2d_with_quant), |
1267 | std::move(conv_transpose2d_with_quant_prepack)}}; |
1268 | } |
1269 | |
1270 | } // namespace jit |
1271 | } // namespace torch |
1272 | |