1#include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
2
3#include <c10/util/irange.h>
4#include <torch/csrc/jit/frontend/schema_matching.h>
5#include <torch/csrc/jit/passes/canonicalize.h>
6#include <torch/csrc/jit/passes/dead_code_elimination.h>
7#include <torch/csrc/jit/passes/inliner.h>
8#include <torch/csrc/jit/passes/lower_tuples.h>
9
10#include <algorithm>
11
12namespace torch {
13namespace jit {
14
15namespace {
16
17bool isEligibleNode(Node* n) {
18 return n->kind() == prim::TracedModuleForward ||
19 n->kind() == prim::TracedFork;
20}
21
22// This pass does several things:
23// 1) It looks at TracedModuleForward nodes and resolves the type of `self`
24// for that (to-be) method call. It adds an input of that type to the
25// block, and adds the TracedAttr value corresponding to that `self`
26// value as a Node input. This ensures `self` is an explicit Use on
27// the node, a property we take advantage of downstream. Example:
28// 2) Convert all references to prim::TracedAttr values to prim::GetAttr
29// calls in the tightest scope possible. Concretely, for each use of
30// a prim::TracedAttr value, we compare the scope of that attribute
31// to the scope of the Use. We emit GetAttr nodes for all atoms
32// that are not shared between the two. For example, if an
33// attribute `f.param` is referenced in scope `f`, we emit a
34// GetAttr[name="param"](%self) node in the `f` block, where
35// `self` is the previously-added `self` argument to the block.
36// 3) Destroy all the prim::TracedAttr nodes, as they should have
37// no more uses.
38//
39// A quick example:
40//
41//
42// Input graph:
43//
44// graph(%self : ClassType<Module>,
45// %x : Float(3, 4)):
46// %1 : bool = prim::TracedAttr[scope="__module.training"]()
47// %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]()
48// %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]()
49// %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
50// = prim::TracedModuleForward[scope="__module.f"](),
51// block0():
52// %6 : Float(3, 4) = aten::mm(%x, %3),
53// -> ()
54// return (%6)
55//
56// The diff after step (1)
57//
58// - = prim::TracedModuleForward[scope="__module.f"](),
59// - block0():
60// + = prim::TracedModuleForward[scope="__module.f"](%2),
61// + block0(%self : ClassType<Module>):
62//
63// The diff after step (2)
64//
65// graph(%self.1 : ClassType<Module>,
66// %x : Float(3, 4)):
67// + %9 : ClassType<Module> = prim::GetAttr[name="f"](%self.1)
68// %1 : bool = prim::TracedAttr[scope="__module.training"]()
69// <....>
70// %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
71// - = prim::TracedModuleForward[scope="__module.f"](%2),
72// + = prim::TracedModuleForward[scope="__module.f"](%9),
73// block0(%self : ClassType<Module>):
74// - %6 : Float(3, 4) = aten::mm(%x, %3),
75// + %8 : Tensor = prim::GetAttr[name="param"](%self)
76// + %6 : Float(3, 4) = aten::mm(%x, %8),
77// -> ()
78// return (%6)
79//
80// The diff after step (3)
81//
82// - %1 : bool = prim::TracedAttr[scope="__module.training"]()
83// - %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]()
84// - %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]()
85// - %4 : bool = prim::TracedAttr[scope="__module.f.training"]()
86struct ConvertTracedAttrReferences {
87 void run(const std::shared_ptr<Graph>& graph) {
88 // Build a table mapping--for each TracedAttr node--the
89 // qualified name of the attribute to the Value* output
90 // of the Node.
91 buildAttrMap(graph);
92 // Step 1
93 addSelfArgToTracedForwardNodes(graph->block());
94 // Step 2
95 convertAttrReferencesToLocalGetAttrs(
96 graph->block(), "__module", graph->inputs()[0]);
97 // Step 3
98 destroyTracedAttrNodes(graph);
99 }
100
101 private:
102 void buildAttrMap(const std::shared_ptr<Graph>& graph) {
103 for (Node* n : graph->nodes()) {
104 if (n->kind() == prim::TracedAttr) {
105 attr_qualname_to_value[n->s(attr::scope)] = n->output();
106 }
107 }
108 }
109
110 void addSelfArgToTracedForwardNodes(Block* b) {
111 for (Node* n : b->nodes()) {
112 if (n->kind() == prim::TracedModuleForward) {
113 n->addInput(attr_qualname_to_value.at(n->s(attr::scope)));
114 n->blocks()[0]->addInput("self")->setType(
115 attr_qualname_to_value.at(n->s(attr::scope))->type());
116 addSelfArgToTracedForwardNodes(n->blocks()[0]);
117 }
118 if (n->kind() == prim::TracedFork) {
119 addSelfArgToTracedForwardNodes(n->blocks()[0]);
120 }
121 }
122 }
123
124 // This is a recursive function that descends down all blocks in the Graph
125 // (NB: not just TracedModuleForward blocks). Each descension has a
126 // corresponding `prefix`, i.e. the qualified name of the scope this
127 // Block represents (or the scope in which this block resides for
128 // non-TracedModuleForward nodes). We use this prefix to make decisions
129 // about whether to emit a GetAttr node for an attribute reference, or
130 // to defer that emission to the caller (in the case where an attribute
131 // reference does not reside in the `prefix` scope).
132 std::vector<Value*> convertAttrReferencesToLocalGetAttrs(
133 Block* b,
134 const c10::QualifiedName& prefix,
135 Value* self) {
136 // Store away Value*'s which are references to TracedAttr's which are
137 // not in the `prefix` scope. We pass this back to the caller, who
138 // should add these Values as explicit inputs as well as inductively
139 // make the same decision on those Values.
140 std::vector<Value*> unresolved_tracedattrs;
141 // To ensure we don't emit redundant GetAttr Nodes in a given scope,
142 // we maintain this map of original TracedAttr Value* to the Value*
143 // corresponding to the GetAttr for that attribute.
144 // We don't rely on CSE here because we currently can't reason about
145 // the correctness of CSE over GetAttr Nodes (i think)
146 std::unordered_map<Value*, Value*> local_remaps;
147
148 for (Node* n : b->nodes()) {
149 // The only difference between these two branches is for
150 // TracedModuleForward we advance the scope, but for other
151 // Nodes with Blocks we don't
152 if (n->kind() == prim::TracedModuleForward) {
153 auto sub_unresolved = convertAttrReferencesToLocalGetAttrs(
154 n->blocks()[0], n->s(attr::scope), n->blocks()[0]->inputs()[0]);
155 for (Value* v : sub_unresolved) {
156 n->addInput(v);
157 }
158 } else if (!n->blocks().empty()) {
159 for (Block* sub_block : n->blocks()) {
160 auto sub_unresolved =
161 convertAttrReferencesToLocalGetAttrs(sub_block, prefix, self);
162 for (Value* v : sub_unresolved) {
163 n->addInput(v);
164 }
165 }
166 }
167
168 for (size_t inp_idx = 0; inp_idx < n->inputs().size(); ++inp_idx) {
169 Value* inp = n->input(inp_idx);
170
171 // Short circuit: if we've already emitted a new Value for this
172 // attribute, just use that.
173 if (local_remaps.count(inp)) {
174 n->replaceInput(inp_idx, local_remaps[inp]);
175 continue;
176 }
177
178 WithInsertPoint guard(b->param_node()->next());
179 replaceTracedAttrInputOnNode(
180 n, inp_idx, prefix, self, local_remaps, unresolved_tracedattrs);
181 } // for (Value *inp : n->inputs())
182 } // for (Node *n : b->nodes())
183 return unresolved_tracedattrs;
184 }
185
186 void replaceTracedAttrInputOnNode(
187 Node* n,
188 size_t inp_idx,
189 const c10::QualifiedName& prefix,
190 Value* self,
191 std::unordered_map<Value*, Value*>& local_remaps,
192 std::vector<Value*>& unresolved_tracedattrs) {
193 auto inp = n->inputs()[inp_idx];
194 auto inp_node = inp->node();
195 auto prefix_atoms = prefix.atoms();
196 if (inp_node->kind() == prim::TracedAttr) {
197 auto attr_qualname = c10::QualifiedName(inp_node->s(attr::scope));
198 if (prefix.isPrefixOf(attr_qualname)) {
199 // Prefix case: the attribute resides in this scope or a
200 // sub-scope. Continually emit GetAttr nodes until we've reached
201 // the proper attribute.
202 auto attr_atoms = attr_qualname.atoms();
203 Value* replaced_value = self;
204 for (const auto i : c10::irange(attr_atoms.size())) {
205 if (i < prefix_atoms.size()) {
206 TORCH_INTERNAL_ASSERT(attr_atoms[i] == prefix_atoms[i]);
207 } else {
208 replaced_value = n->owningBlock()->owningGraph()->insertGetAttr(
209 replaced_value, attr_atoms[i]);
210 } // if (i < prefix_atoms.size())
211 } // for(const auto i : c10::irange(attr_atoms.size()))
212 n->replaceInput(inp_idx, replaced_value);
213 local_remaps[inp] = replaced_value;
214 } else {
215 // Non-prefix case: this is a use of an attribute somewhere
216 // higher in the Module hierarchy. Add a captured input to
217 // the block for this attribute and add to the vector of
218 // Value*'s for the caller to handle.
219 Value* remapped = n->owningBlock()->addInput()->copyMetadata(inp);
220 n->replaceInput(inp_idx, remapped);
221 unresolved_tracedattrs.push_back(inp);
222 local_remaps[inp] = remapped;
223 } // if (prefix.isPrefixOf(attr_qualname))
224 } // if (inp_node->kind() == prim::TracedAttr)
225 }
226
227 // The previous pass should have deleted all uses of TracedAttr
228 // nodes. Let's explicitly delete them here.
229 void destroyTracedAttrNodes(const std::shared_ptr<Graph>& graph) {
230 for (auto& kv : attr_qualname_to_value) {
231 kv.second->node()->destroy();
232 }
233 }
234
235 // For each prim::TracedAttr, record the `scope` value mapped
236 // to the Value* in the graph for that attribute.
237 std::unordered_map<std::string, Value*> attr_qualname_to_value;
238};
239
240// Iterate through all the nodes in program order and--for each use--
241// if the Value referenced is not in a scope that dominates the node,
242// add block and Node outputs to lift it into a scope in which
243// it dominates the Use.
244struct MakeDefsDominateUses {
245 MakeDefsDominateUses() = default;
246
247 void run(Block* b) {
248 processNode(b->param_node(), b);
249 for (Node* n : b->nodes()) {
250 processNode(n, b);
251 }
252 processNode(b->return_node(), b);
253 }
254
255 private:
256 void processNode(Node* n, Block* b) {
257 for (size_t i = 0; i < n->inputs().size(); ++i) {
258 Value* inp = n->inputs()[i];
259
260 // Already lifted to this level by a previously processed Use, switch to
261 // remapped value
262 if (remap.count(inp)) {
263 n->replaceInput(i, remap[inp]);
264 inp = remap[inp];
265 }
266
267 // This conditional isn't strictly necessary, but saves a lot of
268 // computation in the common case that we're using a local value.
269 if (inp->node()->owningBlock() != b) {
270 // Find the common ancestor block between this node and the node that
271 // produced this input. For this input Use to be valid, the Value's
272 // def must be present in this common ancestor node.
273 Block* common_ancestor = n->findCommonAncestorBlockWith(inp->node());
274
275 Value* v_itr = inp;
276 Block* b_itr = inp->node()->owningBlock();
277
278 // Starting from the initial def for this input, iterate to
279 // wider and wider blocks, adding Block outputs and Node outputs
280 // along the way. Then, log the lifted values in the remap table
281 // so we can make subsequent Uses refer to the lifted value, if
282 // the domination condition is met.
283 while (b_itr != common_ancestor) {
284 b_itr->registerOutput(v_itr);
285 Value* remapped =
286 b_itr->owningNode()->addOutput()->setType(v_itr->type());
287 v_itr = remapped;
288 b_itr = b_itr->owningNode()->owningBlock();
289 }
290 // From now on, references to `inp` will be replaced with
291 // references to `v_iter`, the lifted Value
292 remap[inp] = v_itr;
293 n->replaceInput(i, remap[inp]);
294 }
295 }
296
297 if (isEligibleNode(n)) {
298 run(n->blocks()[0]);
299 }
300 }
301
302 // This holds the mapping between a Value* we would see in a Use
303 // and the lifted value, if present. We use this to ensure that
304 // Uses refer to a Value* that is in a dominating scope.
305 using RemappingTable = std::unordered_map<Value*, Value*>;
306 RemappingTable remap;
307};
308
309// For all blocks except graph->block(), convert multiple block
310// returns to a TupleConstruct. This is required for turning the
311// blocks into Methods. (and in the case that self is nullptr,
312// it is required to properly inline the blocks).
313void convertReturnsToTuples(Block* b) {
314 for (Node* n : b->nodes()) {
315 if (n->kind() == prim::TracedFork) {
316 convertReturnsToTuples(n->blocks()[0]);
317 } else if (n->kind() == prim::TracedModuleForward) {
318 TORCH_INTERNAL_ASSERT(n->blocks().size() == 1);
319 convertReturnsToTuples(n->blocks()[0]);
320
321 Graph* g = b->owningGraph();
322 Block* sub_block = n->blocks()[0];
323 if (sub_block->outputs().size() > 1) {
324 {
325 // Make block returns go through a Tuple
326 WithInsertPoint guard(sub_block->return_node());
327 Node* return_tup =
328 g->insertNode(g->createTuple(sub_block->outputs()));
329 while (!sub_block->outputs().empty()) {
330 sub_block->eraseOutput(0);
331 }
332 sub_block->registerOutput(return_tup->output());
333 }
334
335 // Make node outputs a single tuple;
336 std::vector<TypePtr> types;
337 for (size_t i = 0; i < n->outputs().size(); ++i) {
338 types.push_back(n->output(i)->type());
339 }
340 Value* tup_output = n->addOutput()->setType(TupleType::create(types));
341 Node* tup_unpack = g->createTupleUnpack(tup_output)->insertAfter(n);
342 for (size_t i = 0; i < tup_unpack->outputs().size(); ++i) {
343 auto rev_idx = tup_unpack->outputs().size() - i - 1;
344 n->output(rev_idx)->replaceAllUsesWith(tup_unpack->output(rev_idx));
345 n->eraseOutput(rev_idx);
346 }
347 } else if (sub_block->outputs().empty()) {
348 WithInsertPoint guard(sub_block->return_node());
349 sub_block->registerOutput(g->insertNode(g->createNone())->output());
350 n->addOutput()->setType(NoneType::get());
351 }
352 }
353 }
354}
355
356// Lambda lift Values (i.e. add Graph inputs for the purpose of
357// referencing values that dominate the block) and convert
358// the block to a Graph. blocks()[0] on each TracedModuleForward then
359// appears as a Graph attribute attr::Subgraph
360void lambdaLiftBlocksAndConvertToGraph(Block* b) {
361 for (Node* n : b->nodes()) {
362 if (isEligibleNode(n)) {
363 lambdaLiftBlocksAndConvertToGraph(n->blocks()[0]);
364
365 auto graph = std::make_shared<Graph>();
366 std::unordered_map<Value*, Value*> remaps;
367 graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) {
368 if (!remaps.count(v)) {
369 remaps[v] = graph->addInput()->copyMetadata(v);
370 n->addInput(v);
371 }
372 return remaps[v];
373 });
374 LintGraph(graph);
375 n->g_(attr::Subgraph, graph);
376 n->eraseBlock(0);
377 }
378 }
379}
380
381// Find a unique name to add this method as
382// We try {method_name}, {method_name}1, {method_name}2, ...
383std::string mangleMethodName(
384 const std::string& method_name,
385 const ClassTypePtr& mod_type) {
386 for (size_t method_idx = 0;; method_idx++) {
387 auto mangled = method_name;
388 if (method_idx != 0) {
389 mangled += c10::to_string(method_idx);
390 }
391 bool found = false;
392 for (Function* fn : mod_type->methods()) {
393 if (fn->name() == mangled) {
394 found = true;
395 break;
396 }
397 }
398 if (!found) {
399 return mangled;
400 }
401 }
402 TORCH_INTERNAL_ASSERT(false);
403}
404
405// Register the attr::Subgraph Graph values as Functions in the
406// class compilation unit and register that Function as a method
407// on the corresponding Module in the Module hierarchy. Note that we
408// unique the methods by naming them forward, forward1, forward2...
409void createMethodCalls(const std::shared_ptr<Graph>& g) {
410 for (auto node_itr = g->nodes().begin(); node_itr != g->nodes().end();) {
411 Node* n = *node_itr++;
412 if (n->kind() == prim::TracedFork) {
413 createMethodCalls(n->g(attr::Subgraph));
414 } else if (n->kind() == prim::TracedModuleForward) {
415 WithInsertPoint ip(n);
416
417 ClassTypePtr callee_mod_type = n->input(0)->type()->expect<ClassType>();
418
419 createMethodCalls(n->g(attr::Subgraph));
420
421 auto mangled_method_name = mangleMethodName("forward", callee_mod_type);
422 auto qualname = c10::QualifiedName(
423 callee_mod_type->name().value(), mangled_method_name);
424 Function* f = callee_mod_type->compilation_unit()->create_function(
425 qualname, n->g(attr::Subgraph));
426 callee_mod_type->addMethod(f);
427
428 std::vector<NamedValue> nvs;
429 for (Value* i : n->inputs()) {
430 nvs.emplace_back(i->node()->sourceRange(), i);
431 }
432 auto schema = matchSchema(f->getSchema(), n->sourceRange(), *g, nvs, {});
433 Value* retval = g->insertMethodCall(f->qualname().name(), schema);
434 n->output()->replaceAllUsesWith(retval);
435 n->destroy();
436 }
437 }
438}
439
440void inlineScopeBlocks(Block* b) {
441 for (auto n_itr = b->nodes().begin(); n_itr != b->nodes().end();) {
442 Node* n = *n_itr++;
443 for (Block* sub_b : n->blocks()) {
444 inlineScopeBlocks(sub_b);
445 }
446 if (n->kind() == prim::TracedModuleForward) {
447 // Convert the block to a graph so we can inline it
448 auto graph = std::make_shared<Graph>();
449 std::unordered_map<Value*, Value*> remaps;
450 graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) {
451 remaps[v] = graph->block()->addInput()->copyMetadata(v);
452 n->addInput(v);
453 return remaps[v];
454 });
455
456 WithInsertPoint insert_point(n);
457 AT_ASSERT(n->inputs().size() == graph->inputs().size());
458 auto new_outputs = insertGraph(*n->owningGraph(), *graph, n->inputs());
459 const auto& old_outputs = n->outputs();
460
461 AT_ASSERT(new_outputs.size() == old_outputs.size());
462 for (const auto i : c10::irange(old_outputs.size())) {
463 old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
464 }
465 n->destroy();
466 }
467 }
468}
469
470void convertTracedForksToRealForks(const std::shared_ptr<Graph>& g) {
471 for (auto itr = g->nodes().begin(); itr != g->nodes().end();) {
472 Node* n = *itr++;
473 if (n->kind() == prim::TracedFork) {
474 WithInsertPoint guard(n);
475 Node* new_fork_node =
476 g->insertNode(g->create(prim::fork, n->outputs().size()))
477 ->copyAttributes(*n);
478 for (Value* i : n->inputs()) {
479 new_fork_node->addInput(i);
480 }
481 for (size_t i = 0; i < new_fork_node->outputs().size(); ++i) {
482 new_fork_node->outputs()[i]->copyMetadata(n->outputs()[i]);
483 n->outputs()[i]->replaceAllUsesWith(new_fork_node->outputs()[i]);
484 }
485 n->destroy();
486 }
487 }
488}
489
490// Run a few clean-up passes to make the graph a bit cleaner.
491void runCleanupPasses(const std::shared_ptr<Graph>& g) {
492 for (Node* n : g->nodes()) {
493 if (n->kind() == prim::TracedFork) {
494 auto subgraph = n->g(attr::Subgraph);
495 if (getInlineEverythingMode()) {
496 Inline(*subgraph);
497 }
498 convertTracedForksToRealForks(subgraph);
499 LowerSimpleTuples(subgraph);
500 EliminateDeadCode(subgraph);
501 LintGraph(subgraph);
502 }
503 }
504 if (getInlineEverythingMode()) {
505 Inline(*g);
506 }
507 convertTracedForksToRealForks(g);
508 LowerSimpleTuples(g);
509 EliminateDeadCode(g);
510 LintGraph(g);
511}
512
513void runCleanupPasses(Module* m) {
514 auto methods = m->get_methods();
515 for (auto module : m->children()) {
516 runCleanupPasses(&module);
517 }
518 for (auto& method : methods) {
519 runCleanupPasses(method.graph());
520 }
521}
522
523} // namespace
524
525void FixupTraceScopeBlocks(std::shared_ptr<Graph>& graph, Module* self) {
526 if (self) {
527 ConvertTracedAttrReferences().run(graph);
528 } else {
529 for (Node* n : graph->nodes()) {
530 TORCH_INTERNAL_ASSERT(n->kind() != prim::TracedAttr);
531 }
532 }
533 MakeDefsDominateUses().run(graph->block());
534 convertReturnsToTuples(graph->block());
535 if (!self) {
536 // We have no Module, so we're just going to inline everything.
537 // This should give us a totally flat graph.
538 inlineScopeBlocks(graph->block());
539 // For TracedFork nodes
540 lambdaLiftBlocksAndConvertToGraph(graph->block());
541 runCleanupPasses(graph);
542 } else {
543 lambdaLiftBlocksAndConvertToGraph(graph->block());
544 createMethodCalls(graph);
545 runCleanupPasses(self);
546 // `graph` isn't referenced in `self` yet, so we need to run
547 // this separately
548 runCleanupPasses(graph);
549 }
550}
551
552} // namespace jit
553} // namespace torch
554