1#include <torch/csrc/jit/passes/concat_opt.h>
2
3#include <algorithm>
4#include <unordered_set>
5#include <vector>
6
7#include <torch/csrc/jit/ir/alias_analysis.h>
8#include <torch/csrc/jit/ir/ir.h>
9#include <torch/csrc/jit/ir/named_value.h>
10#include <torch/csrc/jit/jit_log.h>
11#include <torch/csrc/jit/passes/constant_pooling.h>
12#include <torch/csrc/jit/passes/dead_code_elimination.h>
13#include <torch/csrc/jit/passes/remove_mutation.h>
14#include <torch/csrc/jit/runtime/graph_iterator.h>
15
16namespace torch {
17namespace jit {
18
19namespace {
20
21void removeCatNodeFromGraph(Node* n) {
22 TORCH_INTERNAL_ASSERT(n->kind() == aten::cat);
23 auto inp_list = n->input(0);
24 GRAPH_UPDATE("Deleting\n", *n);
25 n->destroy();
26 if (!inp_list->hasUses()) {
27 GRAPH_UPDATE("Deleting\n", *inp_list->node());
28 inp_list->node()->destroy();
29 }
30}
31
32bool equal(at::ArrayRef<Value*> list1, at::ArrayRef<Value*> list2) {
33 return list1.size() == list2.size() &&
34 std::equal(list1.begin(), list1.end(), list2.begin());
35}
36
37class ConcatCommonInputsEliminator {
38 public:
39 explicit ConcatCommonInputsEliminator(std::shared_ptr<Graph> graph)
40 : graph_(std::move(graph)) {}
41
42 bool run() {
43 handleBlock(graph_->block());
44 return postprocess();
45 }
46
47 private:
48 void handleBlock(Block* block) {
49 for (auto node : block->nodes()) {
50 if (node->kind() == prim::VarConcat) {
51 handleCat(node);
52 }
53 for (Block* block : node->blocks()) {
54 handleBlock(block);
55 }
56 }
57 }
58
59 void handleCat(Node* node) {
60 GRAPH_DEBUG("Considering cat node for CSE opt: ", node);
61
62 auto curr_all_inputs = node->inputs();
63 auto curr_tensor_inputs =
64 curr_all_inputs.slice(0, curr_all_inputs.size() - 1);
65 auto curr_dim = curr_all_inputs.back();
66
67 // Save the input list and the current cat node, so that this can be
68 // used for subsequent cat nodes, unless there are writes to this cat
69 // node. When there are writes to this cat node, its output does not
70 // represent this concatenated list beyond the writes. Currently, we do
71 // not perform such fine-grained analysis. So, if there are any writes to
72 // the output, we do not use this cat node for optimization here.
73 if (!getOrCreateAliasDb()->hasWriters(node->output())) {
74 concated_outputs_.insert(node);
75 }
76
77 if (curr_tensor_inputs.size() <= 2) {
78 // The case when concat has 2 input tensors could only be optimized if
79 // there is another concat of the exact same 2 input tensors. That case
80 // is expected to be handled by the CSE pass.
81 return;
82 }
83
84 // Now, we check if the first N-1 elements in %inputs appeared in any of
85 // the previous cat ops.
86 //
87 // Example:
88 // %11 = prim::VarConcat(%0, %1, <dim>)
89 // ...
90 // %13 = prim::VarConcat(%0, %1, %2, <dim>) // first 2 inputs same as %11
91 // ...
92 // = %13 ... // Use %13
93 //
94 // After CSE opt:
95 // %11 = prim::VarConcat(%0, %1, <dim>)
96 // ...
97 // %14 = prim::VarConcat(%11, %2, <dim>) // Replace first 2 inputs
98 // // with %11
99 // ...
100 // = %14 ... // Replace use of %13 with %14
101
102 auto curr_tensor_inputs_prefix =
103 curr_tensor_inputs.slice(0, curr_tensor_inputs.size() - 1);
104 for (const auto& prev : concated_outputs_) {
105 auto prev_all_inputs = prev->inputs();
106 auto prev_tensor_inputs =
107 prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
108 auto prev_dim = prev_all_inputs.back();
109 if (equal(curr_tensor_inputs_prefix, prev_tensor_inputs) &&
110 curr_dim == prev_dim) {
111 if (!node->isDominatedBy(prev)) {
112 // We can't use the previous concatenated output if it does not
113 // dominate the current concat node.
114 continue;
115 }
116
117 std::vector<Value*> new_inputs = {
118 prev->output(), curr_tensor_inputs.back(), curr_dim};
119 auto new_concat =
120 node->owningGraph()->create(prim::VarConcat, new_inputs);
121 new_concat->output()->setType(node->output()->type());
122 concats_to_replace_[node] = new_concat;
123 return;
124 }
125 }
126
127 // Now, we check if the last N-1 elements in %inputs appeared in any of
128 // the previous cat ops.
129 //
130 // Example:
131 // %10 = prim::ListConstruct(%1, %2)
132 // %11 = aten::cat(%10, ...)
133 // ...
134 // %12 = prim::ListConstruct(%0, %1, %2) // last 2 inputs same as %11
135 // %13 = aten::cat(%12, ...)
136 // ...
137 // = %13 ... // Use %13
138 //
139 // After CSE opt:
140 // %10 = prim::ListConstruct(%0, %1)
141 // %11 = aten::cat(%10, ...)
142 // ...
143 // %12 = prim::ListConstruct(%0, %11) // Replace last 2 inputs with %11
144 // %13 = aten::cat(%12, ...)
145 // ...
146 // = %13 ... // Use %13
147 auto curr_tensor_inputs_suffix =
148 curr_tensor_inputs.slice(1, curr_tensor_inputs.size() - 1);
149 for (const auto& prev : concated_outputs_) {
150 auto prev_all_inputs = prev->inputs();
151 auto prev_tensor_inputs =
152 prev_all_inputs.slice(0, prev_all_inputs.size() - 1);
153 auto prev_dim = prev_all_inputs.back();
154 if (equal(curr_tensor_inputs_suffix, prev_tensor_inputs) &&
155 curr_dim == prev_dim) {
156 if (!node->isDominatedBy(prev)) {
157 // We can't use the previous concatenated list if it does not
158 // dominate the current list.
159 continue;
160 }
161
162 std::vector<Value*> new_inputs = {
163 curr_tensor_inputs.front(), prev->output(), curr_dim};
164 auto new_concat =
165 node->owningGraph()->create(prim::VarConcat, new_inputs);
166 new_concat->output()->setType(node->output()->type());
167 concats_to_replace_[node] = new_concat;
168 return;
169 }
170 }
171
172 // Do we need to handle other cases where N-2 or lesser elements from
173 // %inputs appear in any of the previous cat ops?
174 // TODO.
175 }
176
177 bool postprocess() {
178 // Replace the list nodes that have been marked.
179 bool changed = false;
180 for (auto it : concats_to_replace_) {
181 auto curr_node = it.first;
182 auto new_node = it.second;
183 GRAPH_UPDATE("Inserting\n", *new_node, "before\n", *curr_node);
184 new_node->insertBefore(curr_node);
185 GRAPH_UPDATE("Replacing uses of\n", *curr_node, "with\n", *new_node);
186 curr_node->output()->replaceAllUsesWith(new_node->output());
187 GRAPH_UPDATE("Deleting\n", *curr_node);
188 curr_node->destroy();
189 changed = true;
190 }
191 return changed;
192 }
193
194 AliasDb* getOrCreateAliasDb() {
195 if (!aliasDb_) {
196 aliasDb_ = std::make_unique<AliasDb>(graph_);
197 }
198 return aliasDb_.get();
199 }
200
201 std::shared_ptr<Graph> graph_;
202 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
203
204 std::unordered_set<Node*> concated_outputs_;
205 std::unordered_map<Node*, Node*> concats_to_replace_;
206};
207
208} // namespace
209
210bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph) {
211 GRAPH_DUMP("Before eliminating Concat common inputs", graph);
212 bool changed = ConcatCommonInputsEliminator(graph).run();
213 if (changed) {
214 GRAPH_DUMP("After eliminating Concat common inputs", graph);
215 }
216 return changed;
217}
218
219namespace {
220
221class ConcatExpander {
222 public:
223 explicit ConcatExpander(std::shared_ptr<Graph> graph)
224 : graph_(std::move(graph)) {}
225
226 void run() {
227 handleBlock(graph_->block());
228 cleanupExpandedCatOps();
229 GRAPH_DUMP("Before reusing copy buffers: ", graph_);
230 reuseBuffersInCopies();
231 }
232
233 private:
234 void handleBlock(Block* block) {
235 for (auto node : block->nodes()) {
236 if (node->kind() == aten::cat) {
237 expandCat(node);
238 }
239 for (Block* block : node->blocks()) {
240 handleBlock(block);
241 }
242 }
243 }
244
245 // Expand cat node into multiple copy nodes.
246 //
247 // Example:
248 // %2 = aten::clamp(%0, ...)
249 // %3 = aten::clamp(%1, ...)
250 // %10 = prim::ListConstruct(%2, %3)
251 // %11 = aten::cat(%10, ...)
252 // ...
253 // = %11 ... // Use %11
254 //
255 // After expanding cat:
256 // %2 = aten::clamp(%0, ...)
257 // %3 = aten::clamp(%1, ...)
258 // %20 = aten::empty(...) // cat output buffer
259 // %21 = aten::slice(%20, ...) // slice for %2
260 // %22 = aten::copy_(%21, %2) // copy %2
261 // %23 = aten::slice(%20, ...) // slice for %3
262 // %24 = aten::copy_(%23, %3) // copy %3
263 // ...
264 // = %20 ... // Use %20 in place of %11
265 void expandCat(Node* node) {
266 GRAPH_DEBUG("Considering cat node for expansion: ", node);
267 // Do not optimize cat nodes whose inputs are mutated in the graph.
268 // TODO: Improve this by checking if it is mutated in the graph region
269 // where this optimization is applied.
270 if (getOrCreateAliasDb()->hasWriters(node->input(0))) {
271 return;
272 }
273 if (node->input(0)->node()->kind() != prim::ListConstruct) {
274 // Unknown form of input to `cat` op.
275 return;
276 }
277 if (!allShapesAreKnown(node)) {
278 // Can't expand when shapes are not known for the `cat` op.
279 return;
280 }
281 for (auto cat_inp : node->input(0)->node()->inputs()) {
282 if (!shapeIsKnown(cat_inp)) {
283 // Can't expand when shapes of the inputs to `cat` are not known.
284 return;
285 }
286 }
287 // TODO: Handle non-contiguous Tensors.
288 // For example, how to handle the cases where the inputs are all channels
289 // last?
290
291 auto maybe_cat_dim = constant_as<int64_t>(node->input(1));
292 if (!maybe_cat_dim) {
293 // Can't expand when cat dimension is not a constant.
294 return;
295 }
296 auto cat_dim_value = maybe_cat_dim.value();
297 auto cat_dim = node->input(1);
298
299 // Set the insertion point to the curent `cat` node.
300 WithInsertPoint guard(node);
301 auto none = graph_->insertConstant(IValue());
302 auto one = graph_->insertConstant(1);
303
304 // Insert the constants needed for the `cat` output buffer size.
305 auto tensortype = node->output()->type()->expect<TensorType>();
306 TORCH_INTERNAL_ASSERT(tensortype);
307 auto tensortype_sizes = tensortype->sizes();
308 std::vector<Value*> cat_out_size;
309 for (size_t i = 0; i < tensortype_sizes.size(); ++i) {
310 cat_out_size.push_back(graph_->insertConstant(tensortype_sizes[i]));
311 }
312
313 // Create a list of int for `cat` output buffer size.
314 auto cat_out_size_list = graph_->createList(IntType::get(), cat_out_size);
315 cat_out_size_list->insertBefore(node);
316
317 // Create an empty buffer to be used as `cat` output buffer.
318 // TODO: Handle tensors with different dtype, layout, device, memory
319 // format, etc.
320 auto cat_out_empty = graph_->create(
321 aten::empty,
322 {cat_out_size_list->output(), none, none, none, none, none});
323 cat_out_empty->insertBefore(node);
324
325 // For every input to this `cat` node:
326 // * Create a slice of `cat` output buffer.
327 auto cat_out_value = cat_out_empty->output();
328 auto cat_inp_list = node->input(0)->node();
329 int start_idx = 0;
330 auto start = graph_->insertConstant(start_idx);
331 for (auto cat_inp : cat_inp_list->inputs()) {
332 // Create a slice of the cat output buffer that correspond to
333 // this input size and position in the output.
334 auto cat_inp_tensor_type =
335 dynamic_cast<TensorType*>(cat_inp->type().get());
336 TORCH_INTERNAL_ASSERT(cat_inp_tensor_type);
337 TORCH_INTERNAL_ASSERT(cat_inp_tensor_type->dim());
338 auto cat_inp_tensortype_sizes = cat_inp_tensor_type->sizes();
339 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
340 int end_idx = start_idx + *cat_inp_tensortype_sizes[cat_dim_value];
341 auto end = graph_->insertConstant(end_idx);
342
343 auto slice = graph_->create(
344 aten::slice, {cat_out_value, cat_dim, start, end, one});
345 GRAPH_UPDATE("Inserting\n", *slice, "before\n", *node);
346 slice->insertBefore(node);
347 slices_added_.push_back(slice);
348
349 // Insert a copy from this input to the output slice.
350 auto copy = graph_->create(aten::copy_, {slice->output(), cat_inp});
351 GRAPH_UPDATE("Inserting\n", *copy, "before\n", *node);
352 copy->insertBefore(node);
353 copies_added_.push_back(copy);
354
355 start_idx = end_idx;
356 start = end;
357 }
358
359 // Replace the uses of `cat` node with the cat output buffer.
360 replace_uses_with_[node->output()] = cat_out_value;
361 nodes_to_remove_.insert(node);
362 }
363
364 bool shapeIsKnown(Value* v) {
365 if (v->type()->cast<TensorType>()) {
366 if (!v->isCompleteTensor()) {
367 return false;
368 }
369 if (*v->type()->castRaw<TensorType>()->dim() == 0) {
370 return false;
371 }
372 }
373 return true;
374 }
375 bool allShapesAreKnown(Node* node) {
376 // TODO: Relax the checks to support dynamic shapes
377 for (Value* input : node->inputs()) {
378 if (!shapeIsKnown(input)) {
379 return false;
380 }
381 }
382 for (Value* output : node->outputs()) {
383 if (!shapeIsKnown(output)) {
384 return false;
385 }
386 }
387 return true;
388 }
389
390 void cleanupExpandedCatOps() {
391 for (auto it : replace_uses_with_) {
392 GRAPH_UPDATE(
393 "Replacing uses of\n",
394 *it.first->node(),
395 "with\n",
396 *it.second->node());
397 it.first->replaceAllUsesWith(it.second);
398 }
399 for (auto n : nodes_to_remove_) {
400 removeCatNodeFromGraph(n);
401 }
402 }
403
404 void moveBefore(Node* node, Node* before) {
405 // In order to move a node before another node, we need to move
406 // all the nodes it depends on as well.
407 for (auto inp : node->inputs()) {
408 moveBefore(inp->node(), before);
409 }
410 node->moveBefore(before);
411 }
412
413 // Reuse buffers in copies wherever possible.
414 //
415 // For example, consider the following sequence of ops:
416 // %10 = prim::ListConstruct(%0, %1)
417 // %11 = aten::cat(%10, ...)
418 // ...
419 // %12 = prim::ListConstruct(%11, %2) // Uses the result of above cat
420 // %13 = aten::cat(%12, ...)
421 //
422 // Once these cat ops are expanded into copies, we will have two buffers; one
423 // for %11 and another for %13. This can be optimized by using only one
424 // buffer. We can only have the buffer that represents %13 and use a view
425 // (slice) of that one as the buffer for %11.
426 //
427 // If any of the copies added earlier has `aten::empty` as its source,
428 // those cases can be replaced with a single buffer.
429 //
430 // Example:
431 // %20 = aten::empty(...) // cat.1 output buffer
432 // %21 = aten::slice(%20, ...)
433 // %22 = aten::copy_(%21, %2)
434 // %23 = aten::slice(%20, ...)
435 // %24 = aten::copy_(%23, %3)
436 // ...
437 // %30 = aten::empty(...) // cat.2 output buffer
438 // %31 = aten::slice(%30, ...)
439 // %32 = aten::copy_(%31, %20) // src of copy is aten::empty
440 // // so, we reuse this buffer above
441 // %33 = aten::slice(%30, ...)
442 // %34 = aten::copy_(%33, %4)
443 //
444 // After reusing copy buffers:
445 // %30 = aten::empty(...) // cat.2 output buffer
446 // %31 = aten::slice(%30, ...) // move %31 and inputs before %20
447 // %21 = aten::slice(%31, ...) // use %31 in place of %20
448 // %22 = aten::copy_(%21, %2)
449 // %23 = aten::slice(%31, ...) // use %31 in place of %20
450 // %24 = aten::copy_(%23, %3)
451 // ...
452 // ... // copy to %31 is now removed
453 // %33 = aten::slice(%30, ...)
454 // %34 = aten::copy_(%33, %4)
455 void reuseBuffersInCopies() {
456 for (auto copy : copies_added_) {
457 auto src = copy->input(1);
458 auto dst = copy->input(0);
459 if (src->node()->kind() != aten::empty) {
460 continue;
461 }
462
463 // Move the destination node before the source.
464 GRAPH_UPDATE("Moving\n", *dst->node(), "before\n", *src->node());
465 moveBefore(dst->node(), src->node());
466
467 GRAPH_UPDATE("Replacing\n", *src->node(), "with\n", *dst->node());
468 src->replaceAllUsesWith(dst);
469
470 GRAPH_UPDATE("Deleting\n", *src->node());
471 src->node()->destroy();
472
473 GRAPH_UPDATE("Deleting\n", *copy);
474 copy->destroy();
475 }
476 }
477
478 AliasDb* getOrCreateAliasDb() {
479 if (!aliasDb_) {
480 aliasDb_ = std::make_unique<AliasDb>(graph_);
481 }
482 return aliasDb_.get();
483 }
484
485 std::shared_ptr<Graph> graph_;
486 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
487
488 std::unordered_set<Node*> nodes_to_remove_;
489 std::unordered_map<Value*, Value*> replace_uses_with_;
490 std::vector<Node*> copies_added_;
491 std::vector<Node*> slices_added_;
492};
493
494} // namespace
495
496void ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph>& graph) {
497 ConcatExpander(graph).run();
498 GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph);
499}
500
501namespace {
502
503size_t determineUsageIdx(Value* value, Node* user) {
504 const auto idx =
505 std::find(user->inputs().begin(), user->inputs().end(), value) -
506 user->inputs().begin();
507 TORCH_CHECK(idx != user->inputs().size());
508 return idx;
509}
510
511std::vector<Value*> getConcatInputs(Node* concat) {
512 TORCH_CHECK(concat->kind() == aten::cat);
513 auto* list = concat->input(0);
514 auto* list_construct = list->node();
515 TORCH_CHECK(list_construct->kind() == prim::ListConstruct);
516 return list_construct->inputs().vec();
517}
518
519class ConcatCombiner {
520 public:
521 explicit ConcatCombiner(std::shared_ptr<Graph> graph)
522 : graph_(std::move(graph)), aliasDb_(graph_) {}
523
524 bool run() {
525 collectOptimizableConcats();
526 bool changed = combineConcats();
527 if (changed) {
528 EliminateDeadCode(graph_);
529 }
530 return changed;
531 }
532
533 private:
534 // Given a concat node, see if it can be optimized with another.
535 // If so, add a CombinablePair to combinable_concats_.
536 void handleConcat(Node* node) {
537 auto* list = node->input(0);
538 auto* list_node = list->node();
539
540 const auto dim_opt = toIValue(node->input(1));
541 // We need to be able to determine dim statically to match it with another
542 // concat.
543 if (!dim_opt || !dim_opt->isInt()) {
544 return;
545 }
546 const auto dim = dim_opt->toInt();
547
548 // Check that the input of this node is an unmodified list construct
549 if (list_node->kind() != prim::ListConstruct ||
550 !aliasDb_.couldMoveBeforeTopologically(list_node, node)) {
551 return;
552 }
553
554 // Check that the only output of this node is used in an unmodified list
555 // construct.
556 const auto& concat_uses = node->output()->uses();
557 if (concat_uses.size() != 1) {
558 return;
559 }
560
561 auto* next_list = concat_uses[0].user;
562 if (next_list->kind() != prim::ListConstruct) {
563 return;
564 }
565
566 const auto& next_list_uses = next_list->output()->uses();
567 if (next_list_uses.size() != 1) {
568 return;
569 }
570
571 auto* next_concat = next_list_uses[0].user;
572
573 if (next_concat->kind() == aten::cat) {
574 // Dimension must be determined statically and match the one we've already
575 // seen.
576 const auto next_dim_opt = toIValue(next_concat->input(1));
577 if (!next_dim_opt || next_dim_opt->toInt() != dim) {
578 return;
579 }
580 combinable_concats_.emplace_back(
581 node, next_concat, determineUsageIdx(node->output(), next_list));
582 }
583 }
584
585 void collectOptimizableConcats() {
586 DepthFirstGraphNodeIterator graph_it(graph_);
587 for (auto* node = graph_it.next(); node != nullptr;
588 node = graph_it.next()) {
589 if (node->kind() == aten::cat) {
590 handleConcat(node);
591 }
592 }
593 }
594
595 Node* createListConstruct(const std::deque<Value*>& inputs) {
596 auto* output = graph_->create(prim::ListConstruct);
597 for (auto* v : inputs) {
598 output->addInput(v);
599 }
600 return output;
601 }
602
603 using ListConstructInputs = std::shared_ptr<std::deque<Value*>>;
604 // Construct a map (concat node) -> (new list inputs for this node).
605 // std::deque is used so we can do O(1) insertions to the front.
606 std::unordered_map<Node*, ListConstructInputs> getListConstructInputs() {
607 std::unordered_map<Node*, ListConstructInputs> cur_list_construct_inputs;
608 for (const auto& combinable : combinable_concats_) {
609 // Combine the list inputs of first_concat with those of second_concat
610 const auto& inputs_to_add = getConcatInputs(combinable.second_concat);
611
612 auto it = cur_list_construct_inputs.find(combinable.first_concat);
613 std::shared_ptr<std::deque<Value*>> cur_list;
614 if (it != cur_list_construct_inputs.end()) {
615 cur_list = it->second;
616 // We're moving all inputs to second_concat.
617 cur_list_construct_inputs.erase(combinable.first_concat);
618 } else {
619 cur_list = std::make_shared<std::deque<Value*>>();
620 }
621 cur_list_construct_inputs.emplace(combinable.second_concat, cur_list);
622
623 // If cur_list is not empty, it's guaranteed to already contain all of
624 // first_concat's inputs.
625 if (cur_list->empty()) {
626 const auto& starting_values = getConcatInputs(combinable.first_concat);
627 cur_list->insert(
628 cur_list->end(), starting_values.begin(), starting_values.end());
629 }
630
631 cur_list->insert(
632 cur_list->begin(),
633 inputs_to_add.begin(),
634 inputs_to_add.begin() + combinable.idx);
635
636 cur_list->insert(
637 cur_list->end(),
638 inputs_to_add.begin() + combinable.idx + 1,
639 inputs_to_add.end());
640 }
641 return cur_list_construct_inputs;
642 }
643
644 bool combineConcats() {
645 if (combinable_concats_.empty()) {
646 return false;
647 }
648
649 auto list_construct_inputs = getListConstructInputs();
650
651 for (const auto& node_and_new_list : list_construct_inputs) {
652 auto* node = node_and_new_list.first;
653 auto& inputs = node_and_new_list.second;
654
655 auto* new_list_construct = createListConstruct(*inputs);
656 auto* old_list_construct = node->input(0)->node();
657 new_list_construct->output()->setType(
658 old_list_construct->output()->type());
659 new_list_construct->insertBefore(node);
660 old_list_construct->replaceAllUsesWith(new_list_construct);
661 }
662 return true;
663 }
664
665 // Represents an optimizable pair of concat nodes.
666 // - first_concat must appear before second_concat
667 // - idx is the index where first_concat's inputs must be inserted into
668 // second_concat's new inputs.
669 // Example:
670 // %inputs.1 = prim::ListConstruct(%0, %0)
671 // %concat.1 = aten::cat(%inputs.1, %dim)
672 // %inputs.2 = prim::ListConstruct(%1, %concat.1, %1)
673 // %concat.2 = aten::cat(%inputs.2, %dim)
674 // -> first_concat = &concat.1, second_concat = &concat.2, idx = 1
675 struct CombinableConcat {
676 CombinableConcat(Node* a, Node* b, size_t i)
677 : first_concat(a), second_concat(b), idx(i) {}
678
679 Node* first_concat;
680 Node* second_concat;
681 size_t idx;
682 };
683
684 std::vector<CombinableConcat> combinable_concats_;
685
686 std::shared_ptr<Graph> graph_;
687 AliasDb aliasDb_;
688};
689
690} // namespace
691
692bool CombineConcats(const std::shared_ptr<Graph>& graph) {
693 bool changed = ConcatCombiner(graph).run();
694 GRAPH_DUMP("After combining concats", graph);
695 return changed;
696}
697
698} // namespace jit
699} // namespace torch
700