1#include <torch/csrc/jit/passes/canonicalize.h>
2
3#include <c10/util/irange.h>
4#include <torch/csrc/jit/ir/ir_views.h>
5
6namespace torch {
7namespace jit {
8
9// Canonicalize a graph, renumbering it so that all structurally equivalent
10// graphs have same numbers.
11// keep_unique_names: If false, canonicalizes unique names by removing them
12// and replacing them with normal value names.
13// Otherwise, ignores values with unique names.
14std::shared_ptr<Graph> Canonicalize(
15 const std::shared_ptr<Graph>& graph,
16 bool keep_unique_names) {
17 auto r = std::make_shared<Graph>(graph->current_scope());
18 std::unordered_map<Value*, Value*> rn_env;
19 auto rn_fn = [&](Value* v) { return rn_env.at(v); };
20 for (auto* input : graph->inputs()) {
21 auto* r_input = r->addInput();
22 r_input->copyMetadata(input);
23 if (!keep_unique_names)
24 r_input->setDebugName("");
25 rn_env[input] = r_input;
26 }
27 for (auto* node : graph->nodes()) {
28 auto* r_node = r->createClone(node, rn_fn);
29 if (!keep_unique_names) {
30 for (auto* output : r_node->outputs()) {
31 output->setDebugName("");
32 }
33 }
34 r->appendNode(r_node);
35 auto outputs = node->outputs();
36 auto r_outputs = r_node->outputs();
37 for (const auto i : c10::irange(outputs.size())) {
38 rn_env[outputs.at(i)] = r_outputs.at(i);
39 }
40 if (node->hasAttribute(attr::Subgraph)) {
41 r_node->g_(
42 attr::Subgraph,
43 Canonicalize(node->g(attr::Subgraph), keep_unique_names));
44 }
45 }
46 for (auto* output : graph->outputs()) {
47 r->registerOutput(rn_fn(output));
48 }
49
50 return r;
51}
52
53// Which index in b's owning Node is b
54size_t blockIndex(const Block* b) {
55 auto n = b->owningNode();
56 AT_ASSERT(n);
57 for (size_t i = 0; i < n->blocks().size(); ++i) {
58 if (n->blocks()[i] == b) {
59 return i;
60 }
61 }
62 AT_ASSERT(false);
63}
64
65/*
66 * This establishes a canonical ordering of nodes.
67 * If n1 and n2 are in the same block, whichever node appears first
68 * is before the other.
69 * If n1 and n2 are contained in different blocks of an if node,
70 * then whichever block is in the true block is ordered before the other.
71 * If n1 contains n2, then n1 is before n2. This has the nice property that
72 * whichever node appears first in a dump of the graph is before the other.
73 * NB: this is not a topological index. Topologically, two nodes in
74 * different blocks of an if node are not topologically < or > each other.
75 */
76bool isBefore(Node* n1, Node* n2) {
77 // Invalid to call with the same node as both args
78 AT_ASSERT(n1 != n2);
79
80 // Set n1 and n2 to be the number of blocks from the Graph block
81 size_t d_1 = n1->blocksFromGraphBlock();
82 size_t d_2 = n2->blocksFromGraphBlock();
83
84 for (; d_1 > d_2; --d_1) {
85 n1 = n1->owningBlock()->owningNode();
86 // n2 contains n1
87 if (n1 == n2) {
88 return false;
89 }
90 }
91
92 for (; d_2 > d_1; --d_2) {
93 n2 = n2->owningBlock()->owningNode();
94 // n1 contains n2
95 if (n2 == n1) {
96 return true;
97 }
98 }
99
100 // Now they are the same numer of blocks from the graph block,
101 // recurse upwards, checking if they are on the same block
102 while (true) {
103 if (n1->owningBlock() == n2->owningBlock()) {
104 return n1->isBefore(n2);
105 }
106
107 auto new_n1 = n1->owningBlock()->owningNode();
108 auto new_n2 = n2->owningBlock()->owningNode();
109
110 AT_ASSERT(new_n1 != nullptr);
111 AT_ASSERT(new_n2 != nullptr);
112
113 if (new_n1 == new_n2) {
114 // take whichever node is in the earlier block
115 auto index_1 = blockIndex(n1->owningBlock());
116 auto index_2 = blockIndex(n2->owningBlock());
117 return index_1 < index_2;
118 }
119
120 n1 = new_n1;
121 n2 = new_n2;
122 }
123}
124
125bool isBefore(const Use& a, const Use& b) {
126 // If two uses are the same node, we order on offset
127 if (a.user == b.user) {
128 return a.offset < b.offset;
129 }
130
131 return isBefore(a.user, b.user);
132}
133
134bool isAfter(const Use& a, const Use& b) {
135 if (a.user == b.user && a.offset == b.offset) {
136 return false;
137 }
138 return !isBefore(a, b);
139}
140
141bool isBeforeOrAfter(const Use& a, const Use& b, bool checking_before) {
142 return checking_before ? isBefore(a, b) : isAfter(a, b);
143}
144
145c10::optional<const Use> firstOrLastUse(Value* v, bool find_first) {
146 if (v->uses().empty()) {
147 return c10::nullopt;
148 }
149 Use extreme_use = v->uses()[0];
150 for (size_t i = 1; i < v->uses().size(); ++i) {
151 auto n_use = v->uses()[i];
152 if (!isBeforeOrAfter(extreme_use, n_use, find_first)) {
153 extreme_use = n_use;
154 }
155 }
156
157 return extreme_use;
158}
159
160std::vector<c10::optional<const Use>> gatherFirstUses(
161 at::ArrayRef<Value*> values) {
162 return fmap(values, [&](Value* v) -> c10::optional<const Use> {
163 return firstOrLastUse(v, true);
164 });
165}
166
167std::vector<size_t> sort_indexes(at::ArrayRef<Value*> values) {
168 // initialize original index locations
169 std::vector<size_t> idx(values.size());
170 std::iota(idx.begin(), idx.end(), 0);
171
172 std::vector<c10::optional<const Use>> first_uses = gatherFirstUses(values);
173
174 // Sort values based on canonical ordering of their first usage
175 std::sort(idx.begin(), idx.end(), [&first_uses](size_t i1, size_t i2) {
176 // if neither has any uses, use original ordering. Since the
177 // only values that jitter are ones added by the compiler and are guaranteed
178 // to have uses, original ordering is fine.
179 if (first_uses[i1] == c10::nullopt && first_uses[i2] == c10::nullopt) {
180 return i1 < i2;
181 }
182 if (first_uses[i1] == c10::nullopt) {
183 return false;
184 } else if (first_uses[i2] == c10::nullopt) {
185 return true;
186 }
187
188 auto fst_v1 = *first_uses[i1];
189 auto fst_v2 = *first_uses[i2];
190
191 return isBefore(fst_v1, fst_v2);
192 });
193
194 return idx;
195}
196
197void CanonicalizeLoopOutputs(Node* n) {
198 auto new_indices = sort_indexes(n->outputs());
199 LoopView(n).permuteLoopCarried(new_indices);
200}
201
202void CanonicalizeIfOutputs(Node* n) {
203 auto new_indices = sort_indexes(n->outputs());
204 IfView(n).permuteOutputs(new_indices);
205}
206
207void CanonicalizeOutputs(Block* block) {
208 // We iterate in reverse since ordering of a node's outputs is dependent on
209 // the value use following it in the graph
210 for (Node* n : block->nodes().reverse()) {
211 switch (n->kind()) {
212 case prim::Loop: {
213 CanonicalizeLoopOutputs(n);
214 } break;
215 case prim::If: {
216 CanonicalizeIfOutputs(n);
217 } break;
218 }
219 // Since an a control flow node's outputs are after
220 // the values outputted within its blocks, first canonicalize
221 // the nodes outputs and then recurse on its blocks
222 for (Block* b : n->blocks()) {
223 CanonicalizeOutputs(b);
224 }
225 }
226}
227
228// Canonicalize a graph's control flow node outputs. We do this to solve jitter
229// issues with outputs added to control flow nodes after the first pass of
230// compilation in ir_emitter.cpp
231void CanonicalizeOutputs(std::shared_ptr<Graph>& graph) {
232 CanonicalizeOutputs(graph->block());
233}
234} // namespace jit
235} // namespace torch
236