1#pragma once
2#include <torch/csrc/jit/api/module.h>
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/ir/subgraph_matcher.h>
5#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
6#include <torch/csrc/jit/passes/quantization/quantization_type.h>
7
8#include <functional>
9#include <regex>
10
11namespace torch {
12namespace jit {
13
14using graph_rewrite_helper::getFuncName;
15
16// Vector of a module and the name of its method
17using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
18// Map of quantization parameter name and value
19// for example _scale, _zero_point,
20// _scalar_type and _axis(for per channel quantization)
21using QParamVector = std::vector<std::pair<std::string, IValue>>;
22
23// =========== helper functions for Value =========
24// Check if a value is weight, since we need to use weight observer
25// for weight
26TORCH_API bool isWeight(Value* v);
27
28// Check if a value is bias for conv and linear, which we do not
29// quantize
30TORCH_API bool isBiasOfConvOrLinear(Value* v);
31
32TORCH_API bool isEmbeddingBagNonInput(Value* v);
33
34// Get the use as scalar input of clamp ops for the input value
35c10::optional<Use> getClampScalarInputUse(Value* v);
36
37// For a given value `v`, get the list of values that we need to check
38// if they are observed/quantized or not, if so, we can say the
39// `v` is also observed/quantized, since we can derive
40// the quantization parameters for `v` given the list of values
41TORCH_API std::vector<Value*> getPassThroughInputs(Value* v);
42
43// Clones the method by the name of orig_method_name into new_method_name method
44TORCH_API void cloneMethod(
45 Module& module,
46 const std::string& orig_method_name,
47 const std::string& new_method_name);
48
49// Check if a value in the graph is a Scalar value
50TORCH_API bool isScalar(Value* v);
51
52// Check if value is the input of the graph
53TORCH_API bool hitGraphInput(Value* value);
54
55// Converts a mangled name, such as
56// __torch__.torch.ao.nn.quantized.modules.conv.___torch_mangle_7.Conv2d
57// into an unmangled name, such as
58// __torch__.torch.ao.nn.quantized.modules.conv.Conv2d
59TORCH_API std::string removeTorchMangle(const std::string& orig_name);
60
61// Return the module name that corresponds to the value.
62TORCH_API c10::optional<std::string> getModuleName(Value* value);
63
64// =========== helper functions for Node =========
65TORCH_API bool isSingleInputGeneralShapeAtenFunction(Node* n);
66
67TORCH_API bool isSingleInputGeneralValueAtenFunction(Node* n);
68
69TORCH_API bool isSingleInputGeneralCallFunction(Node* n);
70
71TORCH_API bool isSingleInputGeneralAtenFunction(Node* n);
72
73TORCH_API bool isClamp(Node* n);
74
75// Check if the node will produce the same result regardless of whether
76// the input tensor is quantized or not, example: aten::size
77TORCH_API bool isTensorInfoNode(Node* n);
78
79// Check if this the the propaagate op that has single input, e.g. aten::cat
80TORCH_API bool isPropagateQuantSingleInputOp(Node* n);
81
82// Check if this is the propagate op that has two inputs, e.g. aten::add
83TORCH_API bool isPropagateQuantBinaryOp(Node* n);
84
85// Check if this is the node that we'll quantize or not quantize depending on
86// whether the input of the node is quantized, example: aten::cat
87TORCH_API bool isPropagateQuantOp(Node* n);
88
89// Check if the node is a binary op like aten::add and aten::mul and
90// if the input 1 is a scalar, these ops will be quantized to
91// quantized::{op}_scalar
92TORCH_API bool isBinaryOpWithScalarInput(Node* n);
93
94TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
95 Node* n);
96
97// We don't want to analyze the graph for some `builtin` CallFunctions
98// like `linear` because we want to preserve the op boundary
99TORCH_API bool userDefinedCallFunction(Node* n);
100
101// Check if the node has scalar input
102TORCH_API bool hasScalarInput(Node* n);
103
104// Check if a node is quantizable
105TORCH_API bool nodeQuantizable(
106 Node* n,
107 QuantType quant_type = QuantType::STATIC);
108
109// Nodes which only require quantization of weight value, eg. embedding_bag
110bool isWeightOnlyStaticQuantOp(Node* n);
111
112// Check if a use of the value is quantizable, this depends on
113// both the use node and the offset
114TORCH_API bool useQuantizable(const Use& use, QuantType quant_type);
115
116// Given a CallFunction node, extract the graph of the called function
117TORCH_API std::shared_ptr<Graph> getCallFunctionGraph(Node* n);
118
119// Check if `use` is a CallFunction of name `func_name` and if value
120// `v` is the nth argument (if provided) of the function
121bool matchCallFuncToUse(
122 const Use& use,
123 const std::string& func_name,
124 c10::optional<int> nth_arg);
125
126// Check if `use` is a AtenFunction of name `func_name` and if value
127// `v` is the nth argument (if provided) of the function
128bool matchAtenFuncToUse(
129 const Use& use,
130 const std::string& func_name,
131 c10::optional<int> nth_arg);
132
133// =========== helper functions for Block =========
134// checks if a block will always raise an Exception
135TORCH_API bool alwaysRaisesException(Block* block);
136
137// =========== helper functions for Module ==========
138// TODO: remove
139TORCH_API std::vector<std::string> getModuleAccessPath(
140 Value* instance,
141 Value* self);
142// TODO: remove
143TORCH_API Module
144findChildModule(const Module& module, const std::vector<std::string>& path);
145
146// Given an CallMethod node, get the module instance corresponding
147// to the instance Value
148// TODO: refactor all current uses of this function to the Opt one
149TORCH_API Module getInvokedModule(Module& module, Node* n, Value* self);
150
151// Given an CallMethod node, get the module instance corresponding
152// to the instance Value if the instance is a module, otherwise return
153// c10::nullopt
154c10::optional<Module> getInvokedModuleOpt(
155 const Module& module,
156 Node* n,
157 Value* self);
158
159// ==================== filter functions for matches ==============
160// filter to check Value `vname` is a constant of int value `value`
161bool is_int_constant(
162 const Match& match,
163 const std::unordered_map<std::string, Value*>& vmap,
164 const std::string& vname,
165 int value);
166
167// filter to check if the %alpha argument of aten::add is constant 1
168bool aten_add_alpha_is_one(
169 const Match& match,
170 const std::unordered_map<std::string, Value*>& vmap);
171
172// filter to check if the functional in CallFunction is relu
173bool is_functional_relu(
174 const Match& match,
175 const std::unordered_map<std::string, Value*>& vmap);
176
177// filter to check if the module is torch.nn.ReLU
178bool is_relu_module(
179 const Match& match,
180 const std::unordered_map<std::string, Value*>& vmap);
181
182bool is_linear_module(
183 const Match& match,
184 const std::unordered_map<std::string, Value*>& vmap);
185
186// TODO: add a macro to declare the filters
187bool is_conv1d_module(
188 const Match& match,
189 const std::unordered_map<std::string, Value*>& vmap);
190
191bool is_conv2d_module(
192 const Match& match,
193 const std::unordered_map<std::string, Value*>& vmap);
194
195bool is_conv3d_module(
196 const Match& match,
197 const std::unordered_map<std::string, Value*>& vmap);
198
199bool is_conv_transpose1d_module(
200 const Match& match,
201 const std::unordered_map<std::string, Value*>& vmap);
202
203bool is_conv_transpose2d_module(
204 const Match& match,
205 const std::unordered_map<std::string, Value*>& vmap);
206
207bool is_batchnorm2d_module(
208 const Match& match,
209 const std::unordered_map<std::string, Value*>& vmap);
210
211bool is_batchnorm3d_module(
212 const Match& match,
213 const std::unordered_map<std::string, Value*>& vmap);
214
215} // namespace jit
216} // namespace torch
217