1#include <torch/csrc/jit/passes/create_functional_graphs.h>
2
3#include <c10/util/Exception.h>
4#include <torch/csrc/jit/ir/alias_analysis.h>
5#include <torch/csrc/jit/passes/constant_pooling.h>
6#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7#include <torch/csrc/utils/memory.h>
8
9#include <cstddef>
10#include <limits>
11
12namespace torch {
13namespace jit {
14
15namespace {
16
17struct FunctionalGraphSlicer {
18 FunctionalGraphSlicer(std::shared_ptr<Graph> graph)
19 : graph_(std::move(graph)) {}
20
21 void run() {
22 bool changed = true;
23 // TODO: more sane strategy
24 size_t MAX_NUM_ITERATIONS = 4;
25
26 // First, analyze the functional subset of the graph, and then create
27 // functional graphs. The graph gets mutated when we create functional
28 // subgraphs, invalidating the AliasDb, so we need to do our analysis
29 // first.
30 for (size_t i = 0; i < MAX_NUM_ITERATIONS && changed; ++i) {
31 aliasDb_ = torch::make_unique<AliasDb>(graph_);
32 AnalyzeFunctionalSubset(graph_->block());
33 changed = CreateFunctionalGraphsImpl(graph_->block());
34 }
35 }
36
37 private:
38 bool isEmptyFunctionalGraph(Node* n) {
39 auto g = n->g(attr::Subgraph);
40 return g->inputs().empty() && g->outputs().empty();
41 }
42
43 void nonConstNodes(Block* block, size_t* num) {
44 for (auto it = block->nodes().begin();
45 it != block->nodes().end() && *num < minSubgraphSize_;
46 ++it) {
47 Node* n = *it;
48 if (n->kind() == prim::Constant) {
49 continue;
50 }
51 *num = *num + 1;
52 for (Block* b : n->blocks()) {
53 nonConstNodes(b, num);
54 }
55 }
56 }
57
58 bool inlineIfTooSmall(Node* n) {
59 AT_ASSERT(n->kind() == prim::FunctionalGraph);
60 auto subgraph = SubgraphUtils::getSubgraph(n);
61 size_t num_modes = 0;
62 nonConstNodes(subgraph->block(), &num_modes);
63 if (num_modes < minSubgraphSize_) {
64 SubgraphUtils::unmergeSubgraph(n);
65 return true;
66 }
67 return false;
68 }
69
70 bool CreateFunctionalGraphsImpl(Block* block) {
71 /*
72 Iterate the block in reverse and create FunctionalSubgraphs.
73 When we encounter a node that isn't functional, we skip it. Otherwise,
74 we try to merge the functional node into the current functional subgraph.
75 If it can't be merged into the current functional subgraph node, then we
76 start a functional subgraph group.
77 */
78 bool changed = false;
79 std::vector<Node*> functional_graph_nodes;
80
81 Node* functional_subgraph_node =
82 graph_->createWithSubgraph(prim::FunctionalGraph)
83 ->insertBefore(block->return_node());
84 auto reverse_iter = block->nodes().reverse();
85 std::vector<Value*> graph_outputs;
86 for (auto it = reverse_iter.begin(); it != reverse_iter.end();) {
87 Node* n = *it++;
88
89 // constants get copied into the graph
90 if (n->kind() == prim::Constant || n == functional_subgraph_node) {
91 continue;
92 }
93
94 // if `n` is functional, all of its blocks will be merged into the
95 // new functional subgraph, so we only need to recurse if it is not
96 // functional
97 if (!functional_nodes_.count(n)) {
98 for (Block* b : n->blocks()) {
99 auto block_changed = CreateFunctionalGraphsImpl(b);
100 changed = block_changed && changed;
101 }
102 continue;
103 }
104
105 if (n->kind() == prim::FunctionalGraph &&
106 isEmptyFunctionalGraph(functional_subgraph_node)) {
107 functional_subgraph_node->destroy();
108 functional_subgraph_node = n;
109 continue;
110 }
111
112 changed = true;
113 if (aliasDb_->moveBeforeTopologicallyValid(n, functional_subgraph_node)) {
114 SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
115 } else {
116 functional_graph_nodes.emplace_back(functional_subgraph_node);
117 functional_subgraph_node =
118 graph_->createWithSubgraph(prim::FunctionalGraph)->insertAfter(n);
119 SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
120 }
121 }
122 functional_graph_nodes.emplace_back(functional_subgraph_node);
123
124 for (Node* functional_node : functional_graph_nodes) {
125 if (!inlineIfTooSmall(functional_node)) {
126 ConstantPooling(functional_node->g(attr::Subgraph));
127 }
128 }
129 return changed;
130 }
131
132 bool AnalyzeFunctionalSubset(Node* n) {
133 // TODO: clarify hasSideEffects, isNondeterministic
134 bool is_functional_node = true;
135
136 // Functional Graphs are not responsible for maintaining aliasing
137 // relationships. If an output of a functional graph escapes scope
138 // or is mutated then we might change semantics of the program if
139 // aliasing relationships are changed.
140 // We don't allow any node in the functional graph to output a value
141 // that escapes scope or is mutated, and we don't allow any mutating nodes
142 // into the graph.
143 // - allow functional graphs to have at most one value that can escape scope
144 // - allow outputs which alias the wildcard set but do not "re-escape"
145 for (Value* v : n->outputs()) {
146 bool has_writers = aliasDb_->hasWriters(v);
147 bool escapes_scope = aliasDb_->escapesScope(v);
148 if (has_writers) {
149 mutated_values_.insert(v);
150 }
151 is_functional_node = is_functional_node && !escapes_scope && !has_writers;
152 }
153
154 for (Block* block : n->blocks()) {
155 auto functional_block = AnalyzeFunctionalSubset(block);
156 is_functional_node = is_functional_node && functional_block;
157 }
158
159 is_functional_node = is_functional_node && !aliasDb_->isMutable(n);
160 if (is_functional_node) {
161 functional_nodes_.insert(n);
162 }
163 return is_functional_node;
164 }
165
166 void AnalyzeFunctionalSubset(at::ArrayRef<Block*> blocks) {
167 for (Block* block : blocks) {
168 AnalyzeFunctionalSubset(block);
169 }
170 }
171
172 bool AnalyzeFunctionalSubset(Block* block) {
173 bool is_functional_block = true;
174 // block inputs will not yet have been iterated through,
175 // so we need to add them to our set of mutated & escape values.
176 for (Value* v : block->inputs()) {
177 bool has_writers = aliasDb_->hasWriters(v);
178 if (has_writers) {
179 mutated_values_.insert(v);
180 }
181 }
182 // if a block output is not functional, then the corresponding output for
183 // the node that contains the block will not be functional either, so we do
184 // not need to analyze the block outputs here.
185 for (Node* n : block->nodes()) {
186 bool functional = AnalyzeFunctionalSubset(n);
187 is_functional_block = is_functional_block && functional;
188 }
189 return is_functional_block;
190 }
191
192 std::unordered_set<Node*> functional_nodes_;
193 std::unordered_set<Value*> mutated_values_;
194 std::shared_ptr<Graph> graph_;
195 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
196 size_t minSubgraphSize_ = 6;
197};
198
199void InlineFunctionalGraphs(Block* block) {
200 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
201 Node* n = *it;
202 it++;
203 for (Block* b : n->blocks()) {
204 InlineFunctionalGraphs(b);
205 }
206 if (n->kind() == prim::FunctionalGraph) {
207 SubgraphUtils::unmergeSubgraph(n);
208 }
209 }
210}
211
212} // namespace
213
214void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
215 // Run Constant Pooling so constants get hoisted
216 ConstantPooling(graph);
217 FunctionalGraphSlicer func(graph);
218 func.run();
219 // Creation of Functional Subgraphs & Deinlining creates excess constants
220 ConstantPooling(graph);
221}
222
223void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
224 InlineFunctionalGraphs(graph->block());
225}
226
227} // namespace jit
228} // namespace torch
229