1#include <c10/util/irange.h>
2#include <torch/csrc/jit/passes/quantization/insert_observers.h>
3
4#include <torch/csrc/jit/frontend/schema_matching.h>
5#include <torch/csrc/jit/ir/subgraph_matcher.h>
6#include <torch/csrc/jit/jit_log.h>
7#include <torch/csrc/jit/passes/constant_pooling.h>
8#include <torch/csrc/jit/passes/constant_propagation.h>
9#include <torch/csrc/jit/passes/fuse_linear.h>
10#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11#include <torch/csrc/jit/passes/inline_fork_wait.h>
12#include <torch/csrc/jit/passes/quantization/helper.h>
13#include <torch/csrc/jit/passes/remove_mutation.h>
14
15#include <memory>
16#include <regex>
17#include <stack>
18#include <string>
19#include <utility>
20
21namespace torch {
22namespace jit {
23
24using ModuleQConfigMap = std::unordered_map<ModulePtr, c10::optional<QConfig>>;
25
26namespace {
27
28struct OptionalQConfigHash {
29 inline size_t operator()(const c10::optional<QConfig>& qconfig_opt) const {
30 if (qconfig_opt.has_value()) {
31 const auto& m1 = std::get<0>(*qconfig_opt);
32 const auto& m2 = std::get<1>(*qconfig_opt);
33 constexpr int CONST = 7;
34 return std::hash<Module>()(m1) + CONST * std::hash<Module>()(m2);
35 }
36 return 0;
37 }
38};
39using QConfigTypePtrMap =
40 std::unordered_map<c10::optional<QConfig>, TypePtr, OptionalQConfigHash>;
41using NameModuleVector = std::vector<std::pair<std::string, Module>>;
42using OptionalModuleVector = std::vector<c10::optional<Module>>;
43using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
44using graph_rewrite_helper::PatternInfo;
45using graph_rewrite_helper::replaceConvolutionWithAtenConv;
46
47// helper functions
48void fillQConfigMap(
49 const Module& module,
50 const QConfigDict& qconfig_dict,
51 ModuleQConfigMap& map,
52 const std::string& key = "",
53 const c10::optional<QConfig>& parent_qconfig = c10::nullopt) {
54 c10::optional<QConfig> qconfig;
55 if (qconfig_dict.find(key) != qconfig_dict.end()) {
56 GRAPH_DEBUG("Got module config for key:", key);
57 qconfig = qconfig_dict.at(key);
58 } else {
59 GRAPH_DEBUG("Inheriting qconfig from parent module:", key);
60 qconfig = parent_qconfig;
61 }
62 map[module._ivalue()] = qconfig;
63
64 for (const NameModule& s : module.named_children()) {
65 std::string child_key;
66 if (key.empty()) {
67 child_key = s.name;
68 } else {
69 child_key = key + "." + s.name;
70 }
71 fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
72 }
73}
74
75Module getObserverModuleFor(Value* v, const QConfig& qconfig) {
76 return isWeight(v) ? std::get<1>(qconfig) : std::get<0>(qconfig);
77}
78
79// helper classes
80class ModuleCloneHelper {
81 public:
82 /** Clone according to module qconfig map, this is for handling the case
83 * where we have two module instances sharing the same ClassType
84 * but configured with different QConfig
85 * code is copied and modified from
86 * https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp
87 * inplace option means if the copy of the Tensor is deepcopy or not
88 * if inplace is true, the cloned module will share the tensors with
89 * original model instead of deepcopy them
90 */
91 Module clone(
92 const Module& module,
93 const ModuleQConfigMap& module_qconfig_map,
94 bool inplace = false) {
95 std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
96 IValue::HashAliasedIValueMap memo;
97 return clone_impl(
98 module, module_qconfig_map, type_remap, inplace, std::move(memo));
99 }
100
101 private:
102 Module clone_impl(
103 const Module& module,
104 const ModuleQConfigMap& module_qconfig_map,
105 std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap,
106 bool inplace,
107 IValue::HashAliasedIValueMap memo) {
108 auto qconfig = module_qconfig_map.at(module._ivalue());
109 auto type = module.type();
110 // Create a new _ivalue in the same compilation unit.
111 // Since now we have shared ClassType, we need to preserve the shared
112 // ClassType during cloning, so we first use type and qconfig to check if
113 // the type is already cloned, if so, we'll create a new module with the
114 // cloned ClassType, if not, we'll create a new module and a new ClassType.
115 bool type_already_cloned = type_remap.find(type) != type_remap.end() &&
116 type_remap.at(type).find(qconfig) != type_remap.at(type).end();
117 Module r;
118 if (type_already_cloned) {
119 // if we cloned the class type before, we'll reuse it
120 Module new_module(
121 module._ivalue()->compilation_unit(),
122 type_remap.at(type).at(qconfig)->cast<ClassType>());
123 r = new_module;
124 } else {
125 Module new_module(
126 *type->name(), module._ivalue()->compilation_unit(), true);
127 r = new_module;
128 type_remap[type][module_qconfig_map.at(module._ivalue())] = r.type();
129 }
130 // Copy slots. If a slot is a module - recursively clone it.
131 size_t N = type->numAttributes();
132 for (const auto i : c10::irange(N)) {
133 IValue s = module._ivalue()->getSlot(i);
134 std::string attr_name = type->getAttributeName(i);
135 TypePtr attr_type = type->getAttribute(i);
136 if (attr_type->is_module()) {
137 const Module& orig = Module(s.toObject());
138 Module cloned =
139 clone_impl(orig, module_qconfig_map, type_remap, inplace, memo);
140
141 // NOTE: why do we need to manually setattr on object instead of using
142 // register_module here? because the attr can be a module interface
143 // type and hold a Module object still. register_module will not let us
144 // correctly set up the type for this attr, so we had to do this
145 // manually. In the case it's an interface type, the type will be shared
146 // by the new cloned instance in the same compilation unit bc it only
147 // contains a list of functionSchema
148 r.type()->addOrCheckAttribute(
149 attr_name,
150 attr_type->cast<ClassType>() ? cloned.type() : attr_type);
151 r._ivalue()->setAttr(attr_name, cloned._ivalue());
152 } else {
153 // we'll deepcopy the IValue in non inplace option
154 r.register_attribute(
155 type->getAttributeName(i),
156 type->getAttribute(i),
157 inplace ? s : s.deepcopy(memo),
158 type->is_parameter(i),
159 type->is_buffer(i));
160 }
161 }
162
163 // only clone the methods and constants if the ClassType is not cloned
164 // before
165 if (!type_already_cloned) {
166 for (size_t i = 0; i < type->numConstants(); ++i) {
167 r.type()->addConstant(type->getConstantName(i), type->getConstant(i));
168 }
169 // Clone methods remapping the types to the cloned ones.
170 for (auto& fn : type->methods()) {
171 clone_method(module, r, *fn, module_qconfig_map, type_remap);
172 }
173 // Execute __setstate__(__getstate__()) to initialize custom class
174 // members.
175 if (auto setstate_method = r.find_method("__setstate__")) {
176 auto getstate_method = r.find_method("__getstate__");
177 TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
178 auto state = (*getstate_method)(Stack{});
179 (*setstate_method)(Stack{std::move(state)});
180 }
181 }
182 return r;
183 }
184
185 void remapTypes(
186 Block* block,
187 Value* self,
188 const Module& source,
189 Module& target,
190 const ModuleQConfigMap& module_qconfig_map,
191 const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>&
192 type_remap_fn) {
193 // remap of %self will be done outside of the function
194 // and we don't support the case when people pass in
195 // module as argument of the method because in that case
196 // we need to do more comprehensive analysis to decide the
197 // QConfig for the module
198 for (size_t i = 1; i < block->inputs().size(); ++i) {
199 TORCH_CHECK(
200 !block->inputs()[i]->type()->cast<ClassType>(),
201 "We don't support quantizing methods that has Object as arguments");
202 }
203 for (Node* node : block->nodes()) {
204 // remapping type for module instance
205 if (node->kind() == prim::CallMethod || node->kind() == prim::GetAttr) {
206 Value* instance = node->inputs()[0];
207 auto child_opt = getInvokedModuleOpt(source, node, self);
208 if (child_opt.has_value()) {
209 auto qconfig = module_qconfig_map.at(child_opt->_ivalue());
210 instance->setType(type_remap_fn(instance->type(), qconfig));
211 }
212 }
213 // We don't remap output and the remapping of module type
214 // will be done in CallMethod, we don't support type remapping
215 // for modules returned from methods or functions
216 for (Block* sub_block : node->blocks()) {
217 remapTypes(
218 sub_block, self, source, target, module_qconfig_map, type_remap_fn);
219 }
220 for (Symbol name : node->attributeNames()) {
221 if (node->kindOf(name) == AttributeKind::g) {
222 remapTypes(
223 node->g(name).get(),
224 source,
225 target,
226 module_qconfig_map,
227 type_remap_fn);
228 } else if (node->kindOf(name) == AttributeKind::gs) {
229 for (const auto& g : node->gs(name)) {
230 remapTypes(
231 g.get(), source, target, module_qconfig_map, type_remap_fn);
232 }
233 }
234 }
235 }
236 }
237
238 void remapTypes(
239 Graph* graph,
240 const Module& source,
241 Module& target,
242 const ModuleQConfigMap& module_qconfig_map,
243 const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>&
244 type_remap_fn) {
245 remapTypes(
246 graph->block(),
247 graph->inputs()[0],
248 source,
249 target,
250 module_qconfig_map,
251 type_remap_fn);
252 }
253
254 void clone_method(
255 const Module& source,
256 Module& target,
257 const Function& method,
258 const ModuleQConfigMap& module_qconfig_map,
259 const std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap) {
260 auto type_remap_fn = [&](TypePtr type_ptr,
261 const c10::optional<QConfig>& qconfig) {
262 if (type_remap.find(type_ptr) != type_remap.end()) {
263 const auto& qconfig_map = type_remap.at(type_ptr);
264 if (qconfig_map.find(qconfig) != qconfig_map.end()) {
265 return qconfig_map.at(qconfig);
266 }
267 }
268 return type_ptr;
269 };
270 auto graph = toGraphFunction(method).graph()->copy();
271 remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn);
272 // remap self
273 graph->inputs()[0]->setType(target.type());
274 // we only support %self being Module in the arguments of function
275 auto schema_type_remap_fn = [&](TypePtr type_ptr) {
276 return type_remap_fn(
277 std::move(type_ptr), module_qconfig_map.at(source._ivalue()));
278 };
279 auto schema =
280 method.getSchema().cloneWithRemappedTypes(schema_type_remap_fn);
281 const auto this_method_name =
282 c10::QualifiedName(*target.type()->name(), method.name());
283 auto copied = target._ivalue()->compilation_unit()->create_function(
284 this_method_name, std::move(graph));
285 target.type()->addMethod(copied);
286 copied->setSchema(std::move(schema));
287 }
288};
289
290class InsertObserversHelper {
291 public:
292 explicit InsertObserversHelper(
293 const ModuleQConfigMap& map,
294 QuantType quant_type)
295 : module_qconfig_map_(map), quant_type_(quant_type) {}
296
297 // TODO: replace (module, method_name) with graph?
298 // preprocess to clean up the graph from tracing
299 void preprocess(Module& module, const std::string& method_name);
300
301 // Fill the map between the caller input/output to input/output
302 // of called graph, this is used to navigate through the graph
303 // to find the observer for a given value
304 void fillBoundaryValueMap(Module& module, const std::string& method_name);
305
306 // analyze the graph and record necessary information that can
307 // be used in insert observers
308 void analyze(Module& module, const std::string& method_name);
309
310 void removeActivationObservers();
311
312 /**
313 * Recursively insert observers for the method, also we'll process
314 * the nodes in the graph in the order of execution of these nodes
315 * since we need the context information to decide whether we want to
316 * observe/quantize a value a not, we don't want to observe a value multiple
317 * times.
318 *
319 * arguemnt: is_entry_point means whether the current method is the forward
320 * method of the top level module.
321 *
322 * Since we want to insert observers in the call site instead of in the called
323 * graph, we'll postpone inserting observer to caller as much as possible, if
324 * we know the current method is the outer most method, then
325 * we will insert all observers in the graph instead of postpone this to the
326 * parent, note that this assumes we don't have recursive method
327 * calls
328 *
329 * returns a tuple of vectors of observer modules for input and output, these
330 * are used for inserting observers for the input/output values
331 * since we need to insert these values at call site.
332 * And a vector of indexes of outputs that indicates whether the output value
333 * is already observed or not, this is used for propagating the observed
334 * property of a value through CallMethods, because we should skip inserting
335 * observers for ops that don't require observation
336 */
337 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
338 insertObservers(
339 Module& module,
340 const std::string& method_name,
341 bool is_entry_point = false,
342 std::unordered_set<Value*> graph_observed_values =
343 std::unordered_set<Value*>());
344
345 void setInsertResetObserverMethod(
346 bool insert_reset_observer_method,
347 const std::string& method_name) {
348 insert_reset_observer_method_ = insert_reset_observer_method;
349 reset_observer_method_name_ = "reset_observers_" + method_name;
350 }
351
352 private:
353 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
354 insertObserversFor(
355 Block* block,
356 script::Module& module,
357 // this is a reference because when we insert observer for a value
358 // in one block it is also observed in another block, we don't want to
359 // insert multiple observers for the same value
360 std::unordered_set<Value*>& block_observed_values,
361 bool is_entry_point = false,
362 bool is_user_defined_function = false);
363
364 // Record v as "ready for observation" by storing it in values_to_observe.
365 // If v is a part of a delayed observation pattern, record v's descendant
366 // (per delay rules) instead. The observers are inserted at a later stage
367 // by reading the state created by this function.
368 void recordObserved(
369 Value* v,
370 const Module& observer_module,
371 std::unordered_map<Value*, Module>& values_to_observe,
372 std::unordered_set<Value*>& block_observed_values);
373
374 ModuleMethodVector getInvokedMethods(
375 Module& module,
376 const std::string& method_name);
377
378 bool valueNeedsToBeQuantized(Value* v, const QConfig& qconfig);
379
380 bool isObserved(
381 Value* v,
382 const std::unordered_set<Value*>& block_observed_values) {
383 return block_observed_values.count(v) || observed_values_.count(v);
384 }
385
386 // Fill the map from value to the corresponding observer module
387 // this map is used in insertObservers to actually insert
388 // observers to the module
389 void fillValueObserverMap(Module& module, const std::string& method_name);
390
391 // Clone observer module and add it to the original module,
392 // and insert a call to observer forward function
393 void insertObserverFor(
394 Value* v,
395 Module& module,
396 const Module& observer_module,
397 NameModuleVector& observer_name_and_modules);
398
399 void insertObserverResetMinMax(
400 Module& module,
401 const NameModuleVector& observer_name_and_modules);
402
403 // Uses the state created by fillBoundaryValueMap and fillValueObserverMap
404 // to return an observer configured for a value, if it is needed.
405 c10::optional<Module> getObserverFor(Value* v);
406
407 // Uses the state created by fillPassThroughValueMap to propage observed
408 // property which should pass through from inputs to outputs.
409 void propagateObservedProperty(
410 Value* output,
411 std::unordered_set<Value*>& block_observed_values);
412
413 // for cat/add/mul we will only observe their output if their input
414 // are observed
415 bool shouldObserve(
416 Node* n,
417 const std::unordered_set<Value*>& block_observed_values,
418 QuantType quant_type) {
419 // Check whether node output uses can be quantized, eg cat followed by
420 // linear op
421 for (Value* v : n->outputs()) {
422 for (const auto& use : v->uses()) {
423 if (useQuantizable(use, quant_type)) {
424 return true;
425 }
426 }
427 }
428 if (isPropagateQuantSingleInputOp(n)) {
429 return isObserved(n->input(0), block_observed_values);
430 } else if (isPropagateQuantBinaryOp(n)) {
431 // This checks both of the input should be tensor and observed.
432 // There is one check that we didn't do here, which is
433 // !isScalar(isObserved(n->input(1), block_observed_values)
434 // to make sure input 1 is not a scalar, because scalar tensor input
435 // for add/mul won't be observed with current rule, we can omit
436 // this check here
437 return isObserved(n->input(0), block_observed_values) &&
438 isObserved(n->input(1), block_observed_values);
439 }
440 return true;
441 }
442
443 void delayObservingValuesInPattern(Graph& graph, const PatternInfo& pattern);
444
445 // Find and mark known patterns such as conv-relu (and others) where
446 // we should not insert observers in the middle of the pattern.
447 void addValuesToDelayObservation(
448 const Module& module,
449 const std::string& method_name);
450
451 // Fill the map from values to the list of values that can pass the observed
452 // property to it
453 void fillPassThroughValueMap(const std::shared_ptr<Graph>& graph);
454
455 bool insertResetObserverMethod() {
456 return insert_reset_observer_method_;
457 }
458
459 const ModuleQConfigMap& module_qconfig_map_;
460
461 // Values we want to delay observation, used to delay the observation for
462 // values in the middle of the ops that are supposed to be fused, e.g.
463 // the output value of conv in the conv - relu pattern
464 // the key is the intermediate output, e.g. output of conv
465 // the value is the value we want to observe, e.g. output of relu
466 //
467 // example, assuming we want to delay conv-relu:
468 // %x1 = conv(%x0)
469 // %x2 = relu(%x1)
470 //
471 // delay_observation_map_ = {
472 // %x1: %x2,
473 // }
474 std::unordered_map<Value*, Value*> delay_observation_map_;
475
476 std::unordered_set<Graph*> visited_graph_of_observer_map_;
477
478 // Map of value to observer module configured for that value.
479 std::unordered_map<Value*, Module> observer_for_value_;
480
481 // Map from values from callsite into the values in the CallMethod graph
482 // key of the map is the value from caller graph, and the value of the map
483 // is the list of values in the callee graph (the graph
484 // corresponding to the called method),
485 // the reason it is a set is that a value in the caller graph
486 // can both correspond to the output of one callee graph and input of another
487 // callee graph.
488 //
489 // example:
490 // // top level module
491 // %x1 = conv(%x0)
492 // %x2 = prim::CallFunction(%foo, %x1)
493 //
494 // // graph of %foo
495 // %y2 = conv(%y1)
496 // return %y2
497 //
498 // boundary_value_map = {
499 // // current module's output values to corresponding return values from
500 // subgraph %x2: %y2,
501 // // current module's input values to corresponding input value to subgraph
502 // %x1: %y1,
503 // }
504 std::unordered_map<Value*, std::unordered_set<Value*>> boundary_value_map_;
505
506 std::unordered_set<Value*> observed_values_;
507
508 // This is used for the observed values to pass through the ops like flatten,
509 // so that output value of flatten does not need to be observed
510 // key is the output of the op, value is a vector of values that need
511 // to be observed in order to pass the observed property to the output
512 //
513 // example:
514 // %x1 = flatten(%x0) // pass_through
515 // %x2 = conv(%x1) // not pass_through
516 //
517 // pass_through_value_map_ = {
518 // %x1: [%x0],
519 // }
520 std::unordered_map<Value*, std::vector<Value*>> pass_through_value_map_;
521
522 // Unique id generator for observer module, used for generating
523 // unique observer names when we insert observer module, we
524 // record the current unique id used to avoid incrementing from 0
525 // every time to find a unique id.
526 int uid_ = 0;
527 // Set of observer forward call nodes
528 std::unordered_set<Node*> observer_nodes_;
529 // Map from block to a vector of observer name and observer modules we
530 // want to add to the module instance that has the block
531 std::unordered_map<Block*, NameModuleVector> block_observer_map_;
532
533 // Type of quantization for this pass.
534 QuantType quant_type_ = QuantType::STATIC;
535 // These are the IR patterns we match to skip inserting observers.
536 // They are compiled once on construction and used repeatedly within
537 // the pass.
538
539 // nn.Linear + nn.ReLU
540 const PatternInfo nn_linear_nn_relu = PatternInfo::parse_from_str(
541 R"(
542graph(%input, %linear, %relu):
543 %first_output = prim::CallMethod[name="forward"](%linear, %input)
544 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
545 return (%second_output) )",
546 {is_linear_module, is_relu_module});
547
548 // nn.Linear + F.relu
549 const PatternInfo nn_linear_f_relu = PatternInfo::parse_from_str(
550 R"(
551graph(%input, %linear, %relu, %inplace):
552 %first_output = prim::CallMethod[name="forward"](%linear, %input)
553 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
554 return (%second_output) )",
555 {is_linear_module, is_functional_relu});
556
557 // nn.Linear + aten::relu
558 const PatternInfo nn_linear_aten_relu = PatternInfo::parse_from_str(
559 R"(
560graph(%input, %linear, %relu):
561 %first_output = prim::CallMethod[name="forward"](%linear, %input)
562 %second_output = aten::relu(%first_output)
563 return (%second_output) )",
564 {is_linear_module});
565
566 // nn.Linear + aten::relu_
567 const PatternInfo nn_linear_aten_relu_ = PatternInfo::parse_from_str(
568 R"(
569graph(%input, %linear, %relu):
570 %first_output = prim::CallMethod[name="forward"](%linear, %input)
571 %second_output = aten::relu_(%first_output)
572 return (%second_output) )",
573 {is_linear_module});
574
575 // aten::linear + nn.ReLU
576 const PatternInfo aten_linear_nn_relu = PatternInfo::parse_from_str(
577 R"(
578graph(%input, %weight, %bias, %relu):
579 %first_output = aten::linear(%input, %weight, %bias)
580 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
581 return (%second_output) )",
582 {is_relu_module});
583
584 // aten::linear + F.relu
585 const PatternInfo aten_linear_f_relu = PatternInfo::parse_from_str(
586 R"(
587graph(%input, %weight, %bias, %relu, %inplace):
588 %first_output = aten::linear(%input, %weight, %bias)
589 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
590 return (%second_output) )",
591 {is_functional_relu});
592
593 // aten::linear + aten::relu
594 const PatternInfo aten_linear_aten_relu = PatternInfo::parse_from_str(
595 R"(
596graph(%input, %weight, %bias):
597 %first_output = aten::linear(%input, %weight, %bias)
598 %second_output = aten::relu(%first_output)
599 return (%second_output) )");
600
601 // aten::linear + aten::relu_
602 const PatternInfo aten_linear_aten_relu_ = PatternInfo::parse_from_str(
603 R"(
604graph(%input, %weight, %bias):
605 %first_output = aten::linear(%input, %weight, %bias)
606 %second_output = aten::relu_(%first_output)
607 return (%second_output) )");
608
609 const PatternInfo nn_conv1d_f_relu = PatternInfo::parse_from_str(
610 R"(
611graph(%self, %input, %conv, %relu, %inplace):
612 %first_output = prim::CallMethod[name="forward"](%conv, %input)
613 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
614 return (%second_output) )",
615 {is_conv1d_module, is_functional_relu});
616
617 const PatternInfo nn_conv1d_nn_relu = PatternInfo::parse_from_str(
618 R"(
619graph(%self, %input, %conv, %relu):
620 %first_output = prim::CallMethod[name="forward"](%conv, %input)
621 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
622 return (%second_output) )",
623 {is_conv1d_module, is_relu_module});
624
625 const PatternInfo nn_conv1d_aten_relu = PatternInfo::parse_from_str(
626 R"(
627graph(%self, %input, %conv):
628 %first_output = prim::CallMethod[name="forward"](%conv, %input)
629 %second_output = aten::relu(%first_output)
630 return (%second_output) )",
631 {is_conv1d_module});
632
633 const PatternInfo nn_conv1d_aten_relu_ = PatternInfo::parse_from_str(
634 R"(
635graph(%self, %input, %conv):
636 %first_output = prim::CallMethod[name="forward"](%conv, %input)
637 %second_output = aten::relu_(%first_output)
638 return (%second_output) )",
639 {is_conv1d_module});
640
641 const PatternInfo nn_conv2d_f_relu = PatternInfo::parse_from_str(
642 R"(
643graph(%self, %input, %conv, %relu, %inplace):
644 %first_output = prim::CallMethod[name="forward"](%conv, %input)
645 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
646 return (%second_output) )",
647 {is_conv2d_module, is_functional_relu});
648
649 const PatternInfo nn_conv2d_nn_relu = PatternInfo::parse_from_str(
650 R"(
651graph(%self, %input, %conv, %relu):
652 %first_output = prim::CallMethod[name="forward"](%conv, %input)
653 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
654 return (%second_output) )",
655 {is_conv2d_module, is_relu_module});
656
657 const PatternInfo nn_conv2d_aten_relu = PatternInfo::parse_from_str(
658 R"(
659graph(%self, %input, %conv):
660 %first_output = prim::CallMethod[name="forward"](%conv, %input)
661 %second_output = aten::relu(%first_output)
662 return (%second_output) )",
663 {is_conv2d_module});
664
665 const PatternInfo nn_conv2d_aten_relu_ = PatternInfo::parse_from_str(
666 R"(
667graph(%self, %input, %conv):
668 %first_output = prim::CallMethod[name="forward"](%conv, %input)
669 %second_output = aten::relu_(%first_output)
670 return (%second_output) )",
671 {is_conv2d_module});
672
673 const PatternInfo nn_conv3d_f_relu = PatternInfo::parse_from_str(
674 R"(
675graph(%self, %input, %conv, %relu, %inplace):
676 %first_output = prim::CallMethod[name="forward"](%conv, %input)
677 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
678 return (%second_output) )",
679 {is_conv3d_module, is_functional_relu});
680
681 const PatternInfo nn_conv3d_nn_relu = PatternInfo::parse_from_str(
682 R"(
683graph(%self, %input, %conv, %relu):
684 %first_output = prim::CallMethod[name="forward"](%conv, %input)
685 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
686 return (%second_output) )",
687 {is_conv3d_module, is_relu_module});
688
689 const PatternInfo nn_conv3d_aten_relu = PatternInfo::parse_from_str(
690 R"(
691graph(%self, %conv, %input):
692 %first_output = prim::CallMethod[name="forward"](%conv, %input)
693 %second_output = aten::relu(%first_output)
694 return (%second_output) )",
695 {is_conv3d_module});
696
697 const PatternInfo nn_conv3d_aten_relu_ = PatternInfo::parse_from_str(
698 R"(
699graph(%self, %input, %conv):
700 %first_output = prim::CallMethod[name="forward"](%conv, %input)
701 %second_output = aten::relu_(%first_output)
702 return (%second_output) )",
703 {is_conv3d_module});
704
705 const PatternInfo add_nn_relu = PatternInfo::parse_from_str(
706 R"(
707graph(%self, %a, %b, %alpha, %relu):
708 %first_output = aten::add(%a, %b, %alpha)
709 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
710 return (%second_output) )",
711 {aten_add_alpha_is_one, is_relu_module});
712
713 const PatternInfo add_f_relu = PatternInfo::parse_from_str(
714 R"(
715graph(%self, %a, %b, %alpha, %relu, %inplace):
716 %first_output = aten::add(%a, %b, %alpha)
717 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
718 return (%second_output) )",
719 {aten_add_alpha_is_one, is_functional_relu});
720
721 const PatternInfo inplace_add_nn_relu = PatternInfo::parse_from_str(
722 R"(
723graph(%self, %a, %b, %alpha, %relu):
724 %first_output = aten::add_(%a, %b, %alpha)
725 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
726 return (%second_output) )",
727 {aten_add_alpha_is_one, is_relu_module});
728
729 const PatternInfo inplace_add_f_relu = PatternInfo::parse_from_str(
730 R"(
731graph(%self, %a, %b, %alpha, %relu, %inplace):
732 %first_output = aten::add_(%a, %b, %alpha)
733 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
734 return (%second_output) )",
735 {aten_add_alpha_is_one, is_functional_relu});
736
737 const PatternInfo add_aten_relu = PatternInfo::parse_from_str(R"(
738graph(%self, %a, %b, %alpha):
739 %first_output = aten::add(%a, %b, %alpha)
740 %second_output = aten::relu(%first_output)
741 return (%second_output) )");
742
743 const PatternInfo add_aten_relu_ = PatternInfo::parse_from_str(R"(
744graph(%self, %a, %b, %alpha):
745 %first_output = aten::add(%a, %b, %alpha)
746 %second_output = aten::relu_(%first_output)
747 return (%second_output) )");
748
749 const PatternInfo inplace_add_aten_relu = PatternInfo::parse_from_str(R"(
750graph(%self, %a, %b, %alpha):
751 %first_output = aten::add_(%a, %b, %alpha)
752 %second_output = aten::relu(%first_output)
753 return (%second_output) )");
754
755 const PatternInfo inplace_add_aten_relu_ = PatternInfo::parse_from_str(R"(
756graph(%self, %a, %b, %alpha):
757 %first_output = aten::add_(%a, %b, %alpha)
758 %second_output = aten::relu_(%first_output)
759 return (%second_output) )");
760
761 const PatternInfo nn_bn2d_nn_relu = PatternInfo::parse_from_str(
762 R"(
763graph(%self, %input, %batchnorm, %relu):
764 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
765 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
766 return (%second_output) )",
767 {is_batchnorm2d_module, is_relu_module});
768
769 const PatternInfo nn_bn2d_f_relu = PatternInfo::parse_from_str(
770 R"(
771graph(%self, %input, %batchnorm, %relu, %inplace):
772 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
773 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
774 return (%second_output) )",
775 {is_batchnorm2d_module, is_functional_relu});
776
777 const PatternInfo nn_bn2d_aten_relu = PatternInfo::parse_from_str(
778 R"(
779graph(%self, %input, %batchnorm):
780 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
781 %second_output = aten::relu(%first_output)
782 return (%second_output) )",
783 {is_batchnorm2d_module});
784
785 const PatternInfo nn_bn2d_aten_relu_ = PatternInfo::parse_from_str(
786 R"(
787graph(%self, %input, %batchnorm):
788 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
789 %second_output = aten::relu_(%first_output)
790 return (%second_output) )",
791 {is_batchnorm2d_module});
792
793 const PatternInfo nn_bn3d_nn_relu = PatternInfo::parse_from_str(
794 R"(
795graph(%self, %input, %batchnorm, %relu):
796 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
797 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
798 return (%second_output) )",
799 {is_batchnorm3d_module, is_relu_module});
800
801 const PatternInfo nn_bn3d_f_relu = PatternInfo::parse_from_str(
802 R"(
803graph(%self, %input, %batchnorm, %relu, %inplace):
804 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
805 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
806 return (%second_output) )",
807 {is_batchnorm3d_module, is_functional_relu});
808
809 const PatternInfo nn_bn3d_aten_relu = PatternInfo::parse_from_str(
810 R"(
811graph(%self, %input, %batchnorm):
812 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
813 %second_output = aten::relu(%first_output)
814 return (%second_output) )",
815 {is_batchnorm3d_module});
816
817 const PatternInfo nn_bn3d_aten_relu_ = PatternInfo::parse_from_str(
818 R"(
819graph(%self, %input, %batchnorm):
820 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
821 %second_output = aten::relu_(%first_output)
822 return (%second_output) )",
823 {is_batchnorm3d_module});
824
825 const PatternInfo mul_nn_relu = PatternInfo::parse_from_str(
826 R"(
827graph(%self, %a, %b, %relu):
828 %first_output = aten::mul(%a, %b)
829 %second_output = prim::CallMethod[name="forward"](%relu, %first_output)
830 return (%second_output) )",
831 {is_relu_module});
832
833 const PatternInfo mul_f_relu = PatternInfo::parse_from_str(
834 R"(
835graph(%self, %a, %b, %relu, %inplace):
836 %first_output = aten::mul(%a, %b)
837 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
838 return (%second_output) )",
839 {is_functional_relu});
840
841 const PatternInfo inplace_mul_nn_relu = PatternInfo::parse_from_str(
842 R"(
843graph(%self, %a, %b, %relu):
844 %first_output = aten::mul_(%a, %b)
845 %second_output = prim::CallMethod[name="forward"](%relu, %first_output)
846 return (%second_output) )",
847 {is_relu_module});
848
849 const PatternInfo inplace_mul_f_relu = PatternInfo::parse_from_str(
850 R"(
851graph(%self, %a, %b, %relu, %inplace):
852 %first_output = aten::mul_(%a, %b)
853 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
854 return (%second_output) )",
855 {is_functional_relu});
856
857 const PatternInfo mul_aten_relu = PatternInfo::parse_from_str(R"(
858graph(%self, %a, %b):
859 %first_output = aten::mul(%a, %b)
860 %second_output = aten::relu(%first_output)
861 return (%second_output) )");
862
863 const PatternInfo mul_aten_relu_ = PatternInfo::parse_from_str(R"(
864graph(%self, %a, %b):
865 %first_output = aten::mul(%a, %b)
866 %second_output = aten::relu_(%first_output)
867 return (%second_output) )");
868
869 const PatternInfo inplace_mul_aten_relu = PatternInfo::parse_from_str(R"(
870graph(%self, %a, %b):
871 %first_output = aten::mul_(%a, %b)
872 %second_output = aten::relu(%first_output)
873 return (%second_output) )");
874
875 const PatternInfo inplace_mul_aten_relu_ = PatternInfo::parse_from_str(R"(
876graph(%self, %a, %b):
877 %first_output = aten::mul_(%a, %b)
878 %second_output = aten::relu_(%first_output)
879 return (%second_output) )");
880
881 const std::vector<std::reference_wrapper<const PatternInfo>> delay_patterns =
882 {
883 nn_linear_f_relu, nn_linear_nn_relu,
884 nn_linear_aten_relu, nn_linear_aten_relu_,
885 aten_linear_f_relu, aten_linear_nn_relu,
886 aten_linear_aten_relu, aten_linear_aten_relu_,
887
888 nn_conv1d_f_relu, nn_conv1d_nn_relu,
889 nn_conv1d_aten_relu, nn_conv1d_aten_relu_,
890 nn_conv2d_f_relu, nn_conv2d_nn_relu,
891 nn_conv2d_aten_relu, nn_conv2d_aten_relu_,
892 nn_conv3d_f_relu, nn_conv3d_nn_relu,
893 nn_conv3d_aten_relu, nn_conv3d_aten_relu_,
894
895 add_nn_relu, add_f_relu,
896 inplace_add_nn_relu, inplace_add_f_relu,
897 add_aten_relu, add_aten_relu_,
898 inplace_add_aten_relu, inplace_add_aten_relu_,
899
900 nn_bn2d_nn_relu, nn_bn2d_f_relu,
901 nn_bn2d_aten_relu, nn_bn2d_aten_relu_,
902 nn_bn3d_nn_relu, nn_bn3d_f_relu,
903 nn_bn3d_aten_relu, nn_bn3d_aten_relu_,
904
905 mul_nn_relu, mul_f_relu,
906 inplace_mul_nn_relu, inplace_mul_f_relu,
907 mul_aten_relu, mul_aten_relu_,
908 inplace_mul_aten_relu, inplace_mul_aten_relu_,
909 };
910
911 bool insert_reset_observer_method_{false};
912 std::string reset_observer_method_name_;
913};
914
915ModuleMethodVector InsertObserversHelper::getInvokedMethods(
916 Module& module,
917 const std::string& method_name) {
918 ModuleMethodVector invoked_methods;
919 Method method = module.get_method(method_name);
920 auto graph = method.graph();
921
922 std::stack<Block*> blocks_to_visit;
923 blocks_to_visit.push(graph->block());
924 while (!blocks_to_visit.empty()) {
925 Block* b = blocks_to_visit.top();
926 blocks_to_visit.pop();
927 for (Node* n : b->nodes()) {
928 // Skip observer nodes
929 if (observer_nodes_.count(n)) {
930 continue;
931 }
932 if (n->kind() == prim::CallMethod) {
933 auto m_opt = getInvokedModuleOpt(module, n, graph->inputs()[0]);
934 if (m_opt.has_value()) {
935 invoked_methods.emplace_back(*m_opt, n->s(attr::name));
936 }
937 }
938
939 for (Block* subblock : n->blocks()) {
940 blocks_to_visit.push(subblock);
941 }
942 }
943 }
944 return invoked_methods;
945}
946
947void InsertObserversHelper::insertObserverFor(
948 Value* v,
949 Module& module,
950 const Module& observer_module,
951 NameModuleVector& observer_name_and_modules) {
952 if (observed_values_.count(v)) {
953 return;
954 }
955 GRAPH_DEBUG("Inserting observer for:", v->debugName());
956 Module observer = observer_module.deepcopy();
957 std::string observer_name = "_observer_" + c10::to_string(uid_++);
958 while (module.hasattr(observer_name)) {
959 observer_name = "_observer_" + c10::to_string(uid_++);
960 }
961 module.register_module(observer_name, observer);
962 observer_name_and_modules.emplace_back(observer_name, observer);
963
964 auto* g = v->owningGraph();
965 // Get handle of observer module
966 Node* observer_instance =
967 g->createGetAttr(g->inputs()[0], observer_name)->insertAfter(v->node());
968 observer_instance->output()->setDebugName(observer_name);
969
970 {
971 WithInsertPoint guard(observer_instance->next());
972 // Match arguments to types of observer's arguments
973 MatchedSchema forward_matched_schema = matchSchema(
974 observer.get_method("forward").function().getSchema(),
975 v->node()->sourceRange(),
976 *g,
977 {observer_instance->output(), v},
978 {});
979 // Insert call to observer's forward
980 Node* call = g->insertMethodCall("forward", forward_matched_schema)->node();
981 call->output()->copyMetadata(v);
982
983 // Replace v with the output of observer
984 v->replaceAllUsesWith(call->output());
985 // The above also replaced the input to `call`, so switch it back to
986 // the correct value
987 call->replaceInput(1, v);
988 observer_nodes_.emplace(call);
989 observed_values_.insert(call->output());
990 }
991}
992
993void InsertObserversHelper::insertObserverResetMinMax(
994 Module& module,
995 const NameModuleVector& observer_name_and_modules) {
996 if (observer_name_and_modules.empty()) {
997 return;
998 }
999 auto reset_min_max_opt = module.find_method(reset_observer_method_name_);
1000 if (!reset_min_max_opt.has_value()) {
1001 std::shared_ptr<Graph> reset_observer_graph = std::make_shared<Graph>();
1002 Value* module_value = reset_observer_graph->addInput("self");
1003 Node* output_node = reset_observer_graph->createNone();
1004 reset_observer_graph->insertNode(output_node);
1005 reset_observer_graph->registerOutput(output_node->output());
1006 module_value->setType(module._ivalue()->type());
1007 const auto method_name = c10::QualifiedName(
1008 *(module.type()->name()), reset_observer_method_name_);
1009 auto reset_observer_fn =
1010 module._ivalue()->compilation_unit()->create_function(
1011 method_name, std::move(reset_observer_graph));
1012 auto self_arg = c10::Argument("self", module.type());
1013 auto output_arg = c10::Argument("none", output_node->output()->type());
1014 auto schema = c10::FunctionSchema(
1015 reset_observer_method_name_,
1016 "",
1017 {std::move(self_arg)},
1018 {std::move(output_arg)});
1019 reset_observer_fn->setSchema(std::move(schema));
1020 module.type()->addMethod(reset_observer_fn);
1021 }
1022 auto reset_min_max_graph =
1023 module.get_method(reset_observer_method_name_).graph();
1024 Value* self = reset_min_max_graph->inputs()[0];
1025
1026 for (const auto& pair : observer_name_and_modules) {
1027 const auto& observer_name = pair.first;
1028 const auto& observer = pair.second;
1029 Value* observer_value =
1030 reset_min_max_graph->insertGetAttr(self, observer_name);
1031 MatchedSchema reset_minmax_schema = matchSchema(
1032 observer.get_method("reset_min_max_vals").function().getSchema(),
1033 observer_value->node()->sourceRange(),
1034 *reset_min_max_graph,
1035 {observer_value},
1036 {});
1037 reset_min_max_graph->insertMethodCall(
1038 "reset_min_max_vals", reset_minmax_schema);
1039 }
1040}
1041
1042void InsertObserversHelper::delayObservingValuesInPattern(
1043 Graph& graph,
1044 const PatternInfo& pattern) {
1045 const Graph& pattern_graph = *pattern.pattern_graph;
1046 const std::unordered_map<std::string, Value*>& vmap = pattern.vmap;
1047
1048 const auto& matches = findPatternMatches(pattern_graph, graph);
1049 for (const auto& match : matches) {
1050 if (!std::all_of(
1051 pattern.filters.begin(),
1052 pattern.filters.end(),
1053 [&](const MatchFilter& f) { return f(match, vmap); })) {
1054 continue;
1055 }
1056 auto first_output = match.values_map.at(vmap.at("first_output"));
1057 auto second_output = match.values_map.at(vmap.at("second_output"));
1058 GRAPH_DEBUG(
1059 "Delay observation for value in function pattern:",
1060 first_output->debugName(),
1061 " to ",
1062 second_output->debugName());
1063 delay_observation_map_[first_output] = second_output;
1064 }
1065}
1066
1067void InsertObserversHelper::addValuesToDelayObservation(
1068 const Module& module,
1069 const std::string& method_name) {
1070 Method method = module.get_method(method_name);
1071 auto graph = method.graph();
1072
1073 for (const auto& pattern : delay_patterns) {
1074 delayObservingValuesInPattern(*graph, pattern);
1075 }
1076}
1077
1078void InsertObserversHelper::fillPassThroughValueMap(
1079 const std::shared_ptr<Graph>& graph) {
1080 std::stack<Block*> blocks_to_visit;
1081 blocks_to_visit.push(graph->block());
1082 while (!blocks_to_visit.empty()) {
1083 Block* b = blocks_to_visit.top();
1084 blocks_to_visit.pop();
1085 for (Node* n : b->nodes()) {
1086 if (userDefinedCallFunction(n)) {
1087 auto g = getCallFunctionGraph(n);
1088 blocks_to_visit.push(g->block());
1089 }
1090 for (auto* output : n->outputs()) {
1091 for (auto* input : getPassThroughInputs(output)) {
1092 pass_through_value_map_[output].push_back(input);
1093 }
1094 }
1095 for (Block* subblock : n->blocks()) {
1096 blocks_to_visit.push(subblock);
1097 }
1098 }
1099 }
1100}
1101
1102void InsertObserversHelper::fillBoundaryValueMap(
1103 Module& module,
1104 const std::string& method_name) {
1105 for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1106 auto& invoked_module = std::get<0>(invoked_method);
1107 const auto& invoked_method_name = std::get<1>(invoked_method);
1108 fillBoundaryValueMap(invoked_module, invoked_method_name);
1109 }
1110
1111 auto graph = module.get_method(method_name).graph();
1112 std::stack<Block*> blocks_to_visit;
1113 blocks_to_visit.push(graph->block());
1114 auto* self = graph->inputs()[0];
1115 while (!blocks_to_visit.empty()) {
1116 Block* b = blocks_to_visit.top();
1117 blocks_to_visit.pop();
1118 for (Node* n : b->nodes()) {
1119 if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
1120 std::shared_ptr<Graph> g;
1121 // offset of input for the caller node, since the first
1122 // input of CallFunction is the function node and the graph
1123 // for CallFunction start with actual input
1124 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1125 size_t input_offset;
1126 if (n->kind() == prim::CallMethod) {
1127 auto m_opt = getInvokedModuleOpt(module, n, self);
1128 if (!m_opt.has_value()) {
1129 continue;
1130 }
1131 auto m = *m_opt;
1132 g = m.get_method(n->s(attr::name)).graph();
1133 input_offset = 0;
1134 } else {
1135 g = getCallFunctionGraph(n);
1136 input_offset = 1;
1137 }
1138 // add mapping from callsite value to value in called graph
1139 for (auto i = 0U; i < g->outputs().size(); ++i) {
1140 auto* return_val = g->outputs()[i];
1141 GRAPH_DEBUG(
1142 "Boundary Map[return]:",
1143 n->output(i)->debugName(),
1144 " -> ",
1145 return_val->debugName());
1146 boundary_value_map_[n->output(i)].insert(return_val);
1147 }
1148 for (auto i = 0U; i < g->inputs().size(); ++i) {
1149 auto caller_input_index = i + input_offset;
1150 auto* caller_input = n->input(caller_input_index);
1151 auto* input_val = g->inputs()[i];
1152 GRAPH_DEBUG(
1153 "Boundary Map[input]:",
1154 caller_input->debugName(),
1155 " -> ",
1156 input_val->debugName());
1157 boundary_value_map_[caller_input].insert(input_val);
1158 }
1159 } else if (n->kind() == prim::If) {
1160 for (Block* subblock : n->blocks()) {
1161 blocks_to_visit.push(subblock);
1162 for (Value* v : n->outputs()) {
1163 Value* subblock_output = subblock->outputs()[v->offset()];
1164 GRAPH_DEBUG(
1165 "Boundary Map[if_output]:",
1166 v->debugName(),
1167 " -> ",
1168 subblock_output->debugName());
1169 boundary_value_map_[v].insert(subblock_output);
1170 }
1171 }
1172 } else {
1173 for (Block* subblock : n->blocks()) {
1174 blocks_to_visit.push(subblock);
1175 }
1176 }
1177 }
1178 }
1179}
1180
1181void InsertObserversHelper::preprocess(
1182 Module& module,
1183 const std::string& method_name) {
1184 // run preprocess for child module before parent, since preprocess
1185 // mutates the graph and it might affect passes like fillBoundaryValueMap
1186 for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1187 auto& invoked_module = std::get<0>(invoked_method);
1188 const auto& invoked_method_name = std::get<1>(invoked_method);
1189 preprocess(invoked_module, invoked_method_name);
1190 }
1191
1192 Method method = module.get_method(method_name);
1193 auto graph = method.graph();
1194 // Inline fork-wait calls
1195 InlineForkWait(graph);
1196 // fuse decomposed linear into aten::linear
1197 FuseLinear(graph);
1198 replaceConvolutionWithAtenConv(graph);
1199 RemoveListMutation(graph);
1200}
1201
1202void InsertObserversHelper::analyze(
1203 Module& module,
1204 const std::string& method_name) {
1205 for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1206 auto& invoked_module = std::get<0>(invoked_method);
1207 const auto& invoked_method_name = std::get<1>(invoked_method);
1208 analyze(invoked_module, invoked_method_name);
1209 }
1210
1211 // fill out various internal state which will be later used in
1212 // insertObservers to insert the correct observer
1213 addValuesToDelayObservation(module, method_name);
1214 fillValueObserverMap(module, method_name);
1215 Method method = module.get_method(method_name);
1216 auto graph = method.graph();
1217 fillPassThroughValueMap(graph);
1218}
1219
1220bool InsertObserversHelper::valueNeedsToBeQuantized(
1221 Value* v,
1222 const QConfig& qconfig) {
1223 if (isBiasOfConvOrLinear(v) ||
1224 !(v->type()->isSubtypeOf(*TensorType::get()) ||
1225 v->type()->isSubtypeOf(*ListType::ofTensors())) ||
1226 isEmbeddingBagNonInput(v)) {
1227 return false;
1228 }
1229 // For dynamic quantization we only insert observers at the input
1230 // of the quantizable function.
1231 if (quant_type_ == QuantType::STATIC) {
1232 // Check whether producer is quantizable
1233 if (!isWeightOnlyStaticQuantOp(v->node()) &&
1234 (nodeQuantizable(v->node()) || isPropagateQuantOp(v->node()))) {
1235 return true;
1236 }
1237 }
1238 if (quant_type_ == QuantType::DYNAMIC) {
1239 // Check the dtype of the observer module.
1240 Module observer_module = getObserverModuleFor(v, qconfig);
1241 auto scalar_type = observer_module.attr("dtype");
1242 // For inputs with Fp16 type that are not-weights we don't observer them for
1243 // dynamic quantization.
1244 if (scalar_type == at::ScalarType::Half && !isWeight(v)) {
1245 return false;
1246 }
1247 }
1248 // Check whether node input value is quantizable
1249 for (const auto& use : v->uses()) {
1250 if (useQuantizable(use, quant_type_)) {
1251 return true;
1252 }
1253 }
1254 return false;
1255}
1256
1257void InsertObserversHelper::removeActivationObservers() {
1258 std::vector<std::unordered_map<Value*, Module>::iterator>
1259 values_to_be_removed;
1260 for (auto it = observer_for_value_.begin(); it != observer_for_value_.end();
1261 it++) {
1262 if (!isWeight(it->first)) {
1263 values_to_be_removed.push_back(it);
1264 }
1265 }
1266 for (auto it : values_to_be_removed) {
1267 observer_for_value_.erase(it);
1268 }
1269}
1270
1271void InsertObserversHelper::fillValueObserverMap(
1272 Module& module,
1273 const std::string& method_name) {
1274 Method method = module.get_method(method_name);
1275 auto graph = method.graph();
1276
1277 if (visited_graph_of_observer_map_.count(graph.get())) {
1278 return;
1279 }
1280 visited_graph_of_observer_map_.insert(graph.get());
1281
1282 std::stack<Block*> blocks_to_visit;
1283 auto qconfig_opt = module_qconfig_map_.at(module._ivalue());
1284 if (!qconfig_opt) {
1285 return;
1286 }
1287 auto qconfig = *qconfig_opt;
1288 for (auto* v : graph->inputs()) {
1289 if (valueNeedsToBeQuantized(v, qconfig)) {
1290 GRAPH_DEBUG("Recording observer for ", v->debugName());
1291 GRAPH_DUMP("In graph:", v->owningGraph());
1292 observer_for_value_[v] = getObserverModuleFor(v, qconfig);
1293 }
1294 }
1295
1296 blocks_to_visit.push(graph->block());
1297 while (!blocks_to_visit.empty()) {
1298 Block* b = blocks_to_visit.top();
1299 blocks_to_visit.pop();
1300 for (Node* n : b->nodes()) {
1301 for (Value* v : n->outputs()) {
1302 if (valueNeedsToBeQuantized(v, qconfig)) {
1303 GRAPH_DEBUG("Recording observer for ", v->debugName());
1304 GRAPH_DUMP("In graph:", v->owningGraph());
1305 observer_for_value_[v] = getObserverModuleFor(v, qconfig);
1306 }
1307 }
1308
1309 for (Block* subblock : n->blocks()) {
1310 blocks_to_visit.push(subblock);
1311 }
1312 }
1313 }
1314}
1315
1316c10::optional<Module> InsertObserversHelper::getObserverFor(Value* v) {
1317 if (observer_for_value_.count(v)) {
1318 auto observer = observer_for_value_.at(v);
1319 GRAPH_DEBUG("Got observer module config for:", v->debugName());
1320 return observer;
1321 }
1322 c10::optional<Module> result;
1323 if (boundary_value_map_.count(v)) {
1324 for (Value* next : boundary_value_map_.at(v)) {
1325 GRAPH_DEBUG(
1326 "Going through boundary map:",
1327 v->debugName(),
1328 " --> ",
1329 next->debugName());
1330 GRAPH_DUMP("From graph:", v->owningGraph());
1331 GRAPH_DUMP("To graph:", next->owningGraph());
1332 auto observer_opt = getObserverFor(next);
1333 if (observer_opt) {
1334 // Need to make sure all values are
1335 // configured with same observer
1336 if (result) {
1337 TORCH_CHECK(
1338 *observer_opt == *result,
1339 "Expecting all values in the graph only configured with one observer");
1340 } else {
1341 result = observer_opt;
1342 }
1343 }
1344 }
1345 }
1346 GRAPH_DEBUG(
1347 "Observer module config for ", v->debugName(), ":", result.has_value());
1348 return result;
1349}
1350
1351std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
1352InsertObserversHelper::insertObservers(
1353 Module& module,
1354 const std::string& method_name,
1355 bool is_entry_point,
1356 std::unordered_set<Value*> graph_observed_values) {
1357 auto graph = module.get_method(method_name).graph();
1358 return insertObserversFor(
1359 graph->block(), module, graph_observed_values, is_entry_point);
1360}
1361
1362void InsertObserversHelper::recordObserved(
1363 Value* v,
1364 const Module& observer_module,
1365 std::unordered_map<Value*, Module>& values_to_observe,
1366 std::unordered_set<Value*>& block_observed_values) {
1367 Value* to_observe = v;
1368 if (delay_observation_map_.count(v)) {
1369 to_observe = delay_observation_map_.at(v);
1370 }
1371 values_to_observe[to_observe] = observer_module;
1372 block_observed_values.insert(to_observe);
1373}
1374
1375std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
1376InsertObserversHelper::insertObserversFor(
1377 Block* block,
1378 script::Module& module,
1379 std::unordered_set<Value*>& block_observed_values,
1380 bool is_entry_point,
1381 bool is_user_defined_function) {
1382 // input/output values, used to skip inserting observers
1383 // for input and output of the block and the owning graph,
1384 // we have to insert the observers at call site because
1385 // the graph itself can be shared
1386 std::unordered_set<Value*> inputs_outputs;
1387 // list of observer modules for input values
1388 std::vector<c10::optional<Module>> block_input_observers;
1389 // list of observer modules for output values
1390 std::vector<c10::optional<Module>> block_output_observers;
1391
1392 // if the current block is the block for entry point graph(the forward graph
1393 // of the top level module), we can insert observers in the block directly
1394 if (!is_entry_point) {
1395 auto* graph = block->owningGraph();
1396 // graph inputs/outputs
1397 for (auto list : {graph->inputs(), graph->outputs()}) {
1398 for (auto* v : list) {
1399 inputs_outputs.insert(v);
1400 }
1401 }
1402 // block outputs
1403 for (auto* v : block->outputs()) {
1404 inputs_outputs.insert(v);
1405 }
1406
1407 for (auto* v : block->inputs()) {
1408 block_input_observers.emplace_back(getObserverFor(v));
1409 }
1410
1411 for (auto* v : block->outputs()) {
1412 // we need explictly skip the values that are already observed
1413 // this might happen in subblocks for `if` since
1414 // these subblock has access to all values before the `if` node
1415 if (!isObserved(v, block_observed_values)) {
1416 block_output_observers.emplace_back(getObserverFor(v));
1417 } else {
1418 block_output_observers.emplace_back(c10::nullopt);
1419 }
1420 }
1421 }
1422
1423 // This means the block is been processed before, we just
1424 // need to attach observer modules and construct the information
1425 // needed by call site here
1426 bool visited = block_observer_map_.count(block);
1427 if (visited) {
1428 // instance clone of observer module and setAttr
1429 for (const auto& observer_attrs : block_observer_map_.at(block)) {
1430 const auto& name = std::get<0>(observer_attrs);
1431 const auto& observer = std::get<1>(observer_attrs);
1432 module._ivalue()->setAttr(name, observer.deepcopy()._ivalue());
1433 }
1434 }
1435 // NB: Why do we need to process the graph even if it's visited?
1436 // Reason is `block_observed_values` can
1437 // change depending on where the method is called, and
1438 // outputs that's been observed(third item of the returned result)
1439 // can change depending on that, so for each graph we'll need to go through
1440 // the whole process of inserting observers, the observers inserted in this
1441 // block won't change, but the information we return to the caller will change
1442 // based on `block_observed_values`
1443
1444 std::stack<Block*> blocks_to_visit;
1445 blocks_to_visit.push(block);
1446 auto* self = block->owningGraph()->inputs()[0];
1447 // We first construct a map from value to the module, then
1448 // insert observers for them later, this is to avoid interference
1449 // of the inserted observers with the analysis to decide where
1450 // to insert observers, also we only insert observers for
1451 // "intermediate values" that is not the input/output of the
1452 // graph
1453 std::unordered_map<Value*, Module> values_to_observe;
1454
1455 for (auto* v : block->inputs()) {
1456 if (!inputs_outputs.count(v) && !values_to_observe.count(v)) {
1457 if (auto observer_opt = getObserverFor(v)) {
1458 recordObserved(
1459 v, *observer_opt, values_to_observe, block_observed_values);
1460 }
1461 }
1462 }
1463 while (!blocks_to_visit.empty()) {
1464 Block* b = blocks_to_visit.top();
1465 blocks_to_visit.pop();
1466 for (Node* n : b->nodes()) {
1467 if (observer_nodes_.count(n)) {
1468 continue;
1469 }
1470 if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
1471 script::Module m;
1472 std::shared_ptr<Graph> g;
1473 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1474 size_t input_offset;
1475 bool is_udf_for_subblock = is_user_defined_function;
1476 if (n->kind() == prim::CallMethod) {
1477 auto m_opt = getInvokedModuleOpt(module, n, self);
1478 if (!m_opt.has_value()) {
1479 continue;
1480 }
1481 m = *m_opt;
1482 g = m.get_method(n->s(attr::name)).graph();
1483 input_offset = 0;
1484 } else { // CallFunction
1485 m = module;
1486 g = getCallFunctionGraph(n);
1487 input_offset = 1;
1488 is_udf_for_subblock = true;
1489 }
1490
1491 std::unordered_set<Value*> callee_observed_inputs;
1492 for (auto i = 0U; i < g->inputs().size(); ++i) {
1493 auto* node_input = n->input(i + input_offset);
1494 if (isObserved(node_input, block_observed_values)) {
1495 callee_observed_inputs.insert(g->inputs()[i]);
1496 }
1497 }
1498 auto* subblock = g->block();
1499 auto info_from_callee = insertObserversFor(
1500 subblock, m, callee_observed_inputs, false, is_udf_for_subblock);
1501 auto input_observers = std::get<0>(info_from_callee);
1502 auto output_observers = std::get<1>(info_from_callee);
1503 auto callee_observed_outputs = std::get<2>(info_from_callee);
1504 for (auto idx : callee_observed_outputs) {
1505 block_observed_values.insert(n->outputs()[idx]);
1506 }
1507 for (auto i = 0U; i < g->inputs().size(); ++i) {
1508 auto* node_input = n->input(i + input_offset);
1509 if (input_observers[i] && !inputs_outputs.count(node_input) &&
1510 !isObserved(node_input, block_observed_values)) {
1511 recordObserved(
1512 node_input,
1513 *input_observers[i],
1514 values_to_observe,
1515 block_observed_values);
1516 }
1517 }
1518 for (auto i = 0U; i < n->outputs().size(); ++i) {
1519 if (output_observers[i] && !inputs_outputs.count(n->output(i)) &&
1520 !isObserved(n->output(i), block_observed_values)) {
1521 recordObserved(
1522 n->output(i),
1523 *output_observers[i],
1524 values_to_observe,
1525 block_observed_values);
1526 }
1527 }
1528 } else if (n->kind() == prim::If) {
1529 // a vector recoding whether each output is observed or not
1530 std::vector<bool> aggregated_output_observe_state;
1531 for (Block* subblock : n->blocks()) {
1532 if (alwaysRaisesException(subblock)) {
1533 continue;
1534 }
1535 // subblock has access to all the values in the scope of prim::If,
1536 // so subblock_observed_values == block_observed_values
1537 auto info_from_subblock =
1538 insertObserversFor(subblock, module, block_observed_values);
1539 // subblock for prim::If doesn't have inputs
1540 auto output_observers = std::get<1>(info_from_subblock);
1541 auto subblock_observed_outputs = std::get<2>(info_from_subblock);
1542
1543 // We'll insert output observer for each subblock, and in the end
1544 // we will check if output of subblocks are quantized consistently
1545 for (size_t i = 0; i < subblock->outputs().size(); ++i) {
1546 Value* output = subblock->outputs()[i];
1547 if (output_observers[i] && !inputs_outputs.count(output) &&
1548 !isObserved(output, block_observed_values)) {
1549 recordObserved(
1550 output,
1551 *output_observers[i],
1552 values_to_observe,
1553 block_observed_values);
1554 }
1555 }
1556 for (auto idx : subblock_observed_outputs) {
1557 block_observed_values.insert(subblock->outputs()[idx]);
1558 }
1559 std::vector<bool> subblock_output_observe_state;
1560 for (size_t i = 0; i < subblock->outputs().size(); ++i) {
1561 Value* output = subblock->outputs()[i];
1562 subblock_output_observe_state.push_back(
1563 isObserved(output, block_observed_values));
1564 }
1565 if (!aggregated_output_observe_state.empty()) {
1566 TORCH_CHECK(
1567 aggregated_output_observe_state ==
1568 subblock_output_observe_state,
1569 "branches for `if` should return values that are observed "
1570 "consistently, if node:",
1571 *n);
1572 } else {
1573 aggregated_output_observe_state = subblock_output_observe_state;
1574 }
1575 }
1576 // mark the output of if as observed
1577 for (size_t i = 0; i < n->outputs().size(); ++i) {
1578 if (aggregated_output_observe_state[i]) {
1579 block_observed_values.insert(n->output(i));
1580 }
1581 }
1582 } else if (n->kind() == prim::Loop) {
1583 TORCH_WARN_ONCE(
1584 "prim::Loop is not yet supported in quantization, "
1585 "please make sure nothing needs to be quantized in the "
1586 "loop");
1587 }
1588 for (Value* v : n->outputs()) {
1589 propagateObservedProperty(v, block_observed_values);
1590 if (!inputs_outputs.count(v) && !isObserved(v, block_observed_values)) {
1591 auto observer_opt = getObserverFor(v);
1592 // If the node is one of the propagate quant node, e.g.
1593 // aten::cat, we should observe its output only
1594 // if the input of the node is observed
1595 if (observer_opt &&
1596 shouldObserve(n, block_observed_values, quant_type_)) {
1597 recordObserved(
1598 v, *observer_opt, values_to_observe, block_observed_values);
1599 }
1600 }
1601 }
1602 }
1603 }
1604 std::vector<size_t> output_idxs;
1605 for (auto i = 0U; i < block->outputs().size(); ++i) {
1606 if (isObserved(block->outputs()[i], block_observed_values)) {
1607 output_idxs.push_back(i);
1608 }
1609 }
1610 if (!visited) {
1611 NameModuleVector observer_name_and_modules;
1612 for (const auto& item : values_to_observe) {
1613 auto* v = item.first;
1614 auto observer = item.second;
1615 TORCH_CHECK(
1616 !is_user_defined_function,
1617 "Inserting observers for user defined functions is not "
1618 "supported right now");
1619 insertObserverFor(v, module, observer, observer_name_and_modules);
1620 }
1621 if (insertResetObserverMethod()) {
1622 insertObserverResetMinMax(module, observer_name_and_modules);
1623 }
1624 block_observer_map_[block] = observer_name_and_modules;
1625 }
1626 return std::make_tuple(
1627 block_input_observers, block_output_observers, output_idxs);
1628}
1629
1630void InsertObserversHelper::propagateObservedProperty(
1631 Value* output,
1632 std::unordered_set<Value*>& block_observed_values) {
1633 if (pass_through_value_map_.count(output)) {
1634 // since the vector is always non-empty, we will
1635 // not return the initial value
1636 bool all_observed = true;
1637 for (Value* v : pass_through_value_map_.at(output)) {
1638 all_observed &=
1639 observed_values_.count(v) || block_observed_values.count(v);
1640 }
1641 if (all_observed) {
1642 GRAPH_DEBUG("Pass through observed property in node:", *output->node());
1643 // This is to propagate observed property through
1644 // all ops that doesn't require observation
1645 block_observed_values.insert(output);
1646 }
1647 }
1648}
1649
1650} // namespace
1651
1652Module InsertObservers(
1653 Module& input_module,
1654 const std::string& method_name,
1655 const QConfigDict& qconfig_dict,
1656 bool inplace,
1657 QuantType quant_type) {
1658 ModuleQConfigMap map_before_clone;
1659 fillQConfigMap(input_module, qconfig_dict, map_before_clone);
1660 ModuleCloneHelper mh;
1661 Module module = mh.clone(input_module, map_before_clone, inplace);
1662 SwapFunctionalLinear(module);
1663 ModuleQConfigMap module_qconfig_map;
1664 // Since the types are changed after clone, we need to fill
1665 // the qconfig map again
1666 fillQConfigMap(module, qconfig_dict, module_qconfig_map);
1667 GRAPH_DEBUG("Quant type:", quant_type);
1668 InsertObserversHelper helper(module_qconfig_map, quant_type);
1669 helper.preprocess(module, method_name);
1670 helper.fillBoundaryValueMap(module, method_name);
1671 // analyze needs to run after fillBoundaryValueMap
1672 // since we need to know the boundary value mapping to trace
1673 // through the calls
1674 helper.analyze(module, method_name);
1675 helper.insertObservers(module, method_name, /* is_entry_point */ true);
1676 return module;
1677}
1678
1679Module InsertObserversForOnDevicePTQ(
1680 Module& input_module,
1681 const std::string& method_name,
1682 const QConfigDict& qconfig_dict,
1683 bool inplace,
1684 QuantType quant_type) {
1685 ModuleQConfigMap map_before_clone;
1686 fillQConfigMap(input_module, qconfig_dict, map_before_clone);
1687 ModuleCloneHelper mh;
1688 Module cloned_module = mh.clone(input_module, map_before_clone, inplace);
1689 std::shared_ptr<Graph> g = cloned_module.get_method(method_name).graph();
1690 SwapFunctionalLinear(g);
1691 std::string observer_method_name = "observe_" + method_name;
1692 cloneMethod(cloned_module, method_name, observer_method_name);
1693 ModuleQConfigMap module_qconfig_map;
1694 // Since the types are changed after clone, we need to fill
1695 // the qconfig map again
1696 fillQConfigMap(cloned_module, qconfig_dict, module_qconfig_map);
1697 GRAPH_DEBUG("Quant type:", quant_type);
1698 InsertObserversHelper helper(module_qconfig_map, quant_type);
1699 // Removes list mutation part is not clear. Is it needed
1700 helper.preprocess(cloned_module, observer_method_name);
1701 // Since we expect the graph to be inlined this should not have any use
1702 // However, this function does handle if blocks
1703 // Although as far as I understood If blocks are not really handled
1704 // in JIT quantization. Should we just protect against this. That is if we
1705 // find observable value inside If block? Also side effect of inlining is that
1706 // you will have multiple getattrs for the same attribute and thus potentially
1707 // multiple observers observing the same value. This will also lead to
1708 // increased size of the packed param struct. I dont expect this to be a
1709 // commong pattern but something to be aware fo Note that current quant
1710 // workflow does not prevent this anyway since during inset quant dequant
1711 // things are inlined anyway
1712 helper.fillBoundaryValueMap(cloned_module, observer_method_name);
1713 // analyze needs to run after fillBoundaryValueMap
1714 // since we need to know the boundary value mapping to trace
1715 // through the calls
1716 helper.analyze(cloned_module, observer_method_name);
1717 // Remove activation observer if quant_type is dynamic
1718 if (quant_type == QuantType::DYNAMIC) {
1719 helper.removeActivationObservers();
1720 }
1721 helper.setInsertResetObserverMethod(true, method_name);
1722 helper.insertObservers(
1723 cloned_module, observer_method_name, /* is_entry_point */ true);
1724 return cloned_module;
1725}
1726} // namespace jit
1727} // namespace torch
1728