1 | #include <torch/csrc/jit/passes/create_functional_graphs.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/jit/ir/alias_analysis.h> |
5 | #include <torch/csrc/jit/passes/constant_pooling.h> |
6 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
7 | #include <torch/csrc/utils/memory.h> |
8 | |
9 | #include <cstddef> |
10 | #include <limits> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | namespace { |
16 | |
17 | struct FunctionalGraphSlicer { |
18 | FunctionalGraphSlicer(std::shared_ptr<Graph> graph) |
19 | : graph_(std::move(graph)) {} |
20 | |
21 | void run() { |
22 | bool changed = true; |
23 | // TODO: more sane strategy |
24 | size_t MAX_NUM_ITERATIONS = 4; |
25 | |
26 | // First, analyze the functional subset of the graph, and then create |
27 | // functional graphs. The graph gets mutated when we create functional |
28 | // subgraphs, invalidating the AliasDb, so we need to do our analysis |
29 | // first. |
30 | for (size_t i = 0; i < MAX_NUM_ITERATIONS && changed; ++i) { |
31 | aliasDb_ = torch::make_unique<AliasDb>(graph_); |
32 | AnalyzeFunctionalSubset(graph_->block()); |
33 | changed = CreateFunctionalGraphsImpl(graph_->block()); |
34 | } |
35 | } |
36 | |
37 | private: |
38 | bool isEmptyFunctionalGraph(Node* n) { |
39 | auto g = n->g(attr::Subgraph); |
40 | return g->inputs().empty() && g->outputs().empty(); |
41 | } |
42 | |
43 | void nonConstNodes(Block* block, size_t* num) { |
44 | for (auto it = block->nodes().begin(); |
45 | it != block->nodes().end() && *num < minSubgraphSize_; |
46 | ++it) { |
47 | Node* n = *it; |
48 | if (n->kind() == prim::Constant) { |
49 | continue; |
50 | } |
51 | *num = *num + 1; |
52 | for (Block* b : n->blocks()) { |
53 | nonConstNodes(b, num); |
54 | } |
55 | } |
56 | } |
57 | |
58 | bool inlineIfTooSmall(Node* n) { |
59 | AT_ASSERT(n->kind() == prim::FunctionalGraph); |
60 | auto subgraph = SubgraphUtils::getSubgraph(n); |
61 | size_t num_modes = 0; |
62 | nonConstNodes(subgraph->block(), &num_modes); |
63 | if (num_modes < minSubgraphSize_) { |
64 | SubgraphUtils::unmergeSubgraph(n); |
65 | return true; |
66 | } |
67 | return false; |
68 | } |
69 | |
70 | bool CreateFunctionalGraphsImpl(Block* block) { |
71 | /* |
72 | Iterate the block in reverse and create FunctionalSubgraphs. |
73 | When we encounter a node that isn't functional, we skip it. Otherwise, |
74 | we try to merge the functional node into the current functional subgraph. |
75 | If it can't be merged into the current functional subgraph node, then we |
76 | start a functional subgraph group. |
77 | */ |
78 | bool changed = false; |
79 | std::vector<Node*> functional_graph_nodes; |
80 | |
81 | Node* functional_subgraph_node = |
82 | graph_->createWithSubgraph(prim::FunctionalGraph) |
83 | ->insertBefore(block->return_node()); |
84 | auto reverse_iter = block->nodes().reverse(); |
85 | std::vector<Value*> graph_outputs; |
86 | for (auto it = reverse_iter.begin(); it != reverse_iter.end();) { |
87 | Node* n = *it++; |
88 | |
89 | // constants get copied into the graph |
90 | if (n->kind() == prim::Constant || n == functional_subgraph_node) { |
91 | continue; |
92 | } |
93 | |
94 | // if `n` is functional, all of its blocks will be merged into the |
95 | // new functional subgraph, so we only need to recurse if it is not |
96 | // functional |
97 | if (!functional_nodes_.count(n)) { |
98 | for (Block* b : n->blocks()) { |
99 | auto block_changed = CreateFunctionalGraphsImpl(b); |
100 | changed = block_changed && changed; |
101 | } |
102 | continue; |
103 | } |
104 | |
105 | if (n->kind() == prim::FunctionalGraph && |
106 | isEmptyFunctionalGraph(functional_subgraph_node)) { |
107 | functional_subgraph_node->destroy(); |
108 | functional_subgraph_node = n; |
109 | continue; |
110 | } |
111 | |
112 | changed = true; |
113 | if (aliasDb_->moveBeforeTopologicallyValid(n, functional_subgraph_node)) { |
114 | SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node); |
115 | } else { |
116 | functional_graph_nodes.emplace_back(functional_subgraph_node); |
117 | functional_subgraph_node = |
118 | graph_->createWithSubgraph(prim::FunctionalGraph)->insertAfter(n); |
119 | SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node); |
120 | } |
121 | } |
122 | functional_graph_nodes.emplace_back(functional_subgraph_node); |
123 | |
124 | for (Node* functional_node : functional_graph_nodes) { |
125 | if (!inlineIfTooSmall(functional_node)) { |
126 | ConstantPooling(functional_node->g(attr::Subgraph)); |
127 | } |
128 | } |
129 | return changed; |
130 | } |
131 | |
132 | bool AnalyzeFunctionalSubset(Node* n) { |
133 | // TODO: clarify hasSideEffects, isNondeterministic |
134 | bool is_functional_node = true; |
135 | |
136 | // Functional Graphs are not responsible for maintaining aliasing |
137 | // relationships. If an output of a functional graph escapes scope |
138 | // or is mutated then we might change semantics of the program if |
139 | // aliasing relationships are changed. |
140 | // We don't allow any node in the functional graph to output a value |
141 | // that escapes scope or is mutated, and we don't allow any mutating nodes |
142 | // into the graph. |
143 | // - allow functional graphs to have at most one value that can escape scope |
144 | // - allow outputs which alias the wildcard set but do not "re-escape" |
145 | for (Value* v : n->outputs()) { |
146 | bool has_writers = aliasDb_->hasWriters(v); |
147 | bool escapes_scope = aliasDb_->escapesScope(v); |
148 | if (has_writers) { |
149 | mutated_values_.insert(v); |
150 | } |
151 | is_functional_node = is_functional_node && !escapes_scope && !has_writers; |
152 | } |
153 | |
154 | for (Block* block : n->blocks()) { |
155 | auto functional_block = AnalyzeFunctionalSubset(block); |
156 | is_functional_node = is_functional_node && functional_block; |
157 | } |
158 | |
159 | is_functional_node = is_functional_node && !aliasDb_->isMutable(n); |
160 | if (is_functional_node) { |
161 | functional_nodes_.insert(n); |
162 | } |
163 | return is_functional_node; |
164 | } |
165 | |
166 | void AnalyzeFunctionalSubset(at::ArrayRef<Block*> blocks) { |
167 | for (Block* block : blocks) { |
168 | AnalyzeFunctionalSubset(block); |
169 | } |
170 | } |
171 | |
172 | bool AnalyzeFunctionalSubset(Block* block) { |
173 | bool is_functional_block = true; |
174 | // block inputs will not yet have been iterated through, |
175 | // so we need to add them to our set of mutated & escape values. |
176 | for (Value* v : block->inputs()) { |
177 | bool has_writers = aliasDb_->hasWriters(v); |
178 | if (has_writers) { |
179 | mutated_values_.insert(v); |
180 | } |
181 | } |
182 | // if a block output is not functional, then the corresponding output for |
183 | // the node that contains the block will not be functional either, so we do |
184 | // not need to analyze the block outputs here. |
185 | for (Node* n : block->nodes()) { |
186 | bool functional = AnalyzeFunctionalSubset(n); |
187 | is_functional_block = is_functional_block && functional; |
188 | } |
189 | return is_functional_block; |
190 | } |
191 | |
192 | std::unordered_set<Node*> functional_nodes_; |
193 | std::unordered_set<Value*> mutated_values_; |
194 | std::shared_ptr<Graph> graph_; |
195 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
196 | size_t minSubgraphSize_ = 6; |
197 | }; |
198 | |
199 | void InlineFunctionalGraphs(Block* block) { |
200 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
201 | Node* n = *it; |
202 | it++; |
203 | for (Block* b : n->blocks()) { |
204 | InlineFunctionalGraphs(b); |
205 | } |
206 | if (n->kind() == prim::FunctionalGraph) { |
207 | SubgraphUtils::unmergeSubgraph(n); |
208 | } |
209 | } |
210 | } |
211 | |
212 | } // namespace |
213 | |
214 | void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph) { |
215 | // Run Constant Pooling so constants get hoisted |
216 | ConstantPooling(graph); |
217 | FunctionalGraphSlicer func(graph); |
218 | func.run(); |
219 | // Creation of Functional Subgraphs & Deinlining creates excess constants |
220 | ConstantPooling(graph); |
221 | } |
222 | |
223 | void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph) { |
224 | InlineFunctionalGraphs(graph->block()); |
225 | } |
226 | |
227 | } // namespace jit |
228 | } // namespace torch |
229 | |