1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | // Directly after tracing, we have an ill-formed graph with blocks inserted. |
10 | // Example: |
11 | // |
12 | // graph(%self : ClassType<Module>, |
13 | // %input.1 : Float(3, 4)): |
14 | // %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self) |
15 | // %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self) |
16 | // %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2) |
17 | // = prim::TracedModuleForward[scope="__module.relu1"]() |
18 | // block0(): |
19 | // %input : Float(3, 4) = aten::relu(%input.1), |
20 | // -> () |
21 | // = prim::TracedModuleForward[scope="__module.relu2"](), |
22 | // block0(): |
23 | // = prim::TracedModuleForward[scope="__module.relu2.rrr"](), |
24 | // block0(): |
25 | // %6 : Float(3, 4) = aten::relu(%input), |
26 | // -> () |
27 | // -> () |
28 | // return (%6) |
29 | // |
30 | // In this pass, we: |
31 | // 1) Lift Value defs to as high of a scope as needed to ensure that |
32 | // they dominate all their uses. For example, `input` in the above |
33 | // graph needs to be lifted to the top-level block so that its use |
34 | // in the second `relu` operator is dominated. |
35 | // 2) Lambda lift the blocks. This ensures that all values used within |
36 | // each scope have their defs captured. |
37 | // 3) Convert the scope blocks into methods on their respective Modules, |
38 | // and convert TracedModuleForward nodes to CallMethod nodes into those |
39 | // methods. |
40 | // |
41 | // Then, we'll have a well-formed graph with proper method calls. |
42 | TORCH_API void FixupTraceScopeBlocks( |
43 | std::shared_ptr<Graph>& graph, |
44 | Module* self); |
45 | |
46 | } // namespace jit |
47 | } // namespace torch |
48 | |