1#include <torch/csrc/jit/passes/graph_fuser.h>
2
3#include <c10/util/Exception.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/jit/codegen/fuser/interface.h>
6#include <torch/csrc/jit/frontend/ir_emitter.h>
7#include <torch/csrc/jit/ir/alias_analysis.h>
8#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
9#include <torch/csrc/jit/passes/constant_pooling.h>
10#include <torch/csrc/jit/passes/dead_code_elimination.h>
11#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
12#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
13#include <torch/csrc/jit/runtime/autodiff.h>
14#include <torch/csrc/jit/runtime/custom_operator.h>
15#include <torch/csrc/jit/runtime/operator.h>
16
17#include <queue>
18#include <unordered_map>
19#include <utility>
20
21namespace torch {
22namespace jit {
23
24namespace {
25
26// What is a simple mappable operator? It:
27// - Has a single tensor output
28// - Output and all tensor inputs have the same shape
29// - Output and all tensor inputs have the same scalar type
30// or all tensor inputs have the same scalar type and
31// output is identified in PropagateInputShapes
32// - Output and all tensor inputs should be on the same device
33// - Produces dense non-overlapping outputs
34// Some of these restrictions may be relaxable, but you should
35// carefully read the code first, as we rely on these assumptions.
36bool isSimpleMap(Node* node) {
37 static OperatorSet simple_mappable{{
38 "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
39
40 "aten::abs(Tensor self) -> Tensor",
41 "aten::acos(Tensor self) -> Tensor",
42 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
43 "aten::asin(Tensor self) -> Tensor",
44 "aten::atan(Tensor self) -> Tensor",
45 "aten::atan2(Tensor self, Tensor other) -> Tensor",
46 "aten::ceil(Tensor self) -> Tensor",
47 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
48 "aten::cos(Tensor self) -> Tensor",
49 "aten::cosh(Tensor self) -> Tensor",
50 "aten::div(Tensor self, Tensor other) -> Tensor",
51 "aten::exp(Tensor self) -> Tensor",
52 "aten::expm1(Tensor self) -> Tensor",
53 "aten::erf(Tensor self) -> Tensor",
54 "aten::erfc(Tensor self) -> Tensor",
55 "aten::floor(Tensor self) -> Tensor",
56 "aten::fmod(Tensor self, Tensor other) -> Tensor",
57 "aten::frac(Tensor self) -> Tensor",
58 "aten::lgamma(Tensor self) -> Tensor",
59 "aten::log(Tensor self) -> Tensor",
60 "aten::log10(Tensor self) -> Tensor",
61 "aten::log1p(Tensor self) -> Tensor",
62 "aten::log2(Tensor self) -> Tensor",
63 "aten::logit(Tensor self, float? eps=None) -> Tensor",
64 "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
65 "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
66 "aten::max(Tensor self, Tensor other) -> Tensor",
67 "aten::min(Tensor self, Tensor other) -> Tensor",
68 "aten::mul(Tensor self, Tensor other) -> Tensor",
69 "aten::neg(Tensor self) -> Tensor",
70 "aten::pow(Tensor self, Tensor exponent) -> Tensor",
71 "aten::pow(Tensor self, Scalar exponent) -> Tensor",
72 "aten::pow(Scalar self, Tensor exponent) -> Tensor",
73 "aten::reciprocal(Tensor self) -> Tensor",
74 "aten::relu(Tensor self) -> Tensor",
75 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
76 "aten::remainder(Tensor self, Tensor other) -> Tensor",
77 "aten::round(Tensor self) -> Tensor",
78 "aten::rsqrt(Tensor self) -> Tensor",
79 "aten::sigmoid(Tensor self) -> Tensor",
80 "aten::sin(Tensor self) -> Tensor",
81 "aten::sinh(Tensor self) -> Tensor",
82 "aten::sqrt(Tensor self) -> Tensor",
83 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
84 "aten::tan(Tensor self) -> Tensor",
85 "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
86 "aten::tanh(Tensor self) -> Tensor",
87 "aten::trunc(Tensor self) -> Tensor",
88 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
89 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
90 "aten::mul(Tensor self, Scalar other) -> Tensor",
91 "aten::div(Tensor self, Scalar other) -> Tensor",
92
93 "aten::eq(Tensor self, Tensor other) -> Tensor",
94 "aten::eq(Tensor self, Scalar other) -> Tensor",
95 "aten::ne(Tensor self, Tensor other) -> Tensor",
96 "aten::ne(Tensor self, Scalar other) -> Tensor",
97 "aten::ge(Tensor self, Tensor other) -> Tensor",
98 "aten::ge(Tensor self, Scalar other) -> Tensor",
99 "aten::gt(Tensor self, Tensor other) -> Tensor",
100 "aten::gt(Tensor self, Scalar other) -> Tensor",
101 "aten::le(Tensor self, Tensor other) -> Tensor",
102 "aten::le(Tensor self, Scalar other) -> Tensor",
103 "aten::lt(Tensor self, Tensor other) -> Tensor",
104 "aten::lt(Tensor self, Scalar other) -> Tensor",
105
106 "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor",
107 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
108
109 "aten::type_as(Tensor self, Tensor other) -> Tensor",
110 }};
111 if (!node->isMemberOf(simple_mappable)) {
112 return false;
113 }
114 for (Value* input : node->inputs()) {
115 if (input->type()->isSubtypeOf(*TensorType::get()) ||
116 input->type()->isSubtypeOf(*FloatType::get())) {
117 continue;
118 }
119 if (input->node()->kind() != prim::Constant) {
120 return false;
121 }
122 }
123 return true;
124}
125
126struct GraphFuser {
127 using FusionCallback = std::function<bool(GraphFuser*, Node*)>;
128
129 Block* block_;
130 AliasDb* aliasDb_;
131 std::shared_ptr<Graph> graph_;
132 FusionCallback callback_ = [](GraphFuser* gf, Node* n) {
133 return gf->isFusableDefault(n, gf->strict_fuser_check_);
134 };
135 Symbol kind_ = prim::FusionGroup;
136 bool strict_fuser_check_ = false;
137
138 // nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
139 // The specific limit is a function of constant memory size, amount available
140 // to pass arguments, and some implementation dependence. Select a safe
141 // limit here.
142 // This limit is also applied to other devices in the fuser by default.
143 // Change with setInputArgLimit
144 size_t subgraph_arg_limit_ = 128;
145
146 GraphFuser(AliasDb* aliasDb, Block* block, bool strict_fuser_check)
147 : block_(block),
148 aliasDb_(aliasDb),
149 strict_fuser_check_(strict_fuser_check) {}
150
151 // Custom passes require kind to specified
152 GraphFuser(
153 AliasDb* aliasDb,
154 Block* block,
155 FusionCallback callback,
156 Symbol kind,
157 bool strict_fuser_check = false)
158 : block_(block),
159 aliasDb_(aliasDb),
160 callback_(std::move(callback)),
161 kind_(kind),
162 strict_fuser_check_(strict_fuser_check) {}
163
164 void setInputArgLimit(size_t limit) {
165 subgraph_arg_limit_ = limit;
166 }
167
168 value_list tensorInputs(Node* node) {
169 return filter(node->inputs(), [](Value* v) {
170 return v->type()->isSubtypeOf(*TensorType::get());
171 });
172 }
173
174 bool isFusable(Node* node) {
175 return callback_(this, node);
176 }
177
178 bool isFusableDevice(Value* v, bool strict_fuser_check) {
179 if (!v->type()->isSubtypeOf(*TensorType::get())) {
180 return true;
181 }
182 auto device = v->type()->expectRef<TensorType>().device();
183 if (!device) {
184 return !strict_fuser_check;
185 }
186 if ((*device).is_cpu()) {
187 return canFuseOnCPULegacy();
188 } else if ((*device).is_cuda()) {
189 return canFuseOnGPU();
190 } else if ((*device).is_xpu()) {
191 return false;
192 } else {
193 TORCH_CHECK_NOT_IMPLEMENTED(false, "Unknown device for graph fuser");
194 }
195 }
196
197 // Default fusability check - used when the user doesn't pass in
198 // a callback.
199 bool isFusableDefault(Node* node, bool strict_fuser_check) {
200 bool fusableDevice = true;
201 for (const auto& output : node->outputs()) {
202 if (!output->uses().empty()) {
203 fusableDevice &= isFusableDevice(output, strict_fuser_check);
204 }
205 }
206 return fusableDevice && isFusableMap(node);
207 }
208
209 bool isFusableMap(Node* node) {
210 // We don't want to bother with cross-block node movements, as they
211 // are not necessarily correct.
212 if (node->owningBlock() != block_)
213 return false;
214 return node->kind() == prim::FusionGroup || isSimpleMap(node);
215 }
216
217 bool isFusableCatNode(Node* node) {
218 if (node->kind() != aten::cat)
219 return false;
220 if (!node->is_constant(attr::dim))
221 return false;
222
223 auto tensors_node = node->namedInput(attr::tensors)->node();
224 if ((tensors_node->inputs().size() + node->outputs().size()) >
225 subgraph_arg_limit_) {
226 return false;
227 }
228 if (tensors_node->kind() != prim::ListConstruct)
229 return false;
230 // NB: Note that technically other uses of the list aren't a big problem for
231 // us. It would be enough to place the prim::FusedConcat before the
232 // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would
233 // still be satisfied. However, I don't expect this to be necessary any time
234 // soon, and so we're simply assuming that we don't have to deal with it.
235 if (tensors_node->output()->uses().size() > 1)
236 return false;
237 return true;
238 }
239
240 bool calculatesSize(Node* node) {
241 return node->matches("aten::size(Tensor self) -> int[]");
242 }
243
244 bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
245 auto defining_node = producer->node();
246 for (auto o : defining_node->outputs()) {
247 for (auto u : o->uses()) {
248 if (u.user != consumer && !calculatesSize(u.user))
249 return false;
250 }
251 }
252 return true;
253 }
254
255 Graph& getSubgraph(Node* n) {
256 AT_ASSERT(n->kind() == kind_);
257 return *n->g(attr::Subgraph);
258 }
259
260 void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
261 // Now we have two fusion groups!
262 // Revert the fusion - place all inner nodes of producer back in the outer
263 // graph.
264 std::vector<Node*> temporary_nodes;
265 auto producer_subgraph = &getSubgraph(producer_group);
266
267 // Initialize a map of inner graph values to outer graph values
268 std::unordered_map<Value*, Value*> inner_to_outer;
269 auto inner_inputs = producer_subgraph->inputs();
270 auto outer_inputs = producer_group->inputs();
271 for (const auto i : c10::irange(inner_inputs.size())) {
272 inner_to_outer[inner_inputs[i]] = outer_inputs[i];
273 }
274
275 // Clone all nodes
276 for (auto inner : producer_subgraph->nodes()) {
277 Node* outer = block_->owningGraph()->createClone(
278 inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); });
279 outer->insertBefore(producer_group);
280 temporary_nodes.emplace_back(outer);
281 auto inner_outputs = inner->outputs();
282 auto outer_outputs = outer->outputs();
283 for (const auto i : c10::irange(inner_outputs.size())) {
284 inner_to_outer[inner_outputs[i]] = outer_outputs[i];
285 }
286 }
287
288 // Replace uses of producer_group outputs and destroy the producer
289 auto subgraph_outputs = producer_subgraph->outputs();
290 for (const auto i : c10::irange(subgraph_outputs.size())) {
291 auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
292 producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
293 // new producer outputs have same aliasing properties as outer_output
294 aliasDb_->replaceWithNewValue(producer_group->outputs()[i], outer_output);
295 }
296 producer_group->destroy();
297 producer_group =
298 nullptr; // Just to get a clear error in case someone uses it
299
300 // Inline the temporary nodes into the first group
301 auto consumer_subgraph = &getSubgraph(consumer_group);
302 for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
303 ++it) {
304 Node* node = *it;
305 Node* merged = mergeNodeIntoGroup(consumer_group, node);
306 // If any of the outputs are still used then we need to add them
307 auto outputs = node->outputs();
308 for (const auto i : c10::irange(outputs.size())) {
309 auto output = outputs[i];
310 if (output->uses().empty())
311 continue;
312 consumer_subgraph->registerOutput(merged->outputs()[i]);
313 auto new_output = consumer_group->addOutput();
314 output->replaceAllUsesWith(new_output);
315 aliasDb_->replaceWithNewValue(output, new_output);
316 new_output->setType(output->type());
317 }
318 node->destroy();
319 }
320 }
321
322 // insert a producer node into a consuming fusion group.
323 // DOES NOT WORK if n is a consumer of an output of the fusion group
324 // returns the node _inside_ the group that represents the node
325 Node* mergeNodeIntoGroup(Node* group, Node* n) {
326 AT_ASSERT(n->kind() != kind_);
327 auto& subgraph = getSubgraph(group);
328 // map from nodes in the surrounding graph to parameters in the fusion
329 // group's subgraph that correspond to them
330 std::unordered_map<Value*, Value*> inputs_map;
331 size_t i = 0;
332 size_t tensor_insert_idx = 0;
333 AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
334 for (auto input : group->inputs()) {
335 inputs_map[input] = subgraph.inputs()[i++];
336 if (input->type()->isSubtypeOf(*TensorType::get()))
337 tensor_insert_idx = i;
338 }
339 // add n's inputs to the fusion group's input list if we don't already have
340 // them
341 // we insert tensors first because the fuser assumes that to be the case
342 // (as a legacy from tensors only)
343 WithInsertPoint guard(*subgraph.nodes().begin());
344 for (auto input : n->inputs()) {
345 if (inputs_map.count(input) == 0) {
346 if (input->type()->isSubtypeOf(*TensorType::get())) {
347 auto in_group = subgraph.insertInput(tensor_insert_idx);
348 in_group->setType(input->type());
349 inputs_map[input] = in_group;
350 group->insertInput(tensor_insert_idx, input);
351 tensor_insert_idx++;
352 } else if (
353 (input->type()->isSubtypeOf(*FloatType::get()) &&
354 input->node()->kind() != prim::Constant) ||
355 (n->kind() == aten::_grad_sum_to_size &&
356 input->type()->isSubtypeOf(*ListType::ofInts()))) {
357 auto in_group = subgraph.addInput();
358 in_group->setType(input->type());
359 inputs_map[input] = in_group;
360 group->addInput(input);
361 } else {
362 // We don't support passing in scalars as arguments to fused kernels,
363 // so we generally don't allow fusing tensor-scalar operations unless
364 // the scalar is constant. In those cases we inline the constants
365 // directly in the body of the fused group.
366 AT_ASSERT(input->node()->kind() == prim::Constant);
367 Node* in_const =
368 subgraph.createClone(input->node(), [](Value*) -> Value* {
369 throw std::runtime_error("unexpected input");
370 });
371 subgraph.insertNode(in_const);
372 inputs_map[input] = in_const->output();
373 }
374 }
375 }
376 // copy n into the graph, remapping its inputs to internal nodes
377 Node* in_graph = subgraph.createClone(
378 n, [&](Value* k) -> Value* { return inputs_map[k]; });
379 // if n's outputs are already inputs to the fusion group,
380 // we need to remove them because n is now inside the fusion group.
381 //
382 // i.e.,
383 // x = f(w); group(x, y, z) becomes group(w, y, z).
384 // x, y, z = f(w); group(x, y, z) becomes group(w).
385 //
386 // remapping nodes that used the input to the newly-merged node
387 // n is not an input when the fusion group is empty
388 auto inputs = group->inputs();
389 for (size_t i = 0; i < n->outputs().size(); ++i) {
390 auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
391 if (it != inputs.end()) {
392 size_t p = it - inputs.begin();
393 group->removeInput(p);
394 subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
395 subgraph.eraseInput(p);
396 }
397 }
398 return subgraph.insertNode(in_graph);
399 }
400
401 // turn consumer node n into a fusion group with just n inside
402 // to prepare for fusion and replace uses of n with the new group
403 Node* createSingletonFusionGroup(Node* n) {
404 auto group = block_->owningGraph()->createWithSubgraph(kind_);
405 // propogate position information for the new node so we can always
406 // have a valid mapping
407 group->insertBefore(n);
408 Node* mergedNode = mergeNodeIntoGroup(group, n);
409 getSubgraph(group).registerOutput(mergedNode->output());
410 auto sel = group->addOutput();
411 sel->copyMetadata(n->output());
412 aliasDb_->replaceWithNewValue(n->output(), sel);
413 n->replaceAllUsesWith(group);
414 n->destroy();
415 return group;
416 }
417
418 at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
419 // this handles cases where producer can be moved _into_ the fusion group of
420 // consumer.
421 // TODO: extend to fusion of consumer into _producer's_ fusion blob
422 // if the consumer allInputsAreThisProducer(consumer,producer)
423 // we can move the consumer up into the producer.
424 // but this requires better handling of merging fusion groups so it is not
425 // done now
426 bool shouldFuse = isFusable(producer->node()) &&
427 // Rearrange nodes such that all uses of producer are after the
428 // consumer. Fusion will rewrite those later uses to use the version of
429 // producer generated by the fused blob. In this case, producer becomes
430 // an output of the fusion group.
431 aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
432
433 if (!shouldFuse) {
434 return at::nullopt;
435 }
436
437 if ((consumer->inputs().size() + consumer->outputs().size() +
438 producer->node()->inputs().size() +
439 producer->node()->outputs().size()) > subgraph_arg_limit_) {
440 return at::nullopt;
441 }
442
443 auto group = consumer;
444 if (consumer->kind() != kind_) {
445 group = createSingletonFusionGroup(consumer);
446 }
447
448 if (producer->node()->kind() == kind_) {
449 mergeFusionGroups(group, producer->node());
450 return group;
451 }
452 AT_ASSERT(producer->node()->outputs().size() == 1);
453 Node* merged = mergeNodeIntoGroup(group, producer->node());
454 // remaining uses of this producer can occur because we allow
455 // fusion in cases where uses remain after the consumer
456 // if these exist, re-route them to the version of producer
457 // created in FusionGroup
458 if (!producer->uses().empty()) {
459 getSubgraph(group).registerOutput(merged->output());
460 Value* new_producer = group->addOutput();
461 new_producer->copyMetadata(producer);
462 aliasDb_->replaceWithNewValue(producer, new_producer);
463 producer->replaceAllUsesWith(new_producer);
464 }
465 producer->node()->destroy();
466 return group;
467 }
468
469 bool canFuseChunk(Node* consumer, Value* producer) {
470 if (consumer->kind() != prim::FusionGroup) {
471 return false;
472 }
473 // Does the chunk have constant chunks/dim?
474 auto* chunk = producer->node();
475 if (chunk->kind() != prim::ConstantChunk)
476 return false;
477 // And all uses of the chunk are in this consumer
478 for (auto s : chunk->outputs()) {
479 for (auto u : s->uses()) {
480 if (u.user != consumer) {
481 return false;
482 }
483 }
484 }
485 // And isn't a no-op chunk (chunks == 1). Have CSE clean this up.
486 // We could fuse this but it's better to just delete the node.
487 if (chunk->i(attr::chunks) == 1) {
488 return false;
489 }
490 return true;
491 }
492
493 c10::optional<Node*> findFusedChunk(Node* group, Value* input) {
494 AT_ASSERT(group->kind() == prim::FusionGroup);
495 auto it = std::find(group->inputs().begin(), group->inputs().end(), input);
496 if (it == group->inputs().end()) {
497 return c10::nullopt;
498 }
499 size_t input_index = it - group->inputs().begin();
500 auto& subgraph = getSubgraph(group);
501 auto* subgraph_input = subgraph.inputs().at(input_index);
502 // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use
503 auto* node = subgraph_input->uses().at(0).user;
504 if (node->kind() == prim::ConstantChunk) {
505 AT_ASSERT(subgraph_input->uses().size() == 1);
506 return node;
507 }
508 return c10::nullopt;
509 }
510
511 void fuseChunkByReusingExistingFusedChunk(
512 Node* group,
513 Node* chunk,
514 Node* existingFusedChunk) {
515 if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
516 return;
517 }
518 auto& subgraph = getSubgraph(group);
519 for (size_t i = 0; i < chunk->outputs().size(); ++i) {
520 // Find the input to the FusionGroup (group)
521 auto* replacement_val = existingFusedChunk->outputs().at(i);
522 auto* val = chunk->outputs().at(i);
523 auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
524 auto input_index = it - group->inputs().begin();
525
526 // Rewrite the graph to use replacement_val
527 auto group_input = subgraph.inputs().at(input_index);
528 group_input->replaceAllUsesWith(replacement_val);
529
530 // Remove the input, it's no longer needed
531 group->removeInput(input_index);
532 subgraph.eraseInput(input_index);
533 }
534 chunk->destroy();
535 }
536
537 // There are two invariants for prim::ConstantChunk:
538 // (1) the tensor input to prim::ConstantChunk must be an input to the fusion
539 // group (2) no two ConstantChunks in the same FusionGroup can share a tensor
540 // input.
541 graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) {
542 auto* chunk = producer->node();
543 AT_ASSERT(consumer->kind() == prim::FusionGroup);
544 AT_ASSERT(chunk->kind() == prim::ConstantChunk);
545
546 // if producer's input is already an input to a prim::ConstantChunk node,
547 // we cannot add a new prim::ConstantChunk node because of invariant (2).
548 auto* chunked_tensor = producer->node()->input();
549 if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) {
550 fuseChunkByReusingExistingFusedChunk(
551 consumer, chunk, *existingFusedChunk);
552 return consumer->reverseIterator();
553 }
554
555 // Move prim::ConstantChunk into the FusionGroup
556 mergeNodeIntoGroup(consumer, chunk);
557 chunk->destroy();
558 return consumer->reverseIterator();
559 }
560
561 value_list sortReverseTopological(ArrayRef<Value*> inputs) {
562 value_list result;
563 for (auto i : inputs) {
564 if (i->node()->owningBlock() == block_) {
565 result.push_back(i);
566 }
567 }
568 // Sort in reverse topological order
569 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
570 return a->node()->isAfter(b->node());
571 });
572 return result;
573 }
574
575 graph_node_list::iterator scanNodeForChunks(Node* consumer) {
576 if (consumer->kind() == prim::FusionGroup) {
577 auto inputs = sortReverseTopological(consumer->inputs());
578 for (auto producer : inputs) {
579 if (!canFuseChunk(consumer, producer)) {
580 continue;
581 }
582 return fuseChunk(consumer, producer);
583 }
584 }
585 return ++consumer->reverseIterator();
586 }
587
588 at::ArrayRef<Value*> broadcast_tensors(value_list inputs) {
589 AT_ASSERT(!inputs.empty());
590 auto* g = inputs[0]->owningGraph();
591 auto* input_list =
592 g->insertNode(g->createList(TensorType::get(), inputs))->output();
593 aliasDb_->createValue(input_list);
594 auto* output_list = g->insert(aten::broadcast_tensors, {input_list});
595 aliasDb_->createValue(output_list);
596 auto* unpack_node = g->insertNode(
597 g->create(prim::ListUnpack, {output_list}, inputs.size()));
598
599 // We are doing:
600 // input_list = listConstruct(a, b, ...)
601 // output_list = broadcast_tensors(input_list)
602 // a_broadcasted, b_broadcasted = listUnpack(output_list)
603 // `a_broadcasted` should receive the same aliasing info as `a`
604 TORCH_INTERNAL_ASSERT(unpack_node->outputs().size() == inputs.size());
605 for (const auto i : c10::irange(inputs.size())) {
606 Value* original_input = inputs[i];
607 Value* broadcasted_output = unpack_node->outputs()[i];
608 aliasDb_->copyValue(original_input, broadcasted_output);
609 }
610
611 return unpack_node->outputs();
612 }
613
614 void insertExplicitBroadcast(Node* node) {
615 WithInsertPoint insert_guard{node};
616 auto tensors = tensorInputs(node);
617 auto new_tensors = broadcast_tensors(std::move(tensors));
618
619 // Replace tensors inputs with broadcasted values
620 auto new_tensors_it = new_tensors.begin();
621 for (size_t i = 0; i < node->inputs().size(); ++i) {
622 if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) {
623 AT_ASSERT(new_tensors_it != new_tensors.end());
624 node->replaceInput(i, *(new_tensors_it++));
625 }
626 }
627 }
628
629 Node* promoteChunkToBroadcastingChunk(Node* chunk) {
630 AT_ASSERT(chunk->kind() == prim::ConstantChunk);
631
632 size_t nchunks = chunk->i(attr::chunks);
633 Node* bchunk =
634 chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
635 bchunk->addInput(chunk->input());
636 for (const auto i : c10::irange(nchunks)) {
637 auto* old_output = chunk->outputs().at(i);
638 auto* new_output = bchunk->outputs().at(i);
639 new_output->copyMetadata(old_output);
640 aliasDb_->replaceWithNewValue(old_output, new_output);
641 old_output->replaceAllUsesWith(new_output);
642 }
643 bchunk->copyAttributes(*chunk);
644 bchunk->insertAfter(chunk);
645 chunk->destroy();
646 return bchunk;
647 }
648
649 // in places where op can be fused into a consumer but chunk is in the way
650 // distribute chunk to op's operands:
651 // replace a,b = chunk(op(x,y,z)) with:
652 // x', y', z' = broadcast_tensors([x, y, z])
653 // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type)
654 // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type)
655 // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type)
656 // a = op(x0,y0,z0) (a,b have their same size but are now contiguous)
657 // b = op(x1,y1,x1)
658 //
659 // The graph fuser uses an intermediate prim::BroadcastingChunk node to
660 // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts
661 // all of its inputs and then chunks each input, in order, the same way.
662 // The above graph is equivalent to:
663 // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z)
664 // a = op(x0,y0,z0)
665 // b = op(x1,y1,x1)
666 //
667 // NB: The explicit broadcast is important for correctness.
668 // Let's say we have:
669 // %z = aten::mul(%x, %y)
670 // %z.1, %z.2 = aten::chunk(%z, ...)
671 // ... = prim::FusionGroup(%z.1, %z.2, ...)
672 // It's possible that %x and %y do not have the same size as %z and
673 // need to be expanded first so that they can be chunked like %z
674 //
675 // NB: Chunk motion only occurs with fusable consumers, which implies
676 // that there is always some other operation, e.g., a+b, that happens
677 // after the chunk, and will be put into the fusion group. This is
678 // important, because distributing the chunk changes the contiguity
679 // of a and b, and so the results would be invalid, except that we know
680 // that simple_mappable operations will restore contiguity before
681 // we exit the fusion group.
682 //
683 // NB: The intermediate BroadcastingChunk is important for moving chunks past
684 // more than one operation: the graph fuser is not able to easily move
685 // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible
686 // ops
687 // x = f(v, w)
688 // z = g(x, y)
689 // a, b = chunk(z)
690 // c = h(a, b)
691 // becomes (with the broadcast_tensors + chunk approach):
692 // x = f(v, w)
693 // x', y' = broadcast_tensors([x, y])
694 // ax, bx = chunk(x')
695 // ay, by = chunk(y')
696 // a = g(ax, ay)
697 // b = g(bx, by)
698 // c = h(a, b)
699 // The broadcast_tensors node makes it harder to move f into the resulting
700 // FusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior
701 // together results in:
702 // x = f(v, w)
703 // ax, bx, ay, by = BroadcastingChunk(x, y)
704 // a = g(ax, ay)
705 // b = g(bx, by)
706 // c = h(a, b)
707 // making it easier to move f after the BroadcastingChunk:
708 // ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
709 // ax = f(av, aw)
710 // by = f(bv, bw)
711 // a = g(ax, ay)
712 // b = g(bx, by)
713 // c = h(a, b)
714
715 bool tryToMoveChunk(Node* consumer, Value* producer) {
716 // is the output from a chunk/bchunk node?
717 auto* chunk = producer->node();
718 if (chunk->kind() != prim::ConstantChunk &&
719 chunk->kind() != prim::BroadcastingChunk)
720 return false;
721
722 // try to find a producer to move after the chunk/bchunk. The producer must
723 // be fusible into the consumer.
724 auto it = std::find_if(
725 chunk->inputs().begin(),
726 chunk->inputs().end(),
727 [&](Value* producer_for_chunk) {
728 return isFusableMap(producer_for_chunk->node()) &&
729 allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
730 });
731 if (it == chunk->inputs().end()) {
732 return false;
733 }
734 Value* producer_for_chunk = *it;
735 size_t producer_index = it - chunk->inputs().begin();
736
737 // all uses of the chunk must be in this consumer
738 for (auto s : chunk->outputs()) {
739 for (auto u : s->uses()) {
740 if (u.user != consumer)
741 return false;
742 }
743 }
744 // multiple return operators
745 Node* producer_for_chunk_node = producer_for_chunk->node();
746 AT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
747
748 // Convert chunk to bchunk, if it isn't one already. The bchunk represents a
749 // broadcast and one or more chunk operations.
750 auto* bchunk = chunk;
751 if (chunk->kind() == prim::ConstantChunk) {
752 bchunk = promoteChunkToBroadcastingChunk(chunk);
753 }
754 size_t nchunks = bchunk->i(attr::chunks);
755 WithInsertPoint guard(bchunk->next());
756
757 std::vector<Value*> producer_chunk_outputs;
758 for (const auto i : c10::irange(nchunks)) {
759 producer_chunk_outputs.push_back(
760 bchunk->output(nchunks * producer_index + i));
761 }
762
763 // Add each of op's operands to the bchunk node.
764 // chunked_inputs[input_nr][chunk_output_idx]
765 // = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr])
766 std::vector<std::vector<Value*>> chunked_inputs;
767
768 for (auto input : producer_for_chunk_node->inputs()) {
769 // XXX: we only work with pointwise ops in here, so we know it is valid to
770 // push the concat only through tensor arguments (and all other args can
771 // be safely ignored).
772 if (!input->type()->isSubtypeOf(*TensorType::get()))
773 continue;
774
775 // if 'input' is already an input to the bchunk, reuse it.
776 auto bchunk_inputs = bchunk->inputs();
777 auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input);
778 if (it != bchunk_inputs.end()) {
779 chunked_inputs.emplace_back();
780 auto input_index = std::distance(bchunk_inputs.begin(), it);
781 for (const auto chunki : c10::irange(nchunks)) {
782 chunked_inputs.back().push_back(
783 bchunk->outputs().at(nchunks * input_index + chunki));
784 }
785 continue;
786 }
787
788 // NB: I decided not to use cloneFrom here, because if we make cloneFrom
789 // copy selects one day, it is definitely not what you want here (selects
790 // have different types).
791 // TODO: Perhaps we should use cloneFrom now, as it seems unlikely
792 // to copy select nodes now that we have refactored to have a Value
793 // distinct from Node.
794 bchunk->addInput(input);
795 chunked_inputs.emplace_back(); // alas, to not be C++17
796 for (auto chunk_sel : producer_chunk_outputs) {
797 Value* input_chunk_sel = bchunk->addOutput();
798 input_chunk_sel->setType(chunk_sel->type());
799 // Add a fresh value for each output element of the broadcasting chunk
800 // node. This is safe because it will be consumed only by the chunked
801 // ops.
802 aliasDb_->createValue(input_chunk_sel);
803 chunked_inputs.back().push_back(input_chunk_sel);
804 }
805 }
806
807 // apply the op to each chunk of the chunked operands,
808 // and then rewrite the graph to use them!
809 for (auto chunk_sel : producer_chunk_outputs) {
810 auto original_inputs = producer_for_chunk_node->inputs();
811 Node* chunked_op =
812 block_->owningGraph()->create(producer_for_chunk_node->kind());
813 chunked_op->copyAttributes(*producer_for_chunk_node);
814 chunked_op->output()->setType(chunk_sel->type());
815 auto chunked_inputs_it = chunked_inputs.begin();
816 for (Value* original_input : original_inputs) {
817 if (original_input->type()->isSubtypeOf(*TensorType::get())) {
818 AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
819 chunked_op->addInput(
820 // NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
821 chunked_inputs_it->at(chunk_sel->offset() % nchunks));
822 ++chunked_inputs_it;
823 } else {
824 chunked_op->addInput(original_input);
825 }
826 }
827 bchunk->owningGraph()->insertNode(chunked_op);
828 chunk_sel->replaceAllUsesWith(chunked_op->output());
829 aliasDb_->replaceWithNewValue(chunk_sel, chunked_op->output());
830 }
831
832 bchunk->removeInput(producer_index);
833 for (const auto i : c10::irange(nchunks)) {
834 (void)i; // Suppress unused variable warning
835 bchunk->eraseOutput(nchunks * producer_index);
836 }
837
838 // The output of producer_for_chunk_node could have been used in some
839 // aten::size operators, so we need to clean those up as well (we simply
840 // broadcast all its tensor inputs).
841 // We need to insert these early in the graph, i.e. immediately after
842 // the producer_for_chunk_node as we will have the _size_if_not_same
843 // that may be before the bchunk.
844 WithInsertPoint guard2(producer_for_chunk_node);
845 auto size_calc_uses = producer_for_chunk_node->output()->uses();
846 if (!size_calc_uses.empty()) {
847 auto tensor_inputs = filter(
848 producer_for_chunk_node->inputs(),
849 [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); });
850 auto tensor_sizes = fmap(tensor_inputs, [&](Value* v) {
851 Value* output = v->owningGraph()->insert(aten::size, {v});
852 aliasDb_->createValue(output);
853 return output;
854 });
855 AT_ASSERT(!tensor_sizes.empty());
856 Value* output_size = tensor_sizes.size() == 1
857 ? tensor_sizes[0]
858 : broadcastSizes(tensor_sizes, aliasDb_);
859 for (Use u : size_calc_uses) {
860 u.user->output()->replaceAllUsesWith(output_size);
861 u.user->destroy();
862 }
863 }
864 producer_for_chunk_node->destroy();
865 return true;
866 }
867
868 // returns where to continue scanning, and whether any fusion was made
869 std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
870 if (isFusable(consumer)) {
871 // handle inputs in reverse topological order as well...
872 // otherwise in f(a,a+b) it will appear a is used twice if we consider
873 // the f-a fusion before the f-(a+b) fusion first.
874 auto inputs = sortReverseTopological(consumer->inputs());
875 for (auto producer : inputs) {
876 if (tryToMoveChunk(consumer, producer)) {
877 // the chunk before this consumer was re-arranged to allow fusion,
878 // we scan this consumer again to perform the fusion
879 return std::make_pair(consumer->reverseIterator(), true);
880 }
881 auto fusion_group = tryFuse(consumer, producer);
882 if (fusion_group) {
883 // after fusion, consumer moves into a FusionGroup, so inputs is no
884 // longer valid so we rescan the new FusionGroup for more fusions...
885 return std::make_pair(fusion_group.value()->reverseIterator(), true);
886 }
887 }
888 }
889 return std::make_pair(++consumer->reverseIterator(), false);
890 }
891
892 void replaceIntermediateBroadcastingChunks() {
893 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
894 auto* node = *it;
895 ++it; // We might delete node, so increment the iterator now.
896 if (node->kind() != prim::BroadcastingChunk) {
897 continue;
898 }
899 auto* bchunk = node;
900 insertExplicitBroadcast(bchunk);
901
902 auto* graph = block_->owningGraph();
903 size_t nchunks = bchunk->i(attr::chunks);
904 WithInsertPoint guard(bchunk->next());
905
906 // Split the bchunk into bchunks.inputs().size() number of chunk nodes.
907 for (size_t input_offset = 0; input_offset < bchunk->inputs().size();
908 input_offset++) {
909 auto* input = bchunk->inputs().at(input_offset);
910
911 Node* new_chunk =
912 graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
913 new_chunk->copyAttributes(*bchunk);
914 for (const auto output_offset : c10::irange(nchunks)) {
915 auto new_output = new_chunk->addOutput();
916 auto old_output =
917 bchunk->outputs().at(input_offset * nchunks + output_offset);
918 new_output->copyMetadata(old_output);
919 aliasDb_->replaceWithNewValue(old_output, new_output);
920 old_output->replaceAllUsesWith(new_output);
921 }
922 }
923 bchunk->destroy();
924 }
925 }
926
927 // Builds up expressions that compute shapes of all intermediates (and
928 // outputs) of the fusion group, based on the sizes of inputs. You should run
929 // DCE to remove those that you end up not using.
930 std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
931 WithInsertPoint insert_guard{fusion_group->next()};
932 std::unordered_map<Value*, Value*> shape_of;
933
934 Graph* graph = fusion_group->owningGraph();
935 auto subgraph = fusion_group->g(attr::Subgraph);
936
937 auto inputs = fusion_group->inputs();
938 auto sinputs = subgraph->inputs();
939 AT_ASSERT(inputs.size() == sinputs.size());
940 for (const auto i : c10::irange(inputs.size())) {
941 if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
942 Value* soutput = graph->insert(aten::size, {inputs[i]});
943 aliasDb_->createValue(soutput);
944 shape_of[sinputs[i]] = soutput;
945 }
946 }
947
948 // When we have a guarantee that an output won't be removed, because it's
949 // used in expressions that don't involve size checks, we can use its size
950 // instead of computing a long chain of broadcasts, starting from the
951 // beginning of the kernel.
952 auto outputs = fusion_group->outputs();
953 auto soutputs = subgraph->outputs();
954 AT_ASSERT(outputs.size() == soutputs.size());
955 for (const auto i : c10::irange(outputs.size())) {
956 if (usedOnlyInSize(outputs[i]))
957 continue;
958 Value* soutput = graph->insert(aten::size, {outputs[i]});
959 aliasDb_->createValue(soutput);
960 shape_of[soutputs[i]] = soutput;
961 }
962
963 for (Node* n : subgraph->nodes()) {
964 // XXX: Use of shape_of.emplace is crucial to the output shape
965 // optimization!
966 if (n->kind() == prim::FusedConcat) {
967 // This is a bit more involved, because we have to account for the case
968 // when inputs have different shapes, but fortunately those tensors are
969 // always outputs, and so we can simply avoid replacing their queries,
970 // because it won't help us.
971 continue;
972 }
973 if (n->kind() == prim::Constant) {
974 continue;
975 }
976 if (n->kind() == prim::ConstantChunk) {
977 Node* sizes_node = graph->insertNode(
978 graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
979 sizes_node->i_(attr::dim, n->i(attr::dim));
980 sizes_node->i_(attr::chunks, n->i(attr::chunks));
981 for (Value* output : sizes_node->outputs()) {
982 aliasDb_->createValue(output);
983 }
984 Value* regular_size = sizes_node->outputs().at(0);
985 Value* last_size = sizes_node->outputs().at(1);
986 regular_size->setType(ListType::ofInts());
987 last_size->setType(ListType::ofInts());
988 auto outputs = n->outputs();
989 for (Value* o : outputs.slice(0, outputs.size() - 1)) {
990 shape_of.emplace(o, regular_size);
991 }
992 shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
993 continue;
994 }
995 auto tensor_inputs = filter(n->inputs(), [](Value* v) {
996 return v->type()->isSubtypeOf(*TensorType::get());
997 });
998 auto shapes =
999 fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
1000 AT_ASSERT(!shapes.empty());
1001 shape_of.emplace(
1002 n->output(),
1003 shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes, aliasDb_));
1004 }
1005 return shape_of;
1006 }
1007
1008 void removeOutputsUsedOnlyInSize(Node* fusion_group) {
1009 if (fusion_group->kind() != prim::FusionGroup)
1010 return;
1011 auto subgraph = fusion_group->g(attr::Subgraph);
1012
1013 auto shape_of = buildShapeExpressions(fusion_group);
1014 auto outputs = fusion_group->outputs().vec();
1015 auto soutputs = subgraph->outputs().vec();
1016 // XXX: Iterating in this order is not only good for performance reasons!
1017 // It is also crucial for correctness (i has to reflect the current true
1018 // index of outputs[i])!
1019 for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
1020 auto output = outputs[i];
1021 auto soutput = soutputs[i];
1022 if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
1023 auto uses = output->uses();
1024 for (Use u : uses) {
1025 AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
1026 u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
1027 u.user->destroy();
1028 }
1029 fusion_group->eraseOutput(i);
1030 subgraph->eraseOutput(i);
1031 }
1032 }
1033 }
1034
1035 bool canFuseWithConcat(Value* producer, Node* before_check) {
1036 if (!isFusable(producer->node())) {
1037 return false;
1038 }
1039 // NB: it is important that this check happens after isFusable, which checks
1040 // that the blocks match, and it's not a special node like prim::Param
1041 if (!aliasDb_->couldMoveBeforeTopologically(
1042 producer->node(), before_check)) {
1043 return false;
1044 }
1045
1046 // If the number of kernel args could exceed the limit, skip.
1047 if ((before_check->inputs().size() + before_check->outputs().size() +
1048 producer->node()->inputs().size() +
1049 producer->node()->outputs().size()) > subgraph_arg_limit_) {
1050 return false;
1051 }
1052
1053 // Fusion groups can be merged with concat's group if and only if
1054 // the value they produce isn't already coming from a concat
1055 if (producer->node()->kind() == prim::FusionGroup) {
1056 auto subgraph = producer->node()->g(attr::Subgraph);
1057 auto* node = subgraph->outputs().at(producer->offset())->node();
1058 return node->kind() != prim::FusedConcat;
1059 }
1060 return true;
1061 }
1062
1063 Node* createFusedConcat(Node* node) {
1064 AT_ASSERT(node->kind() == aten::cat);
1065
1066 Graph* graph = node->owningGraph();
1067 Node* list_construct = node->namedInput(attr::tensors)->node();
1068 int64_t dim = node->get<int64_t>(attr::dim).value();
1069
1070 Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())
1071 ->i_(attr::dim, dim);
1072 fused_cat->insertBefore(list_construct);
1073 fused_cat->output()->copyMetadata(node->output());
1074 aliasDb_->copyValue(node->output(), fused_cat->output());
1075
1076 // NB: this deletes the fused_cat node from the original graph
1077 return createSingletonFusionGroup(fused_cat);
1078 }
1079
1080 void fuseConcats() {
1081 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();
1082 ++it) {
1083 Node* cat = *it;
1084 if (!isFusableCatNode(cat)) {
1085 continue;
1086 }
1087 Node* list_construct = cat->namedInput(attr::tensors)->node();
1088 Node* fused_cat = createFusedConcat(cat);
1089 Value* fused_cat_out = fused_cat->output();
1090
1091 auto sorted_inputs = sortReverseTopological(fused_cat->inputs());
1092 size_t input_idx = 0;
1093 bool any_fused = false;
1094 while (input_idx < sorted_inputs.size()) {
1095 Value* input = sorted_inputs[input_idx++];
1096 if (!canFuseWithConcat(input, fused_cat)) {
1097 continue;
1098 }
1099 any_fused = true;
1100 auto maybe_group = tryFuse(fused_cat, input);
1101 AT_ASSERT(maybe_group && maybe_group == fused_cat);
1102 // We could have destroyed multiple inputs when performing this fusion,
1103 // so we have to recompute the list and iterate over it again.
1104 sorted_inputs = sortReverseTopological(fused_cat->inputs());
1105 input_idx = 0;
1106 }
1107
1108 if (any_fused) {
1109 cat->output()->replaceAllUsesWith(fused_cat_out);
1110 it.destroyCurrent();
1111 if (list_construct->output()->uses().empty()) {
1112 list_construct->destroy();
1113 }
1114 } else {
1115 fused_cat->destroy();
1116 }
1117 }
1118 }
1119
1120 void optimizeFusedGraphs() {
1121 for (Node* node : block_->nodes()) {
1122 if (node->kind() != prim::FusionGroup) {
1123 continue;
1124 }
1125 auto subgraph = node->g(attr::Subgraph);
1126 EliminateDeadCode(subgraph);
1127 EliminateCommonSubexpression(subgraph);
1128 ConstantPooling(subgraph);
1129 }
1130 }
1131
1132 void run() {
1133// TODO: old fuser is not maintained internally, somewhere it is being turned on
1134// inadvertently for certain workflows. make this a no-op until we identify
1135// location
1136#if defined(FBCODE_CAFFE2)
1137 return;
1138#endif
1139
1140 // Run the pass until no changes are made.
1141 // This is necessary, because the algorithm can miss out on certain fusion
1142 // opportunities if ran only once. Consider this graph:
1143 //
1144 // %1 = f(...)
1145 // %2 = g(%1)
1146 // %3 = h(%1)
1147 // %4 = l(%3)
1148 // return (%4, %2)
1149 //
1150 // where f, g, h, l are simple map ops.
1151 // The first iteration will fuse %4 and %3, and see that %1 is an input, but
1152 // can't be fused, because it has a different use before the fusion group
1153 // in our topological ordering. Then, %2 will be considered, and fused with
1154 // %1. If we do another iteration, the algorithm will consider the fusion of
1155 // these two groups and fix the situation.
1156 bool any_changed = true;
1157 while (any_changed) {
1158 any_changed = false;
1159 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1160 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1161 bool changed;
1162 std::tie(it, changed) = scanNode(*it);
1163 any_changed |= changed;
1164 }
1165 }
1166
1167 fuseConcats();
1168
1169 optimizeFusedGraphs();
1170
1171 // The graph fuser can add intermediate prim::BroadcastingChunk nodes.
1172 // Replace them with broadcasts + chunks.
1173 replaceIntermediateBroadcastingChunks();
1174
1175 // Fuse starting chunks into the group.
1176 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1177 it = scanNodeForChunks(*it);
1178 }
1179
1180 // Remove outputs that have been added only because we need their size
1181 for (Node* n : block_->nodes()) {
1182 removeOutputsUsedOnlyInSize(n);
1183 }
1184
1185 for (Node* node : block_->nodes()) {
1186 for (Block* sub_block : node->blocks()) {
1187 GraphFuser(aliasDb_, sub_block, callback_, kind_, strict_fuser_check_)
1188 .run();
1189 }
1190 }
1191 }
1192};
1193
1194void PeepholeOptimizeShapeExpressions(Block* block, AliasDb* db) {
1195 auto nodes = block->nodes();
1196 for (auto it = nodes.begin(); it != nodes.end(); ++it) {
1197 Node* node = *it;
1198 for (Block* subblock : node->blocks()) {
1199 PeepholeOptimizeShapeExpressions(subblock, db);
1200 }
1201 if (node->kind() == prim::BroadcastSizes) {
1202 // Remove no-op broadcasts.
1203 if (node->inputs().size() == 1) {
1204 node->output()->replaceAllUsesWith(node->input());
1205 it.destroyCurrent();
1206 continue;
1207 }
1208 // Deduplicate inputs, but use their unique() values to ensure
1209 // this process only depends on the graph.
1210 std::map<size_t, Value*> unique_to_value;
1211 for (Value* input : node->inputs()) {
1212 unique_to_value.emplace(input->unique(), input);
1213 }
1214 if (unique_to_value.size() != node->inputs().size()) {
1215 std::vector<Value*> inputs;
1216 inputs.reserve(unique_to_value.size());
1217 for (auto& entry : unique_to_value) {
1218 inputs.push_back(entry.second);
1219 }
1220 if (inputs.size() == 1) {
1221 node->output()->replaceAllUsesWith(inputs[0]);
1222 } else {
1223 WithInsertPoint insert_guard{node};
1224 node->output()->replaceAllUsesWith(broadcastSizes(inputs, db));
1225 }
1226 it.destroyCurrent();
1227 --it; // Revisit the node with deduplicated inputs
1228 continue;
1229 }
1230 // Remove compose simple chains of broadcasts into a single node.
1231 const auto& uses = node->output()->uses();
1232 if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
1233 Node* user = uses[0].user;
1234 user->removeInput(uses[0].offset);
1235 // NB: we don't care about deduplication in here, as we will visit user
1236 // later.
1237 for (Value* i : node->inputs()) {
1238 user->addInput(i);
1239 }
1240 it.destroyCurrent();
1241 }
1242 }
1243 }
1244}
1245
1246} // anonymous namespace
1247
1248static bool cpu_fuser_enabled_legacy = false;
1249
1250bool canFuseOnCPULegacy() {
1251 return cpu_fuser_enabled_legacy;
1252}
1253
1254void overrideCanFuseOnCPULegacy(bool value) {
1255 cpu_fuser_enabled_legacy = value;
1256}
1257
1258void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) {
1259 AliasDb db(graph);
1260 GraphFuser(&db, graph->block(), strict_fuser_check).run();
1261 Lint(&db);
1262 // After FuseGraph some common subexpressions may come back
1263 EliminateCommonSubexpression(graph);
1264 // We might have emitted a fair amount of useless shape propagating code, so
1265 // remove it
1266 EliminateDeadCode(graph);
1267 // Improve the quality of shape propagation code that was left
1268 PeepholeOptimizeShapeExpressions(graph->block(), &db);
1269}
1270
1271void CustomFuseGraph(
1272 std::shared_ptr<Graph>& graph,
1273 const std::function<bool(Node*)>& fn,
1274 Symbol kind,
1275 size_t arg_limit) {
1276 AliasDb db(graph);
1277 auto g = GraphFuser(
1278 &db,
1279 graph->block(),
1280 [=](GraphFuser* gf, Node* n) { return fn(n) || n->kind() == kind; },
1281 kind);
1282 g.setInputArgLimit(arg_limit);
1283 g.run();
1284 Lint(&db);
1285}
1286
1287} // namespace jit
1288} // namespace torch
1289