1#include <torch/csrc/jit/passes/lower_tuples.h>
2
3#include <ATen/core/functional.h>
4#include <c10/util/Exception.h>
5#include <c10/util/irange.h>
6#include <torch/csrc/jit/ir/constants.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/passes/dead_code_elimination.h>
9
10#include <utility>
11
12namespace torch {
13namespace jit {
14
15namespace {
16
17// operators where we expect to find tuples as inputs/outputs
18// this is to assert we are only doing modifications when we know
19// we can flatten tuples
20std::unordered_set<Symbol> supported_ops = {
21 prim::If,
22 prim::Loop,
23 prim::Uninitialized,
24 prim::TupleUnpack,
25 prim::TupleConstruct,
26 prim::TupleIndex,
27 prim::TupleSlice,
28 prim::Param,
29 prim::Return,
30 prim::PythonOp,
31 aten::format,
32 prim::Uninitialized,
33 aten::__getitem__};
34
35// Flatten block inputs and insert a tuple construct in the block
36static void flattenTupleInLoopParams(Node* n, size_t index) {
37 auto input = n->inputs().at(index);
38 TupleTypePtr tt = input->type()->cast<TupleType>();
39 TORCH_INTERNAL_ASSERT(tt);
40
41 Block* block = n->blocks().at(0);
42 Node* block_node = n;
43
44 std::vector<Value*> new_node_inputs = {};
45 auto new_construct_node =
46 block->prependNode(block->owningGraph()->create(prim::TupleConstruct));
47 for (size_t j = 0; j < tt->elements().size(); ++j) {
48 auto new_block_in = block->insertInput(index + j);
49 new_construct_node->addInput(new_block_in);
50 block_node->insertInput(index + j + 1, input->node()->inputs().at(j));
51 }
52 new_construct_node->output()->setType(block->inputs().at(index - 1)->type());
53 new_construct_node->copyMetadata(n);
54 block->inputs().at(index - 1)->replaceAllUsesWith(
55 new_construct_node->output());
56 block->eraseInput(index - 1);
57 block_node->removeInput(index);
58}
59
60// Flatten tuple outputs of the block node and append a TupleConstruct
61// node after the block node if there is an outer block.
62static void flattenTupleInBlockReturn(Node* n, size_t index) {
63 auto input = n->inputs().at(index);
64 Block* block = n->owningBlock();
65 Node* block_node = block->owningNode();
66 Node* new_construct_node = nullptr;
67 TupleTypePtr tt = input->type()->cast<TupleType>();
68 TORCH_INTERNAL_ASSERT(tt);
69
70 // 1- Add flattened tuple to block outputs
71 for (size_t j = 0; j < tt->elements().size(); ++j) {
72 block->insertOutput(index + j + 1, input->node()->inputs().at(j));
73 }
74 block->eraseOutput(index);
75
76 if (block_node == nullptr)
77 return;
78 // 2- For uses of the block node in the outer block,
79 // flatten the blocknode outputs and insert a tuple construct
80 // to replace that.
81 // Loop block has an extra element (iter counter)
82 if (block_node->kind() == prim::Loop)
83 index = index - 1;
84 auto tuple_output = block_node->outputs().at(index);
85 // When node has multiple blocks, do not flatten outputs on the second block
86 // again
87 if (!(tuple_output->type()->cast<TupleType>()))
88 return;
89
90 new_construct_node = block->owningGraph()->create(prim::TupleConstruct);
91 new_construct_node->insertAfter(block_node);
92 for (size_t j = 0; j < tt->elements().size(); ++j) {
93 auto new_block_out = block_node->insertOutput(index + j + 1);
94 new_construct_node->addInput(new_block_out);
95 }
96 // Replace the block node with the new TupleConstruct node
97 new_construct_node->output()->setType(tuple_output->type());
98 new_construct_node->copyMetadata(block_node);
99 tuple_output->replaceAllUsesWith(new_construct_node->output());
100 block_node->eraseOutput(index);
101}
102
103void removeTupleNodes(Node* n, bool must_remove_tuples) {
104 if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
105 n->kind() != prim::TupleSlice) {
106 return;
107 }
108 // tuple index has two inputs, tuple and index
109 auto construct_node = n->inputs().at(0)->node();
110 if (construct_node->kind() != prim::TupleConstruct) {
111 if (must_remove_tuples) {
112 AT_ERROR(n->kind().toQualString(), " not matched to tuple construct");
113 }
114 return;
115 }
116 if (n->kind() == prim::TupleUnpack) {
117 for (size_t i = 0; i < n->outputs().size(); ++i) {
118 n->outputs()[i]->replaceAllUsesWith(construct_node->inputs().at(i));
119 }
120 } else if (n->kind() == prim::TupleIndex) {
121 auto idx = n->inputs().at(1);
122 auto maybe_int = constant_as<int64_t>(idx);
123 if (!maybe_int) {
124 if (must_remove_tuples) {
125 AT_ERROR(n->sourceRange(), "tuple index with non-constant index");
126 }
127 return;
128 }
129 auto int_idx = *maybe_int;
130 size_t len = construct_node->output()->type()->containedTypes().size();
131 if (int_idx < 0) {
132 int_idx += len;
133 }
134 // currently, we allow non-constant tuple index if the tuple is of one type.
135 // so we need to check bounds here
136 if (int_idx >= 0 && static_cast<size_t>(int_idx) < len) {
137 n->output()->replaceAllUsesWith(construct_node->inputs().at(int_idx));
138 }
139 } else if (n->kind() == prim::TupleSlice) {
140 std::vector<Value*> values;
141 int64_t beg = n->i(attr::beg);
142 int64_t end = n->i(attr::end);
143 for (int64_t i = beg; i < end; i += 1) {
144 values.push_back(construct_node->inputs().at(i));
145 }
146 auto graph = n->owningGraph();
147 auto tuple_out = graph->createTuple(values);
148 tuple_out->copyMetadata(n);
149 WithInsertPoint insert(n);
150 graph->insertNode(tuple_out);
151 n->output()->replaceAllUsesWith(tuple_out->output());
152 }
153}
154} // anonymous namespace
155
156static void LowerAllTuples(Block* block);
157
158static void RemoveTupleConstants(Node* n) {
159 if (!(n->kind() == prim::Constant &&
160 n->output()->type()->cast<TupleType>())) {
161 return;
162 }
163
164 auto g = n->owningGraph();
165 auto tuple = toIValue(n->output()).value().toTuple();
166 const auto& tuple_elements = tuple->elements();
167 WithInsertPoint insert(n);
168 std::vector<Value*> elements;
169 for (const auto& elem : tuple_elements) {
170 auto constant = insertConstant(*n->owningGraph(), elem);
171 elements.push_back(constant);
172 }
173 auto tuple_type = n->output()->type()->expect<TupleType>();
174 auto tuple_construct = g->insertNode(n->owningGraph()->createTuple(
175 elements, tuple_type->schema() ? std::move(tuple_type) : nullptr));
176 tuple_construct->copyMetadata(n);
177
178 // insert the tuple first before recursing on its elements, so that its
179 // elements will have a use
180 for (Value* elem : elements) {
181 RemoveTupleConstants(elem->node());
182 }
183
184 n->replaceAllUsesWith(tuple_construct);
185}
186
187static void flattenInputs(Node* n, Node* insert_point) {
188 // flatten the input list op(a, tup, b) --> op(a, t0, t1, b)
189 for (size_t i = 0; i < n->inputs().size();) {
190 auto input = n->inputs()[i];
191 if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
192 TORCH_CHECK(
193 (input->node()->kind() == prim::TupleConstruct),
194 "tuple use not matched to tuple construct. Instead found: ",
195 n->kind().toQualString());
196 if (supported_ops.count(n->kind()) > 0) {
197 if (n->kind() == prim::Loop) {
198 // This function supports all node types with blocks that take tuple
199 // inputs.
200 flattenTupleInLoopParams(n, i);
201 } else if (n->kind() == prim::Return) {
202 flattenTupleInBlockReturn(n, i);
203 } else {
204 for (size_t j = 0; j < tt->elements().size(); ++j) {
205 n->insertInput(i + 1 + j, input->node()->inputs().at(j));
206 }
207 n->removeInput(i);
208 }
209 // note: no update to i
210 // since tuples might be nested we need to recursively scan
211 // the new flattened inputs
212 } else {
213 TORCH_WARN(
214 "tuple appears in op inputs, but this op does not forward tuples, ",
215 "unsupported kind: ",
216 n->kind().toQualString());
217 ++i;
218 }
219 } else {
220 ++i;
221 }
222 }
223}
224
225static void flattenOutputs(Node* n, Node* insert_point) {
226 // flatten the outputs list
227 auto& graph = *n->owningGraph();
228 for (size_t i = 0; i < n->outputs().size();) {
229 Value* output = n->outputs()[i];
230 if (!output->hasUses()) {
231 ++i;
232 continue;
233 }
234
235 // (a, b, tup, c) -> (a, b, t0, t1, c)
236 // and:
237 // tup = (t0, t1)
238 // is placed at the current insertion point
239 if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
240 if (supported_ops.count(n->kind()) > 0) {
241 for (const auto j : c10::irange(tt->elements().size())) {
242 n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
243 }
244 auto new_tup =
245 graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
246 new_tup->copyMetadata(n);
247 new_tup->insertBefore(insert_point);
248 insert_point = new_tup;
249 output->replaceAllUsesWith(new_tup->output());
250 n->eraseOutput(i);
251 // note: no update to i to handle nested tuples
252 } else {
253 TORCH_WARN(
254 "tuple appears in the op outputs, but this op does not forward tuples, ",
255 "unsupported kind: ",
256 n->kind().toQualString());
257 ++i;
258 }
259 } else {
260 ++i;
261 }
262 }
263}
264
265static void VisitNode(Node* n, Node* insert_point) {
266 // tuple construction operators will become dead when the unpacks are replaced
267 if (n->kind() == prim::TupleConstruct) {
268 return;
269 }
270 // note: changing the second argument to false changes this pass from a
271 // complete lowering pass to one that removes tuples when possible. When
272 // tuples are first-class in the interpreter, we should still run this pass to
273 // remove extraneous uses
274 if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
275 n->kind() == prim::TupleSlice) {
276 removeTupleNodes(n, /*must_remove_tuples*/ true);
277 return;
278 }
279 flattenInputs(n, insert_point);
280 for (auto b : n->blocks()) {
281 LowerAllTuples(b);
282 }
283 flattenOutputs(n, insert_point);
284}
285
286static void LowerAllTuples(Block* block) {
287 // tuples in parameter lists of a block behave exactly the same as
288 // _outputs_ of normal instructions, since the param_node represents the
289 // parameters as outputs, we can handle it by simply visiting the node
290 VisitNode(block->param_node(), *block->nodes().begin());
291 for (auto it = block->nodes().begin(), end = block->nodes().end();
292 it != end;) {
293 auto n = *it++;
294 RemoveTupleConstants(n);
295 VisitNode(n, *it);
296 }
297 // tuples in return lists of blocks behave exactly the same as
298 // _inputs_ of normal instructions, so we can use VisitNode here as well
299 // insert_point is null because it will never be used since return nodes
300 // have no outputs
301 VisitNode(block->return_node(), nullptr);
302}
303
304static void EnsureNoTuples(ArrayRef<Value*> values) {
305 for (Value* v : values) {
306 TORCH_CHECK(
307 v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples.");
308 }
309}
310
311static void EnsureNoTuples(Block* block) {
312 for (Node* n : block->nodes()) {
313 for (Block* b : n->blocks()) {
314 EnsureNoTuples(b);
315 }
316 EnsureNoTuples(n->outputs());
317 }
318}
319
320void LowerAllTuples(const std::shared_ptr<Graph>& graph) {
321 LowerAllTuples(graph->block());
322 GRAPH_DUMP("After LowerAllTuples: ", graph);
323 EliminateDeadCode(graph->block());
324 EnsureNoTuples(graph->block());
325}
326
327void LowerSimpleTuples(Block* block) {
328 for (auto n : block->nodes()) {
329 removeTupleNodes(n, /*must_remove_tuples*/ false);
330 for (auto b : n->blocks()) {
331 LowerSimpleTuples(b);
332 }
333 }
334}
335
336void LowerSimpleTuples(const std::shared_ptr<Graph>& graph) {
337 LowerSimpleTuples(graph->block());
338 GRAPH_DUMP("After LowerSimpleTuples: ", graph);
339 EliminateDeadCode(graph);
340}
341} // namespace jit
342} // namespace torch
343