1 | #include <torch/csrc/jit/passes/quantization/helper.h> |
2 | |
3 | #include <torch/csrc/jit/api/function_impl.h> |
4 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
5 | |
6 | #include <utility> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | using graph_rewrite_helper::getFuncName; |
12 | |
13 | struct FuncArg { |
14 | std::string func_name; |
15 | int arg_index; |
16 | }; |
17 | |
18 | using AtenFuncArgs = std::vector<FuncArg>; |
19 | using CallFuncArgs = std::vector<FuncArg>; |
20 | |
21 | // Lists of allowed quantizable operators |
22 | std::vector<std::string> _static_quantizable_call_funcs = { |
23 | "conv2d" , |
24 | "linear" , |
25 | "batch_norm" , |
26 | "hardswish" , |
27 | "elu" , |
28 | "celu" , |
29 | "layer_norm" , |
30 | "group_norm" , |
31 | "instance_norm" , |
32 | "embedding_bag" , |
33 | }; |
34 | |
35 | std::vector<std::string> _static_quantizable_aten_funcs = { |
36 | "conv1d" , |
37 | "conv2d" , |
38 | "conv3d" , |
39 | "conv_transpose1d" , |
40 | "conv_transpose2d" , |
41 | "linear" , |
42 | "hardswish" , |
43 | "hardswish_" , |
44 | "elu" , |
45 | "elu_" , |
46 | "celu" , |
47 | "celu_" , |
48 | "batch_norm" , |
49 | "layer_norm" , |
50 | "group_norm" , |
51 | "instance_norm" , |
52 | "embedding_bag" , |
53 | }; |
54 | |
55 | std::vector<std::string> _dynamic_quantizable_call_funcs = { |
56 | "linear" , |
57 | }; |
58 | |
59 | std::vector<std::string> _dynamic_quantizable_aten_funcs = { |
60 | "linear" , |
61 | }; |
62 | |
63 | std::vector<std::string> _static_weight_only_quant_aten_funcs = { |
64 | "embedding_bag" , |
65 | }; |
66 | std::vector<std::string> _static_weight_only_quant_call_funcs = { |
67 | "embedding_bag" , |
68 | }; |
69 | |
70 | // These are the prim::CallFunctions that doesn't require observation and |
71 | // have a single input Tensor |
72 | // example: `prim::CallFunction(%dropout, %input_tensor, ...) |
73 | // so we propagate observed property from %input_tensor to the |
74 | // output of the `prim::CallFunction` |
75 | // Also these ops doesn't do computation on the value of Tensor, the |
76 | // operation only depends on the shape of the Tensor |
77 | std::vector<std::string> _single_input_general_shape_call_funcs = { |
78 | "_max_pool1d" , |
79 | "_max_pool2d" , |
80 | "_max_pool3d" , |
81 | "dropout" , |
82 | "relu" , |
83 | }; |
84 | |
85 | // Similar to prim::CallFunctions, there are aten ops that doesn't |
86 | // require observation and have a single input Tensor |
87 | // Also these ops doesn't do computation on the value of Tensor, the |
88 | // operation only depends on the shape of the Tensor |
89 | // e.g. `aten::flatten(%input_tensor, ...)` |
90 | std::vector<std::string> _single_input_general_shape_aten_funcs = { |
91 | "max_pool1d" , |
92 | "max_pool2d" , |
93 | "max_pool3d" , |
94 | "flatten" , |
95 | "max" , |
96 | "min" , |
97 | "dropout" , |
98 | "reshape" , |
99 | // Non-inplace resize is deprecated |
100 | "resize_" , |
101 | "chunk" , |
102 | "view" , |
103 | "transpose" , |
104 | "contiguous" , |
105 | "permute" , |
106 | "repeat" , |
107 | "repeat_interleave" , |
108 | "relu" , |
109 | "relu_" , |
110 | "squeeze" , |
111 | "squeeze_" , |
112 | "unsqueeze" , |
113 | "unsqueeze_" , |
114 | "detach" , |
115 | "detach_" , |
116 | "stack" , |
117 | "__getitem__" , |
118 | }; |
119 | |
120 | // Theses are prim::CallFunctions for ops that doesn't require observation and |
121 | // have a single input Tensor |
122 | // Also these ops do computation on the value of Tensor |
123 | // TODO: [Need verify] looks like we can quantize simple functionals that just |
124 | // call into aten functions |
125 | std::vector<std::string> _single_input_general_value_call_funcs = { |
126 | "avg_pool1d" , |
127 | "avg_pool2d" , |
128 | "avg_pool3d" , |
129 | "adaptive_avg_pool1d" , |
130 | "adaptive_avg_pool2d" , |
131 | "adaptive_avg_pool3d" , |
132 | "interpolate" , |
133 | "upsample" , |
134 | "upsample_bilinear" , |
135 | "upsample_nearest" , |
136 | "hardtanh" , |
137 | "leaky_relu" , |
138 | }; |
139 | |
140 | // Theses are aten functions for ops that doesn't require observation and |
141 | // have a single input Tensor |
142 | // Also these ops do computation on the value of Tensor |
143 | // e.g. `aten::avg_pool2d(%input_tensor, ...)` |
144 | std::vector<std::string> _single_input_general_value_aten_funcs = { |
145 | "avg_pool1d" , |
146 | "avg_pool2d" , |
147 | "avg_pool3d" , |
148 | "adaptive_avg_pool1d" , |
149 | "adaptive_avg_pool2d" , |
150 | "adaptive_avg_pool3d" , |
151 | "mean" , |
152 | "upsample_nearest1d" , |
153 | "upsample_nearest2d" , |
154 | "upsample_nearest3d" , |
155 | "upsample_linear1d" , |
156 | "upsample_bilinear2d" , |
157 | "upsample_trilinear3d" , |
158 | "upsample_bicubic2d" , |
159 | "clamp" , |
160 | // "clamp_", // Enable when quantized `clamp_` is ready |
161 | "hardtanh" , |
162 | "hardtanh_" , |
163 | "leaky_relu" , |
164 | "leaky_relu_" , |
165 | }; |
166 | |
167 | std::vector<std::string> _clamp_funcs = { |
168 | "hardtanh" , |
169 | "hardtanh_" , |
170 | "clamp" , |
171 | // "clamp_", // Enable when quantized `clamp_` is ready |
172 | }; |
173 | |
174 | const float _asym_scale = 1.0f / 256.0f; |
175 | const int _asym_zero_point = 0; |
176 | const float _sym_scale = 2.0f / 256.0f; |
177 | const int _sym_zero_point = 128; |
178 | // quantization parameters for ops with range 0 to 1 |
179 | // for example: aten/src/ATen/native/quantized/cpu/qsigmoid.cpp |
180 | std::tuple<c10::QScheme, QParamVector> _per_tensor_asym_qparam = |
181 | std::make_tuple( |
182 | c10::kPerTensorAffine, |
183 | QParamVector( |
184 | {std::make_pair(".scale" , IValue(_asym_scale)), |
185 | std::make_pair(".zero_point" , IValue(_asym_zero_point)), |
186 | std::make_pair(".scalar_type" , IValue(c10::kQUInt8))})); |
187 | |
188 | // quantization parrameters for ops with range -1 to 1 |
189 | // for example: aten/src/ATen/native/quantized/cpu/qtanh.cpp |
190 | std::tuple<c10::QScheme, QParamVector> _per_tensor_sym_qparam = std::make_tuple( |
191 | c10::kPerTensorAffine, |
192 | QParamVector( |
193 | {std::make_pair(".scale" , IValue(_sym_scale)), |
194 | std::make_pair(".zero_point" , IValue(_sym_zero_point)), |
195 | std::make_pair(".scalar_type" , IValue(c10::kQUInt8))})); |
196 | |
197 | // Map from aten op symbol to the quantization parameters |
198 | // for the ops with fixed quantization parameters |
199 | std::unordered_map<NodeKind, std::tuple<c10::QScheme, QParamVector>> |
200 | _fixed_qparams_map = { |
201 | {Symbol::aten("hardsigmoid" ), _per_tensor_asym_qparam}, |
202 | {Symbol::aten("hardsigmoid_" ), _per_tensor_asym_qparam}, |
203 | {Symbol::aten("sigmoid" ), _per_tensor_asym_qparam}, |
204 | {Symbol::aten("sigmoid_" ), _per_tensor_asym_qparam}, |
205 | {Symbol::aten("tanh" ), _per_tensor_sym_qparam}, |
206 | {Symbol::aten("tanh_" ), _per_tensor_sym_qparam}, |
207 | }; |
208 | |
209 | // Special checks for ops that do not require observers for all input tensors. |
210 | // For each operator in this list observers are inserted for the input based |
211 | // on the index specified. |
212 | AtenFuncArgs _observe_inputs_aten_func = {}; |
213 | CallFuncArgs _observe_inputs_call_func = {{"batch_norm" , 1}}; |
214 | |
215 | // Aten functions for getting tensor information |
216 | std::vector<std::string> _tensor_info_funcs = {"size" , "len" , "dim" , "numel" }; |
217 | |
218 | // Aten functions whose output will be quantized or not quantized depending |
219 | // on input tensor |
220 | std::vector<std::string> _propagate_quant_single_input_ops = {"cat" }; |
221 | |
222 | // Rules are slightly different for binary ops like `aten::add`, for these ops, |
223 | // if both of the inputs are Tensor, we'll quantize the output only if both of |
224 | // the inputs are quantized |
225 | // if the second input is a Scalar, we'll only look at the first input to decide |
226 | // if we need to quantize the output |
227 | std::vector<std::string> _propagate_quant_binary_ops = { |
228 | "add" , |
229 | "add_" , |
230 | "mul" , |
231 | "mul_" }; |
232 | |
233 | // Check if `use` is an aten function of name `func_name` and if value |
234 | // `v` is the nth argument (if provided) of the function. |
235 | bool matchAtenFuncToUse( |
236 | const Use& use, |
237 | const std::string& func_name, |
238 | c10::optional<int> n) { |
239 | Node* node = use.user; |
240 | return node->kind() == Symbol::aten(func_name) && |
241 | (!n.has_value() || static_cast<size_t>(n.value()) == use.offset); |
242 | } |
243 | |
244 | bool matchCallFuncToUse( |
245 | const Use& use, |
246 | const std::string& func_name, |
247 | c10::optional<int> n) { |
248 | Node* node = use.user; |
249 | return node->kind() == prim::CallFunction && |
250 | getFuncName(node->inputs()[0]) == func_name && |
251 | (!n.has_value() || static_cast<size_t>(n.value()) == use.offset); |
252 | } |
253 | |
254 | // Check any use of `v` matches the aten function call |
255 | // or CallFunction patterns |
256 | bool matchArgPattern( |
257 | Value* v, |
258 | const AtenFuncArgs& aten_func_args, |
259 | const CallFuncArgs& call_func_args) { |
260 | for (const Use& u : v->uses()) { |
261 | for (const auto& func_arg : aten_func_args) { |
262 | if (matchAtenFuncToUse(u, func_arg.func_name, func_arg.arg_index)) { |
263 | return true; |
264 | } |
265 | } |
266 | |
267 | for (const auto& func_arg : call_func_args) { |
268 | if (matchCallFuncToUse(u, func_arg.func_name, func_arg.arg_index)) { |
269 | return true; |
270 | } |
271 | } |
272 | } |
273 | return false; |
274 | } |
275 | |
276 | // TODO add other op signatures. |
277 | bool isWeight(Value* v) { |
278 | bool result = matchArgPattern( |
279 | v, |
280 | // ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq, |
281 | // %mode_enum, %sparse, %per_sample_weights, %include_last_offset) |
282 | AtenFuncArgs( |
283 | {{"conv1d" , 1}, |
284 | {"conv2d" , 1}, |
285 | {"conv3d" , 1}, |
286 | {"conv_transpose1d" , 1}, |
287 | {"conv_transpose2d" , 1}, |
288 | {"linear" , 1}, |
289 | {"embedding_bag" , 0}}), |
290 | // embedding_bag - prim::CallFunction(%func, %input.1, %weight, |
291 | // %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse, |
292 | // %per_sample_weights.1, %include_last_offset) |
293 | CallFuncArgs({{"linear" , 2}, {"embedding_bag" , 2}})); |
294 | return result; |
295 | } |
296 | |
297 | bool isBiasOfConvOrLinear(Value* v) { |
298 | bool result = matchArgPattern( |
299 | v, |
300 | AtenFuncArgs( |
301 | {{"conv1d" , 2}, |
302 | {"conv2d" , 2}, |
303 | {"conv3d" , 2}, |
304 | {"conv_transpose1d" , 2}, |
305 | {"conv_transpose2d" , 2}, |
306 | {"linear" , 2}}), |
307 | CallFuncArgs({{"linear" , 3}})); |
308 | return result; |
309 | } |
310 | |
311 | bool isEmbeddingBagNonInput(Value* v) { |
312 | bool result = matchArgPattern( |
313 | v, |
314 | AtenFuncArgs({{"embedding_bag" , 2}, {"embedding_bag" , 6}}), |
315 | CallFuncArgs({})); |
316 | return result; |
317 | } |
318 | |
319 | c10::optional<Use> getClampScalarInputUse(Value* v) { |
320 | for (const auto& use : v->uses()) { |
321 | for (const auto& aten_func : _clamp_funcs) { |
322 | if (matchAtenFuncToUse(use, aten_func, 1) || |
323 | matchAtenFuncToUse(use, aten_func, 2)) { |
324 | return use; |
325 | } |
326 | } |
327 | } |
328 | return c10::nullopt; |
329 | } |
330 | |
331 | void cloneMethod( |
332 | Module& module, |
333 | const std::string& orig_method_name, |
334 | const std::string& new_method_name) { |
335 | const Function& method = module.get_method(orig_method_name).function(); |
336 | auto graph = toGraphFunction(method).graph()->copy(); |
337 | const auto& schema = method.getSchema(); |
338 | const auto this_method_name = |
339 | c10::QualifiedName(*module.type()->name(), new_method_name); |
340 | auto copied = module._ivalue()->compilation_unit()->create_function( |
341 | this_method_name, std::move(graph)); |
342 | module.type()->addMethod(copied); |
343 | copied->setSchema(schema); |
344 | } |
345 | |
346 | std::vector<Value*> getPassThroughInputs(Value* v) { |
347 | Node* n = v->node(); |
348 | if (isSingleInputGeneralCallFunction(n)) { |
349 | return {n->input(1)}; |
350 | } else if ( |
351 | isSingleInputGeneralAtenFunction(n) || |
352 | (n->kind() == Symbol::aten("sort" ) && v->offset() == 0)) { |
353 | return {n->input(0)}; |
354 | } else if (n->kind() == prim::If && n->outputs().size() == 1) { |
355 | std::vector<Value*> inputs; |
356 | for (Block* subblock : n->blocks()) { |
357 | if (alwaysRaisesException(subblock)) { |
358 | continue; |
359 | } |
360 | auto* output = subblock->outputs()[0]; |
361 | inputs.push_back(output); |
362 | } |
363 | return inputs; |
364 | } else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) { |
365 | // only propagate dequantize for Tensor |
366 | if (v->type()->isSubtypeOf(*TensorType::get())) { |
367 | return {n->input(0)}; |
368 | } else { |
369 | return {}; |
370 | } |
371 | } else if ( |
372 | n->kind() == prim::ListConstruct && |
373 | v->type()->isSubtypeOf(*ListType::ofTensors())) { |
374 | std::vector<Value*> inputs; |
375 | for (auto* v : n->inputs()) { |
376 | inputs.push_back(v); |
377 | } |
378 | return inputs; |
379 | } else if (n->kind() == prim::TupleConstruct) { |
380 | std::vector<Value*> inputs; |
381 | for (auto* input : n->inputs()) { |
382 | if (input->type()->isSubtypeOf(*TensorType::get())) { |
383 | inputs.push_back(input); |
384 | } |
385 | } |
386 | return inputs; |
387 | } else if (n->kind() == Symbol::aten("append" )) { |
388 | std::vector<Value*> inputs; |
389 | for (auto* input : n->inputs()) { |
390 | inputs.push_back(input); |
391 | } |
392 | return inputs; |
393 | } |
394 | |
395 | return {}; |
396 | } |
397 | |
398 | std::vector<NodeKind> toAtenSymbol(const std::vector<std::string>& func_names) { |
399 | std::vector<NodeKind> symbols; |
400 | std::transform( |
401 | func_names.begin(), |
402 | func_names.end(), |
403 | std::back_inserter(symbols), |
404 | Symbol::aten); |
405 | return symbols; |
406 | } |
407 | |
408 | bool isAtenFunc(Node* n, const std::vector<NodeKind>& aten_funcs) { |
409 | return std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) != |
410 | aten_funcs.end(); |
411 | } |
412 | |
413 | bool isAtenFunc(Node* n, const std::vector<std::string>& aten_funcs) { |
414 | const auto& symbols = toAtenSymbol(aten_funcs); |
415 | return isAtenFunc(n, symbols); |
416 | } |
417 | |
418 | // TODO: factor out isCallFunc |
419 | bool isFunctionNode( |
420 | Node* n, |
421 | const std::vector<std::string>& call_funcs, |
422 | const std::vector<std::string>& aten_funcs) { |
423 | bool is_func_node = isAtenFunc(n, aten_funcs); |
424 | if (n->kind() == prim::CallFunction) { |
425 | auto func_name = getFuncName(n->inputs()[0]); |
426 | is_func_node |= |
427 | std::find(call_funcs.begin(), call_funcs.end(), func_name) != |
428 | call_funcs.end(); |
429 | } |
430 | return is_func_node; |
431 | } |
432 | |
433 | bool isSingleInputGeneralShapeAtenFunction(Node* n) { |
434 | return isAtenFunc(n, _single_input_general_shape_aten_funcs); |
435 | } |
436 | |
437 | bool isSingleInputGeneralValueAtenFunction(Node* n) { |
438 | return isAtenFunc(n, _single_input_general_value_aten_funcs) || |
439 | isBinaryOpWithScalarInput(n); |
440 | } |
441 | |
442 | bool isSingleInputGeneralCallFunction(Node* n) { |
443 | static std::vector<std::string> single_input_general_call_funcs; |
444 | std::copy( |
445 | _single_input_general_shape_call_funcs.begin(), |
446 | _single_input_general_shape_call_funcs.end(), |
447 | std::back_inserter(single_input_general_call_funcs)); |
448 | std::copy( |
449 | _single_input_general_value_call_funcs.begin(), |
450 | _single_input_general_value_call_funcs.end(), |
451 | std::back_inserter(single_input_general_call_funcs)); |
452 | return isFunctionNode( |
453 | n, |
454 | /* call_funcs = */ single_input_general_call_funcs, |
455 | /* aten_funcs = */ {}); |
456 | } |
457 | |
458 | bool isSingleInputGeneralAtenFunction(Node* n) { |
459 | static std::vector<NodeKind> fixed_qparams_aten_funcs; |
460 | std::transform( |
461 | _fixed_qparams_map.begin(), |
462 | _fixed_qparams_map.end(), |
463 | std::back_inserter(fixed_qparams_aten_funcs), |
464 | [](auto pair) { return pair.first; }); |
465 | |
466 | return isSingleInputGeneralValueAtenFunction(n) || |
467 | isSingleInputGeneralShapeAtenFunction(n) || |
468 | isAtenFunc(n, fixed_qparams_aten_funcs); |
469 | } |
470 | |
471 | bool isClamp(Node* n) { |
472 | return isAtenFunc(n, _clamp_funcs); |
473 | } |
474 | |
475 | bool isTensorInfoNode(Node* n) { |
476 | return isAtenFunc(n, _tensor_info_funcs); |
477 | } |
478 | |
479 | bool isPropagateQuantSingleInputOp(Node* n) { |
480 | return isAtenFunc(n, _propagate_quant_single_input_ops); |
481 | } |
482 | |
483 | bool isPropagateQuantBinaryOp(Node* n) { |
484 | return isAtenFunc(n, _propagate_quant_binary_ops); |
485 | } |
486 | |
487 | bool isPropagateQuantOp(Node* n) { |
488 | return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n); |
489 | } |
490 | |
491 | bool isBinaryOpWithScalarInput(Node* n) { |
492 | return isPropagateQuantBinaryOp(n) && isScalar(n->input(1)); |
493 | } |
494 | |
495 | c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) { |
496 | static std::vector<NodeKind> fixed_qparam_funcs; |
497 | std::transform( |
498 | _fixed_qparams_map.begin(), |
499 | _fixed_qparams_map.end(), |
500 | std::back_inserter(fixed_qparam_funcs), |
501 | [](const auto& pair) { return pair.first; }); |
502 | if (isAtenFunc(n, fixed_qparam_funcs)) { |
503 | return _fixed_qparams_map.at(n->kind()); |
504 | } |
505 | return c10::nullopt; |
506 | } |
507 | |
508 | bool userDefinedCallFunction(Node* n) { |
509 | return n->kind() == prim::CallFunction && |
510 | !isSingleInputGeneralCallFunction(n) && |
511 | !isFunctionNode(n, _static_quantizable_call_funcs, {}); |
512 | } |
513 | |
514 | bool isWeightOnlyStaticQuantOp(Node* n) { |
515 | return isFunctionNode( |
516 | n, |
517 | _static_weight_only_quant_call_funcs, |
518 | _static_weight_only_quant_aten_funcs); |
519 | } |
520 | |
521 | bool nodeQuantizable(Node* n, QuantType quant_type) { |
522 | bool is_dynamic = quant_type == QuantType::DYNAMIC; |
523 | return isFunctionNode( |
524 | n, |
525 | /* call_funcs = */ |
526 | is_dynamic ? _dynamic_quantizable_call_funcs |
527 | : _static_quantizable_call_funcs, |
528 | /* aten_funcs = */ |
529 | is_dynamic ? _dynamic_quantizable_aten_funcs |
530 | : _static_quantizable_aten_funcs); |
531 | } |
532 | |
533 | bool useQuantizable(const Use& use, QuantType quant_type) { |
534 | if (quant_type == QuantType::STATIC) { |
535 | for (const auto& func_input : _observe_inputs_aten_func) { |
536 | if (matchAtenFuncToUse(use, func_input.func_name, c10::nullopt)) { |
537 | return use.offset == static_cast<size_t>(func_input.arg_index); |
538 | } |
539 | } |
540 | |
541 | for (const auto& func_input : _observe_inputs_call_func) { |
542 | if (matchCallFuncToUse(use, func_input.func_name, c10::nullopt)) { |
543 | return use.offset == static_cast<size_t>(func_input.arg_index); |
544 | } |
545 | } |
546 | } |
547 | |
548 | return nodeQuantizable(use.user, quant_type); |
549 | } |
550 | |
551 | std::shared_ptr<Graph> getCallFunctionGraph(Node* n) { |
552 | auto* func_node = n->input(0)->node(); |
553 | auto func = func_node->output()->type()->expectRef<FunctionType>().function(); |
554 | auto graphFunc = tryToGraphFunction(*func); |
555 | TORCH_CHECK(graphFunc, "Quantization only works for graph function" ); |
556 | return graphFunc->graph(); |
557 | } |
558 | |
559 | // Block helper functions |
560 | bool alwaysRaisesException(Block* block) { |
561 | for (Node* n : block->nodes()) { |
562 | if (n->kind() == prim::RaiseException) { |
563 | return true; |
564 | } |
565 | if (n->kind() == prim::If) { |
566 | bool exception = true; |
567 | for (Block* b : n->blocks()) { |
568 | exception &= alwaysRaisesException(b); |
569 | } |
570 | if (exception) { |
571 | return true; |
572 | } |
573 | } |
574 | } |
575 | return false; |
576 | } |
577 | |
578 | // Check if a value in the graph is a Scalar value |
579 | bool isScalar(Value* v) { |
580 | auto iv = toIValue(v); |
581 | return v->type()->isSubtypeOf(*NumberType::get()) || |
582 | (v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() && |
583 | iv->toTensor().dim() == 0); |
584 | } |
585 | |
586 | // =================== Graph/Module analysis helper functions ============ |
587 | // Check if value is the input of the graph |
588 | bool hitGraphInput(Value* value) { |
589 | Graph* graph = value->owningGraph(); |
590 | const auto& inputs = graph->inputs(); |
591 | return std::find(inputs.begin(), inputs.end(), value) != inputs.end(); |
592 | } |
593 | |
594 | // Get the module access path for a Value representing a module instance |
595 | // by tracing back the GetAttr nodes and recording all the attribute |
596 | // names along the way. |
597 | // Assuming 'self.sub.basic_block.conv1', |
598 | // Input1: Value instance of conv1 |
599 | // Input2: Value instance of self |
600 | // Output: ['sub', 'basic_block', 'conv1'] |
601 | std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) { |
602 | std::vector<std::string> path; |
603 | // Iterator to traverse back the GetAttr calls |
604 | Value* iter = instance; |
605 | // trace back the instance to recover the path of the submodule |
606 | while (!hitGraphInput(iter) && iter->node()->kind() == prim::GetAttr) { |
607 | Node* get_attr = iter->node(); |
608 | // record the name of GetAttr |
609 | path.push_back(get_attr->s(attr::name)); |
610 | // trace back the chain of GetAttr |
611 | iter = get_attr->inputs()[0]; |
612 | } |
613 | TORCH_CHECK( |
614 | iter == self, |
615 | "Can't handle the access pattern of GetAttr " |
616 | " in getModuleAccessPath, traced back to:" , |
617 | iter->debugName(), |
618 | " which is not self:" , |
619 | self->debugName()); |
620 | std::reverse(path.begin(), path.end()); |
621 | return path; |
622 | } |
623 | |
624 | // Assuming self.foo.bar.conv1, |
625 | // Input1: Module instance of self |
626 | // Input2: ['foo', 'bar', 'conv1'] |
627 | // Output: Module instance of conv1 |
628 | Module findChildModule( |
629 | const Module& module, |
630 | const std::vector<std::string>& path) { |
631 | Module m = module; |
632 | for (const auto& p : path) { |
633 | m = m.attr(p).toModule(); |
634 | } |
635 | return m; |
636 | } |
637 | |
638 | Module getInvokedModule(Module& module, Node* n, Value* self) { |
639 | auto* instance = n->inputs()[0]; |
640 | auto path = getModuleAccessPath(instance, self); |
641 | return findChildModule(module, path); |
642 | } |
643 | |
644 | c10::optional<Module> getInvokedModuleOpt( |
645 | const Module& module, |
646 | Node* n, |
647 | Value* self) { |
648 | auto* instance = n->inputs()[0]; |
649 | auto path = getModuleAccessPath(instance, self); |
650 | Module m = module; |
651 | for (const auto& p : path) { |
652 | if (m.attr(p).isModule()) { |
653 | m = m.attr(p).toModule(); |
654 | } else { |
655 | return c10::nullopt; |
656 | } |
657 | } |
658 | return m; |
659 | } |
660 | |
661 | // ==================== filter functions for matches ============== |
662 | bool is_int_constant( |
663 | const Match& match, |
664 | const std::unordered_map<std::string, Value*>& vmap, |
665 | const std::string& vname, |
666 | int value) { |
667 | const auto& match_vmap = match.values_map; |
668 | auto v = toIValue(match_vmap.at(vmap.at(vname))); |
669 | return v && v->isInt() && v->toInt() == value; |
670 | } |
671 | |
672 | bool is_functional( |
673 | const Match& match, |
674 | const std::unordered_map<std::string, Value*>& vmap, |
675 | const std::string& vname, |
676 | const std::string& functional) { |
677 | const auto& match_vmap = match.values_map; |
678 | Value* v = match_vmap.at(vmap.at(vname)); |
679 | return v->type()->cast<FunctionType>() && getFuncName(v) == functional; |
680 | } |
681 | |
682 | std::string removeTorchMangle(const std::string& orig_name) { |
683 | static std::regex mangle_re("\\.___torch_mangle_\\d+" ); |
684 | auto qualified_name = std::regex_replace(orig_name, mangle_re, "" ); |
685 | return qualified_name; |
686 | } |
687 | |
688 | c10::optional<std::string> getModuleName(Value* value) { |
689 | auto type = value->type()->cast<ClassType>(); |
690 | if (type && type->name()) { |
691 | return removeTorchMangle(type->name()->qualifiedName()); |
692 | } |
693 | return c10::nullopt; |
694 | } |
695 | |
696 | bool is_module( |
697 | const Match& match, |
698 | const std::unordered_map<std::string, Value*>& vmap, |
699 | const std::string& vname, |
700 | const std::string& module_qualified_name) { |
701 | const auto& match_vmap = match.values_map; |
702 | Value* v = match_vmap.at(vmap.at(vname)); |
703 | auto module_name = getModuleName(v); |
704 | if (module_name.has_value()) { |
705 | return module_name.value() == module_qualified_name; |
706 | } |
707 | return false; |
708 | }; |
709 | |
710 | bool aten_add_alpha_is_one( |
711 | const Match& match, |
712 | const std::unordered_map<std::string, Value*>& vmap) { |
713 | return is_int_constant(match, vmap, "alpha" , 1); |
714 | } |
715 | |
716 | bool is_functional_relu( |
717 | const Match& match, |
718 | const std::unordered_map<std::string, Value*>& vmap) { |
719 | return is_functional(match, vmap, "relu" , "relu" ); |
720 | } |
721 | |
722 | bool is_relu_module( |
723 | const Match& match, |
724 | const std::unordered_map<std::string, Value*>& vmap) { |
725 | return is_module( |
726 | match, vmap, "relu" , "__torch__.torch.nn.modules.activation.ReLU" ); |
727 | } |
728 | |
729 | bool is_linear_module( |
730 | const Match& match, |
731 | const std::unordered_map<std::string, Value*>& vmap) { |
732 | return is_module( |
733 | match, vmap, "linear" , "__torch__.torch.nn.modules.linear.Linear" ); |
734 | } |
735 | |
736 | bool is_conv1d_module( |
737 | const Match& match, |
738 | const std::unordered_map<std::string, Value*>& vmap) { |
739 | return is_module( |
740 | match, vmap, "conv" , "__torch__.torch.nn.modules.conv.Conv1d" ); |
741 | } |
742 | |
743 | bool is_conv2d_module( |
744 | const Match& match, |
745 | const std::unordered_map<std::string, Value*>& vmap) { |
746 | return is_module( |
747 | match, vmap, "conv" , "__torch__.torch.nn.modules.conv.Conv2d" ); |
748 | } |
749 | |
750 | bool is_conv3d_module( |
751 | const Match& match, |
752 | const std::unordered_map<std::string, Value*>& vmap) { |
753 | return is_module( |
754 | match, vmap, "conv" , "__torch__.torch.nn.modules.conv.Conv3d" ); |
755 | } |
756 | |
757 | bool is_conv_transpose1d_module( |
758 | const Match& match, |
759 | const std::unordered_map<std::string, Value*>& vmap) { |
760 | return is_module( |
761 | match, vmap, "conv" , "__torch__.torch.nn.modules.conv.ConvTranspose1d" ); |
762 | } |
763 | |
764 | bool is_conv_transpose2d_module( |
765 | const Match& match, |
766 | const std::unordered_map<std::string, Value*>& vmap) { |
767 | return is_module( |
768 | match, vmap, "conv" , "__torch__.torch.nn.modules.conv.ConvTranspose2d" ); |
769 | } |
770 | |
771 | bool is_batchnorm2d_module( |
772 | const Match& match, |
773 | const std::unordered_map<std::string, Value*>& vmap) { |
774 | bool regnorm = is_module( |
775 | match, |
776 | vmap, |
777 | "batchnorm" , |
778 | "__torch__.torch.nn.modules.batchnorm.BatchNorm2d" ); |
779 | bool naivenorm = is_module( |
780 | match, |
781 | vmap, |
782 | "batchnorm" , |
783 | "__torch__.mobile_cv.arch.layers.batch_norm.NaiveSyncBatchNorm" ); |
784 | return (regnorm || naivenorm); |
785 | } |
786 | |
787 | bool is_batchnorm3d_module( |
788 | const Match& match, |
789 | const std::unordered_map<std::string, Value*>& vmap) { |
790 | return is_module( |
791 | match, |
792 | vmap, |
793 | "batchnorm" , |
794 | "__torch__.torch.nn.modules.batchnorm.BatchNorm3d" ); |
795 | } |
796 | |
797 | } // namespace jit |
798 | } // namespace torch |
799 | |