1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/jit/passes/quantization/insert_observers.h> |
3 | |
4 | #include <torch/csrc/jit/frontend/schema_matching.h> |
5 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
6 | #include <torch/csrc/jit/jit_log.h> |
7 | #include <torch/csrc/jit/passes/constant_pooling.h> |
8 | #include <torch/csrc/jit/passes/constant_propagation.h> |
9 | #include <torch/csrc/jit/passes/fuse_linear.h> |
10 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
11 | #include <torch/csrc/jit/passes/inline_fork_wait.h> |
12 | #include <torch/csrc/jit/passes/quantization/helper.h> |
13 | #include <torch/csrc/jit/passes/remove_mutation.h> |
14 | |
15 | #include <memory> |
16 | #include <regex> |
17 | #include <stack> |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | namespace torch { |
22 | namespace jit { |
23 | |
24 | using ModuleQConfigMap = std::unordered_map<ModulePtr, c10::optional<QConfig>>; |
25 | |
26 | namespace { |
27 | |
28 | struct OptionalQConfigHash { |
29 | inline size_t operator()(const c10::optional<QConfig>& qconfig_opt) const { |
30 | if (qconfig_opt.has_value()) { |
31 | const auto& m1 = std::get<0>(*qconfig_opt); |
32 | const auto& m2 = std::get<1>(*qconfig_opt); |
33 | constexpr int CONST = 7; |
34 | return std::hash<Module>()(m1) + CONST * std::hash<Module>()(m2); |
35 | } |
36 | return 0; |
37 | } |
38 | }; |
39 | using QConfigTypePtrMap = |
40 | std::unordered_map<c10::optional<QConfig>, TypePtr, OptionalQConfigHash>; |
41 | using NameModuleVector = std::vector<std::pair<std::string, Module>>; |
42 | using OptionalModuleVector = std::vector<c10::optional<Module>>; |
43 | using ModuleMethodVector = std::vector<std::pair<Module, std::string>>; |
44 | using graph_rewrite_helper::PatternInfo; |
45 | using graph_rewrite_helper::replaceConvolutionWithAtenConv; |
46 | |
47 | // helper functions |
48 | void fillQConfigMap( |
49 | const Module& module, |
50 | const QConfigDict& qconfig_dict, |
51 | ModuleQConfigMap& map, |
52 | const std::string& key = "" , |
53 | const c10::optional<QConfig>& parent_qconfig = c10::nullopt) { |
54 | c10::optional<QConfig> qconfig; |
55 | if (qconfig_dict.find(key) != qconfig_dict.end()) { |
56 | GRAPH_DEBUG("Got module config for key:" , key); |
57 | qconfig = qconfig_dict.at(key); |
58 | } else { |
59 | GRAPH_DEBUG("Inheriting qconfig from parent module:" , key); |
60 | qconfig = parent_qconfig; |
61 | } |
62 | map[module._ivalue()] = qconfig; |
63 | |
64 | for (const NameModule& s : module.named_children()) { |
65 | std::string child_key; |
66 | if (key.empty()) { |
67 | child_key = s.name; |
68 | } else { |
69 | child_key = key + "." + s.name; |
70 | } |
71 | fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig); |
72 | } |
73 | } |
74 | |
75 | Module getObserverModuleFor(Value* v, const QConfig& qconfig) { |
76 | return isWeight(v) ? std::get<1>(qconfig) : std::get<0>(qconfig); |
77 | } |
78 | |
79 | // helper classes |
80 | class ModuleCloneHelper { |
81 | public: |
82 | /** Clone according to module qconfig map, this is for handling the case |
83 | * where we have two module instances sharing the same ClassType |
84 | * but configured with different QConfig |
85 | * code is copied and modified from |
86 | * https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp |
87 | * inplace option means if the copy of the Tensor is deepcopy or not |
88 | * if inplace is true, the cloned module will share the tensors with |
89 | * original model instead of deepcopy them |
90 | */ |
91 | Module clone( |
92 | const Module& module, |
93 | const ModuleQConfigMap& module_qconfig_map, |
94 | bool inplace = false) { |
95 | std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap; |
96 | IValue::HashAliasedIValueMap memo; |
97 | return clone_impl( |
98 | module, module_qconfig_map, type_remap, inplace, std::move(memo)); |
99 | } |
100 | |
101 | private: |
102 | Module clone_impl( |
103 | const Module& module, |
104 | const ModuleQConfigMap& module_qconfig_map, |
105 | std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap, |
106 | bool inplace, |
107 | IValue::HashAliasedIValueMap memo) { |
108 | auto qconfig = module_qconfig_map.at(module._ivalue()); |
109 | auto type = module.type(); |
110 | // Create a new _ivalue in the same compilation unit. |
111 | // Since now we have shared ClassType, we need to preserve the shared |
112 | // ClassType during cloning, so we first use type and qconfig to check if |
113 | // the type is already cloned, if so, we'll create a new module with the |
114 | // cloned ClassType, if not, we'll create a new module and a new ClassType. |
115 | bool type_already_cloned = type_remap.find(type) != type_remap.end() && |
116 | type_remap.at(type).find(qconfig) != type_remap.at(type).end(); |
117 | Module r; |
118 | if (type_already_cloned) { |
119 | // if we cloned the class type before, we'll reuse it |
120 | Module new_module( |
121 | module._ivalue()->compilation_unit(), |
122 | type_remap.at(type).at(qconfig)->cast<ClassType>()); |
123 | r = new_module; |
124 | } else { |
125 | Module new_module( |
126 | *type->name(), module._ivalue()->compilation_unit(), true); |
127 | r = new_module; |
128 | type_remap[type][module_qconfig_map.at(module._ivalue())] = r.type(); |
129 | } |
130 | // Copy slots. If a slot is a module - recursively clone it. |
131 | size_t N = type->numAttributes(); |
132 | for (const auto i : c10::irange(N)) { |
133 | IValue s = module._ivalue()->getSlot(i); |
134 | std::string attr_name = type->getAttributeName(i); |
135 | TypePtr attr_type = type->getAttribute(i); |
136 | if (attr_type->is_module()) { |
137 | const Module& orig = Module(s.toObject()); |
138 | Module cloned = |
139 | clone_impl(orig, module_qconfig_map, type_remap, inplace, memo); |
140 | |
141 | // NOTE: why do we need to manually setattr on object instead of using |
142 | // register_module here? because the attr can be a module interface |
143 | // type and hold a Module object still. register_module will not let us |
144 | // correctly set up the type for this attr, so we had to do this |
145 | // manually. In the case it's an interface type, the type will be shared |
146 | // by the new cloned instance in the same compilation unit bc it only |
147 | // contains a list of functionSchema |
148 | r.type()->addOrCheckAttribute( |
149 | attr_name, |
150 | attr_type->cast<ClassType>() ? cloned.type() : attr_type); |
151 | r._ivalue()->setAttr(attr_name, cloned._ivalue()); |
152 | } else { |
153 | // we'll deepcopy the IValue in non inplace option |
154 | r.register_attribute( |
155 | type->getAttributeName(i), |
156 | type->getAttribute(i), |
157 | inplace ? s : s.deepcopy(memo), |
158 | type->is_parameter(i), |
159 | type->is_buffer(i)); |
160 | } |
161 | } |
162 | |
163 | // only clone the methods and constants if the ClassType is not cloned |
164 | // before |
165 | if (!type_already_cloned) { |
166 | for (size_t i = 0; i < type->numConstants(); ++i) { |
167 | r.type()->addConstant(type->getConstantName(i), type->getConstant(i)); |
168 | } |
169 | // Clone methods remapping the types to the cloned ones. |
170 | for (auto& fn : type->methods()) { |
171 | clone_method(module, r, *fn, module_qconfig_map, type_remap); |
172 | } |
173 | // Execute __setstate__(__getstate__()) to initialize custom class |
174 | // members. |
175 | if (auto setstate_method = r.find_method("__setstate__" )) { |
176 | auto getstate_method = r.find_method("__getstate__" ); |
177 | TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__" ); |
178 | auto state = (*getstate_method)(Stack{}); |
179 | (*setstate_method)(Stack{std::move(state)}); |
180 | } |
181 | } |
182 | return r; |
183 | } |
184 | |
185 | void remapTypes( |
186 | Block* block, |
187 | Value* self, |
188 | const Module& source, |
189 | Module& target, |
190 | const ModuleQConfigMap& module_qconfig_map, |
191 | const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>& |
192 | type_remap_fn) { |
193 | // remap of %self will be done outside of the function |
194 | // and we don't support the case when people pass in |
195 | // module as argument of the method because in that case |
196 | // we need to do more comprehensive analysis to decide the |
197 | // QConfig for the module |
198 | for (size_t i = 1; i < block->inputs().size(); ++i) { |
199 | TORCH_CHECK( |
200 | !block->inputs()[i]->type()->cast<ClassType>(), |
201 | "We don't support quantizing methods that has Object as arguments" ); |
202 | } |
203 | for (Node* node : block->nodes()) { |
204 | // remapping type for module instance |
205 | if (node->kind() == prim::CallMethod || node->kind() == prim::GetAttr) { |
206 | Value* instance = node->inputs()[0]; |
207 | auto child_opt = getInvokedModuleOpt(source, node, self); |
208 | if (child_opt.has_value()) { |
209 | auto qconfig = module_qconfig_map.at(child_opt->_ivalue()); |
210 | instance->setType(type_remap_fn(instance->type(), qconfig)); |
211 | } |
212 | } |
213 | // We don't remap output and the remapping of module type |
214 | // will be done in CallMethod, we don't support type remapping |
215 | // for modules returned from methods or functions |
216 | for (Block* sub_block : node->blocks()) { |
217 | remapTypes( |
218 | sub_block, self, source, target, module_qconfig_map, type_remap_fn); |
219 | } |
220 | for (Symbol name : node->attributeNames()) { |
221 | if (node->kindOf(name) == AttributeKind::g) { |
222 | remapTypes( |
223 | node->g(name).get(), |
224 | source, |
225 | target, |
226 | module_qconfig_map, |
227 | type_remap_fn); |
228 | } else if (node->kindOf(name) == AttributeKind::gs) { |
229 | for (const auto& g : node->gs(name)) { |
230 | remapTypes( |
231 | g.get(), source, target, module_qconfig_map, type_remap_fn); |
232 | } |
233 | } |
234 | } |
235 | } |
236 | } |
237 | |
238 | void remapTypes( |
239 | Graph* graph, |
240 | const Module& source, |
241 | Module& target, |
242 | const ModuleQConfigMap& module_qconfig_map, |
243 | const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>& |
244 | type_remap_fn) { |
245 | remapTypes( |
246 | graph->block(), |
247 | graph->inputs()[0], |
248 | source, |
249 | target, |
250 | module_qconfig_map, |
251 | type_remap_fn); |
252 | } |
253 | |
254 | void clone_method( |
255 | const Module& source, |
256 | Module& target, |
257 | const Function& method, |
258 | const ModuleQConfigMap& module_qconfig_map, |
259 | const std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap) { |
260 | auto type_remap_fn = [&](TypePtr type_ptr, |
261 | const c10::optional<QConfig>& qconfig) { |
262 | if (type_remap.find(type_ptr) != type_remap.end()) { |
263 | const auto& qconfig_map = type_remap.at(type_ptr); |
264 | if (qconfig_map.find(qconfig) != qconfig_map.end()) { |
265 | return qconfig_map.at(qconfig); |
266 | } |
267 | } |
268 | return type_ptr; |
269 | }; |
270 | auto graph = toGraphFunction(method).graph()->copy(); |
271 | remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn); |
272 | // remap self |
273 | graph->inputs()[0]->setType(target.type()); |
274 | // we only support %self being Module in the arguments of function |
275 | auto schema_type_remap_fn = [&](TypePtr type_ptr) { |
276 | return type_remap_fn( |
277 | std::move(type_ptr), module_qconfig_map.at(source._ivalue())); |
278 | }; |
279 | auto schema = |
280 | method.getSchema().cloneWithRemappedTypes(schema_type_remap_fn); |
281 | const auto this_method_name = |
282 | c10::QualifiedName(*target.type()->name(), method.name()); |
283 | auto copied = target._ivalue()->compilation_unit()->create_function( |
284 | this_method_name, std::move(graph)); |
285 | target.type()->addMethod(copied); |
286 | copied->setSchema(std::move(schema)); |
287 | } |
288 | }; |
289 | |
290 | class InsertObserversHelper { |
291 | public: |
292 | explicit InsertObserversHelper( |
293 | const ModuleQConfigMap& map, |
294 | QuantType quant_type) |
295 | : module_qconfig_map_(map), quant_type_(quant_type) {} |
296 | |
297 | // TODO: replace (module, method_name) with graph? |
298 | // preprocess to clean up the graph from tracing |
299 | void preprocess(Module& module, const std::string& method_name); |
300 | |
301 | // Fill the map between the caller input/output to input/output |
302 | // of called graph, this is used to navigate through the graph |
303 | // to find the observer for a given value |
304 | void fillBoundaryValueMap(Module& module, const std::string& method_name); |
305 | |
306 | // analyze the graph and record necessary information that can |
307 | // be used in insert observers |
308 | void analyze(Module& module, const std::string& method_name); |
309 | |
310 | void removeActivationObservers(); |
311 | |
312 | /** |
313 | * Recursively insert observers for the method, also we'll process |
314 | * the nodes in the graph in the order of execution of these nodes |
315 | * since we need the context information to decide whether we want to |
316 | * observe/quantize a value a not, we don't want to observe a value multiple |
317 | * times. |
318 | * |
319 | * arguemnt: is_entry_point means whether the current method is the forward |
320 | * method of the top level module. |
321 | * |
322 | * Since we want to insert observers in the call site instead of in the called |
323 | * graph, we'll postpone inserting observer to caller as much as possible, if |
324 | * we know the current method is the outer most method, then |
325 | * we will insert all observers in the graph instead of postpone this to the |
326 | * parent, note that this assumes we don't have recursive method |
327 | * calls |
328 | * |
329 | * returns a tuple of vectors of observer modules for input and output, these |
330 | * are used for inserting observers for the input/output values |
331 | * since we need to insert these values at call site. |
332 | * And a vector of indexes of outputs that indicates whether the output value |
333 | * is already observed or not, this is used for propagating the observed |
334 | * property of a value through CallMethods, because we should skip inserting |
335 | * observers for ops that don't require observation |
336 | */ |
337 | std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>> |
338 | insertObservers( |
339 | Module& module, |
340 | const std::string& method_name, |
341 | bool is_entry_point = false, |
342 | std::unordered_set<Value*> graph_observed_values = |
343 | std::unordered_set<Value*>()); |
344 | |
345 | void setInsertResetObserverMethod( |
346 | bool insert_reset_observer_method, |
347 | const std::string& method_name) { |
348 | insert_reset_observer_method_ = insert_reset_observer_method; |
349 | reset_observer_method_name_ = "reset_observers_" + method_name; |
350 | } |
351 | |
352 | private: |
353 | std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>> |
354 | insertObserversFor( |
355 | Block* block, |
356 | script::Module& module, |
357 | // this is a reference because when we insert observer for a value |
358 | // in one block it is also observed in another block, we don't want to |
359 | // insert multiple observers for the same value |
360 | std::unordered_set<Value*>& block_observed_values, |
361 | bool is_entry_point = false, |
362 | bool is_user_defined_function = false); |
363 | |
364 | // Record v as "ready for observation" by storing it in values_to_observe. |
365 | // If v is a part of a delayed observation pattern, record v's descendant |
366 | // (per delay rules) instead. The observers are inserted at a later stage |
367 | // by reading the state created by this function. |
368 | void recordObserved( |
369 | Value* v, |
370 | const Module& observer_module, |
371 | std::unordered_map<Value*, Module>& values_to_observe, |
372 | std::unordered_set<Value*>& block_observed_values); |
373 | |
374 | ModuleMethodVector getInvokedMethods( |
375 | Module& module, |
376 | const std::string& method_name); |
377 | |
378 | bool valueNeedsToBeQuantized(Value* v, const QConfig& qconfig); |
379 | |
380 | bool isObserved( |
381 | Value* v, |
382 | const std::unordered_set<Value*>& block_observed_values) { |
383 | return block_observed_values.count(v) || observed_values_.count(v); |
384 | } |
385 | |
386 | // Fill the map from value to the corresponding observer module |
387 | // this map is used in insertObservers to actually insert |
388 | // observers to the module |
389 | void fillValueObserverMap(Module& module, const std::string& method_name); |
390 | |
391 | // Clone observer module and add it to the original module, |
392 | // and insert a call to observer forward function |
393 | void insertObserverFor( |
394 | Value* v, |
395 | Module& module, |
396 | const Module& observer_module, |
397 | NameModuleVector& observer_name_and_modules); |
398 | |
399 | void insertObserverResetMinMax( |
400 | Module& module, |
401 | const NameModuleVector& observer_name_and_modules); |
402 | |
403 | // Uses the state created by fillBoundaryValueMap and fillValueObserverMap |
404 | // to return an observer configured for a value, if it is needed. |
405 | c10::optional<Module> getObserverFor(Value* v); |
406 | |
407 | // Uses the state created by fillPassThroughValueMap to propage observed |
408 | // property which should pass through from inputs to outputs. |
409 | void propagateObservedProperty( |
410 | Value* output, |
411 | std::unordered_set<Value*>& block_observed_values); |
412 | |
413 | // for cat/add/mul we will only observe their output if their input |
414 | // are observed |
415 | bool shouldObserve( |
416 | Node* n, |
417 | const std::unordered_set<Value*>& block_observed_values, |
418 | QuantType quant_type) { |
419 | // Check whether node output uses can be quantized, eg cat followed by |
420 | // linear op |
421 | for (Value* v : n->outputs()) { |
422 | for (const auto& use : v->uses()) { |
423 | if (useQuantizable(use, quant_type)) { |
424 | return true; |
425 | } |
426 | } |
427 | } |
428 | if (isPropagateQuantSingleInputOp(n)) { |
429 | return isObserved(n->input(0), block_observed_values); |
430 | } else if (isPropagateQuantBinaryOp(n)) { |
431 | // This checks both of the input should be tensor and observed. |
432 | // There is one check that we didn't do here, which is |
433 | // !isScalar(isObserved(n->input(1), block_observed_values) |
434 | // to make sure input 1 is not a scalar, because scalar tensor input |
435 | // for add/mul won't be observed with current rule, we can omit |
436 | // this check here |
437 | return isObserved(n->input(0), block_observed_values) && |
438 | isObserved(n->input(1), block_observed_values); |
439 | } |
440 | return true; |
441 | } |
442 | |
443 | void delayObservingValuesInPattern(Graph& graph, const PatternInfo& pattern); |
444 | |
445 | // Find and mark known patterns such as conv-relu (and others) where |
446 | // we should not insert observers in the middle of the pattern. |
447 | void addValuesToDelayObservation( |
448 | const Module& module, |
449 | const std::string& method_name); |
450 | |
451 | // Fill the map from values to the list of values that can pass the observed |
452 | // property to it |
453 | void fillPassThroughValueMap(const std::shared_ptr<Graph>& graph); |
454 | |
455 | bool insertResetObserverMethod() { |
456 | return insert_reset_observer_method_; |
457 | } |
458 | |
459 | const ModuleQConfigMap& module_qconfig_map_; |
460 | |
461 | // Values we want to delay observation, used to delay the observation for |
462 | // values in the middle of the ops that are supposed to be fused, e.g. |
463 | // the output value of conv in the conv - relu pattern |
464 | // the key is the intermediate output, e.g. output of conv |
465 | // the value is the value we want to observe, e.g. output of relu |
466 | // |
467 | // example, assuming we want to delay conv-relu: |
468 | // %x1 = conv(%x0) |
469 | // %x2 = relu(%x1) |
470 | // |
471 | // delay_observation_map_ = { |
472 | // %x1: %x2, |
473 | // } |
474 | std::unordered_map<Value*, Value*> delay_observation_map_; |
475 | |
476 | std::unordered_set<Graph*> visited_graph_of_observer_map_; |
477 | |
478 | // Map of value to observer module configured for that value. |
479 | std::unordered_map<Value*, Module> observer_for_value_; |
480 | |
481 | // Map from values from callsite into the values in the CallMethod graph |
482 | // key of the map is the value from caller graph, and the value of the map |
483 | // is the list of values in the callee graph (the graph |
484 | // corresponding to the called method), |
485 | // the reason it is a set is that a value in the caller graph |
486 | // can both correspond to the output of one callee graph and input of another |
487 | // callee graph. |
488 | // |
489 | // example: |
490 | // // top level module |
491 | // %x1 = conv(%x0) |
492 | // %x2 = prim::CallFunction(%foo, %x1) |
493 | // |
494 | // // graph of %foo |
495 | // %y2 = conv(%y1) |
496 | // return %y2 |
497 | // |
498 | // boundary_value_map = { |
499 | // // current module's output values to corresponding return values from |
500 | // subgraph %x2: %y2, |
501 | // // current module's input values to corresponding input value to subgraph |
502 | // %x1: %y1, |
503 | // } |
504 | std::unordered_map<Value*, std::unordered_set<Value*>> boundary_value_map_; |
505 | |
506 | std::unordered_set<Value*> observed_values_; |
507 | |
508 | // This is used for the observed values to pass through the ops like flatten, |
509 | // so that output value of flatten does not need to be observed |
510 | // key is the output of the op, value is a vector of values that need |
511 | // to be observed in order to pass the observed property to the output |
512 | // |
513 | // example: |
514 | // %x1 = flatten(%x0) // pass_through |
515 | // %x2 = conv(%x1) // not pass_through |
516 | // |
517 | // pass_through_value_map_ = { |
518 | // %x1: [%x0], |
519 | // } |
520 | std::unordered_map<Value*, std::vector<Value*>> pass_through_value_map_; |
521 | |
522 | // Unique id generator for observer module, used for generating |
523 | // unique observer names when we insert observer module, we |
524 | // record the current unique id used to avoid incrementing from 0 |
525 | // every time to find a unique id. |
526 | int uid_ = 0; |
527 | // Set of observer forward call nodes |
528 | std::unordered_set<Node*> observer_nodes_; |
529 | // Map from block to a vector of observer name and observer modules we |
530 | // want to add to the module instance that has the block |
531 | std::unordered_map<Block*, NameModuleVector> block_observer_map_; |
532 | |
533 | // Type of quantization for this pass. |
534 | QuantType quant_type_ = QuantType::STATIC; |
535 | // These are the IR patterns we match to skip inserting observers. |
536 | // They are compiled once on construction and used repeatedly within |
537 | // the pass. |
538 | |
539 | // nn.Linear + nn.ReLU |
540 | const PatternInfo nn_linear_nn_relu = PatternInfo::parse_from_str( |
541 | R"( |
542 | graph(%input, %linear, %relu): |
543 | %first_output = prim::CallMethod[name="forward"](%linear, %input) |
544 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
545 | return (%second_output) )" , |
546 | {is_linear_module, is_relu_module}); |
547 | |
548 | // nn.Linear + F.relu |
549 | const PatternInfo nn_linear_f_relu = PatternInfo::parse_from_str( |
550 | R"( |
551 | graph(%input, %linear, %relu, %inplace): |
552 | %first_output = prim::CallMethod[name="forward"](%linear, %input) |
553 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
554 | return (%second_output) )" , |
555 | {is_linear_module, is_functional_relu}); |
556 | |
557 | // nn.Linear + aten::relu |
558 | const PatternInfo nn_linear_aten_relu = PatternInfo::parse_from_str( |
559 | R"( |
560 | graph(%input, %linear, %relu): |
561 | %first_output = prim::CallMethod[name="forward"](%linear, %input) |
562 | %second_output = aten::relu(%first_output) |
563 | return (%second_output) )" , |
564 | {is_linear_module}); |
565 | |
566 | // nn.Linear + aten::relu_ |
567 | const PatternInfo nn_linear_aten_relu_ = PatternInfo::parse_from_str( |
568 | R"( |
569 | graph(%input, %linear, %relu): |
570 | %first_output = prim::CallMethod[name="forward"](%linear, %input) |
571 | %second_output = aten::relu_(%first_output) |
572 | return (%second_output) )" , |
573 | {is_linear_module}); |
574 | |
575 | // aten::linear + nn.ReLU |
576 | const PatternInfo aten_linear_nn_relu = PatternInfo::parse_from_str( |
577 | R"( |
578 | graph(%input, %weight, %bias, %relu): |
579 | %first_output = aten::linear(%input, %weight, %bias) |
580 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
581 | return (%second_output) )" , |
582 | {is_relu_module}); |
583 | |
584 | // aten::linear + F.relu |
585 | const PatternInfo aten_linear_f_relu = PatternInfo::parse_from_str( |
586 | R"( |
587 | graph(%input, %weight, %bias, %relu, %inplace): |
588 | %first_output = aten::linear(%input, %weight, %bias) |
589 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
590 | return (%second_output) )" , |
591 | {is_functional_relu}); |
592 | |
593 | // aten::linear + aten::relu |
594 | const PatternInfo aten_linear_aten_relu = PatternInfo::parse_from_str( |
595 | R"( |
596 | graph(%input, %weight, %bias): |
597 | %first_output = aten::linear(%input, %weight, %bias) |
598 | %second_output = aten::relu(%first_output) |
599 | return (%second_output) )" ); |
600 | |
601 | // aten::linear + aten::relu_ |
602 | const PatternInfo aten_linear_aten_relu_ = PatternInfo::parse_from_str( |
603 | R"( |
604 | graph(%input, %weight, %bias): |
605 | %first_output = aten::linear(%input, %weight, %bias) |
606 | %second_output = aten::relu_(%first_output) |
607 | return (%second_output) )" ); |
608 | |
609 | const PatternInfo nn_conv1d_f_relu = PatternInfo::parse_from_str( |
610 | R"( |
611 | graph(%self, %input, %conv, %relu, %inplace): |
612 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
613 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
614 | return (%second_output) )" , |
615 | {is_conv1d_module, is_functional_relu}); |
616 | |
617 | const PatternInfo nn_conv1d_nn_relu = PatternInfo::parse_from_str( |
618 | R"( |
619 | graph(%self, %input, %conv, %relu): |
620 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
621 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
622 | return (%second_output) )" , |
623 | {is_conv1d_module, is_relu_module}); |
624 | |
625 | const PatternInfo nn_conv1d_aten_relu = PatternInfo::parse_from_str( |
626 | R"( |
627 | graph(%self, %input, %conv): |
628 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
629 | %second_output = aten::relu(%first_output) |
630 | return (%second_output) )" , |
631 | {is_conv1d_module}); |
632 | |
633 | const PatternInfo nn_conv1d_aten_relu_ = PatternInfo::parse_from_str( |
634 | R"( |
635 | graph(%self, %input, %conv): |
636 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
637 | %second_output = aten::relu_(%first_output) |
638 | return (%second_output) )" , |
639 | {is_conv1d_module}); |
640 | |
641 | const PatternInfo nn_conv2d_f_relu = PatternInfo::parse_from_str( |
642 | R"( |
643 | graph(%self, %input, %conv, %relu, %inplace): |
644 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
645 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
646 | return (%second_output) )" , |
647 | {is_conv2d_module, is_functional_relu}); |
648 | |
649 | const PatternInfo nn_conv2d_nn_relu = PatternInfo::parse_from_str( |
650 | R"( |
651 | graph(%self, %input, %conv, %relu): |
652 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
653 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
654 | return (%second_output) )" , |
655 | {is_conv2d_module, is_relu_module}); |
656 | |
657 | const PatternInfo nn_conv2d_aten_relu = PatternInfo::parse_from_str( |
658 | R"( |
659 | graph(%self, %input, %conv): |
660 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
661 | %second_output = aten::relu(%first_output) |
662 | return (%second_output) )" , |
663 | {is_conv2d_module}); |
664 | |
665 | const PatternInfo nn_conv2d_aten_relu_ = PatternInfo::parse_from_str( |
666 | R"( |
667 | graph(%self, %input, %conv): |
668 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
669 | %second_output = aten::relu_(%first_output) |
670 | return (%second_output) )" , |
671 | {is_conv2d_module}); |
672 | |
673 | const PatternInfo nn_conv3d_f_relu = PatternInfo::parse_from_str( |
674 | R"( |
675 | graph(%self, %input, %conv, %relu, %inplace): |
676 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
677 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
678 | return (%second_output) )" , |
679 | {is_conv3d_module, is_functional_relu}); |
680 | |
681 | const PatternInfo nn_conv3d_nn_relu = PatternInfo::parse_from_str( |
682 | R"( |
683 | graph(%self, %input, %conv, %relu): |
684 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
685 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
686 | return (%second_output) )" , |
687 | {is_conv3d_module, is_relu_module}); |
688 | |
689 | const PatternInfo nn_conv3d_aten_relu = PatternInfo::parse_from_str( |
690 | R"( |
691 | graph(%self, %conv, %input): |
692 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
693 | %second_output = aten::relu(%first_output) |
694 | return (%second_output) )" , |
695 | {is_conv3d_module}); |
696 | |
697 | const PatternInfo nn_conv3d_aten_relu_ = PatternInfo::parse_from_str( |
698 | R"( |
699 | graph(%self, %input, %conv): |
700 | %first_output = prim::CallMethod[name="forward"](%conv, %input) |
701 | %second_output = aten::relu_(%first_output) |
702 | return (%second_output) )" , |
703 | {is_conv3d_module}); |
704 | |
705 | const PatternInfo add_nn_relu = PatternInfo::parse_from_str( |
706 | R"( |
707 | graph(%self, %a, %b, %alpha, %relu): |
708 | %first_output = aten::add(%a, %b, %alpha) |
709 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
710 | return (%second_output) )" , |
711 | {aten_add_alpha_is_one, is_relu_module}); |
712 | |
713 | const PatternInfo add_f_relu = PatternInfo::parse_from_str( |
714 | R"( |
715 | graph(%self, %a, %b, %alpha, %relu, %inplace): |
716 | %first_output = aten::add(%a, %b, %alpha) |
717 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
718 | return (%second_output) )" , |
719 | {aten_add_alpha_is_one, is_functional_relu}); |
720 | |
721 | const PatternInfo inplace_add_nn_relu = PatternInfo::parse_from_str( |
722 | R"( |
723 | graph(%self, %a, %b, %alpha, %relu): |
724 | %first_output = aten::add_(%a, %b, %alpha) |
725 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
726 | return (%second_output) )" , |
727 | {aten_add_alpha_is_one, is_relu_module}); |
728 | |
729 | const PatternInfo inplace_add_f_relu = PatternInfo::parse_from_str( |
730 | R"( |
731 | graph(%self, %a, %b, %alpha, %relu, %inplace): |
732 | %first_output = aten::add_(%a, %b, %alpha) |
733 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
734 | return (%second_output) )" , |
735 | {aten_add_alpha_is_one, is_functional_relu}); |
736 | |
737 | const PatternInfo add_aten_relu = PatternInfo::parse_from_str(R"( |
738 | graph(%self, %a, %b, %alpha): |
739 | %first_output = aten::add(%a, %b, %alpha) |
740 | %second_output = aten::relu(%first_output) |
741 | return (%second_output) )" ); |
742 | |
743 | const PatternInfo add_aten_relu_ = PatternInfo::parse_from_str(R"( |
744 | graph(%self, %a, %b, %alpha): |
745 | %first_output = aten::add(%a, %b, %alpha) |
746 | %second_output = aten::relu_(%first_output) |
747 | return (%second_output) )" ); |
748 | |
749 | const PatternInfo inplace_add_aten_relu = PatternInfo::parse_from_str(R"( |
750 | graph(%self, %a, %b, %alpha): |
751 | %first_output = aten::add_(%a, %b, %alpha) |
752 | %second_output = aten::relu(%first_output) |
753 | return (%second_output) )" ); |
754 | |
755 | const PatternInfo inplace_add_aten_relu_ = PatternInfo::parse_from_str(R"( |
756 | graph(%self, %a, %b, %alpha): |
757 | %first_output = aten::add_(%a, %b, %alpha) |
758 | %second_output = aten::relu_(%first_output) |
759 | return (%second_output) )" ); |
760 | |
761 | const PatternInfo nn_bn2d_nn_relu = PatternInfo::parse_from_str( |
762 | R"( |
763 | graph(%self, %input, %batchnorm, %relu): |
764 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
765 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
766 | return (%second_output) )" , |
767 | {is_batchnorm2d_module, is_relu_module}); |
768 | |
769 | const PatternInfo nn_bn2d_f_relu = PatternInfo::parse_from_str( |
770 | R"( |
771 | graph(%self, %input, %batchnorm, %relu, %inplace): |
772 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
773 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
774 | return (%second_output) )" , |
775 | {is_batchnorm2d_module, is_functional_relu}); |
776 | |
777 | const PatternInfo nn_bn2d_aten_relu = PatternInfo::parse_from_str( |
778 | R"( |
779 | graph(%self, %input, %batchnorm): |
780 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
781 | %second_output = aten::relu(%first_output) |
782 | return (%second_output) )" , |
783 | {is_batchnorm2d_module}); |
784 | |
785 | const PatternInfo nn_bn2d_aten_relu_ = PatternInfo::parse_from_str( |
786 | R"( |
787 | graph(%self, %input, %batchnorm): |
788 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
789 | %second_output = aten::relu_(%first_output) |
790 | return (%second_output) )" , |
791 | {is_batchnorm2d_module}); |
792 | |
793 | const PatternInfo nn_bn3d_nn_relu = PatternInfo::parse_from_str( |
794 | R"( |
795 | graph(%self, %input, %batchnorm, %relu): |
796 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
797 | %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output) |
798 | return (%second_output) )" , |
799 | {is_batchnorm3d_module, is_relu_module}); |
800 | |
801 | const PatternInfo nn_bn3d_f_relu = PatternInfo::parse_from_str( |
802 | R"( |
803 | graph(%self, %input, %batchnorm, %relu, %inplace): |
804 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
805 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
806 | return (%second_output) )" , |
807 | {is_batchnorm3d_module, is_functional_relu}); |
808 | |
809 | const PatternInfo nn_bn3d_aten_relu = PatternInfo::parse_from_str( |
810 | R"( |
811 | graph(%self, %input, %batchnorm): |
812 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
813 | %second_output = aten::relu(%first_output) |
814 | return (%second_output) )" , |
815 | {is_batchnorm3d_module}); |
816 | |
817 | const PatternInfo nn_bn3d_aten_relu_ = PatternInfo::parse_from_str( |
818 | R"( |
819 | graph(%self, %input, %batchnorm): |
820 | %first_output = prim::CallMethod[name="forward"](%batchnorm, %input) |
821 | %second_output = aten::relu_(%first_output) |
822 | return (%second_output) )" , |
823 | {is_batchnorm3d_module}); |
824 | |
825 | const PatternInfo mul_nn_relu = PatternInfo::parse_from_str( |
826 | R"( |
827 | graph(%self, %a, %b, %relu): |
828 | %first_output = aten::mul(%a, %b) |
829 | %second_output = prim::CallMethod[name="forward"](%relu, %first_output) |
830 | return (%second_output) )" , |
831 | {is_relu_module}); |
832 | |
833 | const PatternInfo mul_f_relu = PatternInfo::parse_from_str( |
834 | R"( |
835 | graph(%self, %a, %b, %relu, %inplace): |
836 | %first_output = aten::mul(%a, %b) |
837 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
838 | return (%second_output) )" , |
839 | {is_functional_relu}); |
840 | |
841 | const PatternInfo inplace_mul_nn_relu = PatternInfo::parse_from_str( |
842 | R"( |
843 | graph(%self, %a, %b, %relu): |
844 | %first_output = aten::mul_(%a, %b) |
845 | %second_output = prim::CallMethod[name="forward"](%relu, %first_output) |
846 | return (%second_output) )" , |
847 | {is_relu_module}); |
848 | |
849 | const PatternInfo inplace_mul_f_relu = PatternInfo::parse_from_str( |
850 | R"( |
851 | graph(%self, %a, %b, %relu, %inplace): |
852 | %first_output = aten::mul_(%a, %b) |
853 | %second_output = prim::CallFunction(%relu, %first_output, %inplace) |
854 | return (%second_output) )" , |
855 | {is_functional_relu}); |
856 | |
857 | const PatternInfo mul_aten_relu = PatternInfo::parse_from_str(R"( |
858 | graph(%self, %a, %b): |
859 | %first_output = aten::mul(%a, %b) |
860 | %second_output = aten::relu(%first_output) |
861 | return (%second_output) )" ); |
862 | |
863 | const PatternInfo mul_aten_relu_ = PatternInfo::parse_from_str(R"( |
864 | graph(%self, %a, %b): |
865 | %first_output = aten::mul(%a, %b) |
866 | %second_output = aten::relu_(%first_output) |
867 | return (%second_output) )" ); |
868 | |
869 | const PatternInfo inplace_mul_aten_relu = PatternInfo::parse_from_str(R"( |
870 | graph(%self, %a, %b): |
871 | %first_output = aten::mul_(%a, %b) |
872 | %second_output = aten::relu(%first_output) |
873 | return (%second_output) )" ); |
874 | |
875 | const PatternInfo inplace_mul_aten_relu_ = PatternInfo::parse_from_str(R"( |
876 | graph(%self, %a, %b): |
877 | %first_output = aten::mul_(%a, %b) |
878 | %second_output = aten::relu_(%first_output) |
879 | return (%second_output) )" ); |
880 | |
881 | const std::vector<std::reference_wrapper<const PatternInfo>> delay_patterns = |
882 | { |
883 | nn_linear_f_relu, nn_linear_nn_relu, |
884 | nn_linear_aten_relu, nn_linear_aten_relu_, |
885 | aten_linear_f_relu, aten_linear_nn_relu, |
886 | aten_linear_aten_relu, aten_linear_aten_relu_, |
887 | |
888 | nn_conv1d_f_relu, nn_conv1d_nn_relu, |
889 | nn_conv1d_aten_relu, nn_conv1d_aten_relu_, |
890 | nn_conv2d_f_relu, nn_conv2d_nn_relu, |
891 | nn_conv2d_aten_relu, nn_conv2d_aten_relu_, |
892 | nn_conv3d_f_relu, nn_conv3d_nn_relu, |
893 | nn_conv3d_aten_relu, nn_conv3d_aten_relu_, |
894 | |
895 | add_nn_relu, add_f_relu, |
896 | inplace_add_nn_relu, inplace_add_f_relu, |
897 | add_aten_relu, add_aten_relu_, |
898 | inplace_add_aten_relu, inplace_add_aten_relu_, |
899 | |
900 | nn_bn2d_nn_relu, nn_bn2d_f_relu, |
901 | nn_bn2d_aten_relu, nn_bn2d_aten_relu_, |
902 | nn_bn3d_nn_relu, nn_bn3d_f_relu, |
903 | nn_bn3d_aten_relu, nn_bn3d_aten_relu_, |
904 | |
905 | mul_nn_relu, mul_f_relu, |
906 | inplace_mul_nn_relu, inplace_mul_f_relu, |
907 | mul_aten_relu, mul_aten_relu_, |
908 | inplace_mul_aten_relu, inplace_mul_aten_relu_, |
909 | }; |
910 | |
911 | bool insert_reset_observer_method_{false}; |
912 | std::string reset_observer_method_name_; |
913 | }; |
914 | |
915 | ModuleMethodVector InsertObserversHelper::getInvokedMethods( |
916 | Module& module, |
917 | const std::string& method_name) { |
918 | ModuleMethodVector invoked_methods; |
919 | Method method = module.get_method(method_name); |
920 | auto graph = method.graph(); |
921 | |
922 | std::stack<Block*> blocks_to_visit; |
923 | blocks_to_visit.push(graph->block()); |
924 | while (!blocks_to_visit.empty()) { |
925 | Block* b = blocks_to_visit.top(); |
926 | blocks_to_visit.pop(); |
927 | for (Node* n : b->nodes()) { |
928 | // Skip observer nodes |
929 | if (observer_nodes_.count(n)) { |
930 | continue; |
931 | } |
932 | if (n->kind() == prim::CallMethod) { |
933 | auto m_opt = getInvokedModuleOpt(module, n, graph->inputs()[0]); |
934 | if (m_opt.has_value()) { |
935 | invoked_methods.emplace_back(*m_opt, n->s(attr::name)); |
936 | } |
937 | } |
938 | |
939 | for (Block* subblock : n->blocks()) { |
940 | blocks_to_visit.push(subblock); |
941 | } |
942 | } |
943 | } |
944 | return invoked_methods; |
945 | } |
946 | |
947 | void InsertObserversHelper::insertObserverFor( |
948 | Value* v, |
949 | Module& module, |
950 | const Module& observer_module, |
951 | NameModuleVector& observer_name_and_modules) { |
952 | if (observed_values_.count(v)) { |
953 | return; |
954 | } |
955 | GRAPH_DEBUG("Inserting observer for:" , v->debugName()); |
956 | Module observer = observer_module.deepcopy(); |
957 | std::string observer_name = "_observer_" + c10::to_string(uid_++); |
958 | while (module.hasattr(observer_name)) { |
959 | observer_name = "_observer_" + c10::to_string(uid_++); |
960 | } |
961 | module.register_module(observer_name, observer); |
962 | observer_name_and_modules.emplace_back(observer_name, observer); |
963 | |
964 | auto* g = v->owningGraph(); |
965 | // Get handle of observer module |
966 | Node* observer_instance = |
967 | g->createGetAttr(g->inputs()[0], observer_name)->insertAfter(v->node()); |
968 | observer_instance->output()->setDebugName(observer_name); |
969 | |
970 | { |
971 | WithInsertPoint guard(observer_instance->next()); |
972 | // Match arguments to types of observer's arguments |
973 | MatchedSchema forward_matched_schema = matchSchema( |
974 | observer.get_method("forward" ).function().getSchema(), |
975 | v->node()->sourceRange(), |
976 | *g, |
977 | {observer_instance->output(), v}, |
978 | {}); |
979 | // Insert call to observer's forward |
980 | Node* call = g->insertMethodCall("forward" , forward_matched_schema)->node(); |
981 | call->output()->copyMetadata(v); |
982 | |
983 | // Replace v with the output of observer |
984 | v->replaceAllUsesWith(call->output()); |
985 | // The above also replaced the input to `call`, so switch it back to |
986 | // the correct value |
987 | call->replaceInput(1, v); |
988 | observer_nodes_.emplace(call); |
989 | observed_values_.insert(call->output()); |
990 | } |
991 | } |
992 | |
993 | void InsertObserversHelper::insertObserverResetMinMax( |
994 | Module& module, |
995 | const NameModuleVector& observer_name_and_modules) { |
996 | if (observer_name_and_modules.empty()) { |
997 | return; |
998 | } |
999 | auto reset_min_max_opt = module.find_method(reset_observer_method_name_); |
1000 | if (!reset_min_max_opt.has_value()) { |
1001 | std::shared_ptr<Graph> reset_observer_graph = std::make_shared<Graph>(); |
1002 | Value* module_value = reset_observer_graph->addInput("self" ); |
1003 | Node* output_node = reset_observer_graph->createNone(); |
1004 | reset_observer_graph->insertNode(output_node); |
1005 | reset_observer_graph->registerOutput(output_node->output()); |
1006 | module_value->setType(module._ivalue()->type()); |
1007 | const auto method_name = c10::QualifiedName( |
1008 | *(module.type()->name()), reset_observer_method_name_); |
1009 | auto reset_observer_fn = |
1010 | module._ivalue()->compilation_unit()->create_function( |
1011 | method_name, std::move(reset_observer_graph)); |
1012 | auto self_arg = c10::Argument("self" , module.type()); |
1013 | auto output_arg = c10::Argument("none" , output_node->output()->type()); |
1014 | auto schema = c10::FunctionSchema( |
1015 | reset_observer_method_name_, |
1016 | "" , |
1017 | {std::move(self_arg)}, |
1018 | {std::move(output_arg)}); |
1019 | reset_observer_fn->setSchema(std::move(schema)); |
1020 | module.type()->addMethod(reset_observer_fn); |
1021 | } |
1022 | auto reset_min_max_graph = |
1023 | module.get_method(reset_observer_method_name_).graph(); |
1024 | Value* self = reset_min_max_graph->inputs()[0]; |
1025 | |
1026 | for (const auto& pair : observer_name_and_modules) { |
1027 | const auto& observer_name = pair.first; |
1028 | const auto& observer = pair.second; |
1029 | Value* observer_value = |
1030 | reset_min_max_graph->insertGetAttr(self, observer_name); |
1031 | MatchedSchema reset_minmax_schema = matchSchema( |
1032 | observer.get_method("reset_min_max_vals" ).function().getSchema(), |
1033 | observer_value->node()->sourceRange(), |
1034 | *reset_min_max_graph, |
1035 | {observer_value}, |
1036 | {}); |
1037 | reset_min_max_graph->insertMethodCall( |
1038 | "reset_min_max_vals" , reset_minmax_schema); |
1039 | } |
1040 | } |
1041 | |
1042 | void InsertObserversHelper::delayObservingValuesInPattern( |
1043 | Graph& graph, |
1044 | const PatternInfo& pattern) { |
1045 | const Graph& pattern_graph = *pattern.pattern_graph; |
1046 | const std::unordered_map<std::string, Value*>& vmap = pattern.vmap; |
1047 | |
1048 | const auto& matches = findPatternMatches(pattern_graph, graph); |
1049 | for (const auto& match : matches) { |
1050 | if (!std::all_of( |
1051 | pattern.filters.begin(), |
1052 | pattern.filters.end(), |
1053 | [&](const MatchFilter& f) { return f(match, vmap); })) { |
1054 | continue; |
1055 | } |
1056 | auto first_output = match.values_map.at(vmap.at("first_output" )); |
1057 | auto second_output = match.values_map.at(vmap.at("second_output" )); |
1058 | GRAPH_DEBUG( |
1059 | "Delay observation for value in function pattern:" , |
1060 | first_output->debugName(), |
1061 | " to " , |
1062 | second_output->debugName()); |
1063 | delay_observation_map_[first_output] = second_output; |
1064 | } |
1065 | } |
1066 | |
1067 | void InsertObserversHelper::addValuesToDelayObservation( |
1068 | const Module& module, |
1069 | const std::string& method_name) { |
1070 | Method method = module.get_method(method_name); |
1071 | auto graph = method.graph(); |
1072 | |
1073 | for (const auto& pattern : delay_patterns) { |
1074 | delayObservingValuesInPattern(*graph, pattern); |
1075 | } |
1076 | } |
1077 | |
1078 | void InsertObserversHelper::fillPassThroughValueMap( |
1079 | const std::shared_ptr<Graph>& graph) { |
1080 | std::stack<Block*> blocks_to_visit; |
1081 | blocks_to_visit.push(graph->block()); |
1082 | while (!blocks_to_visit.empty()) { |
1083 | Block* b = blocks_to_visit.top(); |
1084 | blocks_to_visit.pop(); |
1085 | for (Node* n : b->nodes()) { |
1086 | if (userDefinedCallFunction(n)) { |
1087 | auto g = getCallFunctionGraph(n); |
1088 | blocks_to_visit.push(g->block()); |
1089 | } |
1090 | for (auto* output : n->outputs()) { |
1091 | for (auto* input : getPassThroughInputs(output)) { |
1092 | pass_through_value_map_[output].push_back(input); |
1093 | } |
1094 | } |
1095 | for (Block* subblock : n->blocks()) { |
1096 | blocks_to_visit.push(subblock); |
1097 | } |
1098 | } |
1099 | } |
1100 | } |
1101 | |
1102 | void InsertObserversHelper::fillBoundaryValueMap( |
1103 | Module& module, |
1104 | const std::string& method_name) { |
1105 | for (auto& invoked_method : getInvokedMethods(module, method_name)) { |
1106 | auto& invoked_module = std::get<0>(invoked_method); |
1107 | const auto& invoked_method_name = std::get<1>(invoked_method); |
1108 | fillBoundaryValueMap(invoked_module, invoked_method_name); |
1109 | } |
1110 | |
1111 | auto graph = module.get_method(method_name).graph(); |
1112 | std::stack<Block*> blocks_to_visit; |
1113 | blocks_to_visit.push(graph->block()); |
1114 | auto* self = graph->inputs()[0]; |
1115 | while (!blocks_to_visit.empty()) { |
1116 | Block* b = blocks_to_visit.top(); |
1117 | blocks_to_visit.pop(); |
1118 | for (Node* n : b->nodes()) { |
1119 | if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) { |
1120 | std::shared_ptr<Graph> g; |
1121 | // offset of input for the caller node, since the first |
1122 | // input of CallFunction is the function node and the graph |
1123 | // for CallFunction start with actual input |
1124 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1125 | size_t input_offset; |
1126 | if (n->kind() == prim::CallMethod) { |
1127 | auto m_opt = getInvokedModuleOpt(module, n, self); |
1128 | if (!m_opt.has_value()) { |
1129 | continue; |
1130 | } |
1131 | auto m = *m_opt; |
1132 | g = m.get_method(n->s(attr::name)).graph(); |
1133 | input_offset = 0; |
1134 | } else { |
1135 | g = getCallFunctionGraph(n); |
1136 | input_offset = 1; |
1137 | } |
1138 | // add mapping from callsite value to value in called graph |
1139 | for (auto i = 0U; i < g->outputs().size(); ++i) { |
1140 | auto* return_val = g->outputs()[i]; |
1141 | GRAPH_DEBUG( |
1142 | "Boundary Map[return]:" , |
1143 | n->output(i)->debugName(), |
1144 | " -> " , |
1145 | return_val->debugName()); |
1146 | boundary_value_map_[n->output(i)].insert(return_val); |
1147 | } |
1148 | for (auto i = 0U; i < g->inputs().size(); ++i) { |
1149 | auto caller_input_index = i + input_offset; |
1150 | auto* caller_input = n->input(caller_input_index); |
1151 | auto* input_val = g->inputs()[i]; |
1152 | GRAPH_DEBUG( |
1153 | "Boundary Map[input]:" , |
1154 | caller_input->debugName(), |
1155 | " -> " , |
1156 | input_val->debugName()); |
1157 | boundary_value_map_[caller_input].insert(input_val); |
1158 | } |
1159 | } else if (n->kind() == prim::If) { |
1160 | for (Block* subblock : n->blocks()) { |
1161 | blocks_to_visit.push(subblock); |
1162 | for (Value* v : n->outputs()) { |
1163 | Value* subblock_output = subblock->outputs()[v->offset()]; |
1164 | GRAPH_DEBUG( |
1165 | "Boundary Map[if_output]:" , |
1166 | v->debugName(), |
1167 | " -> " , |
1168 | subblock_output->debugName()); |
1169 | boundary_value_map_[v].insert(subblock_output); |
1170 | } |
1171 | } |
1172 | } else { |
1173 | for (Block* subblock : n->blocks()) { |
1174 | blocks_to_visit.push(subblock); |
1175 | } |
1176 | } |
1177 | } |
1178 | } |
1179 | } |
1180 | |
1181 | void InsertObserversHelper::preprocess( |
1182 | Module& module, |
1183 | const std::string& method_name) { |
1184 | // run preprocess for child module before parent, since preprocess |
1185 | // mutates the graph and it might affect passes like fillBoundaryValueMap |
1186 | for (auto& invoked_method : getInvokedMethods(module, method_name)) { |
1187 | auto& invoked_module = std::get<0>(invoked_method); |
1188 | const auto& invoked_method_name = std::get<1>(invoked_method); |
1189 | preprocess(invoked_module, invoked_method_name); |
1190 | } |
1191 | |
1192 | Method method = module.get_method(method_name); |
1193 | auto graph = method.graph(); |
1194 | // Inline fork-wait calls |
1195 | InlineForkWait(graph); |
1196 | // fuse decomposed linear into aten::linear |
1197 | FuseLinear(graph); |
1198 | replaceConvolutionWithAtenConv(graph); |
1199 | RemoveListMutation(graph); |
1200 | } |
1201 | |
1202 | void InsertObserversHelper::analyze( |
1203 | Module& module, |
1204 | const std::string& method_name) { |
1205 | for (auto& invoked_method : getInvokedMethods(module, method_name)) { |
1206 | auto& invoked_module = std::get<0>(invoked_method); |
1207 | const auto& invoked_method_name = std::get<1>(invoked_method); |
1208 | analyze(invoked_module, invoked_method_name); |
1209 | } |
1210 | |
1211 | // fill out various internal state which will be later used in |
1212 | // insertObservers to insert the correct observer |
1213 | addValuesToDelayObservation(module, method_name); |
1214 | fillValueObserverMap(module, method_name); |
1215 | Method method = module.get_method(method_name); |
1216 | auto graph = method.graph(); |
1217 | fillPassThroughValueMap(graph); |
1218 | } |
1219 | |
1220 | bool InsertObserversHelper::valueNeedsToBeQuantized( |
1221 | Value* v, |
1222 | const QConfig& qconfig) { |
1223 | if (isBiasOfConvOrLinear(v) || |
1224 | !(v->type()->isSubtypeOf(*TensorType::get()) || |
1225 | v->type()->isSubtypeOf(*ListType::ofTensors())) || |
1226 | isEmbeddingBagNonInput(v)) { |
1227 | return false; |
1228 | } |
1229 | // For dynamic quantization we only insert observers at the input |
1230 | // of the quantizable function. |
1231 | if (quant_type_ == QuantType::STATIC) { |
1232 | // Check whether producer is quantizable |
1233 | if (!isWeightOnlyStaticQuantOp(v->node()) && |
1234 | (nodeQuantizable(v->node()) || isPropagateQuantOp(v->node()))) { |
1235 | return true; |
1236 | } |
1237 | } |
1238 | if (quant_type_ == QuantType::DYNAMIC) { |
1239 | // Check the dtype of the observer module. |
1240 | Module observer_module = getObserverModuleFor(v, qconfig); |
1241 | auto scalar_type = observer_module.attr("dtype" ); |
1242 | // For inputs with Fp16 type that are not-weights we don't observer them for |
1243 | // dynamic quantization. |
1244 | if (scalar_type == at::ScalarType::Half && !isWeight(v)) { |
1245 | return false; |
1246 | } |
1247 | } |
1248 | // Check whether node input value is quantizable |
1249 | for (const auto& use : v->uses()) { |
1250 | if (useQuantizable(use, quant_type_)) { |
1251 | return true; |
1252 | } |
1253 | } |
1254 | return false; |
1255 | } |
1256 | |
1257 | void InsertObserversHelper::removeActivationObservers() { |
1258 | std::vector<std::unordered_map<Value*, Module>::iterator> |
1259 | values_to_be_removed; |
1260 | for (auto it = observer_for_value_.begin(); it != observer_for_value_.end(); |
1261 | it++) { |
1262 | if (!isWeight(it->first)) { |
1263 | values_to_be_removed.push_back(it); |
1264 | } |
1265 | } |
1266 | for (auto it : values_to_be_removed) { |
1267 | observer_for_value_.erase(it); |
1268 | } |
1269 | } |
1270 | |
1271 | void InsertObserversHelper::fillValueObserverMap( |
1272 | Module& module, |
1273 | const std::string& method_name) { |
1274 | Method method = module.get_method(method_name); |
1275 | auto graph = method.graph(); |
1276 | |
1277 | if (visited_graph_of_observer_map_.count(graph.get())) { |
1278 | return; |
1279 | } |
1280 | visited_graph_of_observer_map_.insert(graph.get()); |
1281 | |
1282 | std::stack<Block*> blocks_to_visit; |
1283 | auto qconfig_opt = module_qconfig_map_.at(module._ivalue()); |
1284 | if (!qconfig_opt) { |
1285 | return; |
1286 | } |
1287 | auto qconfig = *qconfig_opt; |
1288 | for (auto* v : graph->inputs()) { |
1289 | if (valueNeedsToBeQuantized(v, qconfig)) { |
1290 | GRAPH_DEBUG("Recording observer for " , v->debugName()); |
1291 | GRAPH_DUMP("In graph:" , v->owningGraph()); |
1292 | observer_for_value_[v] = getObserverModuleFor(v, qconfig); |
1293 | } |
1294 | } |
1295 | |
1296 | blocks_to_visit.push(graph->block()); |
1297 | while (!blocks_to_visit.empty()) { |
1298 | Block* b = blocks_to_visit.top(); |
1299 | blocks_to_visit.pop(); |
1300 | for (Node* n : b->nodes()) { |
1301 | for (Value* v : n->outputs()) { |
1302 | if (valueNeedsToBeQuantized(v, qconfig)) { |
1303 | GRAPH_DEBUG("Recording observer for " , v->debugName()); |
1304 | GRAPH_DUMP("In graph:" , v->owningGraph()); |
1305 | observer_for_value_[v] = getObserverModuleFor(v, qconfig); |
1306 | } |
1307 | } |
1308 | |
1309 | for (Block* subblock : n->blocks()) { |
1310 | blocks_to_visit.push(subblock); |
1311 | } |
1312 | } |
1313 | } |
1314 | } |
1315 | |
1316 | c10::optional<Module> InsertObserversHelper::getObserverFor(Value* v) { |
1317 | if (observer_for_value_.count(v)) { |
1318 | auto observer = observer_for_value_.at(v); |
1319 | GRAPH_DEBUG("Got observer module config for:" , v->debugName()); |
1320 | return observer; |
1321 | } |
1322 | c10::optional<Module> result; |
1323 | if (boundary_value_map_.count(v)) { |
1324 | for (Value* next : boundary_value_map_.at(v)) { |
1325 | GRAPH_DEBUG( |
1326 | "Going through boundary map:" , |
1327 | v->debugName(), |
1328 | " --> " , |
1329 | next->debugName()); |
1330 | GRAPH_DUMP("From graph:" , v->owningGraph()); |
1331 | GRAPH_DUMP("To graph:" , next->owningGraph()); |
1332 | auto observer_opt = getObserverFor(next); |
1333 | if (observer_opt) { |
1334 | // Need to make sure all values are |
1335 | // configured with same observer |
1336 | if (result) { |
1337 | TORCH_CHECK( |
1338 | *observer_opt == *result, |
1339 | "Expecting all values in the graph only configured with one observer" ); |
1340 | } else { |
1341 | result = observer_opt; |
1342 | } |
1343 | } |
1344 | } |
1345 | } |
1346 | GRAPH_DEBUG( |
1347 | "Observer module config for " , v->debugName(), ":" , result.has_value()); |
1348 | return result; |
1349 | } |
1350 | |
1351 | std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>> |
1352 | InsertObserversHelper::insertObservers( |
1353 | Module& module, |
1354 | const std::string& method_name, |
1355 | bool is_entry_point, |
1356 | std::unordered_set<Value*> graph_observed_values) { |
1357 | auto graph = module.get_method(method_name).graph(); |
1358 | return insertObserversFor( |
1359 | graph->block(), module, graph_observed_values, is_entry_point); |
1360 | } |
1361 | |
1362 | void InsertObserversHelper::recordObserved( |
1363 | Value* v, |
1364 | const Module& observer_module, |
1365 | std::unordered_map<Value*, Module>& values_to_observe, |
1366 | std::unordered_set<Value*>& block_observed_values) { |
1367 | Value* to_observe = v; |
1368 | if (delay_observation_map_.count(v)) { |
1369 | to_observe = delay_observation_map_.at(v); |
1370 | } |
1371 | values_to_observe[to_observe] = observer_module; |
1372 | block_observed_values.insert(to_observe); |
1373 | } |
1374 | |
1375 | std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>> |
1376 | InsertObserversHelper::insertObserversFor( |
1377 | Block* block, |
1378 | script::Module& module, |
1379 | std::unordered_set<Value*>& block_observed_values, |
1380 | bool is_entry_point, |
1381 | bool is_user_defined_function) { |
1382 | // input/output values, used to skip inserting observers |
1383 | // for input and output of the block and the owning graph, |
1384 | // we have to insert the observers at call site because |
1385 | // the graph itself can be shared |
1386 | std::unordered_set<Value*> inputs_outputs; |
1387 | // list of observer modules for input values |
1388 | std::vector<c10::optional<Module>> block_input_observers; |
1389 | // list of observer modules for output values |
1390 | std::vector<c10::optional<Module>> block_output_observers; |
1391 | |
1392 | // if the current block is the block for entry point graph(the forward graph |
1393 | // of the top level module), we can insert observers in the block directly |
1394 | if (!is_entry_point) { |
1395 | auto* graph = block->owningGraph(); |
1396 | // graph inputs/outputs |
1397 | for (auto list : {graph->inputs(), graph->outputs()}) { |
1398 | for (auto* v : list) { |
1399 | inputs_outputs.insert(v); |
1400 | } |
1401 | } |
1402 | // block outputs |
1403 | for (auto* v : block->outputs()) { |
1404 | inputs_outputs.insert(v); |
1405 | } |
1406 | |
1407 | for (auto* v : block->inputs()) { |
1408 | block_input_observers.emplace_back(getObserverFor(v)); |
1409 | } |
1410 | |
1411 | for (auto* v : block->outputs()) { |
1412 | // we need explictly skip the values that are already observed |
1413 | // this might happen in subblocks for `if` since |
1414 | // these subblock has access to all values before the `if` node |
1415 | if (!isObserved(v, block_observed_values)) { |
1416 | block_output_observers.emplace_back(getObserverFor(v)); |
1417 | } else { |
1418 | block_output_observers.emplace_back(c10::nullopt); |
1419 | } |
1420 | } |
1421 | } |
1422 | |
1423 | // This means the block is been processed before, we just |
1424 | // need to attach observer modules and construct the information |
1425 | // needed by call site here |
1426 | bool visited = block_observer_map_.count(block); |
1427 | if (visited) { |
1428 | // instance clone of observer module and setAttr |
1429 | for (const auto& observer_attrs : block_observer_map_.at(block)) { |
1430 | const auto& name = std::get<0>(observer_attrs); |
1431 | const auto& observer = std::get<1>(observer_attrs); |
1432 | module._ivalue()->setAttr(name, observer.deepcopy()._ivalue()); |
1433 | } |
1434 | } |
1435 | // NB: Why do we need to process the graph even if it's visited? |
1436 | // Reason is `block_observed_values` can |
1437 | // change depending on where the method is called, and |
1438 | // outputs that's been observed(third item of the returned result) |
1439 | // can change depending on that, so for each graph we'll need to go through |
1440 | // the whole process of inserting observers, the observers inserted in this |
1441 | // block won't change, but the information we return to the caller will change |
1442 | // based on `block_observed_values` |
1443 | |
1444 | std::stack<Block*> blocks_to_visit; |
1445 | blocks_to_visit.push(block); |
1446 | auto* self = block->owningGraph()->inputs()[0]; |
1447 | // We first construct a map from value to the module, then |
1448 | // insert observers for them later, this is to avoid interference |
1449 | // of the inserted observers with the analysis to decide where |
1450 | // to insert observers, also we only insert observers for |
1451 | // "intermediate values" that is not the input/output of the |
1452 | // graph |
1453 | std::unordered_map<Value*, Module> values_to_observe; |
1454 | |
1455 | for (auto* v : block->inputs()) { |
1456 | if (!inputs_outputs.count(v) && !values_to_observe.count(v)) { |
1457 | if (auto observer_opt = getObserverFor(v)) { |
1458 | recordObserved( |
1459 | v, *observer_opt, values_to_observe, block_observed_values); |
1460 | } |
1461 | } |
1462 | } |
1463 | while (!blocks_to_visit.empty()) { |
1464 | Block* b = blocks_to_visit.top(); |
1465 | blocks_to_visit.pop(); |
1466 | for (Node* n : b->nodes()) { |
1467 | if (observer_nodes_.count(n)) { |
1468 | continue; |
1469 | } |
1470 | if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) { |
1471 | script::Module m; |
1472 | std::shared_ptr<Graph> g; |
1473 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1474 | size_t input_offset; |
1475 | bool is_udf_for_subblock = is_user_defined_function; |
1476 | if (n->kind() == prim::CallMethod) { |
1477 | auto m_opt = getInvokedModuleOpt(module, n, self); |
1478 | if (!m_opt.has_value()) { |
1479 | continue; |
1480 | } |
1481 | m = *m_opt; |
1482 | g = m.get_method(n->s(attr::name)).graph(); |
1483 | input_offset = 0; |
1484 | } else { // CallFunction |
1485 | m = module; |
1486 | g = getCallFunctionGraph(n); |
1487 | input_offset = 1; |
1488 | is_udf_for_subblock = true; |
1489 | } |
1490 | |
1491 | std::unordered_set<Value*> callee_observed_inputs; |
1492 | for (auto i = 0U; i < g->inputs().size(); ++i) { |
1493 | auto* node_input = n->input(i + input_offset); |
1494 | if (isObserved(node_input, block_observed_values)) { |
1495 | callee_observed_inputs.insert(g->inputs()[i]); |
1496 | } |
1497 | } |
1498 | auto* subblock = g->block(); |
1499 | auto info_from_callee = insertObserversFor( |
1500 | subblock, m, callee_observed_inputs, false, is_udf_for_subblock); |
1501 | auto input_observers = std::get<0>(info_from_callee); |
1502 | auto output_observers = std::get<1>(info_from_callee); |
1503 | auto callee_observed_outputs = std::get<2>(info_from_callee); |
1504 | for (auto idx : callee_observed_outputs) { |
1505 | block_observed_values.insert(n->outputs()[idx]); |
1506 | } |
1507 | for (auto i = 0U; i < g->inputs().size(); ++i) { |
1508 | auto* node_input = n->input(i + input_offset); |
1509 | if (input_observers[i] && !inputs_outputs.count(node_input) && |
1510 | !isObserved(node_input, block_observed_values)) { |
1511 | recordObserved( |
1512 | node_input, |
1513 | *input_observers[i], |
1514 | values_to_observe, |
1515 | block_observed_values); |
1516 | } |
1517 | } |
1518 | for (auto i = 0U; i < n->outputs().size(); ++i) { |
1519 | if (output_observers[i] && !inputs_outputs.count(n->output(i)) && |
1520 | !isObserved(n->output(i), block_observed_values)) { |
1521 | recordObserved( |
1522 | n->output(i), |
1523 | *output_observers[i], |
1524 | values_to_observe, |
1525 | block_observed_values); |
1526 | } |
1527 | } |
1528 | } else if (n->kind() == prim::If) { |
1529 | // a vector recoding whether each output is observed or not |
1530 | std::vector<bool> aggregated_output_observe_state; |
1531 | for (Block* subblock : n->blocks()) { |
1532 | if (alwaysRaisesException(subblock)) { |
1533 | continue; |
1534 | } |
1535 | // subblock has access to all the values in the scope of prim::If, |
1536 | // so subblock_observed_values == block_observed_values |
1537 | auto info_from_subblock = |
1538 | insertObserversFor(subblock, module, block_observed_values); |
1539 | // subblock for prim::If doesn't have inputs |
1540 | auto output_observers = std::get<1>(info_from_subblock); |
1541 | auto subblock_observed_outputs = std::get<2>(info_from_subblock); |
1542 | |
1543 | // We'll insert output observer for each subblock, and in the end |
1544 | // we will check if output of subblocks are quantized consistently |
1545 | for (size_t i = 0; i < subblock->outputs().size(); ++i) { |
1546 | Value* output = subblock->outputs()[i]; |
1547 | if (output_observers[i] && !inputs_outputs.count(output) && |
1548 | !isObserved(output, block_observed_values)) { |
1549 | recordObserved( |
1550 | output, |
1551 | *output_observers[i], |
1552 | values_to_observe, |
1553 | block_observed_values); |
1554 | } |
1555 | } |
1556 | for (auto idx : subblock_observed_outputs) { |
1557 | block_observed_values.insert(subblock->outputs()[idx]); |
1558 | } |
1559 | std::vector<bool> subblock_output_observe_state; |
1560 | for (size_t i = 0; i < subblock->outputs().size(); ++i) { |
1561 | Value* output = subblock->outputs()[i]; |
1562 | subblock_output_observe_state.push_back( |
1563 | isObserved(output, block_observed_values)); |
1564 | } |
1565 | if (!aggregated_output_observe_state.empty()) { |
1566 | TORCH_CHECK( |
1567 | aggregated_output_observe_state == |
1568 | subblock_output_observe_state, |
1569 | "branches for `if` should return values that are observed " |
1570 | "consistently, if node:" , |
1571 | *n); |
1572 | } else { |
1573 | aggregated_output_observe_state = subblock_output_observe_state; |
1574 | } |
1575 | } |
1576 | // mark the output of if as observed |
1577 | for (size_t i = 0; i < n->outputs().size(); ++i) { |
1578 | if (aggregated_output_observe_state[i]) { |
1579 | block_observed_values.insert(n->output(i)); |
1580 | } |
1581 | } |
1582 | } else if (n->kind() == prim::Loop) { |
1583 | TORCH_WARN_ONCE( |
1584 | "prim::Loop is not yet supported in quantization, " |
1585 | "please make sure nothing needs to be quantized in the " |
1586 | "loop" ); |
1587 | } |
1588 | for (Value* v : n->outputs()) { |
1589 | propagateObservedProperty(v, block_observed_values); |
1590 | if (!inputs_outputs.count(v) && !isObserved(v, block_observed_values)) { |
1591 | auto observer_opt = getObserverFor(v); |
1592 | // If the node is one of the propagate quant node, e.g. |
1593 | // aten::cat, we should observe its output only |
1594 | // if the input of the node is observed |
1595 | if (observer_opt && |
1596 | shouldObserve(n, block_observed_values, quant_type_)) { |
1597 | recordObserved( |
1598 | v, *observer_opt, values_to_observe, block_observed_values); |
1599 | } |
1600 | } |
1601 | } |
1602 | } |
1603 | } |
1604 | std::vector<size_t> output_idxs; |
1605 | for (auto i = 0U; i < block->outputs().size(); ++i) { |
1606 | if (isObserved(block->outputs()[i], block_observed_values)) { |
1607 | output_idxs.push_back(i); |
1608 | } |
1609 | } |
1610 | if (!visited) { |
1611 | NameModuleVector observer_name_and_modules; |
1612 | for (const auto& item : values_to_observe) { |
1613 | auto* v = item.first; |
1614 | auto observer = item.second; |
1615 | TORCH_CHECK( |
1616 | !is_user_defined_function, |
1617 | "Inserting observers for user defined functions is not " |
1618 | "supported right now" ); |
1619 | insertObserverFor(v, module, observer, observer_name_and_modules); |
1620 | } |
1621 | if (insertResetObserverMethod()) { |
1622 | insertObserverResetMinMax(module, observer_name_and_modules); |
1623 | } |
1624 | block_observer_map_[block] = observer_name_and_modules; |
1625 | } |
1626 | return std::make_tuple( |
1627 | block_input_observers, block_output_observers, output_idxs); |
1628 | } |
1629 | |
1630 | void InsertObserversHelper::propagateObservedProperty( |
1631 | Value* output, |
1632 | std::unordered_set<Value*>& block_observed_values) { |
1633 | if (pass_through_value_map_.count(output)) { |
1634 | // since the vector is always non-empty, we will |
1635 | // not return the initial value |
1636 | bool all_observed = true; |
1637 | for (Value* v : pass_through_value_map_.at(output)) { |
1638 | all_observed &= |
1639 | observed_values_.count(v) || block_observed_values.count(v); |
1640 | } |
1641 | if (all_observed) { |
1642 | GRAPH_DEBUG("Pass through observed property in node:" , *output->node()); |
1643 | // This is to propagate observed property through |
1644 | // all ops that doesn't require observation |
1645 | block_observed_values.insert(output); |
1646 | } |
1647 | } |
1648 | } |
1649 | |
1650 | } // namespace |
1651 | |
1652 | Module InsertObservers( |
1653 | Module& input_module, |
1654 | const std::string& method_name, |
1655 | const QConfigDict& qconfig_dict, |
1656 | bool inplace, |
1657 | QuantType quant_type) { |
1658 | ModuleQConfigMap map_before_clone; |
1659 | fillQConfigMap(input_module, qconfig_dict, map_before_clone); |
1660 | ModuleCloneHelper mh; |
1661 | Module module = mh.clone(input_module, map_before_clone, inplace); |
1662 | SwapFunctionalLinear(module); |
1663 | ModuleQConfigMap module_qconfig_map; |
1664 | // Since the types are changed after clone, we need to fill |
1665 | // the qconfig map again |
1666 | fillQConfigMap(module, qconfig_dict, module_qconfig_map); |
1667 | GRAPH_DEBUG("Quant type:" , quant_type); |
1668 | InsertObserversHelper helper(module_qconfig_map, quant_type); |
1669 | helper.preprocess(module, method_name); |
1670 | helper.fillBoundaryValueMap(module, method_name); |
1671 | // analyze needs to run after fillBoundaryValueMap |
1672 | // since we need to know the boundary value mapping to trace |
1673 | // through the calls |
1674 | helper.analyze(module, method_name); |
1675 | helper.insertObservers(module, method_name, /* is_entry_point */ true); |
1676 | return module; |
1677 | } |
1678 | |
1679 | Module InsertObserversForOnDevicePTQ( |
1680 | Module& input_module, |
1681 | const std::string& method_name, |
1682 | const QConfigDict& qconfig_dict, |
1683 | bool inplace, |
1684 | QuantType quant_type) { |
1685 | ModuleQConfigMap map_before_clone; |
1686 | fillQConfigMap(input_module, qconfig_dict, map_before_clone); |
1687 | ModuleCloneHelper mh; |
1688 | Module cloned_module = mh.clone(input_module, map_before_clone, inplace); |
1689 | std::shared_ptr<Graph> g = cloned_module.get_method(method_name).graph(); |
1690 | SwapFunctionalLinear(g); |
1691 | std::string observer_method_name = "observe_" + method_name; |
1692 | cloneMethod(cloned_module, method_name, observer_method_name); |
1693 | ModuleQConfigMap module_qconfig_map; |
1694 | // Since the types are changed after clone, we need to fill |
1695 | // the qconfig map again |
1696 | fillQConfigMap(cloned_module, qconfig_dict, module_qconfig_map); |
1697 | GRAPH_DEBUG("Quant type:" , quant_type); |
1698 | InsertObserversHelper helper(module_qconfig_map, quant_type); |
1699 | // Removes list mutation part is not clear. Is it needed |
1700 | helper.preprocess(cloned_module, observer_method_name); |
1701 | // Since we expect the graph to be inlined this should not have any use |
1702 | // However, this function does handle if blocks |
1703 | // Although as far as I understood If blocks are not really handled |
1704 | // in JIT quantization. Should we just protect against this. That is if we |
1705 | // find observable value inside If block? Also side effect of inlining is that |
1706 | // you will have multiple getattrs for the same attribute and thus potentially |
1707 | // multiple observers observing the same value. This will also lead to |
1708 | // increased size of the packed param struct. I dont expect this to be a |
1709 | // commong pattern but something to be aware fo Note that current quant |
1710 | // workflow does not prevent this anyway since during inset quant dequant |
1711 | // things are inlined anyway |
1712 | helper.fillBoundaryValueMap(cloned_module, observer_method_name); |
1713 | // analyze needs to run after fillBoundaryValueMap |
1714 | // since we need to know the boundary value mapping to trace |
1715 | // through the calls |
1716 | helper.analyze(cloned_module, observer_method_name); |
1717 | // Remove activation observer if quant_type is dynamic |
1718 | if (quant_type == QuantType::DYNAMIC) { |
1719 | helper.removeActivationObservers(); |
1720 | } |
1721 | helper.setInsertResetObserverMethod(true, method_name); |
1722 | helper.insertObservers( |
1723 | cloned_module, observer_method_name, /* is_entry_point */ true); |
1724 | return cloned_module; |
1725 | } |
1726 | } // namespace jit |
1727 | } // namespace torch |
1728 | |