1 | #include <torch/csrc/jit/passes/lower_tuples.h> |
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/jit/ir/constants.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
9 | |
10 | #include <utility> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | namespace { |
16 | |
17 | // operators where we expect to find tuples as inputs/outputs |
18 | // this is to assert we are only doing modifications when we know |
19 | // we can flatten tuples |
20 | std::unordered_set<Symbol> supported_ops = { |
21 | prim::If, |
22 | prim::Loop, |
23 | prim::Uninitialized, |
24 | prim::TupleUnpack, |
25 | prim::TupleConstruct, |
26 | prim::TupleIndex, |
27 | prim::TupleSlice, |
28 | prim::Param, |
29 | prim::Return, |
30 | prim::PythonOp, |
31 | aten::format, |
32 | prim::Uninitialized, |
33 | aten::__getitem__}; |
34 | |
35 | // Flatten block inputs and insert a tuple construct in the block |
36 | static void flattenTupleInLoopParams(Node* n, size_t index) { |
37 | auto input = n->inputs().at(index); |
38 | TupleTypePtr tt = input->type()->cast<TupleType>(); |
39 | TORCH_INTERNAL_ASSERT(tt); |
40 | |
41 | Block* block = n->blocks().at(0); |
42 | Node* block_node = n; |
43 | |
44 | std::vector<Value*> new_node_inputs = {}; |
45 | auto new_construct_node = |
46 | block->prependNode(block->owningGraph()->create(prim::TupleConstruct)); |
47 | for (size_t j = 0; j < tt->elements().size(); ++j) { |
48 | auto new_block_in = block->insertInput(index + j); |
49 | new_construct_node->addInput(new_block_in); |
50 | block_node->insertInput(index + j + 1, input->node()->inputs().at(j)); |
51 | } |
52 | new_construct_node->output()->setType(block->inputs().at(index - 1)->type()); |
53 | new_construct_node->copyMetadata(n); |
54 | block->inputs().at(index - 1)->replaceAllUsesWith( |
55 | new_construct_node->output()); |
56 | block->eraseInput(index - 1); |
57 | block_node->removeInput(index); |
58 | } |
59 | |
60 | // Flatten tuple outputs of the block node and append a TupleConstruct |
61 | // node after the block node if there is an outer block. |
62 | static void flattenTupleInBlockReturn(Node* n, size_t index) { |
63 | auto input = n->inputs().at(index); |
64 | Block* block = n->owningBlock(); |
65 | Node* block_node = block->owningNode(); |
66 | Node* new_construct_node = nullptr; |
67 | TupleTypePtr tt = input->type()->cast<TupleType>(); |
68 | TORCH_INTERNAL_ASSERT(tt); |
69 | |
70 | // 1- Add flattened tuple to block outputs |
71 | for (size_t j = 0; j < tt->elements().size(); ++j) { |
72 | block->insertOutput(index + j + 1, input->node()->inputs().at(j)); |
73 | } |
74 | block->eraseOutput(index); |
75 | |
76 | if (block_node == nullptr) |
77 | return; |
78 | // 2- For uses of the block node in the outer block, |
79 | // flatten the blocknode outputs and insert a tuple construct |
80 | // to replace that. |
81 | // Loop block has an extra element (iter counter) |
82 | if (block_node->kind() == prim::Loop) |
83 | index = index - 1; |
84 | auto tuple_output = block_node->outputs().at(index); |
85 | // When node has multiple blocks, do not flatten outputs on the second block |
86 | // again |
87 | if (!(tuple_output->type()->cast<TupleType>())) |
88 | return; |
89 | |
90 | new_construct_node = block->owningGraph()->create(prim::TupleConstruct); |
91 | new_construct_node->insertAfter(block_node); |
92 | for (size_t j = 0; j < tt->elements().size(); ++j) { |
93 | auto new_block_out = block_node->insertOutput(index + j + 1); |
94 | new_construct_node->addInput(new_block_out); |
95 | } |
96 | // Replace the block node with the new TupleConstruct node |
97 | new_construct_node->output()->setType(tuple_output->type()); |
98 | new_construct_node->copyMetadata(block_node); |
99 | tuple_output->replaceAllUsesWith(new_construct_node->output()); |
100 | block_node->eraseOutput(index); |
101 | } |
102 | |
103 | void removeTupleNodes(Node* n, bool must_remove_tuples) { |
104 | if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex && |
105 | n->kind() != prim::TupleSlice) { |
106 | return; |
107 | } |
108 | // tuple index has two inputs, tuple and index |
109 | auto construct_node = n->inputs().at(0)->node(); |
110 | if (construct_node->kind() != prim::TupleConstruct) { |
111 | if (must_remove_tuples) { |
112 | AT_ERROR(n->kind().toQualString(), " not matched to tuple construct" ); |
113 | } |
114 | return; |
115 | } |
116 | if (n->kind() == prim::TupleUnpack) { |
117 | for (size_t i = 0; i < n->outputs().size(); ++i) { |
118 | n->outputs()[i]->replaceAllUsesWith(construct_node->inputs().at(i)); |
119 | } |
120 | } else if (n->kind() == prim::TupleIndex) { |
121 | auto idx = n->inputs().at(1); |
122 | auto maybe_int = constant_as<int64_t>(idx); |
123 | if (!maybe_int) { |
124 | if (must_remove_tuples) { |
125 | AT_ERROR(n->sourceRange(), "tuple index with non-constant index" ); |
126 | } |
127 | return; |
128 | } |
129 | auto int_idx = *maybe_int; |
130 | size_t len = construct_node->output()->type()->containedTypes().size(); |
131 | if (int_idx < 0) { |
132 | int_idx += len; |
133 | } |
134 | // currently, we allow non-constant tuple index if the tuple is of one type. |
135 | // so we need to check bounds here |
136 | if (int_idx >= 0 && static_cast<size_t>(int_idx) < len) { |
137 | n->output()->replaceAllUsesWith(construct_node->inputs().at(int_idx)); |
138 | } |
139 | } else if (n->kind() == prim::TupleSlice) { |
140 | std::vector<Value*> values; |
141 | int64_t beg = n->i(attr::beg); |
142 | int64_t end = n->i(attr::end); |
143 | for (int64_t i = beg; i < end; i += 1) { |
144 | values.push_back(construct_node->inputs().at(i)); |
145 | } |
146 | auto graph = n->owningGraph(); |
147 | auto tuple_out = graph->createTuple(values); |
148 | tuple_out->copyMetadata(n); |
149 | WithInsertPoint insert(n); |
150 | graph->insertNode(tuple_out); |
151 | n->output()->replaceAllUsesWith(tuple_out->output()); |
152 | } |
153 | } |
154 | } // anonymous namespace |
155 | |
156 | static void LowerAllTuples(Block* block); |
157 | |
158 | static void RemoveTupleConstants(Node* n) { |
159 | if (!(n->kind() == prim::Constant && |
160 | n->output()->type()->cast<TupleType>())) { |
161 | return; |
162 | } |
163 | |
164 | auto g = n->owningGraph(); |
165 | auto tuple = toIValue(n->output()).value().toTuple(); |
166 | const auto& tuple_elements = tuple->elements(); |
167 | WithInsertPoint insert(n); |
168 | std::vector<Value*> elements; |
169 | for (const auto& elem : tuple_elements) { |
170 | auto constant = insertConstant(*n->owningGraph(), elem); |
171 | elements.push_back(constant); |
172 | } |
173 | auto tuple_type = n->output()->type()->expect<TupleType>(); |
174 | auto tuple_construct = g->insertNode(n->owningGraph()->createTuple( |
175 | elements, tuple_type->schema() ? std::move(tuple_type) : nullptr)); |
176 | tuple_construct->copyMetadata(n); |
177 | |
178 | // insert the tuple first before recursing on its elements, so that its |
179 | // elements will have a use |
180 | for (Value* elem : elements) { |
181 | RemoveTupleConstants(elem->node()); |
182 | } |
183 | |
184 | n->replaceAllUsesWith(tuple_construct); |
185 | } |
186 | |
187 | static void flattenInputs(Node* n, Node* insert_point) { |
188 | // flatten the input list op(a, tup, b) --> op(a, t0, t1, b) |
189 | for (size_t i = 0; i < n->inputs().size();) { |
190 | auto input = n->inputs()[i]; |
191 | if (TupleTypePtr tt = input->type()->cast<TupleType>()) { |
192 | TORCH_CHECK( |
193 | (input->node()->kind() == prim::TupleConstruct), |
194 | "tuple use not matched to tuple construct. Instead found: " , |
195 | n->kind().toQualString()); |
196 | if (supported_ops.count(n->kind()) > 0) { |
197 | if (n->kind() == prim::Loop) { |
198 | // This function supports all node types with blocks that take tuple |
199 | // inputs. |
200 | flattenTupleInLoopParams(n, i); |
201 | } else if (n->kind() == prim::Return) { |
202 | flattenTupleInBlockReturn(n, i); |
203 | } else { |
204 | for (size_t j = 0; j < tt->elements().size(); ++j) { |
205 | n->insertInput(i + 1 + j, input->node()->inputs().at(j)); |
206 | } |
207 | n->removeInput(i); |
208 | } |
209 | // note: no update to i |
210 | // since tuples might be nested we need to recursively scan |
211 | // the new flattened inputs |
212 | } else { |
213 | TORCH_WARN( |
214 | "tuple appears in op inputs, but this op does not forward tuples, " , |
215 | "unsupported kind: " , |
216 | n->kind().toQualString()); |
217 | ++i; |
218 | } |
219 | } else { |
220 | ++i; |
221 | } |
222 | } |
223 | } |
224 | |
225 | static void flattenOutputs(Node* n, Node* insert_point) { |
226 | // flatten the outputs list |
227 | auto& graph = *n->owningGraph(); |
228 | for (size_t i = 0; i < n->outputs().size();) { |
229 | Value* output = n->outputs()[i]; |
230 | if (!output->hasUses()) { |
231 | ++i; |
232 | continue; |
233 | } |
234 | |
235 | // (a, b, tup, c) -> (a, b, t0, t1, c) |
236 | // and: |
237 | // tup = (t0, t1) |
238 | // is placed at the current insertion point |
239 | if (TupleTypePtr tt = output->type()->cast<TupleType>()) { |
240 | if (supported_ops.count(n->kind()) > 0) { |
241 | for (const auto j : c10::irange(tt->elements().size())) { |
242 | n->insertOutput(i + 1 + j)->setType(tt->elements()[j]); |
243 | } |
244 | auto new_tup = |
245 | graph.createTuple(n->outputs().slice(i + 1, tt->elements().size())); |
246 | new_tup->copyMetadata(n); |
247 | new_tup->insertBefore(insert_point); |
248 | insert_point = new_tup; |
249 | output->replaceAllUsesWith(new_tup->output()); |
250 | n->eraseOutput(i); |
251 | // note: no update to i to handle nested tuples |
252 | } else { |
253 | TORCH_WARN( |
254 | "tuple appears in the op outputs, but this op does not forward tuples, " , |
255 | "unsupported kind: " , |
256 | n->kind().toQualString()); |
257 | ++i; |
258 | } |
259 | } else { |
260 | ++i; |
261 | } |
262 | } |
263 | } |
264 | |
265 | static void VisitNode(Node* n, Node* insert_point) { |
266 | // tuple construction operators will become dead when the unpacks are replaced |
267 | if (n->kind() == prim::TupleConstruct) { |
268 | return; |
269 | } |
270 | // note: changing the second argument to false changes this pass from a |
271 | // complete lowering pass to one that removes tuples when possible. When |
272 | // tuples are first-class in the interpreter, we should still run this pass to |
273 | // remove extraneous uses |
274 | if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex || |
275 | n->kind() == prim::TupleSlice) { |
276 | removeTupleNodes(n, /*must_remove_tuples*/ true); |
277 | return; |
278 | } |
279 | flattenInputs(n, insert_point); |
280 | for (auto b : n->blocks()) { |
281 | LowerAllTuples(b); |
282 | } |
283 | flattenOutputs(n, insert_point); |
284 | } |
285 | |
286 | static void LowerAllTuples(Block* block) { |
287 | // tuples in parameter lists of a block behave exactly the same as |
288 | // _outputs_ of normal instructions, since the param_node represents the |
289 | // parameters as outputs, we can handle it by simply visiting the node |
290 | VisitNode(block->param_node(), *block->nodes().begin()); |
291 | for (auto it = block->nodes().begin(), end = block->nodes().end(); |
292 | it != end;) { |
293 | auto n = *it++; |
294 | RemoveTupleConstants(n); |
295 | VisitNode(n, *it); |
296 | } |
297 | // tuples in return lists of blocks behave exactly the same as |
298 | // _inputs_ of normal instructions, so we can use VisitNode here as well |
299 | // insert_point is null because it will never be used since return nodes |
300 | // have no outputs |
301 | VisitNode(block->return_node(), nullptr); |
302 | } |
303 | |
304 | static void EnsureNoTuples(ArrayRef<Value*> values) { |
305 | for (Value* v : values) { |
306 | TORCH_CHECK( |
307 | v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples." ); |
308 | } |
309 | } |
310 | |
311 | static void EnsureNoTuples(Block* block) { |
312 | for (Node* n : block->nodes()) { |
313 | for (Block* b : n->blocks()) { |
314 | EnsureNoTuples(b); |
315 | } |
316 | EnsureNoTuples(n->outputs()); |
317 | } |
318 | } |
319 | |
320 | void LowerAllTuples(const std::shared_ptr<Graph>& graph) { |
321 | LowerAllTuples(graph->block()); |
322 | GRAPH_DUMP("After LowerAllTuples: " , graph); |
323 | EliminateDeadCode(graph->block()); |
324 | EnsureNoTuples(graph->block()); |
325 | } |
326 | |
327 | void LowerSimpleTuples(Block* block) { |
328 | for (auto n : block->nodes()) { |
329 | removeTupleNodes(n, /*must_remove_tuples*/ false); |
330 | for (auto b : n->blocks()) { |
331 | LowerSimpleTuples(b); |
332 | } |
333 | } |
334 | } |
335 | |
336 | void LowerSimpleTuples(const std::shared_ptr<Graph>& graph) { |
337 | LowerSimpleTuples(graph->block()); |
338 | GRAPH_DUMP("After LowerSimpleTuples: " , graph); |
339 | EliminateDeadCode(graph); |
340 | } |
341 | } // namespace jit |
342 | } // namespace torch |
343 | |