1 | #include <torch/csrc/jit/passes/concat_opt.h> |
2 | |
3 | #include <algorithm> |
4 | #include <unordered_set> |
5 | #include <vector> |
6 | |
7 | #include <torch/csrc/jit/ir/alias_analysis.h> |
8 | #include <torch/csrc/jit/ir/ir.h> |
9 | #include <torch/csrc/jit/ir/named_value.h> |
10 | #include <torch/csrc/jit/jit_log.h> |
11 | #include <torch/csrc/jit/passes/constant_pooling.h> |
12 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
13 | #include <torch/csrc/jit/passes/remove_mutation.h> |
14 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | |
19 | namespace { |
20 | |
21 | void removeCatNodeFromGraph(Node* n) { |
22 | TORCH_INTERNAL_ASSERT(n->kind() == aten::cat); |
23 | auto inp_list = n->input(0); |
24 | GRAPH_UPDATE("Deleting\n" , *n); |
25 | n->destroy(); |
26 | if (!inp_list->hasUses()) { |
27 | GRAPH_UPDATE("Deleting\n" , *inp_list->node()); |
28 | inp_list->node()->destroy(); |
29 | } |
30 | } |
31 | |
32 | bool equal(at::ArrayRef<Value*> list1, at::ArrayRef<Value*> list2) { |
33 | return list1.size() == list2.size() && |
34 | std::equal(list1.begin(), list1.end(), list2.begin()); |
35 | } |
36 | |
37 | class ConcatCommonInputsEliminator { |
38 | public: |
39 | explicit ConcatCommonInputsEliminator(std::shared_ptr<Graph> graph) |
40 | : graph_(std::move(graph)) {} |
41 | |
42 | bool run() { |
43 | handleBlock(graph_->block()); |
44 | return postprocess(); |
45 | } |
46 | |
47 | private: |
48 | void handleBlock(Block* block) { |
49 | for (auto node : block->nodes()) { |
50 | if (node->kind() == prim::VarConcat) { |
51 | handleCat(node); |
52 | } |
53 | for (Block* block : node->blocks()) { |
54 | handleBlock(block); |
55 | } |
56 | } |
57 | } |
58 | |
59 | void handleCat(Node* node) { |
60 | GRAPH_DEBUG("Considering cat node for CSE opt: " , node); |
61 | |
62 | auto curr_all_inputs = node->inputs(); |
63 | auto curr_tensor_inputs = |
64 | curr_all_inputs.slice(0, curr_all_inputs.size() - 1); |
65 | auto curr_dim = curr_all_inputs.back(); |
66 | |
67 | // Save the input list and the current cat node, so that this can be |
68 | // used for subsequent cat nodes, unless there are writes to this cat |
69 | // node. When there are writes to this cat node, its output does not |
70 | // represent this concatenated list beyond the writes. Currently, we do |
71 | // not perform such fine-grained analysis. So, if there are any writes to |
72 | // the output, we do not use this cat node for optimization here. |
73 | if (!getOrCreateAliasDb()->hasWriters(node->output())) { |
74 | concated_outputs_.insert(node); |
75 | } |
76 | |
77 | if (curr_tensor_inputs.size() <= 2) { |
78 | // The case when concat has 2 input tensors could only be optimized if |
79 | // there is another concat of the exact same 2 input tensors. That case |
80 | // is expected to be handled by the CSE pass. |
81 | return; |
82 | } |
83 | |
84 | // Now, we check if the first N-1 elements in %inputs appeared in any of |
85 | // the previous cat ops. |
86 | // |
87 | // Example: |
88 | // %11 = prim::VarConcat(%0, %1, <dim>) |
89 | // ... |
90 | // %13 = prim::VarConcat(%0, %1, %2, <dim>) // first 2 inputs same as %11 |
91 | // ... |
92 | // = %13 ... // Use %13 |
93 | // |
94 | // After CSE opt: |
95 | // %11 = prim::VarConcat(%0, %1, <dim>) |
96 | // ... |
97 | // %14 = prim::VarConcat(%11, %2, <dim>) // Replace first 2 inputs |
98 | // // with %11 |
99 | // ... |
100 | // = %14 ... // Replace use of %13 with %14 |
101 | |
102 | auto curr_tensor_inputs_prefix = |
103 | curr_tensor_inputs.slice(0, curr_tensor_inputs.size() - 1); |
104 | for (const auto& prev : concated_outputs_) { |
105 | auto prev_all_inputs = prev->inputs(); |
106 | auto prev_tensor_inputs = |
107 | prev_all_inputs.slice(0, prev_all_inputs.size() - 1); |
108 | auto prev_dim = prev_all_inputs.back(); |
109 | if (equal(curr_tensor_inputs_prefix, prev_tensor_inputs) && |
110 | curr_dim == prev_dim) { |
111 | if (!node->isDominatedBy(prev)) { |
112 | // We can't use the previous concatenated output if it does not |
113 | // dominate the current concat node. |
114 | continue; |
115 | } |
116 | |
117 | std::vector<Value*> new_inputs = { |
118 | prev->output(), curr_tensor_inputs.back(), curr_dim}; |
119 | auto new_concat = |
120 | node->owningGraph()->create(prim::VarConcat, new_inputs); |
121 | new_concat->output()->setType(node->output()->type()); |
122 | concats_to_replace_[node] = new_concat; |
123 | return; |
124 | } |
125 | } |
126 | |
127 | // Now, we check if the last N-1 elements in %inputs appeared in any of |
128 | // the previous cat ops. |
129 | // |
130 | // Example: |
131 | // %10 = prim::ListConstruct(%1, %2) |
132 | // %11 = aten::cat(%10, ...) |
133 | // ... |
134 | // %12 = prim::ListConstruct(%0, %1, %2) // last 2 inputs same as %11 |
135 | // %13 = aten::cat(%12, ...) |
136 | // ... |
137 | // = %13 ... // Use %13 |
138 | // |
139 | // After CSE opt: |
140 | // %10 = prim::ListConstruct(%0, %1) |
141 | // %11 = aten::cat(%10, ...) |
142 | // ... |
143 | // %12 = prim::ListConstruct(%0, %11) // Replace last 2 inputs with %11 |
144 | // %13 = aten::cat(%12, ...) |
145 | // ... |
146 | // = %13 ... // Use %13 |
147 | auto curr_tensor_inputs_suffix = |
148 | curr_tensor_inputs.slice(1, curr_tensor_inputs.size() - 1); |
149 | for (const auto& prev : concated_outputs_) { |
150 | auto prev_all_inputs = prev->inputs(); |
151 | auto prev_tensor_inputs = |
152 | prev_all_inputs.slice(0, prev_all_inputs.size() - 1); |
153 | auto prev_dim = prev_all_inputs.back(); |
154 | if (equal(curr_tensor_inputs_suffix, prev_tensor_inputs) && |
155 | curr_dim == prev_dim) { |
156 | if (!node->isDominatedBy(prev)) { |
157 | // We can't use the previous concatenated list if it does not |
158 | // dominate the current list. |
159 | continue; |
160 | } |
161 | |
162 | std::vector<Value*> new_inputs = { |
163 | curr_tensor_inputs.front(), prev->output(), curr_dim}; |
164 | auto new_concat = |
165 | node->owningGraph()->create(prim::VarConcat, new_inputs); |
166 | new_concat->output()->setType(node->output()->type()); |
167 | concats_to_replace_[node] = new_concat; |
168 | return; |
169 | } |
170 | } |
171 | |
172 | // Do we need to handle other cases where N-2 or lesser elements from |
173 | // %inputs appear in any of the previous cat ops? |
174 | // TODO. |
175 | } |
176 | |
177 | bool postprocess() { |
178 | // Replace the list nodes that have been marked. |
179 | bool changed = false; |
180 | for (auto it : concats_to_replace_) { |
181 | auto curr_node = it.first; |
182 | auto new_node = it.second; |
183 | GRAPH_UPDATE("Inserting\n" , *new_node, "before\n" , *curr_node); |
184 | new_node->insertBefore(curr_node); |
185 | GRAPH_UPDATE("Replacing uses of\n" , *curr_node, "with\n" , *new_node); |
186 | curr_node->output()->replaceAllUsesWith(new_node->output()); |
187 | GRAPH_UPDATE("Deleting\n" , *curr_node); |
188 | curr_node->destroy(); |
189 | changed = true; |
190 | } |
191 | return changed; |
192 | } |
193 | |
194 | AliasDb* getOrCreateAliasDb() { |
195 | if (!aliasDb_) { |
196 | aliasDb_ = std::make_unique<AliasDb>(graph_); |
197 | } |
198 | return aliasDb_.get(); |
199 | } |
200 | |
201 | std::shared_ptr<Graph> graph_; |
202 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
203 | |
204 | std::unordered_set<Node*> concated_outputs_; |
205 | std::unordered_map<Node*, Node*> concats_to_replace_; |
206 | }; |
207 | |
208 | } // namespace |
209 | |
210 | bool EliminateConcatCommonInputs(const std::shared_ptr<Graph>& graph) { |
211 | GRAPH_DUMP("Before eliminating Concat common inputs" , graph); |
212 | bool changed = ConcatCommonInputsEliminator(graph).run(); |
213 | if (changed) { |
214 | GRAPH_DUMP("After eliminating Concat common inputs" , graph); |
215 | } |
216 | return changed; |
217 | } |
218 | |
219 | namespace { |
220 | |
221 | class ConcatExpander { |
222 | public: |
223 | explicit ConcatExpander(std::shared_ptr<Graph> graph) |
224 | : graph_(std::move(graph)) {} |
225 | |
226 | void run() { |
227 | handleBlock(graph_->block()); |
228 | cleanupExpandedCatOps(); |
229 | GRAPH_DUMP("Before reusing copy buffers: " , graph_); |
230 | reuseBuffersInCopies(); |
231 | } |
232 | |
233 | private: |
234 | void handleBlock(Block* block) { |
235 | for (auto node : block->nodes()) { |
236 | if (node->kind() == aten::cat) { |
237 | expandCat(node); |
238 | } |
239 | for (Block* block : node->blocks()) { |
240 | handleBlock(block); |
241 | } |
242 | } |
243 | } |
244 | |
245 | // Expand cat node into multiple copy nodes. |
246 | // |
247 | // Example: |
248 | // %2 = aten::clamp(%0, ...) |
249 | // %3 = aten::clamp(%1, ...) |
250 | // %10 = prim::ListConstruct(%2, %3) |
251 | // %11 = aten::cat(%10, ...) |
252 | // ... |
253 | // = %11 ... // Use %11 |
254 | // |
255 | // After expanding cat: |
256 | // %2 = aten::clamp(%0, ...) |
257 | // %3 = aten::clamp(%1, ...) |
258 | // %20 = aten::empty(...) // cat output buffer |
259 | // %21 = aten::slice(%20, ...) // slice for %2 |
260 | // %22 = aten::copy_(%21, %2) // copy %2 |
261 | // %23 = aten::slice(%20, ...) // slice for %3 |
262 | // %24 = aten::copy_(%23, %3) // copy %3 |
263 | // ... |
264 | // = %20 ... // Use %20 in place of %11 |
265 | void expandCat(Node* node) { |
266 | GRAPH_DEBUG("Considering cat node for expansion: " , node); |
267 | // Do not optimize cat nodes whose inputs are mutated in the graph. |
268 | // TODO: Improve this by checking if it is mutated in the graph region |
269 | // where this optimization is applied. |
270 | if (getOrCreateAliasDb()->hasWriters(node->input(0))) { |
271 | return; |
272 | } |
273 | if (node->input(0)->node()->kind() != prim::ListConstruct) { |
274 | // Unknown form of input to `cat` op. |
275 | return; |
276 | } |
277 | if (!allShapesAreKnown(node)) { |
278 | // Can't expand when shapes are not known for the `cat` op. |
279 | return; |
280 | } |
281 | for (auto cat_inp : node->input(0)->node()->inputs()) { |
282 | if (!shapeIsKnown(cat_inp)) { |
283 | // Can't expand when shapes of the inputs to `cat` are not known. |
284 | return; |
285 | } |
286 | } |
287 | // TODO: Handle non-contiguous Tensors. |
288 | // For example, how to handle the cases where the inputs are all channels |
289 | // last? |
290 | |
291 | auto maybe_cat_dim = constant_as<int64_t>(node->input(1)); |
292 | if (!maybe_cat_dim) { |
293 | // Can't expand when cat dimension is not a constant. |
294 | return; |
295 | } |
296 | auto cat_dim_value = maybe_cat_dim.value(); |
297 | auto cat_dim = node->input(1); |
298 | |
299 | // Set the insertion point to the curent `cat` node. |
300 | WithInsertPoint guard(node); |
301 | auto none = graph_->insertConstant(IValue()); |
302 | auto one = graph_->insertConstant(1); |
303 | |
304 | // Insert the constants needed for the `cat` output buffer size. |
305 | auto tensortype = node->output()->type()->expect<TensorType>(); |
306 | TORCH_INTERNAL_ASSERT(tensortype); |
307 | auto tensortype_sizes = tensortype->sizes(); |
308 | std::vector<Value*> cat_out_size; |
309 | for (size_t i = 0; i < tensortype_sizes.size(); ++i) { |
310 | cat_out_size.push_back(graph_->insertConstant(tensortype_sizes[i])); |
311 | } |
312 | |
313 | // Create a list of int for `cat` output buffer size. |
314 | auto cat_out_size_list = graph_->createList(IntType::get(), cat_out_size); |
315 | cat_out_size_list->insertBefore(node); |
316 | |
317 | // Create an empty buffer to be used as `cat` output buffer. |
318 | // TODO: Handle tensors with different dtype, layout, device, memory |
319 | // format, etc. |
320 | auto cat_out_empty = graph_->create( |
321 | aten::empty, |
322 | {cat_out_size_list->output(), none, none, none, none, none}); |
323 | cat_out_empty->insertBefore(node); |
324 | |
325 | // For every input to this `cat` node: |
326 | // * Create a slice of `cat` output buffer. |
327 | auto cat_out_value = cat_out_empty->output(); |
328 | auto cat_inp_list = node->input(0)->node(); |
329 | int start_idx = 0; |
330 | auto start = graph_->insertConstant(start_idx); |
331 | for (auto cat_inp : cat_inp_list->inputs()) { |
332 | // Create a slice of the cat output buffer that correspond to |
333 | // this input size and position in the output. |
334 | auto cat_inp_tensor_type = |
335 | dynamic_cast<TensorType*>(cat_inp->type().get()); |
336 | TORCH_INTERNAL_ASSERT(cat_inp_tensor_type); |
337 | TORCH_INTERNAL_ASSERT(cat_inp_tensor_type->dim()); |
338 | auto cat_inp_tensortype_sizes = cat_inp_tensor_type->sizes(); |
339 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
340 | int end_idx = start_idx + *cat_inp_tensortype_sizes[cat_dim_value]; |
341 | auto end = graph_->insertConstant(end_idx); |
342 | |
343 | auto slice = graph_->create( |
344 | aten::slice, {cat_out_value, cat_dim, start, end, one}); |
345 | GRAPH_UPDATE("Inserting\n" , *slice, "before\n" , *node); |
346 | slice->insertBefore(node); |
347 | slices_added_.push_back(slice); |
348 | |
349 | // Insert a copy from this input to the output slice. |
350 | auto copy = graph_->create(aten::copy_, {slice->output(), cat_inp}); |
351 | GRAPH_UPDATE("Inserting\n" , *copy, "before\n" , *node); |
352 | copy->insertBefore(node); |
353 | copies_added_.push_back(copy); |
354 | |
355 | start_idx = end_idx; |
356 | start = end; |
357 | } |
358 | |
359 | // Replace the uses of `cat` node with the cat output buffer. |
360 | replace_uses_with_[node->output()] = cat_out_value; |
361 | nodes_to_remove_.insert(node); |
362 | } |
363 | |
364 | bool shapeIsKnown(Value* v) { |
365 | if (v->type()->cast<TensorType>()) { |
366 | if (!v->isCompleteTensor()) { |
367 | return false; |
368 | } |
369 | if (*v->type()->castRaw<TensorType>()->dim() == 0) { |
370 | return false; |
371 | } |
372 | } |
373 | return true; |
374 | } |
375 | bool allShapesAreKnown(Node* node) { |
376 | // TODO: Relax the checks to support dynamic shapes |
377 | for (Value* input : node->inputs()) { |
378 | if (!shapeIsKnown(input)) { |
379 | return false; |
380 | } |
381 | } |
382 | for (Value* output : node->outputs()) { |
383 | if (!shapeIsKnown(output)) { |
384 | return false; |
385 | } |
386 | } |
387 | return true; |
388 | } |
389 | |
390 | void cleanupExpandedCatOps() { |
391 | for (auto it : replace_uses_with_) { |
392 | GRAPH_UPDATE( |
393 | "Replacing uses of\n" , |
394 | *it.first->node(), |
395 | "with\n" , |
396 | *it.second->node()); |
397 | it.first->replaceAllUsesWith(it.second); |
398 | } |
399 | for (auto n : nodes_to_remove_) { |
400 | removeCatNodeFromGraph(n); |
401 | } |
402 | } |
403 | |
404 | void moveBefore(Node* node, Node* before) { |
405 | // In order to move a node before another node, we need to move |
406 | // all the nodes it depends on as well. |
407 | for (auto inp : node->inputs()) { |
408 | moveBefore(inp->node(), before); |
409 | } |
410 | node->moveBefore(before); |
411 | } |
412 | |
413 | // Reuse buffers in copies wherever possible. |
414 | // |
415 | // For example, consider the following sequence of ops: |
416 | // %10 = prim::ListConstruct(%0, %1) |
417 | // %11 = aten::cat(%10, ...) |
418 | // ... |
419 | // %12 = prim::ListConstruct(%11, %2) // Uses the result of above cat |
420 | // %13 = aten::cat(%12, ...) |
421 | // |
422 | // Once these cat ops are expanded into copies, we will have two buffers; one |
423 | // for %11 and another for %13. This can be optimized by using only one |
424 | // buffer. We can only have the buffer that represents %13 and use a view |
425 | // (slice) of that one as the buffer for %11. |
426 | // |
427 | // If any of the copies added earlier has `aten::empty` as its source, |
428 | // those cases can be replaced with a single buffer. |
429 | // |
430 | // Example: |
431 | // %20 = aten::empty(...) // cat.1 output buffer |
432 | // %21 = aten::slice(%20, ...) |
433 | // %22 = aten::copy_(%21, %2) |
434 | // %23 = aten::slice(%20, ...) |
435 | // %24 = aten::copy_(%23, %3) |
436 | // ... |
437 | // %30 = aten::empty(...) // cat.2 output buffer |
438 | // %31 = aten::slice(%30, ...) |
439 | // %32 = aten::copy_(%31, %20) // src of copy is aten::empty |
440 | // // so, we reuse this buffer above |
441 | // %33 = aten::slice(%30, ...) |
442 | // %34 = aten::copy_(%33, %4) |
443 | // |
444 | // After reusing copy buffers: |
445 | // %30 = aten::empty(...) // cat.2 output buffer |
446 | // %31 = aten::slice(%30, ...) // move %31 and inputs before %20 |
447 | // %21 = aten::slice(%31, ...) // use %31 in place of %20 |
448 | // %22 = aten::copy_(%21, %2) |
449 | // %23 = aten::slice(%31, ...) // use %31 in place of %20 |
450 | // %24 = aten::copy_(%23, %3) |
451 | // ... |
452 | // ... // copy to %31 is now removed |
453 | // %33 = aten::slice(%30, ...) |
454 | // %34 = aten::copy_(%33, %4) |
455 | void reuseBuffersInCopies() { |
456 | for (auto copy : copies_added_) { |
457 | auto src = copy->input(1); |
458 | auto dst = copy->input(0); |
459 | if (src->node()->kind() != aten::empty) { |
460 | continue; |
461 | } |
462 | |
463 | // Move the destination node before the source. |
464 | GRAPH_UPDATE("Moving\n" , *dst->node(), "before\n" , *src->node()); |
465 | moveBefore(dst->node(), src->node()); |
466 | |
467 | GRAPH_UPDATE("Replacing\n" , *src->node(), "with\n" , *dst->node()); |
468 | src->replaceAllUsesWith(dst); |
469 | |
470 | GRAPH_UPDATE("Deleting\n" , *src->node()); |
471 | src->node()->destroy(); |
472 | |
473 | GRAPH_UPDATE("Deleting\n" , *copy); |
474 | copy->destroy(); |
475 | } |
476 | } |
477 | |
478 | AliasDb* getOrCreateAliasDb() { |
479 | if (!aliasDb_) { |
480 | aliasDb_ = std::make_unique<AliasDb>(graph_); |
481 | } |
482 | return aliasDb_.get(); |
483 | } |
484 | |
485 | std::shared_ptr<Graph> graph_; |
486 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
487 | |
488 | std::unordered_set<Node*> nodes_to_remove_; |
489 | std::unordered_map<Value*, Value*> replace_uses_with_; |
490 | std::vector<Node*> copies_added_; |
491 | std::vector<Node*> slices_added_; |
492 | }; |
493 | |
494 | } // namespace |
495 | |
496 | void ExpandConcatAndEliminateRedundancy(const std::shared_ptr<Graph>& graph) { |
497 | ConcatExpander(graph).run(); |
498 | GRAPH_DUMP("After expanding Concat and eliminating redundancy" , graph); |
499 | } |
500 | |
501 | namespace { |
502 | |
503 | size_t determineUsageIdx(Value* value, Node* user) { |
504 | const auto idx = |
505 | std::find(user->inputs().begin(), user->inputs().end(), value) - |
506 | user->inputs().begin(); |
507 | TORCH_CHECK(idx != user->inputs().size()); |
508 | return idx; |
509 | } |
510 | |
511 | std::vector<Value*> getConcatInputs(Node* concat) { |
512 | TORCH_CHECK(concat->kind() == aten::cat); |
513 | auto* list = concat->input(0); |
514 | auto* list_construct = list->node(); |
515 | TORCH_CHECK(list_construct->kind() == prim::ListConstruct); |
516 | return list_construct->inputs().vec(); |
517 | } |
518 | |
519 | class ConcatCombiner { |
520 | public: |
521 | explicit ConcatCombiner(std::shared_ptr<Graph> graph) |
522 | : graph_(std::move(graph)), aliasDb_(graph_) {} |
523 | |
524 | bool run() { |
525 | collectOptimizableConcats(); |
526 | bool changed = combineConcats(); |
527 | if (changed) { |
528 | EliminateDeadCode(graph_); |
529 | } |
530 | return changed; |
531 | } |
532 | |
533 | private: |
534 | // Given a concat node, see if it can be optimized with another. |
535 | // If so, add a CombinablePair to combinable_concats_. |
536 | void handleConcat(Node* node) { |
537 | auto* list = node->input(0); |
538 | auto* list_node = list->node(); |
539 | |
540 | const auto dim_opt = toIValue(node->input(1)); |
541 | // We need to be able to determine dim statically to match it with another |
542 | // concat. |
543 | if (!dim_opt || !dim_opt->isInt()) { |
544 | return; |
545 | } |
546 | const auto dim = dim_opt->toInt(); |
547 | |
548 | // Check that the input of this node is an unmodified list construct |
549 | if (list_node->kind() != prim::ListConstruct || |
550 | !aliasDb_.couldMoveBeforeTopologically(list_node, node)) { |
551 | return; |
552 | } |
553 | |
554 | // Check that the only output of this node is used in an unmodified list |
555 | // construct. |
556 | const auto& concat_uses = node->output()->uses(); |
557 | if (concat_uses.size() != 1) { |
558 | return; |
559 | } |
560 | |
561 | auto* next_list = concat_uses[0].user; |
562 | if (next_list->kind() != prim::ListConstruct) { |
563 | return; |
564 | } |
565 | |
566 | const auto& next_list_uses = next_list->output()->uses(); |
567 | if (next_list_uses.size() != 1) { |
568 | return; |
569 | } |
570 | |
571 | auto* next_concat = next_list_uses[0].user; |
572 | |
573 | if (next_concat->kind() == aten::cat) { |
574 | // Dimension must be determined statically and match the one we've already |
575 | // seen. |
576 | const auto next_dim_opt = toIValue(next_concat->input(1)); |
577 | if (!next_dim_opt || next_dim_opt->toInt() != dim) { |
578 | return; |
579 | } |
580 | combinable_concats_.emplace_back( |
581 | node, next_concat, determineUsageIdx(node->output(), next_list)); |
582 | } |
583 | } |
584 | |
585 | void collectOptimizableConcats() { |
586 | DepthFirstGraphNodeIterator graph_it(graph_); |
587 | for (auto* node = graph_it.next(); node != nullptr; |
588 | node = graph_it.next()) { |
589 | if (node->kind() == aten::cat) { |
590 | handleConcat(node); |
591 | } |
592 | } |
593 | } |
594 | |
595 | Node* createListConstruct(const std::deque<Value*>& inputs) { |
596 | auto* output = graph_->create(prim::ListConstruct); |
597 | for (auto* v : inputs) { |
598 | output->addInput(v); |
599 | } |
600 | return output; |
601 | } |
602 | |
603 | using ListConstructInputs = std::shared_ptr<std::deque<Value*>>; |
604 | // Construct a map (concat node) -> (new list inputs for this node). |
605 | // std::deque is used so we can do O(1) insertions to the front. |
606 | std::unordered_map<Node*, ListConstructInputs> getListConstructInputs() { |
607 | std::unordered_map<Node*, ListConstructInputs> cur_list_construct_inputs; |
608 | for (const auto& combinable : combinable_concats_) { |
609 | // Combine the list inputs of first_concat with those of second_concat |
610 | const auto& inputs_to_add = getConcatInputs(combinable.second_concat); |
611 | |
612 | auto it = cur_list_construct_inputs.find(combinable.first_concat); |
613 | std::shared_ptr<std::deque<Value*>> cur_list; |
614 | if (it != cur_list_construct_inputs.end()) { |
615 | cur_list = it->second; |
616 | // We're moving all inputs to second_concat. |
617 | cur_list_construct_inputs.erase(combinable.first_concat); |
618 | } else { |
619 | cur_list = std::make_shared<std::deque<Value*>>(); |
620 | } |
621 | cur_list_construct_inputs.emplace(combinable.second_concat, cur_list); |
622 | |
623 | // If cur_list is not empty, it's guaranteed to already contain all of |
624 | // first_concat's inputs. |
625 | if (cur_list->empty()) { |
626 | const auto& starting_values = getConcatInputs(combinable.first_concat); |
627 | cur_list->insert( |
628 | cur_list->end(), starting_values.begin(), starting_values.end()); |
629 | } |
630 | |
631 | cur_list->insert( |
632 | cur_list->begin(), |
633 | inputs_to_add.begin(), |
634 | inputs_to_add.begin() + combinable.idx); |
635 | |
636 | cur_list->insert( |
637 | cur_list->end(), |
638 | inputs_to_add.begin() + combinable.idx + 1, |
639 | inputs_to_add.end()); |
640 | } |
641 | return cur_list_construct_inputs; |
642 | } |
643 | |
644 | bool combineConcats() { |
645 | if (combinable_concats_.empty()) { |
646 | return false; |
647 | } |
648 | |
649 | auto list_construct_inputs = getListConstructInputs(); |
650 | |
651 | for (const auto& node_and_new_list : list_construct_inputs) { |
652 | auto* node = node_and_new_list.first; |
653 | auto& inputs = node_and_new_list.second; |
654 | |
655 | auto* new_list_construct = createListConstruct(*inputs); |
656 | auto* old_list_construct = node->input(0)->node(); |
657 | new_list_construct->output()->setType( |
658 | old_list_construct->output()->type()); |
659 | new_list_construct->insertBefore(node); |
660 | old_list_construct->replaceAllUsesWith(new_list_construct); |
661 | } |
662 | return true; |
663 | } |
664 | |
665 | // Represents an optimizable pair of concat nodes. |
666 | // - first_concat must appear before second_concat |
667 | // - idx is the index where first_concat's inputs must be inserted into |
668 | // second_concat's new inputs. |
669 | // Example: |
670 | // %inputs.1 = prim::ListConstruct(%0, %0) |
671 | // %concat.1 = aten::cat(%inputs.1, %dim) |
672 | // %inputs.2 = prim::ListConstruct(%1, %concat.1, %1) |
673 | // %concat.2 = aten::cat(%inputs.2, %dim) |
674 | // -> first_concat = &concat.1, second_concat = &concat.2, idx = 1 |
675 | struct CombinableConcat { |
676 | CombinableConcat(Node* a, Node* b, size_t i) |
677 | : first_concat(a), second_concat(b), idx(i) {} |
678 | |
679 | Node* first_concat; |
680 | Node* second_concat; |
681 | size_t idx; |
682 | }; |
683 | |
684 | std::vector<CombinableConcat> combinable_concats_; |
685 | |
686 | std::shared_ptr<Graph> graph_; |
687 | AliasDb aliasDb_; |
688 | }; |
689 | |
690 | } // namespace |
691 | |
692 | bool CombineConcats(const std::shared_ptr<Graph>& graph) { |
693 | bool changed = ConcatCombiner(graph).run(); |
694 | GRAPH_DUMP("After combining concats" , graph); |
695 | return changed; |
696 | } |
697 | |
698 | } // namespace jit |
699 | } // namespace torch |
700 | |