1#pragma once
2
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/ir/ir.h>
5
6namespace torch {
7namespace jit {
8
9// Directly after tracing, we have an ill-formed graph with blocks inserted.
10// Example:
11//
12// graph(%self : ClassType<Module>,
13// %input.1 : Float(3, 4)):
14// %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self)
15// %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self)
16// %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2)
17// = prim::TracedModuleForward[scope="__module.relu1"]()
18// block0():
19// %input : Float(3, 4) = aten::relu(%input.1),
20// -> ()
21// = prim::TracedModuleForward[scope="__module.relu2"](),
22// block0():
23// = prim::TracedModuleForward[scope="__module.relu2.rrr"](),
24// block0():
25// %6 : Float(3, 4) = aten::relu(%input),
26// -> ()
27// -> ()
28// return (%6)
29//
30// In this pass, we:
31// 1) Lift Value defs to as high of a scope as needed to ensure that
32// they dominate all their uses. For example, `input` in the above
33// graph needs to be lifted to the top-level block so that its use
34// in the second `relu` operator is dominated.
35// 2) Lambda lift the blocks. This ensures that all values used within
36// each scope have their defs captured.
37// 3) Convert the scope blocks into methods on their respective Modules,
38// and convert TracedModuleForward nodes to CallMethod nodes into those
39// methods.
40//
41// Then, we'll have a well-formed graph with proper method calls.
42TORCH_API void FixupTraceScopeBlocks(
43 std::shared_ptr<Graph>& graph,
44 Module* self);
45
46} // namespace jit
47} // namespace torch
48