1 | #include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/jit/frontend/schema_matching.h> |
5 | #include <torch/csrc/jit/passes/canonicalize.h> |
6 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
7 | #include <torch/csrc/jit/passes/inliner.h> |
8 | #include <torch/csrc/jit/passes/lower_tuples.h> |
9 | |
10 | #include <algorithm> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | namespace { |
16 | |
17 | bool isEligibleNode(Node* n) { |
18 | return n->kind() == prim::TracedModuleForward || |
19 | n->kind() == prim::TracedFork; |
20 | } |
21 | |
22 | // This pass does several things: |
23 | // 1) It looks at TracedModuleForward nodes and resolves the type of `self` |
24 | // for that (to-be) method call. It adds an input of that type to the |
25 | // block, and adds the TracedAttr value corresponding to that `self` |
26 | // value as a Node input. This ensures `self` is an explicit Use on |
27 | // the node, a property we take advantage of downstream. Example: |
28 | // 2) Convert all references to prim::TracedAttr values to prim::GetAttr |
29 | // calls in the tightest scope possible. Concretely, for each use of |
30 | // a prim::TracedAttr value, we compare the scope of that attribute |
31 | // to the scope of the Use. We emit GetAttr nodes for all atoms |
32 | // that are not shared between the two. For example, if an |
33 | // attribute `f.param` is referenced in scope `f`, we emit a |
34 | // GetAttr[name="param"](%self) node in the `f` block, where |
35 | // `self` is the previously-added `self` argument to the block. |
36 | // 3) Destroy all the prim::TracedAttr nodes, as they should have |
37 | // no more uses. |
38 | // |
39 | // A quick example: |
40 | // |
41 | // |
42 | // Input graph: |
43 | // |
44 | // graph(%self : ClassType<Module>, |
45 | // %x : Float(3, 4)): |
46 | // %1 : bool = prim::TracedAttr[scope="__module.training"]() |
47 | // %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]() |
48 | // %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]() |
49 | // %4 : bool = prim::TracedAttr[scope="__module.f.training"]() |
50 | // = prim::TracedModuleForward[scope="__module.f"](), |
51 | // block0(): |
52 | // %6 : Float(3, 4) = aten::mm(%x, %3), |
53 | // -> () |
54 | // return (%6) |
55 | // |
56 | // The diff after step (1) |
57 | // |
58 | // - = prim::TracedModuleForward[scope="__module.f"](), |
59 | // - block0(): |
60 | // + = prim::TracedModuleForward[scope="__module.f"](%2), |
61 | // + block0(%self : ClassType<Module>): |
62 | // |
63 | // The diff after step (2) |
64 | // |
65 | // graph(%self.1 : ClassType<Module>, |
66 | // %x : Float(3, 4)): |
67 | // + %9 : ClassType<Module> = prim::GetAttr[name="f"](%self.1) |
68 | // %1 : bool = prim::TracedAttr[scope="__module.training"]() |
69 | // <....> |
70 | // %4 : bool = prim::TracedAttr[scope="__module.f.training"]() |
71 | // - = prim::TracedModuleForward[scope="__module.f"](%2), |
72 | // + = prim::TracedModuleForward[scope="__module.f"](%9), |
73 | // block0(%self : ClassType<Module>): |
74 | // - %6 : Float(3, 4) = aten::mm(%x, %3), |
75 | // + %8 : Tensor = prim::GetAttr[name="param"](%self) |
76 | // + %6 : Float(3, 4) = aten::mm(%x, %8), |
77 | // -> () |
78 | // return (%6) |
79 | // |
80 | // The diff after step (3) |
81 | // |
82 | // - %1 : bool = prim::TracedAttr[scope="__module.training"]() |
83 | // - %2 : ClassType<Module> = prim::TracedAttr[scope="__module.f"]() |
84 | // - %3 : Float(4, 4) = prim::TracedAttr[scope="__module.f.param"]() |
85 | // - %4 : bool = prim::TracedAttr[scope="__module.f.training"]() |
86 | struct ConvertTracedAttrReferences { |
87 | void run(const std::shared_ptr<Graph>& graph) { |
88 | // Build a table mapping--for each TracedAttr node--the |
89 | // qualified name of the attribute to the Value* output |
90 | // of the Node. |
91 | buildAttrMap(graph); |
92 | // Step 1 |
93 | addSelfArgToTracedForwardNodes(graph->block()); |
94 | // Step 2 |
95 | convertAttrReferencesToLocalGetAttrs( |
96 | graph->block(), "__module" , graph->inputs()[0]); |
97 | // Step 3 |
98 | destroyTracedAttrNodes(graph); |
99 | } |
100 | |
101 | private: |
102 | void buildAttrMap(const std::shared_ptr<Graph>& graph) { |
103 | for (Node* n : graph->nodes()) { |
104 | if (n->kind() == prim::TracedAttr) { |
105 | attr_qualname_to_value[n->s(attr::scope)] = n->output(); |
106 | } |
107 | } |
108 | } |
109 | |
110 | void addSelfArgToTracedForwardNodes(Block* b) { |
111 | for (Node* n : b->nodes()) { |
112 | if (n->kind() == prim::TracedModuleForward) { |
113 | n->addInput(attr_qualname_to_value.at(n->s(attr::scope))); |
114 | n->blocks()[0]->addInput("self" )->setType( |
115 | attr_qualname_to_value.at(n->s(attr::scope))->type()); |
116 | addSelfArgToTracedForwardNodes(n->blocks()[0]); |
117 | } |
118 | if (n->kind() == prim::TracedFork) { |
119 | addSelfArgToTracedForwardNodes(n->blocks()[0]); |
120 | } |
121 | } |
122 | } |
123 | |
124 | // This is a recursive function that descends down all blocks in the Graph |
125 | // (NB: not just TracedModuleForward blocks). Each descension has a |
126 | // corresponding `prefix`, i.e. the qualified name of the scope this |
127 | // Block represents (or the scope in which this block resides for |
128 | // non-TracedModuleForward nodes). We use this prefix to make decisions |
129 | // about whether to emit a GetAttr node for an attribute reference, or |
130 | // to defer that emission to the caller (in the case where an attribute |
131 | // reference does not reside in the `prefix` scope). |
132 | std::vector<Value*> convertAttrReferencesToLocalGetAttrs( |
133 | Block* b, |
134 | const c10::QualifiedName& prefix, |
135 | Value* self) { |
136 | // Store away Value*'s which are references to TracedAttr's which are |
137 | // not in the `prefix` scope. We pass this back to the caller, who |
138 | // should add these Values as explicit inputs as well as inductively |
139 | // make the same decision on those Values. |
140 | std::vector<Value*> unresolved_tracedattrs; |
141 | // To ensure we don't emit redundant GetAttr Nodes in a given scope, |
142 | // we maintain this map of original TracedAttr Value* to the Value* |
143 | // corresponding to the GetAttr for that attribute. |
144 | // We don't rely on CSE here because we currently can't reason about |
145 | // the correctness of CSE over GetAttr Nodes (i think) |
146 | std::unordered_map<Value*, Value*> local_remaps; |
147 | |
148 | for (Node* n : b->nodes()) { |
149 | // The only difference between these two branches is for |
150 | // TracedModuleForward we advance the scope, but for other |
151 | // Nodes with Blocks we don't |
152 | if (n->kind() == prim::TracedModuleForward) { |
153 | auto sub_unresolved = convertAttrReferencesToLocalGetAttrs( |
154 | n->blocks()[0], n->s(attr::scope), n->blocks()[0]->inputs()[0]); |
155 | for (Value* v : sub_unresolved) { |
156 | n->addInput(v); |
157 | } |
158 | } else if (!n->blocks().empty()) { |
159 | for (Block* sub_block : n->blocks()) { |
160 | auto sub_unresolved = |
161 | convertAttrReferencesToLocalGetAttrs(sub_block, prefix, self); |
162 | for (Value* v : sub_unresolved) { |
163 | n->addInput(v); |
164 | } |
165 | } |
166 | } |
167 | |
168 | for (size_t inp_idx = 0; inp_idx < n->inputs().size(); ++inp_idx) { |
169 | Value* inp = n->input(inp_idx); |
170 | |
171 | // Short circuit: if we've already emitted a new Value for this |
172 | // attribute, just use that. |
173 | if (local_remaps.count(inp)) { |
174 | n->replaceInput(inp_idx, local_remaps[inp]); |
175 | continue; |
176 | } |
177 | |
178 | WithInsertPoint guard(b->param_node()->next()); |
179 | replaceTracedAttrInputOnNode( |
180 | n, inp_idx, prefix, self, local_remaps, unresolved_tracedattrs); |
181 | } // for (Value *inp : n->inputs()) |
182 | } // for (Node *n : b->nodes()) |
183 | return unresolved_tracedattrs; |
184 | } |
185 | |
186 | void replaceTracedAttrInputOnNode( |
187 | Node* n, |
188 | size_t inp_idx, |
189 | const c10::QualifiedName& prefix, |
190 | Value* self, |
191 | std::unordered_map<Value*, Value*>& local_remaps, |
192 | std::vector<Value*>& unresolved_tracedattrs) { |
193 | auto inp = n->inputs()[inp_idx]; |
194 | auto inp_node = inp->node(); |
195 | auto prefix_atoms = prefix.atoms(); |
196 | if (inp_node->kind() == prim::TracedAttr) { |
197 | auto attr_qualname = c10::QualifiedName(inp_node->s(attr::scope)); |
198 | if (prefix.isPrefixOf(attr_qualname)) { |
199 | // Prefix case: the attribute resides in this scope or a |
200 | // sub-scope. Continually emit GetAttr nodes until we've reached |
201 | // the proper attribute. |
202 | auto attr_atoms = attr_qualname.atoms(); |
203 | Value* replaced_value = self; |
204 | for (const auto i : c10::irange(attr_atoms.size())) { |
205 | if (i < prefix_atoms.size()) { |
206 | TORCH_INTERNAL_ASSERT(attr_atoms[i] == prefix_atoms[i]); |
207 | } else { |
208 | replaced_value = n->owningBlock()->owningGraph()->insertGetAttr( |
209 | replaced_value, attr_atoms[i]); |
210 | } // if (i < prefix_atoms.size()) |
211 | } // for(const auto i : c10::irange(attr_atoms.size())) |
212 | n->replaceInput(inp_idx, replaced_value); |
213 | local_remaps[inp] = replaced_value; |
214 | } else { |
215 | // Non-prefix case: this is a use of an attribute somewhere |
216 | // higher in the Module hierarchy. Add a captured input to |
217 | // the block for this attribute and add to the vector of |
218 | // Value*'s for the caller to handle. |
219 | Value* remapped = n->owningBlock()->addInput()->copyMetadata(inp); |
220 | n->replaceInput(inp_idx, remapped); |
221 | unresolved_tracedattrs.push_back(inp); |
222 | local_remaps[inp] = remapped; |
223 | } // if (prefix.isPrefixOf(attr_qualname)) |
224 | } // if (inp_node->kind() == prim::TracedAttr) |
225 | } |
226 | |
227 | // The previous pass should have deleted all uses of TracedAttr |
228 | // nodes. Let's explicitly delete them here. |
229 | void destroyTracedAttrNodes(const std::shared_ptr<Graph>& graph) { |
230 | for (auto& kv : attr_qualname_to_value) { |
231 | kv.second->node()->destroy(); |
232 | } |
233 | } |
234 | |
235 | // For each prim::TracedAttr, record the `scope` value mapped |
236 | // to the Value* in the graph for that attribute. |
237 | std::unordered_map<std::string, Value*> attr_qualname_to_value; |
238 | }; |
239 | |
240 | // Iterate through all the nodes in program order and--for each use-- |
241 | // if the Value referenced is not in a scope that dominates the node, |
242 | // add block and Node outputs to lift it into a scope in which |
243 | // it dominates the Use. |
244 | struct MakeDefsDominateUses { |
245 | MakeDefsDominateUses() = default; |
246 | |
247 | void run(Block* b) { |
248 | processNode(b->param_node(), b); |
249 | for (Node* n : b->nodes()) { |
250 | processNode(n, b); |
251 | } |
252 | processNode(b->return_node(), b); |
253 | } |
254 | |
255 | private: |
256 | void processNode(Node* n, Block* b) { |
257 | for (size_t i = 0; i < n->inputs().size(); ++i) { |
258 | Value* inp = n->inputs()[i]; |
259 | |
260 | // Already lifted to this level by a previously processed Use, switch to |
261 | // remapped value |
262 | if (remap.count(inp)) { |
263 | n->replaceInput(i, remap[inp]); |
264 | inp = remap[inp]; |
265 | } |
266 | |
267 | // This conditional isn't strictly necessary, but saves a lot of |
268 | // computation in the common case that we're using a local value. |
269 | if (inp->node()->owningBlock() != b) { |
270 | // Find the common ancestor block between this node and the node that |
271 | // produced this input. For this input Use to be valid, the Value's |
272 | // def must be present in this common ancestor node. |
273 | Block* common_ancestor = n->findCommonAncestorBlockWith(inp->node()); |
274 | |
275 | Value* v_itr = inp; |
276 | Block* b_itr = inp->node()->owningBlock(); |
277 | |
278 | // Starting from the initial def for this input, iterate to |
279 | // wider and wider blocks, adding Block outputs and Node outputs |
280 | // along the way. Then, log the lifted values in the remap table |
281 | // so we can make subsequent Uses refer to the lifted value, if |
282 | // the domination condition is met. |
283 | while (b_itr != common_ancestor) { |
284 | b_itr->registerOutput(v_itr); |
285 | Value* remapped = |
286 | b_itr->owningNode()->addOutput()->setType(v_itr->type()); |
287 | v_itr = remapped; |
288 | b_itr = b_itr->owningNode()->owningBlock(); |
289 | } |
290 | // From now on, references to `inp` will be replaced with |
291 | // references to `v_iter`, the lifted Value |
292 | remap[inp] = v_itr; |
293 | n->replaceInput(i, remap[inp]); |
294 | } |
295 | } |
296 | |
297 | if (isEligibleNode(n)) { |
298 | run(n->blocks()[0]); |
299 | } |
300 | } |
301 | |
302 | // This holds the mapping between a Value* we would see in a Use |
303 | // and the lifted value, if present. We use this to ensure that |
304 | // Uses refer to a Value* that is in a dominating scope. |
305 | using RemappingTable = std::unordered_map<Value*, Value*>; |
306 | RemappingTable remap; |
307 | }; |
308 | |
309 | // For all blocks except graph->block(), convert multiple block |
310 | // returns to a TupleConstruct. This is required for turning the |
311 | // blocks into Methods. (and in the case that self is nullptr, |
312 | // it is required to properly inline the blocks). |
313 | void convertReturnsToTuples(Block* b) { |
314 | for (Node* n : b->nodes()) { |
315 | if (n->kind() == prim::TracedFork) { |
316 | convertReturnsToTuples(n->blocks()[0]); |
317 | } else if (n->kind() == prim::TracedModuleForward) { |
318 | TORCH_INTERNAL_ASSERT(n->blocks().size() == 1); |
319 | convertReturnsToTuples(n->blocks()[0]); |
320 | |
321 | Graph* g = b->owningGraph(); |
322 | Block* sub_block = n->blocks()[0]; |
323 | if (sub_block->outputs().size() > 1) { |
324 | { |
325 | // Make block returns go through a Tuple |
326 | WithInsertPoint guard(sub_block->return_node()); |
327 | Node* return_tup = |
328 | g->insertNode(g->createTuple(sub_block->outputs())); |
329 | while (!sub_block->outputs().empty()) { |
330 | sub_block->eraseOutput(0); |
331 | } |
332 | sub_block->registerOutput(return_tup->output()); |
333 | } |
334 | |
335 | // Make node outputs a single tuple; |
336 | std::vector<TypePtr> types; |
337 | for (size_t i = 0; i < n->outputs().size(); ++i) { |
338 | types.push_back(n->output(i)->type()); |
339 | } |
340 | Value* tup_output = n->addOutput()->setType(TupleType::create(types)); |
341 | Node* tup_unpack = g->createTupleUnpack(tup_output)->insertAfter(n); |
342 | for (size_t i = 0; i < tup_unpack->outputs().size(); ++i) { |
343 | auto rev_idx = tup_unpack->outputs().size() - i - 1; |
344 | n->output(rev_idx)->replaceAllUsesWith(tup_unpack->output(rev_idx)); |
345 | n->eraseOutput(rev_idx); |
346 | } |
347 | } else if (sub_block->outputs().empty()) { |
348 | WithInsertPoint guard(sub_block->return_node()); |
349 | sub_block->registerOutput(g->insertNode(g->createNone())->output()); |
350 | n->addOutput()->setType(NoneType::get()); |
351 | } |
352 | } |
353 | } |
354 | } |
355 | |
356 | // Lambda lift Values (i.e. add Graph inputs for the purpose of |
357 | // referencing values that dominate the block) and convert |
358 | // the block to a Graph. blocks()[0] on each TracedModuleForward then |
359 | // appears as a Graph attribute attr::Subgraph |
360 | void lambdaLiftBlocksAndConvertToGraph(Block* b) { |
361 | for (Node* n : b->nodes()) { |
362 | if (isEligibleNode(n)) { |
363 | lambdaLiftBlocksAndConvertToGraph(n->blocks()[0]); |
364 | |
365 | auto graph = std::make_shared<Graph>(); |
366 | std::unordered_map<Value*, Value*> remaps; |
367 | graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) { |
368 | if (!remaps.count(v)) { |
369 | remaps[v] = graph->addInput()->copyMetadata(v); |
370 | n->addInput(v); |
371 | } |
372 | return remaps[v]; |
373 | }); |
374 | LintGraph(graph); |
375 | n->g_(attr::Subgraph, graph); |
376 | n->eraseBlock(0); |
377 | } |
378 | } |
379 | } |
380 | |
381 | // Find a unique name to add this method as |
382 | // We try {method_name}, {method_name}1, {method_name}2, ... |
383 | std::string mangleMethodName( |
384 | const std::string& method_name, |
385 | const ClassTypePtr& mod_type) { |
386 | for (size_t method_idx = 0;; method_idx++) { |
387 | auto mangled = method_name; |
388 | if (method_idx != 0) { |
389 | mangled += c10::to_string(method_idx); |
390 | } |
391 | bool found = false; |
392 | for (Function* fn : mod_type->methods()) { |
393 | if (fn->name() == mangled) { |
394 | found = true; |
395 | break; |
396 | } |
397 | } |
398 | if (!found) { |
399 | return mangled; |
400 | } |
401 | } |
402 | TORCH_INTERNAL_ASSERT(false); |
403 | } |
404 | |
405 | // Register the attr::Subgraph Graph values as Functions in the |
406 | // class compilation unit and register that Function as a method |
407 | // on the corresponding Module in the Module hierarchy. Note that we |
408 | // unique the methods by naming them forward, forward1, forward2... |
409 | void createMethodCalls(const std::shared_ptr<Graph>& g) { |
410 | for (auto node_itr = g->nodes().begin(); node_itr != g->nodes().end();) { |
411 | Node* n = *node_itr++; |
412 | if (n->kind() == prim::TracedFork) { |
413 | createMethodCalls(n->g(attr::Subgraph)); |
414 | } else if (n->kind() == prim::TracedModuleForward) { |
415 | WithInsertPoint ip(n); |
416 | |
417 | ClassTypePtr callee_mod_type = n->input(0)->type()->expect<ClassType>(); |
418 | |
419 | createMethodCalls(n->g(attr::Subgraph)); |
420 | |
421 | auto mangled_method_name = mangleMethodName("forward" , callee_mod_type); |
422 | auto qualname = c10::QualifiedName( |
423 | callee_mod_type->name().value(), mangled_method_name); |
424 | Function* f = callee_mod_type->compilation_unit()->create_function( |
425 | qualname, n->g(attr::Subgraph)); |
426 | callee_mod_type->addMethod(f); |
427 | |
428 | std::vector<NamedValue> nvs; |
429 | for (Value* i : n->inputs()) { |
430 | nvs.emplace_back(i->node()->sourceRange(), i); |
431 | } |
432 | auto schema = matchSchema(f->getSchema(), n->sourceRange(), *g, nvs, {}); |
433 | Value* retval = g->insertMethodCall(f->qualname().name(), schema); |
434 | n->output()->replaceAllUsesWith(retval); |
435 | n->destroy(); |
436 | } |
437 | } |
438 | } |
439 | |
440 | void inlineScopeBlocks(Block* b) { |
441 | for (auto n_itr = b->nodes().begin(); n_itr != b->nodes().end();) { |
442 | Node* n = *n_itr++; |
443 | for (Block* sub_b : n->blocks()) { |
444 | inlineScopeBlocks(sub_b); |
445 | } |
446 | if (n->kind() == prim::TracedModuleForward) { |
447 | // Convert the block to a graph so we can inline it |
448 | auto graph = std::make_shared<Graph>(); |
449 | std::unordered_map<Value*, Value*> remaps; |
450 | graph->block()->cloneFrom(n->blocks()[0], [&](Value* v) { |
451 | remaps[v] = graph->block()->addInput()->copyMetadata(v); |
452 | n->addInput(v); |
453 | return remaps[v]; |
454 | }); |
455 | |
456 | WithInsertPoint insert_point(n); |
457 | AT_ASSERT(n->inputs().size() == graph->inputs().size()); |
458 | auto new_outputs = insertGraph(*n->owningGraph(), *graph, n->inputs()); |
459 | const auto& old_outputs = n->outputs(); |
460 | |
461 | AT_ASSERT(new_outputs.size() == old_outputs.size()); |
462 | for (const auto i : c10::irange(old_outputs.size())) { |
463 | old_outputs[i]->replaceAllUsesWith(new_outputs[i]); |
464 | } |
465 | n->destroy(); |
466 | } |
467 | } |
468 | } |
469 | |
470 | void convertTracedForksToRealForks(const std::shared_ptr<Graph>& g) { |
471 | for (auto itr = g->nodes().begin(); itr != g->nodes().end();) { |
472 | Node* n = *itr++; |
473 | if (n->kind() == prim::TracedFork) { |
474 | WithInsertPoint guard(n); |
475 | Node* new_fork_node = |
476 | g->insertNode(g->create(prim::fork, n->outputs().size())) |
477 | ->copyAttributes(*n); |
478 | for (Value* i : n->inputs()) { |
479 | new_fork_node->addInput(i); |
480 | } |
481 | for (size_t i = 0; i < new_fork_node->outputs().size(); ++i) { |
482 | new_fork_node->outputs()[i]->copyMetadata(n->outputs()[i]); |
483 | n->outputs()[i]->replaceAllUsesWith(new_fork_node->outputs()[i]); |
484 | } |
485 | n->destroy(); |
486 | } |
487 | } |
488 | } |
489 | |
490 | // Run a few clean-up passes to make the graph a bit cleaner. |
491 | void runCleanupPasses(const std::shared_ptr<Graph>& g) { |
492 | for (Node* n : g->nodes()) { |
493 | if (n->kind() == prim::TracedFork) { |
494 | auto subgraph = n->g(attr::Subgraph); |
495 | if (getInlineEverythingMode()) { |
496 | Inline(*subgraph); |
497 | } |
498 | convertTracedForksToRealForks(subgraph); |
499 | LowerSimpleTuples(subgraph); |
500 | EliminateDeadCode(subgraph); |
501 | LintGraph(subgraph); |
502 | } |
503 | } |
504 | if (getInlineEverythingMode()) { |
505 | Inline(*g); |
506 | } |
507 | convertTracedForksToRealForks(g); |
508 | LowerSimpleTuples(g); |
509 | EliminateDeadCode(g); |
510 | LintGraph(g); |
511 | } |
512 | |
513 | void runCleanupPasses(Module* m) { |
514 | auto methods = m->get_methods(); |
515 | for (auto module : m->children()) { |
516 | runCleanupPasses(&module); |
517 | } |
518 | for (auto& method : methods) { |
519 | runCleanupPasses(method.graph()); |
520 | } |
521 | } |
522 | |
523 | } // namespace |
524 | |
525 | void FixupTraceScopeBlocks(std::shared_ptr<Graph>& graph, Module* self) { |
526 | if (self) { |
527 | ConvertTracedAttrReferences().run(graph); |
528 | } else { |
529 | for (Node* n : graph->nodes()) { |
530 | TORCH_INTERNAL_ASSERT(n->kind() != prim::TracedAttr); |
531 | } |
532 | } |
533 | MakeDefsDominateUses().run(graph->block()); |
534 | convertReturnsToTuples(graph->block()); |
535 | if (!self) { |
536 | // We have no Module, so we're just going to inline everything. |
537 | // This should give us a totally flat graph. |
538 | inlineScopeBlocks(graph->block()); |
539 | // For TracedFork nodes |
540 | lambdaLiftBlocksAndConvertToGraph(graph->block()); |
541 | runCleanupPasses(graph); |
542 | } else { |
543 | lambdaLiftBlocksAndConvertToGraph(graph->block()); |
544 | createMethodCalls(graph); |
545 | runCleanupPasses(self); |
546 | // `graph` isn't referenced in `self` yet, so we need to run |
547 | // this separately |
548 | runCleanupPasses(graph); |
549 | } |
550 | } |
551 | |
552 | } // namespace jit |
553 | } // namespace torch |
554 | |