1 | #include <torch/csrc/jit/passes/canonicalize.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/jit/ir/ir_views.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | // Canonicalize a graph, renumbering it so that all structurally equivalent |
10 | // graphs have same numbers. |
11 | // keep_unique_names: If false, canonicalizes unique names by removing them |
12 | // and replacing them with normal value names. |
13 | // Otherwise, ignores values with unique names. |
14 | std::shared_ptr<Graph> Canonicalize( |
15 | const std::shared_ptr<Graph>& graph, |
16 | bool keep_unique_names) { |
17 | auto r = std::make_shared<Graph>(graph->current_scope()); |
18 | std::unordered_map<Value*, Value*> rn_env; |
19 | auto rn_fn = [&](Value* v) { return rn_env.at(v); }; |
20 | for (auto* input : graph->inputs()) { |
21 | auto* r_input = r->addInput(); |
22 | r_input->copyMetadata(input); |
23 | if (!keep_unique_names) |
24 | r_input->setDebugName("" ); |
25 | rn_env[input] = r_input; |
26 | } |
27 | for (auto* node : graph->nodes()) { |
28 | auto* r_node = r->createClone(node, rn_fn); |
29 | if (!keep_unique_names) { |
30 | for (auto* output : r_node->outputs()) { |
31 | output->setDebugName("" ); |
32 | } |
33 | } |
34 | r->appendNode(r_node); |
35 | auto outputs = node->outputs(); |
36 | auto r_outputs = r_node->outputs(); |
37 | for (const auto i : c10::irange(outputs.size())) { |
38 | rn_env[outputs.at(i)] = r_outputs.at(i); |
39 | } |
40 | if (node->hasAttribute(attr::Subgraph)) { |
41 | r_node->g_( |
42 | attr::Subgraph, |
43 | Canonicalize(node->g(attr::Subgraph), keep_unique_names)); |
44 | } |
45 | } |
46 | for (auto* output : graph->outputs()) { |
47 | r->registerOutput(rn_fn(output)); |
48 | } |
49 | |
50 | return r; |
51 | } |
52 | |
53 | // Which index in b's owning Node is b |
54 | size_t blockIndex(const Block* b) { |
55 | auto n = b->owningNode(); |
56 | AT_ASSERT(n); |
57 | for (size_t i = 0; i < n->blocks().size(); ++i) { |
58 | if (n->blocks()[i] == b) { |
59 | return i; |
60 | } |
61 | } |
62 | AT_ASSERT(false); |
63 | } |
64 | |
65 | /* |
66 | * This establishes a canonical ordering of nodes. |
67 | * If n1 and n2 are in the same block, whichever node appears first |
68 | * is before the other. |
69 | * If n1 and n2 are contained in different blocks of an if node, |
70 | * then whichever block is in the true block is ordered before the other. |
71 | * If n1 contains n2, then n1 is before n2. This has the nice property that |
72 | * whichever node appears first in a dump of the graph is before the other. |
73 | * NB: this is not a topological index. Topologically, two nodes in |
74 | * different blocks of an if node are not topologically < or > each other. |
75 | */ |
76 | bool isBefore(Node* n1, Node* n2) { |
77 | // Invalid to call with the same node as both args |
78 | AT_ASSERT(n1 != n2); |
79 | |
80 | // Set n1 and n2 to be the number of blocks from the Graph block |
81 | size_t d_1 = n1->blocksFromGraphBlock(); |
82 | size_t d_2 = n2->blocksFromGraphBlock(); |
83 | |
84 | for (; d_1 > d_2; --d_1) { |
85 | n1 = n1->owningBlock()->owningNode(); |
86 | // n2 contains n1 |
87 | if (n1 == n2) { |
88 | return false; |
89 | } |
90 | } |
91 | |
92 | for (; d_2 > d_1; --d_2) { |
93 | n2 = n2->owningBlock()->owningNode(); |
94 | // n1 contains n2 |
95 | if (n2 == n1) { |
96 | return true; |
97 | } |
98 | } |
99 | |
100 | // Now they are the same numer of blocks from the graph block, |
101 | // recurse upwards, checking if they are on the same block |
102 | while (true) { |
103 | if (n1->owningBlock() == n2->owningBlock()) { |
104 | return n1->isBefore(n2); |
105 | } |
106 | |
107 | auto new_n1 = n1->owningBlock()->owningNode(); |
108 | auto new_n2 = n2->owningBlock()->owningNode(); |
109 | |
110 | AT_ASSERT(new_n1 != nullptr); |
111 | AT_ASSERT(new_n2 != nullptr); |
112 | |
113 | if (new_n1 == new_n2) { |
114 | // take whichever node is in the earlier block |
115 | auto index_1 = blockIndex(n1->owningBlock()); |
116 | auto index_2 = blockIndex(n2->owningBlock()); |
117 | return index_1 < index_2; |
118 | } |
119 | |
120 | n1 = new_n1; |
121 | n2 = new_n2; |
122 | } |
123 | } |
124 | |
125 | bool isBefore(const Use& a, const Use& b) { |
126 | // If two uses are the same node, we order on offset |
127 | if (a.user == b.user) { |
128 | return a.offset < b.offset; |
129 | } |
130 | |
131 | return isBefore(a.user, b.user); |
132 | } |
133 | |
134 | bool isAfter(const Use& a, const Use& b) { |
135 | if (a.user == b.user && a.offset == b.offset) { |
136 | return false; |
137 | } |
138 | return !isBefore(a, b); |
139 | } |
140 | |
141 | bool isBeforeOrAfter(const Use& a, const Use& b, bool checking_before) { |
142 | return checking_before ? isBefore(a, b) : isAfter(a, b); |
143 | } |
144 | |
145 | c10::optional<const Use> firstOrLastUse(Value* v, bool find_first) { |
146 | if (v->uses().empty()) { |
147 | return c10::nullopt; |
148 | } |
149 | Use extreme_use = v->uses()[0]; |
150 | for (size_t i = 1; i < v->uses().size(); ++i) { |
151 | auto n_use = v->uses()[i]; |
152 | if (!isBeforeOrAfter(extreme_use, n_use, find_first)) { |
153 | extreme_use = n_use; |
154 | } |
155 | } |
156 | |
157 | return extreme_use; |
158 | } |
159 | |
160 | std::vector<c10::optional<const Use>> gatherFirstUses( |
161 | at::ArrayRef<Value*> values) { |
162 | return fmap(values, [&](Value* v) -> c10::optional<const Use> { |
163 | return firstOrLastUse(v, true); |
164 | }); |
165 | } |
166 | |
167 | std::vector<size_t> sort_indexes(at::ArrayRef<Value*> values) { |
168 | // initialize original index locations |
169 | std::vector<size_t> idx(values.size()); |
170 | std::iota(idx.begin(), idx.end(), 0); |
171 | |
172 | std::vector<c10::optional<const Use>> first_uses = gatherFirstUses(values); |
173 | |
174 | // Sort values based on canonical ordering of their first usage |
175 | std::sort(idx.begin(), idx.end(), [&first_uses](size_t i1, size_t i2) { |
176 | // if neither has any uses, use original ordering. Since the |
177 | // only values that jitter are ones added by the compiler and are guaranteed |
178 | // to have uses, original ordering is fine. |
179 | if (first_uses[i1] == c10::nullopt && first_uses[i2] == c10::nullopt) { |
180 | return i1 < i2; |
181 | } |
182 | if (first_uses[i1] == c10::nullopt) { |
183 | return false; |
184 | } else if (first_uses[i2] == c10::nullopt) { |
185 | return true; |
186 | } |
187 | |
188 | auto fst_v1 = *first_uses[i1]; |
189 | auto fst_v2 = *first_uses[i2]; |
190 | |
191 | return isBefore(fst_v1, fst_v2); |
192 | }); |
193 | |
194 | return idx; |
195 | } |
196 | |
197 | void CanonicalizeLoopOutputs(Node* n) { |
198 | auto new_indices = sort_indexes(n->outputs()); |
199 | LoopView(n).permuteLoopCarried(new_indices); |
200 | } |
201 | |
202 | void CanonicalizeIfOutputs(Node* n) { |
203 | auto new_indices = sort_indexes(n->outputs()); |
204 | IfView(n).permuteOutputs(new_indices); |
205 | } |
206 | |
207 | void CanonicalizeOutputs(Block* block) { |
208 | // We iterate in reverse since ordering of a node's outputs is dependent on |
209 | // the value use following it in the graph |
210 | for (Node* n : block->nodes().reverse()) { |
211 | switch (n->kind()) { |
212 | case prim::Loop: { |
213 | CanonicalizeLoopOutputs(n); |
214 | } break; |
215 | case prim::If: { |
216 | CanonicalizeIfOutputs(n); |
217 | } break; |
218 | } |
219 | // Since an a control flow node's outputs are after |
220 | // the values outputted within its blocks, first canonicalize |
221 | // the nodes outputs and then recurse on its blocks |
222 | for (Block* b : n->blocks()) { |
223 | CanonicalizeOutputs(b); |
224 | } |
225 | } |
226 | } |
227 | |
228 | // Canonicalize a graph's control flow node outputs. We do this to solve jitter |
229 | // issues with outputs added to control flow nodes after the first pass of |
230 | // compilation in ir_emitter.cpp |
231 | void CanonicalizeOutputs(std::shared_ptr<Graph>& graph) { |
232 | CanonicalizeOutputs(graph->block()); |
233 | } |
234 | } // namespace jit |
235 | } // namespace torch |
236 | |