1#pragma once
2
3#include <c10/util/irange.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/ir/subgraph_matcher.h>
6#include <torch/csrc/jit/jit_log.h>
7#include <torch/csrc/jit/passes/quantization/helper.h>
8#include <torch/csrc/jit/passes/subgraph_rewrite.h>
9#include <string>
10#include <unordered_map>
11#include <utility>
12
13namespace torch {
14namespace jit {
15
16struct QuantFusionInfo {
17 std::string quantized_op_name;
18 std::string pattern;
19 std::string replacement;
20 std::vector<MatchFilter> filters = {};
21};
22
23namespace {
24std::string getExtraArgList(std::vector<std::string> extra_args) {
25 return std::accumulate(
26 extra_args.begin(),
27 extra_args.end(),
28 std::string(),
29 [](std::string acc, const std::string& arg) { return acc + ", " + arg; });
30}
31
32// Get the pattern we want to replace the match with
33std::string getAtenOpPattern(
34 const std::string& graph_header,
35 const std::string& op_name,
36 const std::vector<std::string>& extra_op_args,
37 bool scalar_args = false) {
38 std::vector<std::string> _extra_op_args = extra_op_args;
39 std::string aten_op_pattern = graph_header;
40 if (scalar_args) {
41 for (const auto& extra_arg : _extra_op_args) {
42 aten_op_pattern
43 .append(R"(
44 )")
45 .append(extra_arg)
46 .append("_scalar = aten::item(")
47 .append(extra_arg)
48 .append(")");
49 }
50
51 for (auto& _extra_op_arg : _extra_op_args) {
52 _extra_op_arg.append("_scalar");
53 }
54 }
55 const auto& extra_op_arg_list = getExtraArgList(std::move(_extra_op_args));
56 aten_op_pattern += R"(
57 %r = )";
58 aten_op_pattern += op_name + "(" + "%a_quant" + extra_op_arg_list + ")";
59 aten_op_pattern += R"(
60 return (%r) )";
61 return aten_op_pattern;
62}
63
64// generate ops for quantize pattern for a scalar value
65std::string getQuantizeForScalar(const std::string& value) {
66 // 6 is `torch.float` ScalarType, we are creating a float scalar
67 // tensor from a scalar value
68 std::string quantize_pattern = R"(
69 )" +
70 value + "_float_scalar_type : int = prim::Constant[value=6]()";
71 quantize_pattern += R"(
72 )" +
73 value + "_none : None = prim::Constant()";
74 quantize_pattern += R"(
75 )" +
76 value + "_tensor : Tensor = aten::scalar_tensor(" + value + ", " + value +
77 "_float_scalar_type";
78 for (const auto i : c10::irange(3)) {
79 (void)i; // Suppress unused variable warning
80 quantize_pattern += ", " + value + "_none";
81 }
82 quantize_pattern += ")";
83 quantize_pattern +=
84 R"(
85 )" +
86 value + "_quant = aten::quantize_per_tensor(" + value + "_tensor" +
87 getExtraArgList(
88 {value + "_scale", value + "_zero_point", value + "_dtype"}) +
89 ")";
90 return quantize_pattern;
91}
92
93std::string getDequantize(const std::string& value) {
94 return R"(
95 )" +
96 value + "_dequant = aten::dequantize(" + value + "_quant)";
97}
98
99std::string getItem(const std::string& value) {
100 return R"(
101 )" +
102 value + "_scalar : float = aten::item(" + value + "_dequant)";
103}
104
105// Patterns for the ops that inherit parameters from input
106std::string getInputTensorQParamOpPattern(
107 const std::string& op_name,
108 const std::vector<std::string>& extra_op_args) {
109 const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
110 std::string op_pattern = "graph(%a_quant" + extra_op_arg_list + "):" + R"(
111 %a_dequant = aten::dequantize(%a_quant)
112 %r = )" +
113 op_name + "(" + "%a_dequant" + extra_op_arg_list + ")" + R"(
114 %r_scale : float = aten::q_scale(%a_quant)
115 %r_zero_point : int = aten::q_zero_point(%a_quant)
116 %r_dtype : int = prim::dtype(%a_quant)
117 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
118 return (%r_quant) )";
119 return op_pattern;
120}
121
122// QuantFusionInfo for the ops that inherit parameters from input
123QuantFusionInfo getInputTensorQParamOpFusionInfo(
124 const std::string& op_name,
125 const std::vector<std::string>& extra_op_args) {
126 std::string op_pattern =
127 getInputTensorQParamOpPattern(op_name, extra_op_args);
128 const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
129 std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
130 std::string op_replacement =
131 getAtenOpPattern(graph_header, op_name, extra_op_args);
132
133 return {op_name, std::move(op_pattern), std::move(op_replacement)};
134}
135
136// quant fusion for ops like `quantized::add_scalar`, `quantized::mul_scalar`
137QuantFusionInfo getBinaryOpScalarFusionInfo(
138 const std::string& op_name,
139 const std::vector<std::string>& extra_op_args,
140 const std::string& quantized_op_name,
141 const std::vector<std::string>& extra_quantized_op_args,
142 const std::vector<MatchFilter>& filters = {}) {
143 std::string op_pattern =
144 getInputTensorQParamOpPattern(op_name, extra_op_args);
145
146 const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
147 std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
148 std::string op_replacement = getAtenOpPattern(
149 graph_header, quantized_op_name, extra_quantized_op_args);
150
151 return {op_name, std::move(op_pattern), std::move(op_replacement), filters};
152}
153
154QuantFusionInfo getClampOpFusionInfo(
155 const std::string& op_name,
156 const std::vector<std::string>& extra_op_args) {
157 std::vector<std::string> header_args = extra_op_args;
158 std::vector<std::string> input_qparams = {"_scale", "_zero_point", "_dtype"};
159 for (const auto& arg : extra_op_args) {
160 for (const auto& qparam : input_qparams) {
161 header_args.push_back(arg + qparam);
162 }
163 }
164 for (const auto& qparam : input_qparams) {
165 header_args.push_back("%r" + qparam);
166 }
167 const auto& extra_header_arg_list = getExtraArgList(std::move(header_args));
168 std::string graph_header = "graph(%a_quant" + extra_header_arg_list + "):";
169 std::string op_pattern = graph_header;
170 for (const auto& arg : extra_op_args) {
171 op_pattern += getQuantizeForScalar(arg);
172 op_pattern += getDequantize(arg);
173 op_pattern += getItem(arg);
174 }
175 op_pattern += getDequantize("%a");
176 op_pattern += R"(
177 %r = )";
178 std::vector<std::string> scalar_extra_args;
179 scalar_extra_args.reserve(extra_op_args.size());
180 for (const auto& arg : extra_op_args) {
181 scalar_extra_args.push_back(arg + "_scalar");
182 }
183 op_pattern += op_name + "(" + "%a_dequant" +
184 getExtraArgList(std::move(scalar_extra_args)) + ")";
185 // IR pattern common to all ops that inherit qparam from input
186 op_pattern += R"(
187 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
188 return (%r_quant) )";
189
190 std::string aten_op_pattern =
191 getAtenOpPattern(graph_header, op_name, extra_op_args);
192
193 return {op_name, std::move(op_pattern), std::move(aten_op_pattern)};
194}
195
196// Patterns for the ops that has fixed quantization parameters
197QuantFusionInfo getFixedQParamOpFusionInfo(
198 const std::string& op_name,
199 const std::vector<std::string>& extra_op_args,
200 bool is_symmetric) {
201 const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
202 std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
203 std::string op_pattern = graph_header;
204 op_pattern += R"(
205 %a_dequant = aten::dequantize(%a_quant)
206 %r = )";
207 op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")";
208 // IR pattern common to all ops with fixed quantization parameters for
209 // asymetric quantization
210 std::string asym_fixed_qparam_op_suffix = R"(
211 %r_scale : float = prim::Constant[value=0.00390625]()
212 %r_zero_point : int = prim::Constant[value=0]()
213 %r_dtype : int = prim::Constant[value=13]()
214 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
215 return (%r_quant) )";
216
217 std::string sym_fixed_qparam_op_suffix = R"(
218 %r_scale : float = prim::Constant[value=0.0078125]()
219 %r_zero_point : int = prim::Constant[value=128]()
220 %r_dtype : int = prim::Constant[value=13]()
221 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
222 return (%r_quant) )";
223 op_pattern +=
224 is_symmetric ? sym_fixed_qparam_op_suffix : asym_fixed_qparam_op_suffix;
225
226 std::string aten_op_pattern =
227 getAtenOpPattern(graph_header, op_name, extra_op_args);
228
229 return {op_name, std::move(op_pattern), std::move(aten_op_pattern)};
230}
231
232// filter that checks %b_scalar is a scalar
233bool input_b_is_scalar(
234 const Match& match,
235 const std::unordered_map<std::string, Value*>& vmap) {
236 const auto& match_vmap = match.values_map;
237 auto b_scalar = match_vmap.at(vmap.at("b_scalar"));
238 return isScalar(b_scalar);
239}
240
241// Patterns for ops that require observation for output quantization parameters
242// Example:
243//
244// before fusion:
245//
246// graph(%a_quant, %r_scale, %r_zero_point, %r_dtype):
247// %a_dequant = aten::dequantize(%a_quant)
248// %r = {op_name}(%a_dequant, {extra_args})
249// %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point,
250// %r_dtype) return (%r_quant)
251//
252// after fusion:
253//
254// graph(%a_quant, %r_scale, %r_zero_point, %r_dtype):
255// %r_quant = {quantized_op_name}(%a_quant, {extra_args}, %r_scale,
256// %r_zero_point) return (%r_quant)
257QuantFusionInfo getObservedQParamOpFusionInfo(
258 const std::string& fp_op_name,
259 const std::string& q_op_name,
260 const std::vector<std::string>& fp_extra_args,
261 const std::vector<std::string>& q_extra_args) {
262 const auto& fp_extra_arg_list = getExtraArgList(fp_extra_args);
263 const auto& q_extra_arg_list = getExtraArgList(q_extra_args);
264
265 std::string op_pattern = "graph(%a_quant" + fp_extra_arg_list +
266 ", %r_scale, %r_zero_point, %r_dtype):" + R"(
267 %a_dequant = aten::dequantize(%a_quant)
268 %r = )" +
269 fp_op_name + "(" + "%a_dequant" + fp_extra_arg_list + ")" + R"(
270 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
271 return (%r_quant) )";
272
273 std::string aten_op_pattern = "graph(%a_quant" + fp_extra_arg_list +
274 ", %r_scale, %r_zero_point, %r_dtype):" + R"(
275 %r_quant = )" +
276 q_op_name + "(%a_quant" + q_extra_arg_list +
277 ", %r_scale, %r_zero_point)" + R"(
278 return (%r_quant) )";
279
280 return {q_op_name, std::move(op_pattern), std::move(aten_op_pattern)};
281}
282
283} // namespace
284
285std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() {
286 // aten::conv1d
287 std::string conv1d = R"(
288graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
289 %a_dequant = aten::dequantize(%a_quant)
290 %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
291 %w_dequant = aten::dequantize(%w_quant)
292 %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
293 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
294 return (%r_quant) )";
295
296 // aten::conv1d - aten::relu
297 std::string conv1d_relu = R"(
298graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
299 %a_dequant = aten::dequantize(%a_quant)
300 %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
301 %w_dequant = aten::dequantize(%w_quant)
302 %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
303 %r = aten::relu(%conv_out)
304 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
305 return (%r_quant) )";
306
307 // aten::conv1d - aten::relu_
308 std::string conv1d_inplace_relu = R"(
309graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
310 %a_dequant = aten::dequantize(%a_quant)
311 %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
312 %w_dequant = aten::dequantize(%w_quant)
313 %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
314 %r = aten::relu_(%conv_out)
315 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
316 return (%r_quant) )";
317
318 // quantized::conv1d
319 std::string quantized_conv1d = R"(
320graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
321 %r_quant = quantized::conv1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
322 return (%r_quant) )";
323
324 // quantized::conv1d_relu
325 std::string quantized_conv1d_relu = R"(
326graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
327 %r_quant = quantized::conv1d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
328 return (%r_quant) )";
329
330 // aten::conv2d
331 std::string conv2d = R"(
332graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
333 %a_dequant = aten::dequantize(%a_quant)
334 %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
335 %w_dequant = aten::dequantize(%w_quant)
336 %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
337 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
338 return (%r_quant) )";
339
340 // aten::conv2d - aten::relu
341 std::string conv2d_relu = R"(
342graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
343 %a_dequant = aten::dequantize(%a_quant)
344 %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
345 %w_dequant = aten::dequantize(%w_quant)
346 %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
347 %r = aten::relu(%conv_out)
348 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
349 return (%r_quant) )";
350
351 // aten::conv2d - aten::relu_
352 std::string conv2d_inplace_relu = R"(
353graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
354 %a_dequant = aten::dequantize(%a_quant)
355 %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
356 %w_dequant = aten::dequantize(%w_quant)
357 %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
358 %r = aten::relu_(%conv_out)
359 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
360 return (%r_quant) )";
361
362 // quantized::conv2d
363 std::string quantized_conv2d = R"(
364graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
365 %r_quant = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
366 return (%r_quant) )";
367
368 // quantized::conv2d_relu
369 std::string quantized_conv2d_relu = R"(
370graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
371 %r_quant = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
372 return (%r_quant) )";
373
374 // aten::conv3d
375 std::string conv3d = R"(
376graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
377 %a_dequant = aten::dequantize(%a_quant)
378 %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
379 %w_dequant = aten::dequantize(%w_quant)
380 %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
381 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
382 return (%r_quant) )";
383
384 // aten::conv3d - aten::relu
385 std::string conv3d_relu = R"(
386graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
387 %a_dequant = aten::dequantize(%a_quant)
388 %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
389 %w_dequant = aten::dequantize(%w_quant)
390 %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
391 %r = aten::relu(%conv_out)
392 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
393 return (%r_quant) )";
394
395 // aten::conv3d - aten::relu_
396 std::string conv3d_inplace_relu = R"(
397graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
398 %a_dequant = aten::dequantize(%a_quant)
399 %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
400 %w_dequant = aten::dequantize(%w_quant)
401 %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
402 %r = aten::relu_(%conv_out)
403 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
404 return (%r_quant) )";
405
406 // quantized::conv3d
407 std::string quantized_conv3d = R"(
408graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
409 %r_quant = quantized::conv3d(%a_quant, %packed_params, %r_scale, %r_zero_point)
410 return (%r_quant) )";
411
412 // quantized::conv3d_relu
413 std::string quantized_conv3d_relu = R"(
414graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
415 %r_quant = quantized::conv3d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
416 return (%r_quant) )";
417
418 // aten::conv_transpose1d
419 std::string conv_transpose1d = R"(
420graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
421 %a_dequant = aten::dequantize(%a_quant)
422 %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
423 %w_dequant = aten::dequantize(%w_quant)
424 %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
425 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
426 return (%r_quant) )";
427
428 // quantized::conv_transpose1d
429 std::string quantized_conv_transpose1d = R"(
430graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
431 %r_quant = quantized::conv_transpose1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
432 return (%r_quant) )";
433
434 // aten::conv_transpose2d
435 std::string conv_transpose2d = R"(
436graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
437 %a_dequant = aten::dequantize(%a_quant)
438 %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
439 %w_dequant = aten::dequantize(%w_quant)
440 %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
441 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
442 return (%r_quant) )";
443
444 // quantized::conv_transpose1d
445 std::string quantized_conv_transpose2d = R"(
446graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
447 %r_quant = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
448 return (%r_quant) )";
449
450 std::string add_relu = R"(
451graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
452 %a_dequant = aten::dequantize(%a_quant)
453 %b_dequant = aten::dequantize(%b_quant)
454 %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
455 %r_relu = aten::relu(%r_add)
456 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
457 return (%r) )";
458
459 std::string add_inplace_relu = R"(
460graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
461 %a_dequant = aten::dequantize(%a_quant)
462 %b_dequant = aten::dequantize(%b_quant)
463 %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
464 %r_relu = aten::relu_(%r_add)
465 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
466 return (%r) )";
467
468 std::string inplace_add_relu = R"(
469graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
470 %a_dequant = aten::dequantize(%a_quant)
471 %b_dequant = aten::dequantize(%b_quant)
472 %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
473 %r_relu = aten::relu(%r_add)
474 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
475 return (%r) )";
476
477 std::string inplace_add_inplace_relu = R"(
478graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
479 %a_dequant = aten::dequantize(%a_quant)
480 %b_dequant = aten::dequantize(%b_quant)
481 %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
482 %r_relu = aten::relu_(%r_add)
483 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
484 return (%r) )";
485
486 std::string quantized_add_relu = R"(
487graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
488 %r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point)
489 return (%r) )";
490
491 // aten::linear
492 std::string linear = R"(
493graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
494 %a_dequant = aten::dequantize(%a_quant)
495 %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
496 %w_dequant = aten::dequantize(%w_quant)
497 %r = aten::linear(%a_dequant, %w_dequant, %b)
498 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
499 return (%r_quant) )";
500
501 std::string linear_relu = R"(
502graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
503 %a_dequant = aten::dequantize(%a_quant)
504 %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
505 %w_dequant = aten::dequantize(%w_quant)
506 %linear_out = aten::linear(%a_dequant, %w_dequant, %b)
507 %r = aten::relu(%linear_out)
508 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
509 return (%r_quant) )";
510
511 std::string linear_inplace_relu = R"(
512graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
513 %a_dequant = aten::dequantize(%a_quant)
514 %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
515 %w_dequant = aten::dequantize(%w_quant)
516 %linear_out = aten::linear(%a_dequant, %w_dequant, %b)
517 %r = aten::relu_(%linear_out)
518 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
519 return (%r_quant) )";
520
521 // quantized::linear
522 std::string quantized_linear = R"(
523graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
524 %r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
525 return (%r) )";
526
527 std::string quantized_linear_relu = R"(
528graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
529 %r = quantized::linear_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
530 return (%r) )";
531
532 std::string cat = R"(
533graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype):
534 %input_dequant = aten::dequantize(%input_quant)
535 %r = aten::cat(%input_dequant, %dim)
536 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
537 return (%r_quant) )";
538
539 std::string quantized_cat = R"(
540graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype):
541 %r_quant = quantized::cat(%input_quant, %dim, %r_scale, %r_zero_point)
542 return (%r_quant) )";
543
544 // aten::add
545 std::string add = R"(
546graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
547 %a_dequant = aten::dequantize(%a_quant)
548 %b_dequant = aten::dequantize(%b_quant)
549 %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
550 %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
551 return (%r) )";
552
553 // TODO: add %dtype after when https://github.com/pytorch/pytorch/issues/34351
554 // is fixed
555 // quantized::add
556 std::string quantized_add = R"(
557graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
558 %r = quantized::add(%a_quant, %b_quant, %scale, %zero_point)
559 return (%r) )";
560
561 // aten::add_
562 std::string inplace_add = R"(
563graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
564 %a_dequant = aten::dequantize(%a_quant)
565 %b_dequant = aten::dequantize(%b_quant)
566 %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
567 %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
568 return (%r) )";
569
570 auto add_scalar = getBinaryOpScalarFusionInfo(
571 "aten::add",
572 {"%b_scalar", "%alpha"},
573 "quantized::add_scalar",
574 {"%b_scalar"},
575 {aten_add_alpha_is_one, input_b_is_scalar});
576
577 auto add_scalar_out = getBinaryOpScalarFusionInfo(
578 "aten::add_",
579 {"%b_scalar", "%alpha"},
580 "quantized::add_scalar_out",
581 {"%b_scalar", "%a_quant"},
582 {aten_add_alpha_is_one, input_b_is_scalar});
583
584 // quantized::add_scalar_relu -- fusing quantized::add_scalar
585 // and aten::relu
586 auto quantized_add_scalar_relu_pattern = R"(
587graph(%a_quant, %b_scalar):
588 %r_add = quantized::add_scalar(%a_quant, %b_scalar)
589 %r = aten::relu(%r_add)
590 return (%r) )";
591
592 auto quantized_add_scalar_inplace_relu_pattern = R"(
593graph(%a_quant, %b_scalar):
594 %r_add = quantized::add_scalar(%a_quant, %b_scalar)
595 %r = aten::relu_(%r_add)
596 return (%r) )";
597
598 auto quantized_add_scalar_relu_replacement = R"(
599graph(%a_quant, %b_scalar):
600 %r = quantized::add_scalar_relu(%a_quant, %b_scalar)
601 return (%r) )";
602
603 // quantized::add_scalar_relu_out -- fusing quantized::add_scalarOut
604 // and aten::relu
605 auto quantized_add_scalar_relu_out_pattern = R"(
606graph(%a_quant, %b_scalar):
607 %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
608 %r = aten::relu(%r_add)
609 return (%r) )";
610
611 auto quantized_add_scalar_inplace_relu_out_pattern = R"(
612graph(%a_quant, %b_scalar):
613 %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
614 %r = aten::relu_(%r_add)
615 return (%r) )";
616
617 auto quantized_add_scalar_relu_out_replacement = R"(
618graph(%a_quant, %b_scalar):
619 %r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
620 return (%r) )";
621
622 // quantized::batch_norm
623 std::string batch_norm = R"(
624graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
625 %a_dequant = aten::dequantize(%a_quant)
626 %r_bn = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
627 %r = aten::quantize_per_tensor(%r_bn, %scale, %zero_point, %scalar_type)
628 return (%r) )";
629 std::string quantized_batch_norm = R"(
630graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
631 %r = quantized::batch_norm(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point)
632 return (%r) )";
633
634 std::string batch_norm_relu = R"(
635graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
636 %a_dequant = aten::dequantize(%a_quant)
637 %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
638 %relu = aten::relu(%bn_out)
639 %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type)
640 return (%r) )";
641 std::string batch_norm_inplace_relu = R"(
642graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
643 %a_dequant = aten::dequantize(%a_quant)
644 %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
645 %relu = aten::relu_(%bn_out)
646 %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type)
647 return (%r) )";
648
649 std::string quantized_batch_norm_relu = R"(
650graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
651 %r = quantized::batch_norm_relu(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point)
652 return (%r) )";
653
654 // aten::mul
655 std::string mul = R"(
656graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
657 %a_dequant = aten::dequantize(%a_quant)
658 %b_dequant = aten::dequantize(%b_quant)
659 %r_mul = aten::mul(%a_dequant, %b_dequant)
660 %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype)
661 return (%r) )";
662
663 // aten::mul_
664 std::string inplace_mul = R"(
665graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
666 %a_dequant = aten::dequantize(%a_quant)
667 %b_dequant = aten::dequantize(%b_quant)
668 %r_mul = aten::mul_(%a_dequant, %b_dequant)
669 %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype)
670 return (%r) )";
671
672 // quantized::mul
673 std::string quantized_mul = R"(
674graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
675 %r = quantized::mul(%a_quant, %b_quant, %scale, %zero_point)
676 return (%r) )";
677
678 auto mul_scalar = getBinaryOpScalarFusionInfo(
679 "aten::mul",
680 {"%b_scalar"},
681 "quantized::mul_scalar",
682 {"%b_scalar"},
683 {input_b_is_scalar});
684
685 auto mul_scalar_out = getBinaryOpScalarFusionInfo(
686 "aten::mul_",
687 {"%b_scalar"},
688 "quantized::mul_scalar_out",
689 {"%b_scalar", "%a_quant"},
690 {input_b_is_scalar});
691
692 // quantized::mul_relu
693 std::string mul_relu = R"(
694graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
695 %a_dequant = aten::dequantize(%a_quant)
696 %b_dequant = aten::dequantize(%b_quant)
697 %r_mul = aten::mul(%a_dequant, %b_dequant)
698 %r_relu = aten::relu(%r_mul)
699 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
700 return (%r) )";
701
702 std::string mul_inplace_relu = R"(
703graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
704 %a_dequant = aten::dequantize(%a_quant)
705 %b_dequant = aten::dequantize(%b_quant)
706 %r_mul = aten::mul(%a_dequant, %b_dequant)
707 %r_relu = aten::relu_(%r_mul)
708 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
709 return (%r) )";
710
711 std::string inplace_mul_relu = R"(
712graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
713 %a_dequant = aten::dequantize(%a_quant)
714 %b_dequant = aten::dequantize(%b_quant)
715 %r_mul = aten::mul_(%a_dequant, %b_dequant)
716 %r_relu = aten::relu(%r_mul)
717 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
718 return (%r) )";
719
720 std::string inplace_mul_inplace_relu = R"(
721graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
722 %a_dequant = aten::dequantize(%a_quant)
723 %b_dequant = aten::dequantize(%b_quant)
724 %r_mul = aten::mul_(%a_dequant, %b_dequant)
725 %r_relu = aten::relu_(%r_mul)
726 %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
727 return (%r) )";
728
729 std::string quantized_mul_relu = R"(
730graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
731 %r = quantized::mul_relu(%a_quant, %b_quant, %scale, %zero_point)
732 return (%r) )";
733
734 // quantized::mul_scalar_relu -- fusing quantized::mul_scalar
735 // and aten::relu
736 auto quantized_mul_scalar_relu_pattern = R"(
737graph(%a_quant, %b_scalar):
738 %r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
739 %r = aten::relu(%r_mul)
740 return (%r) )";
741
742 auto quantized_mul_scalar_inplace_relu_pattern = R"(
743graph(%a_quant, %b_scalar):
744 %r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
745 %r = aten::relu_(%r_mul)
746 return (%r) )";
747
748 auto quantized_mul_scalar_relu_replacement = R"(
749graph(%a_quant, %b_scalar):
750 %r = quantized::mul_scalar_relu(%a_quant, %b_scalar)
751 return (%r) )";
752
753 // quantized::mul_scalar_relu_out -- fusing quantized::mul_scalarOut
754 // and aten::relu
755 auto quantized_mul_scalar_relu_out_pattern = R"(
756graph(%a_quant, %b_scalar):
757 %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
758 %r = aten::relu(%r_mul)
759 return (%r) )";
760
761 auto quantized_mul_scalar_inplace_relu_out_pattern = R"(
762graph(%a_quant, %b_scalar):
763 %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
764 %r = aten::relu_(%r_mul)
765 return (%r) )";
766
767 auto quantized_mul_scalar_relu_out_replacement = R"(
768graph(%a_quant, %b_scalar):
769 %r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
770 return (%r) )";
771
772 // quantized::elu
773 std::string elu = R"(
774graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
775 %a_dequant = aten::dequantize(%a_quant)
776 %r = aten::elu(%a_dequant, %alpha, %scale, %input_scale)
777 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
778 return (%r_quant) )";
779
780 std::string quantized_elu = R"(
781graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
782 %r_quant = quantized::elu(%a_quant, %r_scale, %r_zero_point, %alpha, %scale, %input_scale)
783 return (%r_quant) )";
784
785 std::string elu_ = R"(
786graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
787 %a_dequant = aten::dequantize(%a_quant)
788 %r = aten::elu_(%a_dequant, %alpha, %scale, %input_scale)
789 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
790 return (%r_quant) )";
791
792 // ============= General Ops that inherit quantization paramters from input
793 // tensor =============
794 auto avg_pool1d = getInputTensorQParamOpFusionInfo(
795 "aten::avg_pool1d",
796 {"%kernel_size",
797 "%stride",
798 "%padding",
799 "%ceil_mode",
800 "%count_include_pad"});
801
802 auto avg_pool2d = getInputTensorQParamOpFusionInfo(
803 "aten::avg_pool2d",
804 {"%kernel_size",
805 "%stride",
806 "%padding",
807 "%ceil_mode",
808 "%count_include_pad",
809 "%divisor_override"});
810
811 std::string common_general_value_op = R"(
812 %r_scale : float = aten::q_scale(%a_quant)
813 %r_zero_point : int = aten::q_zero_point(%a_quant)
814 %r_dtype : int = prim::dtype(%a_quant)
815 %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
816 return (%r_quant) )";
817
818 auto avg_pool3d = getInputTensorQParamOpFusionInfo(
819 "aten::avg_pool3d",
820 {"%kernel_size",
821 "%stride",
822 "%padding",
823 "%ceil_mode",
824 "%count_include_pad",
825 "%divisor_override"});
826
827 auto adaptive_avg_pool1d = getInputTensorQParamOpFusionInfo(
828 "aten::adaptive_avg_pool1d", {"%output_size"});
829
830 auto adaptive_avg_pool2d = getInputTensorQParamOpFusionInfo(
831 "aten::adaptive_avg_pool2d", {"%output_size"});
832
833 auto adaptive_avg_pool3d = getInputTensorQParamOpFusionInfo(
834 "aten::adaptive_avg_pool3d", {"%output_size"});
835
836 auto mean1 = getInputTensorQParamOpFusionInfo("aten::mean", {"%dim"});
837
838 auto mean2 = getInputTensorQParamOpFusionInfo(
839 "aten::mean", {"%dim", "%keepdim", "%out"});
840
841 auto upsample_nearest1d_vec = getInputTensorQParamOpFusionInfo(
842 "aten::upsample_nearest1d", {"%output_size", "%scale_factors"});
843
844 auto upsample_nearest2d_vec = getInputTensorQParamOpFusionInfo(
845 "aten::upsample_nearest2d", {"%output_size", "%scale_factors"});
846
847 auto upsample_nearest3d_vec = getInputTensorQParamOpFusionInfo(
848 "aten::upsample_nearest3d", {"%output_size", "%scale_factors"});
849
850 auto upsample_linear1d_vec = getInputTensorQParamOpFusionInfo(
851 "aten::upsample_linear1d",
852 {"%output_size", "%align_corners", "%scale_factors"});
853
854 auto upsample_bilinear2d_vec = getInputTensorQParamOpFusionInfo(
855 "aten::upsample_bilinear2d",
856 {"%output_size", "%align_corners", "%scale_factors"});
857
858 auto upsample_trilinear3d_vec = getInputTensorQParamOpFusionInfo(
859 "aten::upsample_trilinear3d",
860 {"%output_size", "%align_corners", "%scale_factors"});
861
862 auto upsample_nearest1d = getInputTensorQParamOpFusionInfo(
863 "aten::upsample_nearest1d", {"%output_size", "%scales"});
864
865 auto upsample_nearest2d = getInputTensorQParamOpFusionInfo(
866 "aten::upsample_nearest2d", {"%output_size", "%scale_h", "%scale_w"});
867
868 auto upsample_nearest3d = getInputTensorQParamOpFusionInfo(
869 "aten::upsample_nearest3d",
870 {"%output_size", "%scale_d", "%scale_h", "%scale_w"});
871
872 auto upsample_linear1d = getInputTensorQParamOpFusionInfo(
873 "aten::upsample_linear1d", {"%output_size", "%align_corners", "%scales"});
874
875 auto upsample_bilinear2d = getInputTensorQParamOpFusionInfo(
876 "aten::upsample_bilinear2d",
877 {"%output_size", "%align_corners", "%scale_h", "%scale_w"});
878
879 auto upsample_trilinear3d = getInputTensorQParamOpFusionInfo(
880 "aten::upsample_trilinear3d",
881 {"%output_size", "%align_corners", "%scale_d", "%scale_h", "%scale_w"});
882
883 auto clamp = getClampOpFusionInfo("aten::clamp", {"%min", "%max"});
884
885 auto hardtanh = getClampOpFusionInfo("aten::hardtanh", {"%min", "%max"});
886
887 auto hardtanh_ = getClampOpFusionInfo("aten::hardtanh_", {"%min", "%max"});
888
889 auto leaky_relu =
890 getInputTensorQParamOpFusionInfo("aten::leaky_relu", {"%negative_slope"});
891
892 auto leaky_relu_ = getInputTensorQParamOpFusionInfo(
893 "aten::leaky_relu_", {"%negative_slope"});
894
895 // Ops with fixed quantization parameters
896 auto hardsigmoid = getFixedQParamOpFusionInfo("aten::hardsigmoid", {}, false);
897
898 auto hardsigmoid_ =
899 getFixedQParamOpFusionInfo("aten::hardsigmoid_", {}, false);
900
901 auto sigmoid = getFixedQParamOpFusionInfo("aten::sigmoid", {}, false);
902
903 auto sigmoid_ = getFixedQParamOpFusionInfo("aten::sigmoid_", {}, false);
904
905 auto tanh = getFixedQParamOpFusionInfo("aten::tanh", {}, true);
906
907 auto tanh_ = getFixedQParamOpFusionInfo("aten::tanh_", {}, true);
908
909 auto hardswish = getObservedQParamOpFusionInfo(
910 "aten::hardswish", "quantized::hardswish", {}, {});
911
912 auto hardswish_ = getObservedQParamOpFusionInfo(
913 "aten::hardswish_", "quantized::hardswish", {}, {});
914
915 auto layer_norm = getObservedQParamOpFusionInfo(
916 "aten::layer_norm",
917 "quantized::layer_norm",
918 {"%normalized_shape", "%weight", "%bias", "%eps", "%cudnn_enabled"},
919 {"%normalized_shape", "%weight", "%bias", "%eps"});
920
921 auto group_norm = getObservedQParamOpFusionInfo(
922 "aten::group_norm",
923 "quantized::group_norm",
924 {"%num_groups", "%weight", "%bias", "%eps", "%cudnn_enabled"},
925 {"%num_groups", "%weight", "%bias", "%eps"});
926
927 auto instance_norm = getObservedQParamOpFusionInfo(
928 "aten::instance_norm",
929 "quantized::instance_norm",
930 {"%weight",
931 "%bias",
932 "%running_mean",
933 "%running_var",
934 "%use_input_stats",
935 "%momentum",
936 "%eps",
937 "%cudnn_enabled"},
938 {"%weight", "%bias", "%eps"});
939
940 return {
941 {"quantized::conv1d", std::move(conv1d), std::move(quantized_conv1d)},
942 {"quantized::conv1d_relu", std::move(conv1d_relu), quantized_conv1d_relu},
943 {"quantized::conv1d_relu",
944 std::move(conv1d_inplace_relu),
945 std::move(quantized_conv1d_relu)},
946 {"quantized::conv2d", std::move(conv2d), std::move(quantized_conv2d)},
947 {"quantized::conv2d_relu", std::move(conv2d_relu), quantized_conv2d_relu},
948 {"quantized::conv2d_relu",
949 std::move(conv2d_inplace_relu),
950 std::move(quantized_conv2d_relu)},
951 {"quantized::conv3d", std::move(conv3d), std::move(quantized_conv3d)},
952 {"quantized::conv3d_relu", std::move(conv3d_relu), quantized_conv3d_relu},
953 {"quantized::conv3d_relu",
954 std::move(conv3d_inplace_relu),
955 std::move(quantized_conv3d_relu)},
956 {"quantized::conv_transpose1d",
957 std::move(conv_transpose1d),
958 std::move(quantized_conv_transpose1d)},
959 {"quantized::conv_transpose2d",
960 std::move(conv_transpose2d),
961 std::move(quantized_conv_transpose2d)},
962 {"quantized::linear", std::move(linear), std::move(quantized_linear)},
963 {"quantized::linear_relu", std::move(linear_relu), quantized_linear_relu},
964 {"quantized::linear_relu",
965 std::move(linear_inplace_relu),
966 std::move(quantized_linear_relu)},
967 {"quantized::add_relu",
968 std::move(add_relu),
969 quantized_add_relu,
970 {aten_add_alpha_is_one}},
971 {"quantized::add_relu",
972 std::move(add_inplace_relu),
973 quantized_add_relu,
974 {aten_add_alpha_is_one}},
975 {"quantized::add_relu",
976 std::move(inplace_add_relu),
977 quantized_add_relu,
978 {aten_add_alpha_is_one}},
979 {"quantized::add_relu",
980 std::move(inplace_add_inplace_relu),
981 std::move(quantized_add_relu),
982 {aten_add_alpha_is_one}},
983 std::move(add_scalar),
984 std::move(add_scalar_out),
985 // note that these must come after quantized::add_scalar and
986 // quantized::add_scalar_out patterns
987 {"quantized::add_scalar_relu",
988 quantized_add_scalar_relu_pattern,
989 quantized_add_scalar_relu_replacement},
990 {"quantized::add_scalar_relu",
991 quantized_add_scalar_inplace_relu_pattern,
992 quantized_add_scalar_relu_replacement},
993 {"quantized::add_scalar_relu_out",
994 quantized_add_scalar_relu_out_pattern,
995 quantized_add_scalar_relu_out_replacement},
996 {"quantized::add_scalar_relu_out",
997 quantized_add_scalar_inplace_relu_out_pattern,
998 quantized_add_scalar_relu_out_replacement},
999 {"quantized::add",
1000 std::move(add),
1001 quantized_add,
1002 {aten_add_alpha_is_one}},
1003 {"quantized::add",
1004 std::move(inplace_add),
1005 std::move(quantized_add),
1006 {aten_add_alpha_is_one}},
1007 {"quantized::cat", std::move(cat), std::move(quantized_cat)},
1008 {"quantized::batch_norm",
1009 std::move(batch_norm),
1010 std::move(quantized_batch_norm)},
1011 {"quantized::batch_norm_relu",
1012 std::move(batch_norm_relu),
1013 quantized_batch_norm_relu},
1014 {"quantized::batch_norm_relu",
1015 std::move(batch_norm_inplace_relu),
1016 std::move(quantized_batch_norm_relu)},
1017 std::move(mul_scalar),
1018 std::move(mul_scalar_out),
1019 // note that these must come after quantized::mul_scalar and
1020 // quantized::mul_scalar_out patterns
1021 {"quantized::mul_scalar_relu",
1022 quantized_mul_scalar_relu_pattern,
1023 quantized_mul_scalar_relu_replacement},
1024 {"quantized::mul_scalar_relu",
1025 quantized_mul_scalar_inplace_relu_pattern,
1026 quantized_mul_scalar_relu_replacement},
1027 {"quantized::mul_scalar_relu_out",
1028 quantized_mul_scalar_relu_out_pattern,
1029 quantized_mul_scalar_relu_out_replacement},
1030 {"quantized::mul_scalar_relu_out",
1031 quantized_mul_scalar_inplace_relu_out_pattern,
1032 quantized_mul_scalar_relu_out_replacement},
1033 {"quantized::mul_relu", std::move(mul_relu), quantized_mul_relu},
1034 {"quantized::mul_relu", std::move(mul_inplace_relu), quantized_mul_relu},
1035 {"quantized::mul_relu", std::move(inplace_mul_relu), quantized_mul_relu},
1036 {"quantized::mul_relu",
1037 std::move(inplace_mul_inplace_relu),
1038 std::move(quantized_mul_relu)},
1039 {"quantized::mul", std::move(mul), quantized_mul},
1040 {"quantized::mul", std::move(inplace_mul), std::move(quantized_mul)},
1041 std::move(hardswish),
1042 std::move(hardswish_),
1043 std::move(layer_norm),
1044 std::move(group_norm),
1045 std::move(instance_norm),
1046 {"quantized::elu", std::move(elu), quantized_elu},
1047 {"quantized::elu_", std::move(elu_), std::move(quantized_elu)},
1048 std::move(avg_pool1d),
1049 std::move(avg_pool2d),
1050 std::move(avg_pool3d),
1051 std::move(adaptive_avg_pool1d),
1052 std::move(adaptive_avg_pool2d),
1053 std::move(adaptive_avg_pool3d),
1054 std::move(mean1),
1055 std::move(mean2),
1056 std::move(upsample_nearest1d),
1057 std::move(upsample_nearest2d),
1058 std::move(upsample_nearest3d),
1059 std::move(upsample_linear1d),
1060 std::move(upsample_bilinear2d),
1061 std::move(upsample_trilinear3d),
1062 std::move(upsample_nearest1d_vec),
1063 std::move(upsample_nearest2d_vec),
1064 std::move(upsample_nearest3d_vec),
1065 std::move(upsample_linear1d_vec),
1066 std::move(upsample_bilinear2d_vec),
1067 std::move(upsample_trilinear3d_vec),
1068 std::move(clamp),
1069 std::move(hardtanh),
1070 std::move(hardtanh_),
1071 std::move(leaky_relu),
1072 std::move(leaky_relu_),
1073 // fixed qparam ops
1074 std::move(hardsigmoid),
1075 std::move(hardsigmoid_),
1076 std::move(sigmoid),
1077 std::move(sigmoid_),
1078 std::move(tanh),
1079 std::move(tanh_),
1080 };
1081}
1082
1083inline std::vector<QuantFusionInfo>
1084dynamic_quantized_linear_pattern_and_replacements() {
1085 std::string linear_dynamic = R"(
1086graph(%packed_params, %a):
1087 %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
1088 %w_dequant = aten::dequantize(%w_quant)
1089 %r = aten::linear(%a, %w_dequant, %b)
1090 return (%r) )";
1091
1092 // This pattern ignores reduce range
1093 // Set the reduce range to default to true, since qnnpack backend ignores this
1094 // argument.
1095 std::string quantized_linear_dynamic = R"(
1096graph(%packed_params, %a):
1097 %reduce_range : bool = prim::Constant[value=1]()
1098 %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range)
1099 return (%r) )";
1100
1101 return {
1102 {"quantized::linear_dynamic",
1103 std::move(linear_dynamic),
1104 std::move(quantized_linear_dynamic)},
1105 };
1106}
1107
1108std::vector<QuantFusionInfo> dynamic_quant_fusion_pattern_and_replacements() {
1109 std::string linear_dynamic = R"(
1110graph(%packed_params, %a, %reduce_range, %a_dtype):
1111 %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
1112 %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
1113 %a_dequant = aten::dequantize(%a_quant)
1114 %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
1115 %w_dequant = aten::dequantize(%w_quant)
1116 %r = aten::linear(%a_dequant, %w_dequant, %b)
1117 return (%r) )";
1118
1119 std::string quantized_linear_dynamic = R"(
1120graph(%packed_params, %a, %reduce_range, %a_dtype):
1121 %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range)
1122 return (%r) )";
1123
1124 std::string linear_dynamic_fp16 = R"(
1125graph(%packed_params, %a):
1126 %w_unpacked : Tensor, %b : Tensor? = quantized::linear_unpack_fp16(%packed_params)
1127 %r = aten::linear(%a, %w_unpacked, %b)
1128 return (%r) )";
1129
1130 std::string quantized_linear_dynamic_fp16 = R"(
1131graph(%packed_params, %a):
1132 %r = quantized::linear_dynamic_fp16(%a, %packed_params)
1133 return (%r) )";
1134
1135 return {
1136 {"quantized::linear_dynamic",
1137 std::move(linear_dynamic),
1138 std::move(quantized_linear_dynamic)},
1139 {"quantized::linear_dynamic_fp16",
1140 std::move(linear_dynamic_fp16),
1141 std::move(quantized_linear_dynamic_fp16)},
1142 };
1143}
1144
1145std::vector<QuantFusionInfo> linear_prepack_unpack_patterns() {
1146 std::string linear_with_quant = R"(
1147graph(%a_dequant, %w_quant, %b):
1148 %w_dequant = aten::dequantize(%w_quant)
1149 %r = aten::linear(%a_dequant, %w_dequant, %b)
1150 return (%r) )";
1151
1152 std::string linear_with_quant_prepack = R"(
1153graph(%a_dequant, %w_quant, %b):
1154 %packed_params = quantized::linear_prepack(%w_quant, %b)
1155 %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack(%packed_params)
1156 %w_dequant = aten::dequantize(%w_quant_unpacked)
1157 %r = aten::linear(%a_dequant, %w_dequant, %b_unpacked)
1158 return (%r) )";
1159 std::string linear_fp16_with_cast = R"(
1160graph(%w, %a_dq, %b):
1161 %fp16_tensor = aten::_saturate_weight_to_fp16(%w)
1162 %r = aten::linear(%a_dq, %fp16_tensor, %b)
1163 return (%r) )";
1164 std::string linear_fp16_with_prepack = R"(
1165graph(%w, %a_dq, %b):
1166 %packed_params = quantized::linear_prepack_fp16(%w, %b)
1167 %w_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack_fp16(%packed_params)
1168 %r = aten::linear(%a_dq, %w_unpacked, %b_unpacked)
1169 return (%r) )";
1170
1171 return {
1172 {"linear_prepack_unpack",
1173 std::move(linear_with_quant),
1174 std::move(linear_with_quant_prepack)},
1175 {"linear_fp16_prepack_unpack",
1176 std::move(linear_fp16_with_cast),
1177 std::move(linear_fp16_with_prepack)},
1178 };
1179}
1180
1181std::vector<QuantFusionInfo> conv_prepack_unpack_patterns() {
1182 std::string conv1d_with_quant = R"(
1183graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1184 %w_dequant = aten::dequantize(%w_quant)
1185 %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
1186 return (%r) )";
1187
1188 std::string conv1d_with_quant_prepack = R"(
1189graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1190 %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv1d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
1191 %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv1d_unpack(%packed_params)
1192 %w_dequant = aten::dequantize(%w_quant_unpacked)
1193 %r = aten::conv1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
1194 return (%r) )";
1195
1196 std::string conv2d_with_quant = R"(
1197graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1198 %w_dequant = aten::dequantize(%w_quant)
1199 %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
1200 return (%r) )";
1201
1202 std::string conv2d_with_quant_prepack = R"(
1203graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1204 %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
1205 %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv2d_unpack(%packed_params)
1206 %w_dequant = aten::dequantize(%w_quant_unpacked)
1207 %r = aten::conv2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
1208 return (%r) )";
1209
1210 std::string conv3d_with_quant = R"(
1211graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1212 %w_dequant = aten::dequantize(%w_quant)
1213 %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
1214 return (%r) )";
1215
1216 std::string conv3d_with_quant_prepack = R"(
1217graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1218 %packed_params : __torch__.torch.classes.quantized.Conv3dPackedParamsBase = quantized::conv3d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
1219 %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv3d_unpack(%packed_params)
1220 %w_dequant = aten::dequantize(%w_quant_unpacked)
1221 %r = aten::conv3d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
1222 return (%r) )";
1223
1224 std::string conv_transpose1d_with_quant = R"(
1225graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1226 %w_dequant = aten::dequantize(%w_quant)
1227 %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
1228 return (%r) )";
1229
1230 std::string conv_transpose1d_with_quant_prepack = R"(
1231graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1232 %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose1d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
1233 %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
1234 %w_dequant = aten::dequantize(%w_quant_unpacked)
1235 %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
1236 return (%r) )";
1237
1238 std::string conv_transpose2d_with_quant = R"(
1239graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1240 %w_dequant = aten::dequantize(%w_quant)
1241 %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
1242 return (%r) )";
1243
1244 std::string conv_transpose2d_with_quant_prepack = R"(
1245graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1246 %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose2d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
1247 %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
1248 %w_dequant = aten::dequantize(%w_quant_unpacked)
1249 %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
1250 return (%r) )";
1251
1252 return {
1253 {"conv1d_prepack_unpack",
1254 std::move(conv1d_with_quant),
1255 std::move(conv1d_with_quant_prepack)},
1256 {"conv2d_prepack_unpack",
1257 std::move(conv2d_with_quant),
1258 std::move(conv2d_with_quant_prepack)},
1259 {"conv3d_prepack_unpack",
1260 std::move(conv3d_with_quant),
1261 std::move(conv3d_with_quant_prepack)},
1262 {"conv_transpose1d_prepack_unpack",
1263 std::move(conv_transpose1d_with_quant),
1264 std::move(conv_transpose1d_with_quant_prepack)},
1265 {"conv_transpose2d_prepack_unpack",
1266 std::move(conv_transpose2d_with_quant),
1267 std::move(conv_transpose2d_with_quant_prepack)}};
1268}
1269
1270} // namespace jit
1271} // namespace torch
1272