1#include <torch/csrc/jit/passes/quantization/helper.h>
2
3#include <torch/csrc/jit/api/function_impl.h>
4#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
5
6#include <utility>
7
8namespace torch {
9namespace jit {
10
11using graph_rewrite_helper::getFuncName;
12
13struct FuncArg {
14 std::string func_name;
15 int arg_index;
16};
17
18using AtenFuncArgs = std::vector<FuncArg>;
19using CallFuncArgs = std::vector<FuncArg>;
20
21// Lists of allowed quantizable operators
22std::vector<std::string> _static_quantizable_call_funcs = {
23 "conv2d",
24 "linear",
25 "batch_norm",
26 "hardswish",
27 "elu",
28 "celu",
29 "layer_norm",
30 "group_norm",
31 "instance_norm",
32 "embedding_bag",
33};
34
35std::vector<std::string> _static_quantizable_aten_funcs = {
36 "conv1d",
37 "conv2d",
38 "conv3d",
39 "conv_transpose1d",
40 "conv_transpose2d",
41 "linear",
42 "hardswish",
43 "hardswish_",
44 "elu",
45 "elu_",
46 "celu",
47 "celu_",
48 "batch_norm",
49 "layer_norm",
50 "group_norm",
51 "instance_norm",
52 "embedding_bag",
53};
54
55std::vector<std::string> _dynamic_quantizable_call_funcs = {
56 "linear",
57};
58
59std::vector<std::string> _dynamic_quantizable_aten_funcs = {
60 "linear",
61};
62
63std::vector<std::string> _static_weight_only_quant_aten_funcs = {
64 "embedding_bag",
65};
66std::vector<std::string> _static_weight_only_quant_call_funcs = {
67 "embedding_bag",
68};
69
70// These are the prim::CallFunctions that doesn't require observation and
71// have a single input Tensor
72// example: `prim::CallFunction(%dropout, %input_tensor, ...)
73// so we propagate observed property from %input_tensor to the
74// output of the `prim::CallFunction`
75// Also these ops doesn't do computation on the value of Tensor, the
76// operation only depends on the shape of the Tensor
77std::vector<std::string> _single_input_general_shape_call_funcs = {
78 "_max_pool1d",
79 "_max_pool2d",
80 "_max_pool3d",
81 "dropout",
82 "relu",
83};
84
85// Similar to prim::CallFunctions, there are aten ops that doesn't
86// require observation and have a single input Tensor
87// Also these ops doesn't do computation on the value of Tensor, the
88// operation only depends on the shape of the Tensor
89// e.g. `aten::flatten(%input_tensor, ...)`
90std::vector<std::string> _single_input_general_shape_aten_funcs = {
91 "max_pool1d",
92 "max_pool2d",
93 "max_pool3d",
94 "flatten",
95 "max",
96 "min",
97 "dropout",
98 "reshape",
99 // Non-inplace resize is deprecated
100 "resize_",
101 "chunk",
102 "view",
103 "transpose",
104 "contiguous",
105 "permute",
106 "repeat",
107 "repeat_interleave",
108 "relu",
109 "relu_",
110 "squeeze",
111 "squeeze_",
112 "unsqueeze",
113 "unsqueeze_",
114 "detach",
115 "detach_",
116 "stack",
117 "__getitem__",
118};
119
120// Theses are prim::CallFunctions for ops that doesn't require observation and
121// have a single input Tensor
122// Also these ops do computation on the value of Tensor
123// TODO: [Need verify] looks like we can quantize simple functionals that just
124// call into aten functions
125std::vector<std::string> _single_input_general_value_call_funcs = {
126 "avg_pool1d",
127 "avg_pool2d",
128 "avg_pool3d",
129 "adaptive_avg_pool1d",
130 "adaptive_avg_pool2d",
131 "adaptive_avg_pool3d",
132 "interpolate",
133 "upsample",
134 "upsample_bilinear",
135 "upsample_nearest",
136 "hardtanh",
137 "leaky_relu",
138};
139
140// Theses are aten functions for ops that doesn't require observation and
141// have a single input Tensor
142// Also these ops do computation on the value of Tensor
143// e.g. `aten::avg_pool2d(%input_tensor, ...)`
144std::vector<std::string> _single_input_general_value_aten_funcs = {
145 "avg_pool1d",
146 "avg_pool2d",
147 "avg_pool3d",
148 "adaptive_avg_pool1d",
149 "adaptive_avg_pool2d",
150 "adaptive_avg_pool3d",
151 "mean",
152 "upsample_nearest1d",
153 "upsample_nearest2d",
154 "upsample_nearest3d",
155 "upsample_linear1d",
156 "upsample_bilinear2d",
157 "upsample_trilinear3d",
158 "upsample_bicubic2d",
159 "clamp",
160 // "clamp_", // Enable when quantized `clamp_` is ready
161 "hardtanh",
162 "hardtanh_",
163 "leaky_relu",
164 "leaky_relu_",
165};
166
167std::vector<std::string> _clamp_funcs = {
168 "hardtanh",
169 "hardtanh_",
170 "clamp",
171 // "clamp_", // Enable when quantized `clamp_` is ready
172};
173
174const float _asym_scale = 1.0f / 256.0f;
175const int _asym_zero_point = 0;
176const float _sym_scale = 2.0f / 256.0f;
177const int _sym_zero_point = 128;
178// quantization parameters for ops with range 0 to 1
179// for example: aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
180std::tuple<c10::QScheme, QParamVector> _per_tensor_asym_qparam =
181 std::make_tuple(
182 c10::kPerTensorAffine,
183 QParamVector(
184 {std::make_pair(".scale", IValue(_asym_scale)),
185 std::make_pair(".zero_point", IValue(_asym_zero_point)),
186 std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
187
188// quantization parrameters for ops with range -1 to 1
189// for example: aten/src/ATen/native/quantized/cpu/qtanh.cpp
190std::tuple<c10::QScheme, QParamVector> _per_tensor_sym_qparam = std::make_tuple(
191 c10::kPerTensorAffine,
192 QParamVector(
193 {std::make_pair(".scale", IValue(_sym_scale)),
194 std::make_pair(".zero_point", IValue(_sym_zero_point)),
195 std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
196
197// Map from aten op symbol to the quantization parameters
198// for the ops with fixed quantization parameters
199std::unordered_map<NodeKind, std::tuple<c10::QScheme, QParamVector>>
200 _fixed_qparams_map = {
201 {Symbol::aten("hardsigmoid"), _per_tensor_asym_qparam},
202 {Symbol::aten("hardsigmoid_"), _per_tensor_asym_qparam},
203 {Symbol::aten("sigmoid"), _per_tensor_asym_qparam},
204 {Symbol::aten("sigmoid_"), _per_tensor_asym_qparam},
205 {Symbol::aten("tanh"), _per_tensor_sym_qparam},
206 {Symbol::aten("tanh_"), _per_tensor_sym_qparam},
207};
208
209// Special checks for ops that do not require observers for all input tensors.
210// For each operator in this list observers are inserted for the input based
211// on the index specified.
212AtenFuncArgs _observe_inputs_aten_func = {};
213CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};
214
215// Aten functions for getting tensor information
216std::vector<std::string> _tensor_info_funcs = {"size", "len", "dim", "numel"};
217
218// Aten functions whose output will be quantized or not quantized depending
219// on input tensor
220std::vector<std::string> _propagate_quant_single_input_ops = {"cat"};
221
222// Rules are slightly different for binary ops like `aten::add`, for these ops,
223// if both of the inputs are Tensor, we'll quantize the output only if both of
224// the inputs are quantized
225// if the second input is a Scalar, we'll only look at the first input to decide
226// if we need to quantize the output
227std::vector<std::string> _propagate_quant_binary_ops = {
228 "add",
229 "add_",
230 "mul",
231 "mul_"};
232
233// Check if `use` is an aten function of name `func_name` and if value
234// `v` is the nth argument (if provided) of the function.
235bool matchAtenFuncToUse(
236 const Use& use,
237 const std::string& func_name,
238 c10::optional<int> n) {
239 Node* node = use.user;
240 return node->kind() == Symbol::aten(func_name) &&
241 (!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
242}
243
244bool matchCallFuncToUse(
245 const Use& use,
246 const std::string& func_name,
247 c10::optional<int> n) {
248 Node* node = use.user;
249 return node->kind() == prim::CallFunction &&
250 getFuncName(node->inputs()[0]) == func_name &&
251 (!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
252}
253
254// Check any use of `v` matches the aten function call
255// or CallFunction patterns
256bool matchArgPattern(
257 Value* v,
258 const AtenFuncArgs& aten_func_args,
259 const CallFuncArgs& call_func_args) {
260 for (const Use& u : v->uses()) {
261 for (const auto& func_arg : aten_func_args) {
262 if (matchAtenFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
263 return true;
264 }
265 }
266
267 for (const auto& func_arg : call_func_args) {
268 if (matchCallFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
269 return true;
270 }
271 }
272 }
273 return false;
274}
275
276// TODO add other op signatures.
277bool isWeight(Value* v) {
278 bool result = matchArgPattern(
279 v,
280 // ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq,
281 // %mode_enum, %sparse, %per_sample_weights, %include_last_offset)
282 AtenFuncArgs(
283 {{"conv1d", 1},
284 {"conv2d", 1},
285 {"conv3d", 1},
286 {"conv_transpose1d", 1},
287 {"conv_transpose2d", 1},
288 {"linear", 1},
289 {"embedding_bag", 0}}),
290 // embedding_bag - prim::CallFunction(%func, %input.1, %weight,
291 // %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse,
292 // %per_sample_weights.1, %include_last_offset)
293 CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}}));
294 return result;
295}
296
297bool isBiasOfConvOrLinear(Value* v) {
298 bool result = matchArgPattern(
299 v,
300 AtenFuncArgs(
301 {{"conv1d", 2},
302 {"conv2d", 2},
303 {"conv3d", 2},
304 {"conv_transpose1d", 2},
305 {"conv_transpose2d", 2},
306 {"linear", 2}}),
307 CallFuncArgs({{"linear", 3}}));
308 return result;
309}
310
311bool isEmbeddingBagNonInput(Value* v) {
312 bool result = matchArgPattern(
313 v,
314 AtenFuncArgs({{"embedding_bag", 2}, {"embedding_bag", 6}}),
315 CallFuncArgs({}));
316 return result;
317}
318
319c10::optional<Use> getClampScalarInputUse(Value* v) {
320 for (const auto& use : v->uses()) {
321 for (const auto& aten_func : _clamp_funcs) {
322 if (matchAtenFuncToUse(use, aten_func, 1) ||
323 matchAtenFuncToUse(use, aten_func, 2)) {
324 return use;
325 }
326 }
327 }
328 return c10::nullopt;
329}
330
331void cloneMethod(
332 Module& module,
333 const std::string& orig_method_name,
334 const std::string& new_method_name) {
335 const Function& method = module.get_method(orig_method_name).function();
336 auto graph = toGraphFunction(method).graph()->copy();
337 const auto& schema = method.getSchema();
338 const auto this_method_name =
339 c10::QualifiedName(*module.type()->name(), new_method_name);
340 auto copied = module._ivalue()->compilation_unit()->create_function(
341 this_method_name, std::move(graph));
342 module.type()->addMethod(copied);
343 copied->setSchema(schema);
344}
345
346std::vector<Value*> getPassThroughInputs(Value* v) {
347 Node* n = v->node();
348 if (isSingleInputGeneralCallFunction(n)) {
349 return {n->input(1)};
350 } else if (
351 isSingleInputGeneralAtenFunction(n) ||
352 (n->kind() == Symbol::aten("sort") && v->offset() == 0)) {
353 return {n->input(0)};
354 } else if (n->kind() == prim::If && n->outputs().size() == 1) {
355 std::vector<Value*> inputs;
356 for (Block* subblock : n->blocks()) {
357 if (alwaysRaisesException(subblock)) {
358 continue;
359 }
360 auto* output = subblock->outputs()[0];
361 inputs.push_back(output);
362 }
363 return inputs;
364 } else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
365 // only propagate dequantize for Tensor
366 if (v->type()->isSubtypeOf(*TensorType::get())) {
367 return {n->input(0)};
368 } else {
369 return {};
370 }
371 } else if (
372 n->kind() == prim::ListConstruct &&
373 v->type()->isSubtypeOf(*ListType::ofTensors())) {
374 std::vector<Value*> inputs;
375 for (auto* v : n->inputs()) {
376 inputs.push_back(v);
377 }
378 return inputs;
379 } else if (n->kind() == prim::TupleConstruct) {
380 std::vector<Value*> inputs;
381 for (auto* input : n->inputs()) {
382 if (input->type()->isSubtypeOf(*TensorType::get())) {
383 inputs.push_back(input);
384 }
385 }
386 return inputs;
387 } else if (n->kind() == Symbol::aten("append")) {
388 std::vector<Value*> inputs;
389 for (auto* input : n->inputs()) {
390 inputs.push_back(input);
391 }
392 return inputs;
393 }
394
395 return {};
396}
397
398std::vector<NodeKind> toAtenSymbol(const std::vector<std::string>& func_names) {
399 std::vector<NodeKind> symbols;
400 std::transform(
401 func_names.begin(),
402 func_names.end(),
403 std::back_inserter(symbols),
404 Symbol::aten);
405 return symbols;
406}
407
408bool isAtenFunc(Node* n, const std::vector<NodeKind>& aten_funcs) {
409 return std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) !=
410 aten_funcs.end();
411}
412
413bool isAtenFunc(Node* n, const std::vector<std::string>& aten_funcs) {
414 const auto& symbols = toAtenSymbol(aten_funcs);
415 return isAtenFunc(n, symbols);
416}
417
418// TODO: factor out isCallFunc
419bool isFunctionNode(
420 Node* n,
421 const std::vector<std::string>& call_funcs,
422 const std::vector<std::string>& aten_funcs) {
423 bool is_func_node = isAtenFunc(n, aten_funcs);
424 if (n->kind() == prim::CallFunction) {
425 auto func_name = getFuncName(n->inputs()[0]);
426 is_func_node |=
427 std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
428 call_funcs.end();
429 }
430 return is_func_node;
431}
432
433bool isSingleInputGeneralShapeAtenFunction(Node* n) {
434 return isAtenFunc(n, _single_input_general_shape_aten_funcs);
435}
436
437bool isSingleInputGeneralValueAtenFunction(Node* n) {
438 return isAtenFunc(n, _single_input_general_value_aten_funcs) ||
439 isBinaryOpWithScalarInput(n);
440}
441
442bool isSingleInputGeneralCallFunction(Node* n) {
443 static std::vector<std::string> single_input_general_call_funcs;
444 std::copy(
445 _single_input_general_shape_call_funcs.begin(),
446 _single_input_general_shape_call_funcs.end(),
447 std::back_inserter(single_input_general_call_funcs));
448 std::copy(
449 _single_input_general_value_call_funcs.begin(),
450 _single_input_general_value_call_funcs.end(),
451 std::back_inserter(single_input_general_call_funcs));
452 return isFunctionNode(
453 n,
454 /* call_funcs = */ single_input_general_call_funcs,
455 /* aten_funcs = */ {});
456}
457
458bool isSingleInputGeneralAtenFunction(Node* n) {
459 static std::vector<NodeKind> fixed_qparams_aten_funcs;
460 std::transform(
461 _fixed_qparams_map.begin(),
462 _fixed_qparams_map.end(),
463 std::back_inserter(fixed_qparams_aten_funcs),
464 [](auto pair) { return pair.first; });
465
466 return isSingleInputGeneralValueAtenFunction(n) ||
467 isSingleInputGeneralShapeAtenFunction(n) ||
468 isAtenFunc(n, fixed_qparams_aten_funcs);
469}
470
471bool isClamp(Node* n) {
472 return isAtenFunc(n, _clamp_funcs);
473}
474
475bool isTensorInfoNode(Node* n) {
476 return isAtenFunc(n, _tensor_info_funcs);
477}
478
479bool isPropagateQuantSingleInputOp(Node* n) {
480 return isAtenFunc(n, _propagate_quant_single_input_ops);
481}
482
483bool isPropagateQuantBinaryOp(Node* n) {
484 return isAtenFunc(n, _propagate_quant_binary_ops);
485}
486
487bool isPropagateQuantOp(Node* n) {
488 return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n);
489}
490
491bool isBinaryOpWithScalarInput(Node* n) {
492 return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
493}
494
495c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
496 static std::vector<NodeKind> fixed_qparam_funcs;
497 std::transform(
498 _fixed_qparams_map.begin(),
499 _fixed_qparams_map.end(),
500 std::back_inserter(fixed_qparam_funcs),
501 [](const auto& pair) { return pair.first; });
502 if (isAtenFunc(n, fixed_qparam_funcs)) {
503 return _fixed_qparams_map.at(n->kind());
504 }
505 return c10::nullopt;
506}
507
508bool userDefinedCallFunction(Node* n) {
509 return n->kind() == prim::CallFunction &&
510 !isSingleInputGeneralCallFunction(n) &&
511 !isFunctionNode(n, _static_quantizable_call_funcs, {});
512}
513
514bool isWeightOnlyStaticQuantOp(Node* n) {
515 return isFunctionNode(
516 n,
517 _static_weight_only_quant_call_funcs,
518 _static_weight_only_quant_aten_funcs);
519}
520
521bool nodeQuantizable(Node* n, QuantType quant_type) {
522 bool is_dynamic = quant_type == QuantType::DYNAMIC;
523 return isFunctionNode(
524 n,
525 /* call_funcs = */
526 is_dynamic ? _dynamic_quantizable_call_funcs
527 : _static_quantizable_call_funcs,
528 /* aten_funcs = */
529 is_dynamic ? _dynamic_quantizable_aten_funcs
530 : _static_quantizable_aten_funcs);
531}
532
533bool useQuantizable(const Use& use, QuantType quant_type) {
534 if (quant_type == QuantType::STATIC) {
535 for (const auto& func_input : _observe_inputs_aten_func) {
536 if (matchAtenFuncToUse(use, func_input.func_name, c10::nullopt)) {
537 return use.offset == static_cast<size_t>(func_input.arg_index);
538 }
539 }
540
541 for (const auto& func_input : _observe_inputs_call_func) {
542 if (matchCallFuncToUse(use, func_input.func_name, c10::nullopt)) {
543 return use.offset == static_cast<size_t>(func_input.arg_index);
544 }
545 }
546 }
547
548 return nodeQuantizable(use.user, quant_type);
549}
550
551std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
552 auto* func_node = n->input(0)->node();
553 auto func = func_node->output()->type()->expectRef<FunctionType>().function();
554 auto graphFunc = tryToGraphFunction(*func);
555 TORCH_CHECK(graphFunc, "Quantization only works for graph function");
556 return graphFunc->graph();
557}
558
559// Block helper functions
560bool alwaysRaisesException(Block* block) {
561 for (Node* n : block->nodes()) {
562 if (n->kind() == prim::RaiseException) {
563 return true;
564 }
565 if (n->kind() == prim::If) {
566 bool exception = true;
567 for (Block* b : n->blocks()) {
568 exception &= alwaysRaisesException(b);
569 }
570 if (exception) {
571 return true;
572 }
573 }
574 }
575 return false;
576}
577
578// Check if a value in the graph is a Scalar value
579bool isScalar(Value* v) {
580 auto iv = toIValue(v);
581 return v->type()->isSubtypeOf(*NumberType::get()) ||
582 (v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() &&
583 iv->toTensor().dim() == 0);
584}
585
586// =================== Graph/Module analysis helper functions ============
587// Check if value is the input of the graph
588bool hitGraphInput(Value* value) {
589 Graph* graph = value->owningGraph();
590 const auto& inputs = graph->inputs();
591 return std::find(inputs.begin(), inputs.end(), value) != inputs.end();
592}
593
594// Get the module access path for a Value representing a module instance
595// by tracing back the GetAttr nodes and recording all the attribute
596// names along the way.
597// Assuming 'self.sub.basic_block.conv1',
598// Input1: Value instance of conv1
599// Input2: Value instance of self
600// Output: ['sub', 'basic_block', 'conv1']
601std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
602 std::vector<std::string> path;
603 // Iterator to traverse back the GetAttr calls
604 Value* iter = instance;
605 // trace back the instance to recover the path of the submodule
606 while (!hitGraphInput(iter) && iter->node()->kind() == prim::GetAttr) {
607 Node* get_attr = iter->node();
608 // record the name of GetAttr
609 path.push_back(get_attr->s(attr::name));
610 // trace back the chain of GetAttr
611 iter = get_attr->inputs()[0];
612 }
613 TORCH_CHECK(
614 iter == self,
615 "Can't handle the access pattern of GetAttr "
616 " in getModuleAccessPath, traced back to:",
617 iter->debugName(),
618 " which is not self:",
619 self->debugName());
620 std::reverse(path.begin(), path.end());
621 return path;
622}
623
624// Assuming self.foo.bar.conv1,
625// Input1: Module instance of self
626// Input2: ['foo', 'bar', 'conv1']
627// Output: Module instance of conv1
628Module findChildModule(
629 const Module& module,
630 const std::vector<std::string>& path) {
631 Module m = module;
632 for (const auto& p : path) {
633 m = m.attr(p).toModule();
634 }
635 return m;
636}
637
638Module getInvokedModule(Module& module, Node* n, Value* self) {
639 auto* instance = n->inputs()[0];
640 auto path = getModuleAccessPath(instance, self);
641 return findChildModule(module, path);
642}
643
644c10::optional<Module> getInvokedModuleOpt(
645 const Module& module,
646 Node* n,
647 Value* self) {
648 auto* instance = n->inputs()[0];
649 auto path = getModuleAccessPath(instance, self);
650 Module m = module;
651 for (const auto& p : path) {
652 if (m.attr(p).isModule()) {
653 m = m.attr(p).toModule();
654 } else {
655 return c10::nullopt;
656 }
657 }
658 return m;
659}
660
661// ==================== filter functions for matches ==============
662bool is_int_constant(
663 const Match& match,
664 const std::unordered_map<std::string, Value*>& vmap,
665 const std::string& vname,
666 int value) {
667 const auto& match_vmap = match.values_map;
668 auto v = toIValue(match_vmap.at(vmap.at(vname)));
669 return v && v->isInt() && v->toInt() == value;
670}
671
672bool is_functional(
673 const Match& match,
674 const std::unordered_map<std::string, Value*>& vmap,
675 const std::string& vname,
676 const std::string& functional) {
677 const auto& match_vmap = match.values_map;
678 Value* v = match_vmap.at(vmap.at(vname));
679 return v->type()->cast<FunctionType>() && getFuncName(v) == functional;
680}
681
682std::string removeTorchMangle(const std::string& orig_name) {
683 static std::regex mangle_re("\\.___torch_mangle_\\d+");
684 auto qualified_name = std::regex_replace(orig_name, mangle_re, "");
685 return qualified_name;
686}
687
688c10::optional<std::string> getModuleName(Value* value) {
689 auto type = value->type()->cast<ClassType>();
690 if (type && type->name()) {
691 return removeTorchMangle(type->name()->qualifiedName());
692 }
693 return c10::nullopt;
694}
695
696bool is_module(
697 const Match& match,
698 const std::unordered_map<std::string, Value*>& vmap,
699 const std::string& vname,
700 const std::string& module_qualified_name) {
701 const auto& match_vmap = match.values_map;
702 Value* v = match_vmap.at(vmap.at(vname));
703 auto module_name = getModuleName(v);
704 if (module_name.has_value()) {
705 return module_name.value() == module_qualified_name;
706 }
707 return false;
708};
709
710bool aten_add_alpha_is_one(
711 const Match& match,
712 const std::unordered_map<std::string, Value*>& vmap) {
713 return is_int_constant(match, vmap, "alpha", 1);
714}
715
716bool is_functional_relu(
717 const Match& match,
718 const std::unordered_map<std::string, Value*>& vmap) {
719 return is_functional(match, vmap, "relu", "relu");
720}
721
722bool is_relu_module(
723 const Match& match,
724 const std::unordered_map<std::string, Value*>& vmap) {
725 return is_module(
726 match, vmap, "relu", "__torch__.torch.nn.modules.activation.ReLU");
727}
728
729bool is_linear_module(
730 const Match& match,
731 const std::unordered_map<std::string, Value*>& vmap) {
732 return is_module(
733 match, vmap, "linear", "__torch__.torch.nn.modules.linear.Linear");
734}
735
736bool is_conv1d_module(
737 const Match& match,
738 const std::unordered_map<std::string, Value*>& vmap) {
739 return is_module(
740 match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv1d");
741}
742
743bool is_conv2d_module(
744 const Match& match,
745 const std::unordered_map<std::string, Value*>& vmap) {
746 return is_module(
747 match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv2d");
748}
749
750bool is_conv3d_module(
751 const Match& match,
752 const std::unordered_map<std::string, Value*>& vmap) {
753 return is_module(
754 match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv3d");
755}
756
757bool is_conv_transpose1d_module(
758 const Match& match,
759 const std::unordered_map<std::string, Value*>& vmap) {
760 return is_module(
761 match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose1d");
762}
763
764bool is_conv_transpose2d_module(
765 const Match& match,
766 const std::unordered_map<std::string, Value*>& vmap) {
767 return is_module(
768 match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose2d");
769}
770
771bool is_batchnorm2d_module(
772 const Match& match,
773 const std::unordered_map<std::string, Value*>& vmap) {
774 bool regnorm = is_module(
775 match,
776 vmap,
777 "batchnorm",
778 "__torch__.torch.nn.modules.batchnorm.BatchNorm2d");
779 bool naivenorm = is_module(
780 match,
781 vmap,
782 "batchnorm",
783 "__torch__.mobile_cv.arch.layers.batch_norm.NaiveSyncBatchNorm");
784 return (regnorm || naivenorm);
785}
786
787bool is_batchnorm3d_module(
788 const Match& match,
789 const std::unordered_map<std::string, Value*>& vmap) {
790 return is_module(
791 match,
792 vmap,
793 "batchnorm",
794 "__torch__.torch.nn.modules.batchnorm.BatchNorm3d");
795}
796
797} // namespace jit
798} // namespace torch
799