1 | #include <torch/csrc/jit/passes/inliner.h> |
2 | |
3 | #include <ATen/core/interned_strings.h> |
4 | #include <torch/csrc/jit/api/function_impl.h> |
5 | #include <torch/csrc/jit/api/module.h> |
6 | #include <torch/csrc/jit/frontend/error_report.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | namespace prim { |
13 | using namespace ::c10::prim; |
14 | } |
15 | |
16 | GraphFunction* tryToGraphFunction(Node* n) { |
17 | if (n->kind() == prim::CallFunction) { |
18 | AT_ASSERT(n->input(0)->node()->kind() == prim::Constant); |
19 | auto function_constant = n->input(0)->node(); |
20 | auto fun_type = function_constant->output()->type()->expect<FunctionType>(); |
21 | return tryToGraphFunction(*fun_type->function()); |
22 | } |
23 | if (n->kind() == prim::CallMethod) { |
24 | const std::string& name = n->s(attr::name); |
25 | if (auto class_type = n->input(0)->type()->cast<ClassType>()) { |
26 | Function& function = class_type->getMethod(name); |
27 | return tryToGraphFunction(function); |
28 | } |
29 | } |
30 | return nullptr; |
31 | } |
32 | |
33 | void inlineCalls(Block* block) { |
34 | for (auto it = block->nodes().begin(), end = block->nodes().end(); |
35 | it != end;) { |
36 | Node* cur = *it++; |
37 | switch (cur->kind()) { |
38 | case prim::CallFunction: { |
39 | if (auto graphFunction = tryToGraphFunction(cur)) { |
40 | auto function_constant = cur->input(0)->node(); |
41 | auto fun_type = |
42 | function_constant->output()->type()->expect<FunctionType>(); |
43 | |
44 | cur->removeInput(0); |
45 | GRAPH_UPDATE( |
46 | "Inlining function '" , |
47 | fun_type->function()->name(), |
48 | "' to " , |
49 | *cur); |
50 | |
51 | std::shared_ptr<Graph> g = nullptr; |
52 | // inline optimized graph for debugging/testing purposes. |
53 | // we only insert fallback functions in JIT optimized graphs for |
54 | // execution, not on the Graph that is used for serialization |
55 | bool fallback = |
56 | function_constant->hasAttribute(Symbol::attr("fallback" )); |
57 | if (fallback && graphFunction->get_executor().isOptimized()) { |
58 | auto exec_plans = |
59 | graphFunction->get_executor().getDebugState().execution_plans; |
60 | if (!exec_plans.empty()) { |
61 | g = exec_plans.begin()->second.graph; |
62 | // optimized_graph() calls Inline, so we only need to explicitly |
63 | // invoke inlining on the jit optimized graph with recursive |
64 | // fallback funciton calls |
65 | Inline(*g.get()); |
66 | } |
67 | } |
68 | if (g == nullptr) { |
69 | g = graphFunction->optimized_graph(); |
70 | } |
71 | |
72 | GRAPH_UPDATE("Function body: " , g); |
73 | inlineCallTo(cur, graphFunction, g.get()); |
74 | } |
75 | } break; |
76 | case prim::CallMethod: { |
77 | if (auto graphFunction = tryToGraphFunction(cur)) { |
78 | GRAPH_UPDATE("Inlining method '" , cur->s(attr::name), "' to " , *cur); |
79 | GRAPH_UPDATE("Function body: " , graphFunction->optimized_graph()); |
80 | inlineCallTo(cur, graphFunction); |
81 | } |
82 | } break; |
83 | default: { |
84 | for (auto b : cur->blocks()) { |
85 | inlineCalls(b); |
86 | } |
87 | } break; |
88 | } |
89 | } |
90 | } |
91 | |
92 | void Inline(Graph& graph) { |
93 | GRAPH_DUMP("Before Inlining: " , &graph); |
94 | inlineCalls(graph.block()); |
95 | GRAPH_DUMP("After Inlining: " , &graph); |
96 | } |
97 | |
98 | } // namespace jit |
99 | } // namespace torch |
100 | |