1#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
2
3#include <c10/util/Exception.h>
4#include <c10/util/irange.h>
5#include <instrumentation.h>
6#include <parser.h>
7#include <partition.h>
8#include <transform_view.h>
9#include <utils.h>
10#include <torch/csrc/jit/frontend/ir_emitter.h>
11#include <torch/csrc/jit/ir/alias_analysis.h>
12#include <torch/csrc/jit/jit_log.h>
13#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
14#include <torch/csrc/jit/passes/constant_pooling.h>
15#include <torch/csrc/jit/passes/dead_code_elimination.h>
16#include <torch/csrc/jit/passes/pass_manager.h>
17#include <torch/csrc/jit/passes/remove_mutation.h>
18#include <torch/csrc/jit/passes/restore_mutation.h>
19#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
20#include <torch/csrc/jit/runtime/autodiff.h>
21#include <torch/csrc/jit/runtime/custom_operator.h>
22#include <torch/csrc/jit/runtime/graph_iterator.h>
23#include <torch/csrc/jit/runtime/operator.h>
24
25#include <torch/csrc/jit/ir/alias_analysis.h>
26#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
27
28#include <queue>
29#include <unordered_map>
30
31namespace torch {
32namespace jit {
33namespace fuser {
34namespace cuda {
35
36constexpr size_t NVRTC_KERNEL_ARG_LIMIT = 128;
37
38namespace {
39
40bool usedOnlyInDtype(Value* v) {
41 const auto& uses = v->uses();
42 if (uses.empty()) {
43 return false;
44 }
45 return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
46 return u.user->matches("prim::dtype(Tensor a) -> int");
47 });
48}
49
50Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
51 AT_ASSERT(!sizes.empty());
52 Graph* graph = sizes[0]->owningGraph();
53 Node* insertion_point = sizes[0]->node()->next();
54 for (size_t i = 1; i < sizes.size(); i++) {
55 if (insertion_point->isBefore(sizes[i]->node()->next())) {
56 insertion_point = sizes[i]->node()->next();
57 }
58 }
59 WithInsertPoint guard(insertion_point);
60 Node* broadcast_n =
61 graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
62 broadcast_n->output()->setType(ListType::ofInts());
63 return broadcast_n->output();
64}
65
66Value* createConditionalConstant(Node* profile_ivalue) {
67 TORCH_INTERNAL_ASSERT(profile_ivalue->kind() == prim::profile_ivalue);
68
69 auto graph = profile_ivalue->owningGraph();
70
71 IValue val; // default to None
72 if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int_list"))) {
73 // int[]
74 val = IValue(profile_ivalue->is(Symbol::attr("profiled_int_list")));
75 } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool_list"))) {
76 // bool[]
77 auto int_list = profile_ivalue->is(Symbol::attr("profiled_bool_list"));
78 std::vector<bool> bool_list(int_list.begin(), int_list.end());
79 val = IValue(bool_list);
80 } else if (profile_ivalue->hasAttribute(
81 Symbol::attr("profiled_reduction_size"))) {
82 // int[]
83 val = IValue(profile_ivalue->is(Symbol::attr("profiled_reduction_size")));
84 } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_view_size"))) {
85 // int[]
86 val = IValue(profile_ivalue->is(Symbol::attr("profiled_view_size")));
87 } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool"))) {
88 // bool
89 val = IValue(
90 static_cast<bool>(profile_ivalue->i(Symbol::attr("profiled_bool"))));
91 } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int"))) {
92 // int
93 val = IValue(
94 static_cast<int>(profile_ivalue->i(Symbol::attr("profiled_int"))));
95 } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_str"))) {
96 // str
97 val = IValue(static_cast<std::string>(
98 profile_ivalue->s(Symbol::attr("profiled_str"))));
99 } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_ival"))) {
100 // ival
101 val = IValue(profile_ivalue->ival(Symbol::attr("profiled_ival")));
102 } else {
103 GRAPH_DEBUG("no profile info in profile_ivalue node: ", *profile_ivalue);
104 TORCH_WARN_ONCE(
105 __func__,
106 " profile_node ",
107 *profile_ivalue,
108 " does not have profile information");
109 return nullptr;
110 }
111
112 return graph->insertConstant(val);
113}
114
115struct CudaGraphFuser {
116 using FusionCallback = std::function<bool(Node*)>;
117
118 Block* block_;
119 std::unique_ptr<AliasDb> aliasDb_;
120 std::shared_ptr<Graph> graph_;
121 Symbol kind_ = prim::CudaFusionGroup;
122 std::unordered_map<Value*, Value*> fusion_value_to_runtime_shape_;
123
124 // nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
125 // The specific limit is a function of constant memory size, amount available
126 // to pass arguments, and some implementation dependence. Select a safe
127 // limit here.
128 // This limit is also applied to other devices in the fuser by default.
129 // Change with setInputArgLimit
130 size_t subgraph_arg_limit_ = NVRTC_KERNEL_ARG_LIMIT;
131
132 CudaGraphFuser(Block* block, std::shared_ptr<Graph> graph)
133 : block_(block), graph_(std::move(graph)) {}
134
135 void setInputArgLimit(size_t limit) {
136 subgraph_arg_limit_ = limit;
137 }
138
139 value_list tensorInputs(Node* node) {
140 return filter(node->inputs(), [](Value* v) {
141 return v->type()->isSubtypeOf(*TensorType::get());
142 });
143 }
144
145 bool calculatesSize(Node* node) {
146 return node->matches("aten::size(Tensor self) -> int[]");
147 }
148
149 bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
150 auto defining_node = producer->node();
151 for (auto o : defining_node->outputs()) {
152 for (auto u : o->uses()) {
153 if (u.user != consumer && !calculatesSize(u.user))
154 return false;
155 }
156 }
157 return true;
158 }
159
160 Graph& getSubgraph(Node* n) {
161 AT_ASSERT(n->kind() == kind_);
162 return *n->g(attr::Subgraph);
163 }
164
165 void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
166 // Now we have two fusion groups!
167 // Revert the fusion - place all inner nodes of producer back in the outer
168 // graph.
169 std::vector<Node*> temporary_nodes;
170 auto producer_subgraph = &getSubgraph(producer_group);
171
172 // Initialize a map of inner graph values to outer graph values
173 std::unordered_map<Value*, Value*> inner_to_outer;
174 auto inner_inputs = producer_subgraph->inputs();
175 auto outer_inputs = producer_group->inputs();
176 for (const auto i : c10::irange(inner_inputs.size())) {
177 inner_to_outer[inner_inputs[i]] = outer_inputs[i];
178 }
179
180 // Clone all nodes
181 for (auto inner : producer_subgraph->nodes()) {
182 Node* outer = block_->owningGraph()->createClone(
183 inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); });
184 outer->insertBefore(producer_group);
185 temporary_nodes.emplace_back(outer);
186 auto inner_outputs = inner->outputs();
187 auto outer_outputs = outer->outputs();
188 for (const auto i : c10::irange(inner_outputs.size())) {
189 inner_to_outer[inner_outputs[i]] = outer_outputs[i];
190 }
191 }
192
193 // Replace uses of producer_group outputs and destroy the producer
194 auto subgraph_outputs = producer_subgraph->outputs();
195 for (const auto i : c10::irange(subgraph_outputs.size())) {
196 auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
197 producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
198 }
199 producer_group->destroy();
200 producer_group =
201 nullptr; // Just to get a clear error in case someone uses it
202
203 // Inline the temporary nodes into the first group
204 auto consumer_subgraph = &getSubgraph(consumer_group);
205 for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
206 ++it) {
207 Node* node = *it;
208 Node* merged = mergeNodeIntoGroup(consumer_group, node);
209 // If any of the outputs are still used then we need to add them
210 auto outputs = node->outputs();
211 for (const auto i : c10::irange(outputs.size())) {
212 auto output = outputs[i];
213 if (output->uses().size() == 0)
214 continue;
215 consumer_subgraph->registerOutput(merged->outputs()[i]);
216 auto new_output = consumer_group->addOutput();
217 output->replaceAllUsesWith(new_output);
218 new_output->setType(output->type());
219 }
220 node->destroy();
221 }
222 }
223
224 // insert a producer node into a consuming fusion group.
225 // DOES NOT WORK if n is a consumer of an output of the fusion group
226 // returns the node _inside_ the group that represents the node
227 Node* mergeNodeIntoGroup(Node* group, Node* n) {
228 AT_ASSERT(n->kind() != kind_);
229 auto& subgraph = getSubgraph(group);
230 // map from nodes in the surrounding graph to parameters in the fusion
231 // group's subgraph that correspond to them
232 std::unordered_map<Value*, Value*> inputs_map;
233 size_t i = 0;
234 size_t tensor_insert_idx = 0;
235 for (auto input : group->inputs()) {
236 inputs_map[input] = subgraph.inputs()[i++];
237 if (input->type()->isSubtypeOf(*TensorType::get()))
238 tensor_insert_idx = i;
239 }
240 // add n's inputs to the fusion group's input list if we don't already have
241 // them
242 // we insert tensors first because the fuser assumes that to be the case
243 // (as a legacy from tensors only)
244 WithInsertPoint guard(*subgraph.nodes().begin());
245 for (auto input : n->inputs()) {
246 if (inputs_map.count(input) == 0) {
247 // TODO: we are following the convention for no good reason;
248 // we don't need tensor to come before any other inputs.
249 if (input->type()->isSubtypeOf(*TensorType::get())) {
250 auto in_group = subgraph.insertInput(tensor_insert_idx);
251 in_group->setType(input->type());
252 inputs_map[input] = in_group;
253 group->insertInput(tensor_insert_idx, input);
254 tensor_insert_idx++;
255 } else if (
256 // TODO: extend the supporting inputs here.
257 (input->type()->isSubtypeOf(*FloatType::get()) &&
258 input->node()->kind() != prim::Constant)) {
259 auto in_group = subgraph.addInput();
260 in_group->setType(input->type());
261 inputs_map[input] = in_group;
262 group->addInput(input);
263 } else if (input->node()->kind() == prim::Constant) {
264 // inline the constants directly in the body of the fused group.
265 Node* in_const =
266 subgraph.createClone(input->node(), [&](Value* v) -> Value* {
267 if (v->node()->kind() != prim::profile_ivalue) {
268 throw std::runtime_error(
269 std::string(
270 "merging constant with unexpected input from node") +
271 v->node()->kind().toDisplayString());
272 }
273 group->addInput(v->node()->output());
274
275 // we are doing this just to keep alias_analysis silent with
276 // their checks
277 auto in_group = subgraph.addInput();
278 in_group->setType(v->type());
279 return in_group;
280 });
281 subgraph.insertNode(in_const);
282 inputs_map[input] = in_const->output();
283 } else {
284 // TODO: we need to figure out what are supported input scalar
285 auto in_group = subgraph.addInput();
286 in_group->setType(input->type());
287 inputs_map[input] = in_group;
288 group->addInput(input);
289 }
290 }
291 }
292 // copy n into the graph, remapping its inputs to internal nodes
293 Node* in_graph = subgraph.createClone(
294 n, [&](Value* k) -> Value* { return inputs_map[k]; });
295 // if n's outputs are already inputs to the fusion group,
296 // we need to remove them because n is now inside the fusion group.
297 //
298 // i.e.,
299 // x = f(w); group(x, y, z) becomes group(w, y, z).
300 // x, y, z = f(w); group(x, y, z) becomes group(w).
301 //
302 // remapping nodes that used the input to the newly-merged node
303 // n is not an input when the fusion group is empty
304 auto inputs = group->inputs();
305 for (const auto i : c10::irange(n->outputs().size())) {
306 auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
307 if (it != inputs.end()) {
308 size_t p = it - inputs.begin();
309 group->removeInput(p);
310 subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
311 subgraph.eraseInput(p);
312 }
313 }
314 return subgraph.insertNode(in_graph);
315 }
316
317 // turn consumer node n into a fusion group with just n inside
318 // to prepare for fusion and replace uses of n with the new group
319 Node* createSingletonFusionGroup(Node* n) {
320 auto group = block_->owningGraph()->createWithSubgraph(kind_);
321 // propogate position information for the new node so we can always
322 // have a valid mapping
323 group->insertBefore(n);
324 Node* mergedNode = mergeNodeIntoGroup(group, n);
325 for (const auto i : c10::irange(n->outputs().size())) {
326 getSubgraph(group).registerOutput(mergedNode->output(i));
327 auto sel = group->addOutput();
328 sel->copyMetadata(n->output(i));
329 }
330 n->replaceAllUsesWith(group);
331 n->destroy();
332 return group;
333 }
334
335 at::optional<Node*> tryFuse(Node* consumer, Node* producer) {
336 // this handles cases where producer can be moved _into_ the fusion group of
337 // consumer.
338 // TODO: extend to fusion of consumer into _producer's_ fusion blob
339 // if the consumer allInputsAreThisProducer(consumer,producer)
340 // we can move the consumer up into the producer.
341 // but this requires better handling of merging fusion groups so it is not
342 // done now
343 bool shouldFuse =
344 fuser::cuda::isFusibleCudaFusionGroup(consumer, producer) &&
345 // Rearrange nodes such that all uses of producer's outputs are after
346 // consumer. Fusion will rewrite those later uses to use the version of
347 // producer generated by the fused blob. In this case, producer becomes
348 // an output of the fusion group.
349 aliasDb_->moveBeforeTopologicallyValid(producer, consumer);
350
351 if (!shouldFuse) {
352 return at::nullopt;
353 }
354
355 if ((consumer->inputs().size() + consumer->outputs().size() +
356 producer->inputs().size() + producer->outputs().size()) >
357 subgraph_arg_limit_) {
358 return at::nullopt;
359 }
360
361 auto group = consumer;
362 if (consumer->kind() != kind_) {
363 group = createSingletonFusionGroup(consumer);
364 }
365
366 if (producer->kind() == kind_) {
367 mergeFusionGroups(group, producer);
368 return group;
369 }
370 Node* merged = mergeNodeIntoGroup(group, producer);
371 // remaining uses of this producer can occur because we allow
372 // fusion in cases where uses remain after the consumer
373 // if these exist, re-route them to the version of producer
374 // created in FusionGroup
375
376 // We need to apply this to all outputs from producer->node();
377 auto producer_outputs = producer->outputs();
378 for (const auto i : c10::irange(producer_outputs.size())) {
379 if (producer_outputs[i]->uses().size() != 0) {
380 getSubgraph(group).registerOutput(merged->outputs()[i]);
381 Value* new_producer = group->addOutput();
382 new_producer->copyMetadata(producer_outputs[i]);
383 producer_outputs[i]->replaceAllUsesWith(new_producer);
384 }
385 }
386 producer->destroy();
387 return group;
388 }
389
390 c10::optional<Node*> findFusedChunk(Node* group, Value* input) {
391 AT_ASSERT(group->kind() == kind_);
392 auto it = std::find(group->inputs().begin(), group->inputs().end(), input);
393 if (it == group->inputs().end()) {
394 return c10::nullopt;
395 }
396 size_t input_index = it - group->inputs().begin();
397 auto& subgraph = getSubgraph(group);
398 auto* subgraph_input = subgraph.inputs().at(input_index);
399 // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use
400 auto* node = subgraph_input->uses().at(0).user;
401 if (node->kind() == prim::ConstantChunk) {
402 AT_ASSERT(subgraph_input->uses().size() == 1);
403 return node;
404 }
405 return c10::nullopt;
406 }
407
408 void fuseChunkByReusingExistingFusedChunk(
409 Node* group,
410 Node* chunk,
411 Node* existingFusedChunk) {
412 if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
413 return;
414 }
415 auto& subgraph = getSubgraph(group);
416 for (const auto i : c10::irange(chunk->outputs().size())) {
417 // Find the input to the FusionGroup (group)
418 auto* replacement_val = existingFusedChunk->outputs().at(i);
419 auto* val = chunk->outputs().at(i);
420 auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
421 auto input_index = it - group->inputs().begin();
422
423 // Rewrite the graph to use replacement_val
424 auto group_input = subgraph.inputs().at(input_index);
425 group_input->replaceAllUsesWith(replacement_val);
426
427 // Remove the input, it's no longer needed
428 group->removeInput(input_index);
429 subgraph.eraseInput(input_index);
430 }
431 chunk->destroy();
432 }
433
434 value_list sortReverseTopological(ArrayRef<Value*> inputs) {
435 value_list result;
436 for (auto i : inputs) {
437 if (i->node()->owningBlock() == block_) {
438 result.push_back(i);
439 }
440 }
441 // Sort in reverse topological order
442 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
443 return a->node()->isAfter(b->node());
444 });
445 return result;
446 }
447
448 at::ArrayRef<Value*> broadcast_tensors(value_list inputs) {
449 AT_ASSERT(inputs.size() > 0);
450 auto* g = inputs[0]->owningGraph();
451 auto* input_list =
452 g->insertNode(g->createList(TensorType::get(), inputs))->output();
453 auto* output_list = g->insert(aten::broadcast_tensors, {input_list});
454 auto* unpack_node = g->insertNode(
455 g->create(prim::ListUnpack, {output_list}, inputs.size()));
456 return unpack_node->outputs();
457 }
458
459 void insertExplicitBroadcast(Node* node) {
460 WithInsertPoint insert_guard{node};
461 auto tensors = tensorInputs(node);
462 auto new_tensors = broadcast_tensors(tensors);
463
464 // Replace tensors inputs with broadcasted values
465 auto new_tensors_it = new_tensors.begin();
466 for (const auto i : c10::irange(node->inputs().size())) {
467 if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) {
468 AT_ASSERT(new_tensors_it != new_tensors.end());
469 node->replaceInput(i, *(new_tensors_it++));
470 }
471 }
472 }
473
474 Node* promoteChunkToBroadcastingChunk(Node* chunk) {
475 AT_ASSERT(chunk->kind() == prim::ConstantChunk);
476
477 size_t nchunks = chunk->i(attr::chunks);
478 Node* bchunk =
479 chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
480 bchunk->addInput(chunk->input());
481 for (const auto i : c10::irange(nchunks)) {
482 auto* old_output = chunk->outputs().at(i);
483 auto* new_output = bchunk->outputs().at(i);
484 new_output->copyMetadata(old_output);
485 old_output->replaceAllUsesWith(new_output);
486 }
487 bchunk->copyAttributes(*chunk);
488 bchunk->insertAfter(chunk);
489 chunk->destroy();
490 return bchunk;
491 }
492
493 // in places where op can be fused into a consumer but chunk is in the way
494 // distribute chunk to op's operands:
495 // replace a,b = chunk(op(x,y,z)) with:
496 // x', y', z' = broadcast_tensors([x, y, z])
497 // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type)
498 // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type)
499 // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type)
500 // a = op(x0,y0,z0) (a,b have their same size but are now contiguous)
501 // b = op(x1,y1,x1)
502 //
503 // The graph fuser uses an intermediate prim::BroadcastingChunk node to
504 // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts
505 // all of its inputs and then chunks each input, in order, the same way.
506 // The above graph is equivalent to:
507 // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z)
508 // a = op(x0,y0,z0)
509 // b = op(x1,y1,x1)
510 //
511 // NB: The explicit broadcast is important for correctness.
512 // Let's say we have:
513 // %z = aten::mul(%x, %y)
514 // %z.1, %z.2 = aten::chunk(%z, ...)
515 // ... = prim::CudaFusionGroup(%z.1, %z.2, ...)
516 // It's possible that %x and %y do not have the same size as %z and
517 // need to be expanded first so that they can be chunked like %z
518 //
519 // NB: Chunk motion only occurs with fusable consumers, which implies
520 // that there is always some other operation, e.g., a+b, that happens
521 // after the chunk, and will be put into the fusion group. This is
522 // important, because distributing the chunk changes the contiguity
523 // of a and b, and so the results would be invalid, except that we know
524 // that simple_mappable operations will restore contiguity before
525 // we exit the fusion group.
526 //
527 // NB: The intermediate BroadcastingChunk is important for moving chunks past
528 // more than one operation: the graph fuser is not able to easily move
529 // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusable
530 // ops
531 // x = f(v, w)
532 // z = g(x, y)
533 // a, b = chunk(z)
534 // c = h(a, b)
535 // becomes (with the broadcast_tensors + chunk approach):
536 // x = f(v, w)
537 // x', y' = broadcast_tensors([x, y])
538 // ax, bx = chunk(x')
539 // ay, by = chunk(y')
540 // a = g(ax, ay)
541 // b = g(bx, by)
542 // c = h(a, b)
543 // The broadcast_tensors node makes it harder to move f into the resulting
544 // FusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior
545 // together results in:
546 // x = f(v, w)
547 // ax, bx, ay, by = BroadcastingChunk(x, y)
548 // a = g(ax, ay)
549 // b = g(bx, by)
550 // c = h(a, b)
551 // making it easier to move f after the BroadcastingChunk:
552 // ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
553 // ax = f(av, aw)
554 // by = f(bv, bw)
555 // a = g(ax, ay)
556 // b = g(bx, by)
557 // c = h(a, b)
558
559 bool tryToMoveChunk(Node* consumer, Value* producer) {
560 // is the output from a chunk/bchunk node?
561 auto* chunk = producer->node();
562 if (chunk->kind() != prim::ConstantChunk &&
563 chunk->kind() != prim::BroadcastingChunk)
564 return false;
565
566 // try to find a producer to move after the chunk/bchunk. The producer must
567 // be fusable into the consumer.
568 auto it = std::find_if(
569 chunk->inputs().begin(),
570 chunk->inputs().end(),
571 [&](Value* producer_for_chunk) {
572 return fuser::cuda::isFusibleCudaFusionGroup(
573 consumer, producer_for_chunk->node()) &&
574 isElementWiseNode(consumer) &&
575 allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
576 });
577 if (it == chunk->inputs().end()) {
578 return false;
579 }
580 Value* producer_for_chunk = *it;
581 size_t producer_index = it - chunk->inputs().begin();
582
583 // all uses of the chunk must be in this consumer
584 for (auto s : chunk->outputs()) {
585 for (auto u : s->uses()) {
586 if (u.user != consumer)
587 return false;
588 }
589 }
590 // multiple return operators
591 Node* producer_for_chunk_node = producer_for_chunk->node();
592 AT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
593
594 // Convert chunk to bchunk, if it isn't one already. The bchunk represents a
595 // broadcast and one or more chunk operations.
596 auto* bchunk = chunk;
597 if (chunk->kind() == prim::ConstantChunk) {
598 bchunk = promoteChunkToBroadcastingChunk(chunk);
599 }
600 size_t nchunks = bchunk->i(attr::chunks);
601 TORCH_INTERNAL_ASSERT(nchunks > 0, "number of chunks cannot be zero");
602 WithInsertPoint guard(bchunk->next());
603
604 std::vector<Value*> producer_chunk_outputs;
605 for (const auto i : c10::irange(nchunks)) {
606 producer_chunk_outputs.push_back(
607 bchunk->output(nchunks * producer_index + i));
608 }
609
610 // Add each of op's operands to the bchunk node.
611 // chunked_inputs[input_nr][chunk_output_idx]
612 // = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr])
613 std::vector<std::vector<Value*>> chunked_inputs;
614
615 // We have asserted single output earlier
616 auto producer_output_sizes =
617 producer_for_chunk_node->output()->type()->cast<TensorType>()->sizes();
618
619 for (auto input : producer_for_chunk_node->inputs()) {
620 // XXX: we only work with pointwise ops in here, so we know it is valid to
621 // push the concat only through tensor arguments (and all other args can
622 // be safely ignored).
623 if (!input->type()->isSubtypeOf(*TensorType::get()))
624 continue;
625
626 // if 'input' is already an input to the bchunk, reuse it.
627 auto bchunk_inputs = bchunk->inputs();
628 auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input);
629 if (it != bchunk_inputs.end()) {
630 chunked_inputs.emplace_back();
631 auto input_index = std::distance(bchunk_inputs.begin(), it);
632 for (const auto chunk : c10::irange(nchunks)) {
633 chunked_inputs.back().push_back(
634 bchunk->outputs().at(nchunks * input_index + chunk));
635 }
636 continue;
637 }
638
639 // NB: I decided not to use cloneFrom here, because if we make cloneFrom
640 // copy selects one day, it is definitely not what you want here (selects
641 // have different types).
642 // TODO: Perhaps we should use cloneFrom now, as it seems unlikely
643 // to copy select nodes now that we have refactored to have a Value
644 // distinct from Node.
645 bchunk->addInput(input);
646 chunked_inputs.emplace_back(); // alas, to not be C++17
647
648 // properly compute strides for BroadcastingChunk
649 //
650 // We copy stride of each dimension from input to output for
651 // BroadcastingChunk. A note is that Chunk should not alter strides,
652 // However, broadcasted dimension should have a stride 0. We could have
653 // broadcasting happening on existing dimensions in input (case1), as well
654 // as extended dimension that does not exist in input (case2).
655 // e.g.
656 // If we look at an input tensor t0 with shape [3, 1] broadcasted to
657 // output tensor t1 with shape [4, 1, 3, 3],
658 // We set stride to zero in case of broadcast, which could happen in:
659 // case1: t1.dim[3] (broadcasted as in the description above)
660 // case2: t1.dim[0] (broadcasted implicitly)
661 std::vector<int64_t> strides;
662 auto input_type = input->type()->cast<TensorType>();
663 auto input_sizes = input_type->sizes();
664 auto input_strides = input_type->strides();
665 if (producer_output_sizes.isComplete() && input_sizes.isComplete() &&
666 input_strides.isComplete()) {
667 auto input_c_sizes = input_sizes.concrete_sizes().value();
668 auto input_c_strides = input_strides.concrete_sizes().value();
669 auto output_c_sizes = producer_output_sizes.concrete_sizes().value();
670 int output_index = int(output_c_sizes.size()) - 1;
671 strides.resize(output_index + 1);
672 AT_ASSERT(output_index >= int(input_c_sizes.size()) - 1);
673 for (int input_index = int(input_c_sizes.size()) - 1; input_index >= 0;
674 input_index--, output_index--) {
675 // in braodcast case 1, we set stride to 0;
676 // otherwise, stride remain the same.
677 if (input_c_sizes[input_index] == 1 &&
678 output_c_sizes[output_index] != 1) {
679 strides[output_index] = 0;
680 } else {
681 strides[output_index] = input_c_strides[input_index];
682 }
683 }
684
685 // continue on expanding dimensions to set stride to 0 for case2
686 while (output_index >= 0) {
687 strides[output_index] =
688 output_c_sizes[output_index] == 1 ? strides[output_index + 1] : 0;
689 output_index--;
690 }
691 }
692
693 for (auto chunk_sel : producer_chunk_outputs) {
694 Value* input_chunk_sel = bchunk->addOutput();
695 auto chunk_sel_type = chunk_sel->type()->cast<TensorType>();
696 if (strides.empty() || !chunk_sel_type->sizes().isComplete()) {
697 input_chunk_sel->setType(chunk_sel_type);
698 } else {
699 input_chunk_sel->setType(chunk_sel_type->withSizesStrides(
700 chunk_sel_type->sizes().concrete_sizes().value(), strides));
701 }
702 chunked_inputs.back().push_back(input_chunk_sel);
703 }
704 }
705
706 // apply the op to each chunk of the chunked operands,
707 // and then rewrite the graph to use them!
708 for (auto chunk_sel : producer_chunk_outputs) {
709 auto original_inputs = producer_for_chunk_node->inputs();
710 Node* chunked_op =
711 block_->owningGraph()->create(producer_for_chunk_node->kind());
712 chunked_op->copyAttributes(*producer_for_chunk_node);
713 chunked_op->output()->setType(chunk_sel->type());
714 auto chunked_inputs_it = chunked_inputs.begin();
715 for (Value* original_input : original_inputs) {
716 if (original_input->type()->isSubtypeOf(*TensorType::get())) {
717 AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
718 chunked_op->addInput(
719 // NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
720 chunked_inputs_it->at(chunk_sel->offset() % nchunks));
721 ++chunked_inputs_it;
722 } else {
723 chunked_op->addInput(original_input);
724 }
725 }
726 bchunk->owningGraph()->insertNode(chunked_op);
727 chunk_sel->replaceAllUsesWith(chunked_op->output());
728 }
729
730 bchunk->removeInput(producer_index);
731 for (const auto _ : c10::irange(nchunks)) {
732 bchunk->eraseOutput(nchunks * producer_index);
733 }
734
735 // The output of producer_for_chunk_node could have been used in some
736 // aten::size operators, so we need to clean those up as well (we simply
737 // broadcast all its tensor inputs).
738 // We need to insert these early in the graph, i.e. immediately after
739 // the producer_for_chunk_node as we will have the _size_if_not_same
740 // that may be before the bchunk.
741 WithInsertPoint guard2(producer_for_chunk_node);
742 auto size_calc_uses = producer_for_chunk_node->output()->uses();
743 if (!size_calc_uses.empty()) {
744 auto tensor_inputs = filter(
745 producer_for_chunk_node->inputs(),
746 [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); });
747 auto tensor_sizes = fmap(tensor_inputs, [](Value* v) {
748 return v->owningGraph()->insert(aten::size, {v});
749 });
750 AT_ASSERT(!tensor_sizes.empty());
751 Value* output_size = tensor_sizes.size() == 1
752 ? tensor_sizes[0]
753 : broadcastSizes(tensor_sizes);
754 for (Use u : size_calc_uses) {
755 u.user->output()->replaceAllUsesWith(output_size);
756 u.user->destroy();
757 }
758 }
759 producer_for_chunk_node->destroy();
760 return true;
761 }
762
763 // returns where to continue scanning, and whether any fusion was made
764 std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
765 if (fuser::cuda::isFusibleCudaFusionGroup(consumer)) {
766 // handle inputs in reverse topological order as well...
767 // otherwise in f(a,a+b) it will appear a is used twice if we consider
768 // the f-a fusion before the f-(a+b) fusion first.
769 auto inputs = sortReverseTopological(consumer->inputs());
770 for (auto producer : inputs) {
771 if (tryToMoveChunk(consumer, producer)) {
772 // the chunk before this consumer was re-arranged to allow fusion,
773 // we scan this consumer again to perform the fusion
774 return std::make_pair(consumer->reverseIterator(), true);
775 }
776 if (getSingletonFusion() && consumer->kind() != kind_) {
777 consumer = createSingletonFusionGroup(consumer);
778 }
779 auto fusion_group = tryFuse(consumer, producer->node());
780 if (fusion_group) {
781 // after fusion, consumer moves into a FusionGroup, so inputs is no
782 // longer valid so we rescan the new FusionGroup for more fusions...
783 return std::make_pair(fusion_group.value()->reverseIterator(), true);
784 }
785
786 // horizontal fusion only applies on non-scalar tensor inputs
787 if (getHorizontalFusion() &&
788 producer->type()->isSubtypeOf(*TensorType::get()) &&
789 !is_cpu_scalar(*producer->type()->cast<TensorType>())) {
790 // fusing nodes sharing inputs, this could save memory bandwidth by
791 // reducing number of tensor read.
792 for (const auto& u : producer->uses()) {
793 // only merge nodes before consumer, since any sibling after
794 // consumer has already considered merging this consumer to them
795 // already.
796 if (u.user->isBefore(consumer)) {
797 auto fusion_group = tryFuse(consumer, u.user);
798 if (fusion_group) {
799 return std::make_pair(
800 fusion_group.value()->reverseIterator(), true);
801 }
802 }
803 }
804 }
805 }
806 }
807 return std::make_pair(++consumer->reverseIterator(), false);
808 }
809
810 void replaceIntermediateBroadcastingChunks() {
811 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
812 auto* node = *it;
813 ++it; // We might delete node, so increment the iterator now.
814 if (node->kind() != prim::BroadcastingChunk) {
815 continue;
816 }
817 auto* bchunk = node;
818 insertExplicitBroadcast(bchunk);
819
820 auto* graph = block_->owningGraph();
821 size_t nchunks = bchunk->i(attr::chunks);
822 WithInsertPoint guard(bchunk->next());
823
824 // Split the bchunk into bchunks.inputs().size() number of chunk nodes.
825 for (const auto input_offset : c10::irange(bchunk->inputs().size())) {
826 auto* input = bchunk->inputs().at(input_offset);
827
828 Node* new_chunk =
829 graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
830 new_chunk->copyAttributes(*bchunk);
831 for (const auto output_offset : c10::irange(nchunks)) {
832 auto new_output = new_chunk->addOutput();
833 auto old_output =
834 bchunk->outputs().at(input_offset * nchunks + output_offset);
835 new_output->copyMetadata(old_output);
836 old_output->replaceAllUsesWith(new_output);
837 }
838 }
839 bchunk->destroy();
840 }
841 }
842
843 bool usedInDtype(Value* v) {
844 const auto& uses = v->uses();
845 return std::any_of(uses.begin(), uses.end(), [](const Use& u) {
846 return u.user->matches("prim::dtype(Tensor a) -> int");
847 });
848 }
849
850 bool usedOnlyInDtypeAndSize(Value* v) {
851 const auto& uses = v->uses();
852 return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
853 return u.user->matches("prim::dtype(Tensor a) -> int") ||
854 u.user->matches("aten::size(Tensor self) -> int[]");
855 });
856 }
857
858 // Builds up expressions that compute shapes of all intermediates (and
859 // outputs) of the fusion group, based on the sizes of inputs. You should run
860 // DCE to remove those that you end up not using.
861 // TODO: Add shape support for view, reshape, unsqueeze, and squeeze
862 std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
863 WithInsertPoint insert_guard{fusion_group->next()};
864 std::unordered_map<Value*, Value*> shape_of;
865
866 Graph* graph = fusion_group->owningGraph();
867 auto subgraph = fusion_group->g(attr::Subgraph);
868
869 auto inputs = fusion_group->inputs();
870 auto sinputs = subgraph->inputs();
871 AT_ASSERT(inputs.size() == sinputs.size());
872 for (const auto i : c10::irange(inputs.size())) {
873 if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
874 auto sinput_value = graph->insert(aten::size, {inputs[i]});
875 shape_of[sinputs[i]] = sinput_value;
876 sinput_value->node()->moveBefore(fusion_group);
877 }
878 }
879
880 // When we have a guarantee that an output won't be removed, because it's
881 // used in expressions that don't involve size checks, we can use its size
882 // instead of computing a long chain of broadcasts, starting from the
883 // beginning of the kernel.
884 auto outputs = fusion_group->outputs();
885 auto soutputs = subgraph->outputs();
886 AT_ASSERT(outputs.size() == soutputs.size());
887 for (const auto i : c10::irange(outputs.size())) {
888 if (usedOnlyInDtypeAndSize(outputs[i]))
889 continue;
890 if (soutputs[i]->type()->isSubtypeOf(TensorType::get())) {
891 shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
892 }
893 }
894
895 // Place all the shape expressions for intermediates in fusion
896 // before the CudaFusionGroup
897 graph->setInsertPoint(fusion_group);
898
899 // hmmm, do I need to setInsertPoint...
900 const auto map_inputs = [&](Value* v) -> Value* {
901 // if constant ever has an input, it has to come from
902 // profile_ivalue dependency
903 if (v->node()->kind() == prim::Param &&
904 fusion_group->input(v->offset())->node()->kind() ==
905 prim::profile_ivalue) {
906 // we need to map it along profile_ivalue dependency
907 return fusion_group->input(v->offset());
908 } else {
909 throw std::runtime_error(
910 std::string("unexpected input from node") +
911 v->node()->kind().toDisplayString());
912 }
913 };
914
915 for (Node* n : subgraph->nodes()) {
916 // XXX: Use of shape_of.emplace is crucial to the output shape
917 // optimization!
918 if (n->kind() == prim::FusedConcat) {
919 // This is a bit more involved, because we have to account for the case
920 // when inputs have different shapes, but fortunately those tensors are
921 // always outputs, and so we can simply avoid replacing their queries,
922 // because it won't help us.
923 continue;
924 }
925 if (n->kind() == prim::Constant) {
926 continue;
927 }
928 if (n->kind() == prim::ConstantChunk) {
929 TORCH_INTERNAL_ASSERT(
930 shape_of.count(n->input()) > 0,
931 "buildShapeExpressions failed at accessing input shapes");
932 Node* sizes_node = graph->insertNode(
933 graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
934 sizes_node->i_(attr::dim, n->i(attr::dim));
935 sizes_node->i_(attr::chunks, n->i(attr::chunks));
936 Value* regular_size = sizes_node->outputs().at(0);
937 Value* last_size = sizes_node->outputs().at(1);
938 regular_size->setType(ListType::ofInts());
939 last_size->setType(ListType::ofInts());
940 auto outputs = n->outputs();
941 for (Value* o : outputs.slice(0, outputs.size() - 1)) {
942 shape_of.emplace(o, regular_size);
943 }
944 shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
945 continue;
946 }
947 // extended shape expression support to reduction operations
948 // TODO: `aten::sum` is too flexible, we should restrict for a better
949 // match
950 // TODO: Add python tests where we check for existing ops and their
951 // shape expression logic.
952 static std::unordered_set<Symbol> reduction_ops(
953 {aten::sum, aten::mean, aten::var, aten::std});
954 if (reduction_ops.find(n->kind()) != reduction_ops.end()) {
955 // TODO: expand support to wire non-constant inputs, this is currently
956 // blocked by profiling executor not capable of profiling scalar inputs.
957 TORCH_INTERNAL_ASSERT(
958 n->input(1)->node()->kind() == prim::Constant &&
959 n->input(2)->node()->kind() == prim::Constant,
960 "only supports reduction axes and keepdim being constant");
961
962 Node* in1_const = graph->createClone(n->input(1)->node(), map_inputs);
963 graph->insertNode(in1_const);
964 Node* in2_const = graph->createClone(n->input(2)->node(), map_inputs);
965 graph->insertNode(in2_const);
966
967 TORCH_INTERNAL_ASSERT(
968 shape_of.count(n->input(0)) > 0,
969 "buildShapeExpressions failed at accessing input shapes");
970 std::vector<Value*> inputs = {
971 shape_of.at(n->input(0)), in1_const->output(), in2_const->output()};
972 Node* size_node =
973 graph->insertNode(graph->create(prim::ReductionSizes, inputs, 1));
974 Value* size = size_node->output(0);
975 size->setType(ListType::ofInts());
976 shape_of.emplace(n->output(), size);
977 continue;
978 }
979 // TODO: output(1) & output(2) should also be marked
980 if (n->kind() == aten::native_layer_norm) {
981 TORCH_INTERNAL_ASSERT(
982 shape_of.count(n->input(0)) > 0,
983 "buildShapeExpressions failed at accessing input shapes");
984 shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
985 continue;
986 }
987 // TODO: output(1) & output(2) should also be marked
988 if (n->kind() == aten::native_layer_norm_backward) {
989 TORCH_INTERNAL_ASSERT(
990 shape_of.count(n->input(0)) > 0,
991 "buildShapeExpressions failed at accessing input shapes");
992 shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
993 if (shape_of.count(n->input(5)) > 0) {
994 shape_of.emplace(n->output(1), shape_of.at(n->input(5)));
995 }
996 if (shape_of.count(n->input(6)) > 0) {
997 shape_of.emplace(n->output(2), shape_of.at(n->input(6)));
998 }
999 continue;
1000 }
1001 // TODO: output(1) & output(2) should also be marked
1002 if (n->kind() == aten::native_batch_norm ||
1003 n->kind() == aten::_batch_norm_impl_index) {
1004 TORCH_INTERNAL_ASSERT(
1005 shape_of.count(n->input(0)) > 0,
1006 "buildShapeExpressions failed at accessing input shapes");
1007 shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
1008 continue;
1009 }
1010 // TODO: output(1) & output(2) should also be marked
1011 if (n->kind() == aten::native_batch_norm_backward) {
1012 TORCH_INTERNAL_ASSERT(
1013 shape_of.count(n->input(0)) > 0,
1014 "buildShapeExpressions failed at accessing input shapes");
1015 shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
1016 if (shape_of.count(n->input(2)) > 0) {
1017 shape_of.emplace(n->output(1), shape_of.at(n->input(2)));
1018 // use shape of weight here for grad_bias
1019 shape_of.emplace(n->output(2), shape_of.at(n->input(2)));
1020 }
1021 continue;
1022 }
1023 if (n->kind() == aten::_batch_norm_impl_index_backward) {
1024 TORCH_INTERNAL_ASSERT(
1025 shape_of.count(n->input(1)) > 0,
1026 "buildShapeExpressions failed at accessing input shapes");
1027 shape_of.emplace(n->output(0), shape_of.at(n->input(1)));
1028 if (shape_of.count(n->input(3)) > 0) {
1029 shape_of.emplace(n->output(1), shape_of.at(n->input(3)));
1030 // use shape of weight here for grad_bias
1031 shape_of.emplace(n->output(2), shape_of.at(n->input(3)));
1032 }
1033 continue;
1034 }
1035 if (n->kind() == aten::native_dropout) {
1036 TORCH_INTERNAL_ASSERT(
1037 shape_of.count(n->input(0)) > 0,
1038 "buildShapeExpressions failed at accessing input shapes");
1039 shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
1040 shape_of.emplace(n->output(1), shape_of.at(n->input(0)));
1041 continue;
1042 }
1043 if (n->kind() == prim::unsqueeze_copy) {
1044 TORCH_INTERNAL_ASSERT(
1045 shape_of.count(n->input(0)) > 0,
1046 "buildShapeExpressions failed at accessing input shapes");
1047 TORCH_INTERNAL_ASSERT(
1048 n->input(1)->node()->kind() == prim::Constant,
1049 "only supports unsqueeze axes being constant");
1050 Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs);
1051 graph->insertNode(dim_const);
1052 std::vector<Value*> inputs = {
1053 shape_of.at(n->input(0)), dim_const->output()};
1054 Node* size_node = graph->insertNode(graph->create(
1055 Symbol::fromQualString("prim::infer_unsqueeze_size"), inputs, 1));
1056 Value* size = size_node->output(0);
1057 size->setType(ListType::ofInts());
1058 shape_of.emplace(n->output(), size);
1059 continue;
1060 }
1061 if (n->kind() == prim::squeeze_copy) {
1062 TORCH_INTERNAL_ASSERT(
1063 shape_of.count(n->input(0)) > 0,
1064 "buildShapeExpressions failed at accessing input shapes");
1065 TORCH_INTERNAL_ASSERT(
1066 n->inputs().size() == 2 || n->inputs().size() == 1,
1067 "prim::squeeze_copy expects one or two inputs");
1068 std::vector<Value*> inputs = {shape_of.at(n->input(0))};
1069
1070 if (n->inputs().size() == 2) {
1071 TORCH_INTERNAL_ASSERT(
1072 n->input(1)->node()->kind() == prim::Constant,
1073 "only supports squeeze axes being constant");
1074 Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs);
1075 graph->insertNode(dim_const);
1076 inputs.push_back(dim_const->output());
1077 }
1078 Node* size_node = graph->insertNode(graph->create(
1079 Symbol::fromQualString("prim::infer_squeeze_size"), inputs, 1));
1080 Value* size = size_node->output(0);
1081 size->setType(ListType::ofInts());
1082 shape_of.emplace(n->output(), size);
1083 continue;
1084 }
1085
1086 auto tensor_inputs = filter(n->inputs(), [](Value* v) {
1087 return v->type()->isSubtypeOf(*TensorType::get());
1088 });
1089 auto shapes = fmap(tensor_inputs, [&](Value* v) {
1090 TORCH_INTERNAL_ASSERT(
1091 shape_of.count(v) > 0,
1092 "buildShapeExpressions failed at accessing input shapes");
1093 return shape_of.at(v);
1094 });
1095 AT_ASSERT(!shapes.empty());
1096 shape_of.emplace(
1097 n->output(0),
1098 shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
1099 }
1100 return shape_of;
1101 }
1102
1103 void removeOutputsUsedOnlyInSize(Node* fusion_group) {
1104 if (fusion_group->kind() != prim::CudaFusionGroup)
1105 return;
1106 auto subgraph = fusion_group->g(attr::Subgraph);
1107
1108 // TODO: failure in buildShapeExpressions should not break fusion execution,
1109 // we can add a try/catch here to bailout from removeOutputsUsedOnlyInSize.
1110 GRAPH_DEBUG("before build shape expression: ", *graph_);
1111 auto shape_map = buildShapeExpressions(fusion_group);
1112 fusion_value_to_runtime_shape_.insert(shape_map.begin(), shape_map.end());
1113 GRAPH_DEBUG("after build shape expression: ", *graph_);
1114
1115 auto outputs = fusion_group->outputs().vec();
1116 auto soutputs = subgraph->outputs().vec();
1117 // XXX: Iterating in this order is not only good for performance reasons!
1118 // It is also crucial for correctness (i has to reflect the current true
1119 // index of outputs[i])!
1120 for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
1121 auto output = outputs[i];
1122 auto soutput = soutputs[i];
1123 if (usedOnlyInDtypeAndSize(output) && shape_map.count(soutput) > 0) {
1124 bool has_dtype = usedInDtype(output);
1125 auto uses = output->uses();
1126 for (Use u : uses) {
1127 if (u.user->matches("aten::size(Tensor self) -> int[]")) {
1128 u.user->output()->replaceAllUsesWith(shape_map.at(soutput));
1129 u.user->destroy();
1130 } else if (u.user->matches("prim::dtype(Tensor a) -> int")) {
1131 continue;
1132 } else {
1133 AT_ASSERT(
1134 false,
1135 "unrecognized consumer should not trigger removeOutputsUsedOnlyInSize");
1136 }
1137 }
1138 // We only wipe the output when there's no more dtype consumer.
1139 // This is to be removed by `removeOutputUsedOnlyInDtype`
1140 if (!has_dtype) {
1141 fusion_group->eraseOutput(i);
1142 subgraph->eraseOutput(i);
1143 }
1144 }
1145 }
1146 GRAPH_DEBUG("after build shape expression and re-wiring: ", *graph_);
1147 }
1148
1149 void refreshAliasDb() {
1150 aliasDb_ = torch::make_unique<AliasDb>(graph_);
1151 }
1152
1153 void removeNoopBinaryOps(Block* block) {
1154 for (Node* node : block->nodes()) {
1155 for (Block* b : node->blocks()) {
1156 removeNoopBinaryOps(b);
1157 }
1158
1159 if (node->matches(
1160 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1161 /*const_inputs=*/{attr::alpha, attr::other}) ||
1162 node->matches(
1163 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1164 /*const_inputs=*/{attr::alpha, attr::other})) {
1165 // x + 0 == x - 0 == x
1166 // if either scalar input is a float, than removing this operator could
1167 // remove type promotion and affect semantics
1168 auto scalar_type =
1169 node->input(0)->type()->expectRef<TensorType>().scalarType();
1170 if (!scalar_type.has_value() ||
1171 !at::isFloatingType(scalar_type.value())) {
1172 auto inps = node->inputs();
1173 if (!inps.at(1)->type()->isSubtypeOf(IntType::get()) ||
1174 !inps.at(2)->type()->isSubtypeOf(IntType::get())) {
1175 continue;
1176 }
1177 }
1178
1179 if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 &&
1180 node->get<at::Scalar>(attr::other)->toDouble() == 0) {
1181 GRAPH_UPDATE(
1182 getHeader(node),
1183 " (x + 0 == x - 0 == x) is replaced with ",
1184 node->input(0)->debugName());
1185 node->output()->replaceAllUsesWith(node->input(0));
1186 }
1187 } else if (
1188 node->matches(
1189 "aten::mul(Tensor self, Scalar other) -> Tensor",
1190 /*const_inputs=*/attr::other) ||
1191 node->matches(
1192 "aten::div(Tensor self, Scalar other) -> Tensor",
1193 /*const_inputs=*/attr::other)) {
1194 // x * 1 == x / 1 == x
1195 // is the node is a division or other isn't an integer, than removing
1196 // this operator could remove type promotion and affect semantics
1197 auto scalar_type =
1198 node->input(0)->type()->expectRef<TensorType>().scalarType();
1199 if (!scalar_type.has_value() ||
1200 !at::isFloatingType(scalar_type.value())) {
1201 if (node->kind() == aten::div ||
1202 !node->input(1)->type()->isSubtypeOf(IntType::get())) {
1203 continue;
1204 }
1205 }
1206
1207 if (node->get<at::Scalar>(attr::other)->toDouble() == 1) {
1208 GRAPH_UPDATE(
1209 getHeader(node),
1210 " (x * 1 == x / 1 == x) is replaced with ",
1211 node->input(0)->debugName());
1212 node->output()->replaceAllUsesWith(node->input(0));
1213 }
1214 }
1215 }
1216 }
1217
1218 void optimizeFusedGraphs() {
1219 for (Node* node : block_->nodes()) {
1220 if (node->kind() != kind_) {
1221 continue;
1222 }
1223 auto subgraph = node->g(attr::Subgraph);
1224 GRAPH_DEBUG("before optimizing: ", *subgraph);
1225 removeNoopBinaryOps(subgraph->block());
1226 EliminateDeadCode(subgraph);
1227 EliminateCommonSubexpression(subgraph);
1228 ConstantPooling(subgraph);
1229 GRAPH_DEBUG("after optimizing: ", *subgraph);
1230 }
1231 }
1232
1233 void run() {
1234 // Run the pass until no changes are made.
1235 // This is necessary, because the algorithm can miss out on certain fusion
1236 // opportunities if ran only once. Consider this graph:
1237 //
1238 // %1 = f(...)
1239 // %2 = g(%1)
1240 // %3 = h(%1)
1241 // %4 = l(%3)
1242 // return (%4, %2)
1243 //
1244 // where f, g, h, l are simple map ops.
1245 // The first iteration will fuse %4 and %3, and see that %1 is an input, but
1246 // can't be fused, because it has a different use before the fusion group
1247 // in our topological ordering. Then, %2 will be considered, and fused with
1248 // %1. If we do another iteration, the algorithm will consider the fusion of
1249 // these two groups and fix the situation.
1250 bool any_changed = true;
1251 while (any_changed) {
1252 any_changed = false;
1253 refreshAliasDb();
1254 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1255 bool changed = false;
1256 std::tie(it, changed) = scanNode(*it);
1257 any_changed |= changed;
1258 }
1259 }
1260
1261 GRAPH_DEBUG("after scan and merge", *graph_);
1262 refreshAliasDb();
1263
1264 optimizeFusedGraphs();
1265
1266 // The graph fuser can add intermediate prim::BroadcastingChunk nodes.
1267 // Replace them with broadcasts + chunks.
1268 replaceIntermediateBroadcastingChunks();
1269
1270 // Fuse starting chunks into the group.
1271 // for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1272 // it = scanNodeForChunks(*it);
1273 //}
1274
1275 GRAPH_DEBUG("before removeOutputsUsedOnlyInSize", *graph_);
1276 // Remove outputs that have been added only because we need their size
1277 for (Node* n : block_->nodes()) {
1278 removeOutputsUsedOnlyInSize(n);
1279 }
1280 GRAPH_DEBUG("after removeOutputsUsedOnlyInSize", *graph_);
1281
1282 for (Node* node : block_->nodes()) {
1283 for (Block* sub_block : node->blocks()) {
1284 CudaGraphFuser sub_block_cfg(sub_block, graph_);
1285 sub_block_cfg.run();
1286 // Accumulate runtime shapes for all sub-blocks
1287 fusion_value_to_runtime_shape_.insert(
1288 sub_block_cfg.fusion_value_to_runtime_shape_.begin(),
1289 sub_block_cfg.fusion_value_to_runtime_shape_.end());
1290 }
1291 }
1292 }
1293};
1294
1295void removeCudaFusionPathForGuardNode(Node* n) {
1296 auto uses = n->output()->uses();
1297 TORCH_INTERNAL_ASSERT(
1298 uses.size() == 1,
1299 "CudaFusionGuard should only be used once by prim::If or prim::ListConstruct");
1300 Node* if_node = uses[0].user;
1301 if (if_node->kind() != prim::If) {
1302 TORCH_INTERNAL_ASSERT(
1303 if_node->kind() == prim::ListConstruct,
1304 "CudaFusionGuard is not used by neither prim::If or prim::ListConstruct");
1305 // break all inputs so producer prim::CudaFusionGuard can be removed later
1306 if_node->removeAllInputs();
1307 auto list_use = if_node->output()->uses();
1308 TORCH_INTERNAL_ASSERT(
1309 list_use.size() == 1 && list_use[0].user->kind() == aten::all,
1310 "prim::ListConstruct should only be used once by aten::all");
1311 auto all_use = list_use[0].user->output()->uses();
1312 TORCH_INTERNAL_ASSERT(
1313 all_use.size() == 1 && all_use[0].user->kind() == prim::If,
1314 "aten::all should only be used once by prim::If");
1315 if_node = all_use[0].user;
1316 }
1317
1318 auto fall_back_graph = if_node->blocks()[1];
1319 Node* fallback_node = nullptr;
1320 for (auto fb_n : fall_back_graph->nodes()) {
1321 TORCH_INTERNAL_ASSERT(
1322 fb_n->kind() == prim::FallbackGraph,
1323 "CudaFusionGuard fallback path should only have single fallback node");
1324 TORCH_INTERNAL_ASSERT(
1325 fallback_node == nullptr,
1326 "CudaFusionGuard fallback path should only have single fallback node");
1327 fallback_node = fb_n;
1328 }
1329
1330 TORCH_INTERNAL_ASSERT(
1331 fallback_node != nullptr,
1332 "CudaFusionGuard fallback path found no fallback node");
1333 fallback_node->moveBefore(n);
1334
1335 TORCH_INTERNAL_ASSERT(
1336 fallback_node->outputs().size() == if_node->outputs().size(),
1337 "CudaFusionGuard fallback should have same number of outputs as with nesting if block");
1338
1339 if_node->replaceAllUsesWith(fallback_node);
1340 if_node->destroy();
1341 n->destroy();
1342}
1343
1344bool missingCompleteTypes(const std::vector<TypePtr>& types) {
1345 for (const auto& type : types) {
1346 if (auto tensor_type = type->cast<TensorType>()) {
1347 // if we found one missing value, we know that we are not going to able to
1348 // generate a kernel, so we bail out;
1349 if (!tensor_type->device().has_value() ||
1350 !tensor_type->dim().has_value() ||
1351 !tensor_type->scalarType().has_value()) {
1352 return true;
1353 }
1354 }
1355 }
1356 return false;
1357}
1358
1359void removeFusionWithMissingProfilingInformation(Block* block) {
1360 FUSER_PERF_SCOPE("compileFusionRecursive");
1361 std::vector<Node*> removeCudaFusionNodes;
1362
1363 for (auto node : block->nodes()) {
1364 if (node->kind() == prim::CudaFusionGuard &&
1365 missingCompleteTypes(node->tys(attr::types))) {
1366 removeCudaFusionNodes.push_back(node);
1367 }
1368 for (auto sub_block : node->blocks()) {
1369 removeFusionWithMissingProfilingInformation(sub_block);
1370 }
1371 }
1372
1373 for (auto node : removeCudaFusionNodes) {
1374 removeCudaFusionPathForGuardNode(node);
1375 }
1376}
1377
1378void compileFusionRecursive(Block* block) {
1379 FUSER_PERF_SCOPE("compileFusionRecursive");
1380
1381 for (auto node : block->nodes()) {
1382 if (node->kind() == prim::CudaFusionGroup) {
1383 fuser::cuda::compileFusionGroup(node);
1384 }
1385 for (auto sub_block : node->blocks()) {
1386 compileFusionRecursive(sub_block);
1387 }
1388 }
1389}
1390
1391void PeepholeOptimizeShapeExpressions(Block* block) {
1392 FUSER_PERF_SCOPE("PeepholeOptimizeShapeExpressions");
1393
1394 auto nodes = block->nodes();
1395 for (auto it = nodes.begin(); it != nodes.end(); ++it) {
1396 Node* node = *it;
1397 for (Block* subblock : node->blocks()) {
1398 PeepholeOptimizeShapeExpressions(subblock);
1399 }
1400 if (node->kind() == prim::BroadcastSizes) {
1401 // Remove no-op broadcasts.
1402 if (node->inputs().size() == 1) {
1403 node->output()->replaceAllUsesWith(node->input());
1404 it.destroyCurrent();
1405 continue;
1406 }
1407 // Deduplicate inputs, but use their unique() values to ensure
1408 // this process only depends on the graph.
1409 std::map<size_t, Value*> unique_to_value;
1410 for (Value* input : node->inputs()) {
1411 unique_to_value.emplace(input->unique(), input);
1412 }
1413 if (unique_to_value.size() != node->inputs().size()) {
1414 std::vector<Value*> inputs;
1415 inputs.reserve(unique_to_value.size());
1416 for (auto& entry : unique_to_value) {
1417 inputs.push_back(entry.second);
1418 }
1419 if (inputs.size() == 1) {
1420 node->output()->replaceAllUsesWith(inputs[0]);
1421 } else {
1422 WithInsertPoint insert_guard{node};
1423 node->output()->replaceAllUsesWith(broadcastSizes(inputs));
1424 }
1425 it.destroyCurrent();
1426 --it; // Revisit the node with deduplicated inputs
1427 continue;
1428 }
1429 // Remove compose simple chains of broadcasts into a single node.
1430 const auto& uses = node->output()->uses();
1431 if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
1432 Node* user = uses[0].user;
1433 user->removeInput(uses[0].offset);
1434 // NB: we don't care about deduplication in here, as we will visit user
1435 // later.
1436 for (Value* i : node->inputs()) {
1437 user->addInput(i);
1438 }
1439 it.destroyCurrent();
1440 }
1441 }
1442 }
1443}
1444
1445// view_sizes_runtime is the profiled-ivalue argument for view-size.
1446// view_sizes_constant_list is the constant list recorded during profiling runs.
1447Value* guardView(
1448 Node* fusion,
1449 std::unordered_map<Value*, Value*>& fusion_value_to_runtime_size,
1450 Node* versioning_if,
1451 Node* view,
1452 Value* view_sizes_runtime) {
1453 // 1. Get self tensor sizes and view_sizes
1454 auto self_value = view->inputs().front();
1455 auto self_type = self_value->type()->cast<TensorType>();
1456 auto self_sizes_constant_list = getTensorSizes(self_type);
1457
1458 auto view_sizes_constant_list =
1459 constant_as<c10::List<int64_t>>(view->inputs().back());
1460 TORCH_INTERNAL_ASSERT(view_sizes_constant_list.has_value());
1461 std::vector<int64_t> view_sizes = view_sizes_constant_list->vec();
1462 // 2. Get constraints for self tensor and view_sizes
1463 auto constraints =
1464 analyzeViewConstraint(self_sizes_constant_list, view_sizes);
1465
1466 // 3. Add constraints as constant to graph
1467 auto full_constraints = fusion->owningGraph()->insertConstant(
1468 IValue(constraints.conglomerateString()));
1469 full_constraints->node()->moveBefore(versioning_if);
1470
1471 // 4. Create CudaFusionViewGuard using input tensor, profile_ivalue
1472 // for view_sizes list, and constraints
1473 TORCH_INTERNAL_ASSERT(
1474 fusion_value_to_runtime_size.find(self_value) !=
1475 fusion_value_to_runtime_size.end(),
1476 "Failed to find runtime size for fusion value:\t",
1477 self_value->node()->kind().toDisplayString());
1478 Node* viewcheck_node =
1479 fusion->owningGraph()
1480 ->create(
1481 c10::Symbol::fromQualString("prim::CudaFusionViewGuard"),
1482 {fusion_value_to_runtime_size.at(self_value),
1483 view_sizes_runtime,
1484 full_constraints},
1485 1)
1486 ->insertBefore(versioning_if);
1487 return viewcheck_node->output();
1488}
1489
1490//! [ Note -- CudaFusionGuard implementation ]
1491//!
1492//! shamelessly copying code from NNC (tensorexpr_fuser) with very little
1493//! modification, original code at:
1494//! `../../passes/tensorexpr_fuser.cpp:guardFusionGroup`
1495//!
1496//! Add prim::CudaFusionGuard node to ensure that accepted profiling information
1497//! is not violated at runtime.
1498//!
1499//! We replace a single
1500//!
1501//! outputs = prim::CudaFusionGroup[cache_id](inputs)
1502//!
1503//! with the following pattern:
1504//!
1505//! %1 : bool = prim::CudaFusionGuard[types=[...]](inputs)
1506//! outputs = prim::If(%1)
1507//! block0():
1508//! outputs = prim::CudaFusionGroup[cache_id](inputs)
1509//! -> (outputs)
1510//! block1():
1511//! %2 : Function = prim::Constant[name="fallback_function", fallback=1]()
1512//! otuputs = prim::CallFunction(%2, inputs)
1513//! -> (outputs)
1514//!
1515//! `prim::CudaFusionGuard` stores all profiled data type in attribute
1516//! `attr::types`.
1517//! At runtime, we check input tensors against our profiled data type and return
1518//! an output holds the result of the check (bool).
1519//! See [ Note -- type guard logic in CudaFusionGuard ]
1520//!
1521//! This ensures that `prim::CudaFusionGroup` only execute compatible inputs.
1522//! In case of check failure, execution goes through false block, which
1523//! recursively goes along another profiling / optimization iteration. (could be
1524//! tuned by `bailout_depth`)
1525//!
1526//! TODO: we also need to assert/check reduction axes and replace it with
1527//! constants in `CudaFusionGroup`
1528void guardFusionGroup(
1529 Node* fusion,
1530 std::unordered_map<Value*, Value*>& fusion_value_to_runtime_size) {
1531 // Fixup types of the subgraph inputs
1532 std::vector<TypePtr> guard_types;
1533 std::vector<Value*> tensor_inputs_to_check;
1534 std::set<size_t> profiled_ivalue_indices;
1535
1536 for (const auto index : c10::irange(fusion->inputs().size())) {
1537 Value* input = fusion->inputs()[index];
1538 if (input->type()->cast<TensorType>()) {
1539 // We only check inputs of the fusion group and expect NNC to infer
1540 // intermediates and outputs shapes
1541
1542 // note: modified from original implementation, we are guarding fusion
1543 // outputs
1544 if (input->node()->kind() == prim::Constant) {
1545 continue;
1546 }
1547 tensor_inputs_to_check.push_back(input);
1548 guard_types.push_back(input->type());
1549 } else if (input->node()->kind() == prim::profile_ivalue) {
1550 // Conditional constant from profiled_ivalue, should be guarded
1551 profiled_ivalue_indices.insert(index);
1552 }
1553 }
1554
1555 // insert the if block first;
1556 auto versioning_if =
1557 fusion->owningGraph()->create(prim::If, fusion->outputs().size());
1558 for (const auto idx : c10::irange(fusion->outputs().size())) {
1559 versioning_if->output(idx)->setType(fusion->output(idx)->type());
1560 fusion->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
1561 }
1562 auto true_block = versioning_if->addBlock();
1563 auto false_block = versioning_if->addBlock();
1564
1565 // insert typecheck_node;
1566 Node* typecheck_node =
1567 fusion->owningGraph()
1568 ->create(prim::CudaFusionGuard, tensor_inputs_to_check, 1)
1569 ->insertBefore(fusion);
1570 // fix output to BoolType
1571 typecheck_node->output()->setType(BoolType::get());
1572 Value* typecheck_result = typecheck_node->output();
1573 typecheck_node->tys_(attr::types, guard_types);
1574
1575 versioning_if->insertAfter(typecheck_node);
1576
1577 auto fusion_graph = fusion->g(attr::Subgraph);
1578 std::vector<Value*> check_flags = {};
1579
1580 // Fill in the false block. It should contain the unoptimized
1581 // copy of the fused subgraph, unless we have conditional constants from
1582 // profiled_ivalue;
1583 std::shared_ptr<Graph> fb_graph; // resource holder;
1584 // Restore the dependency for constant introduced by profiled_ivalue within
1585 // the graph.
1586 if (!profiled_ivalue_indices.empty()) {
1587 // This is necessary as it cleans up the fallback graph, which was copied
1588 // from subgraph, since the two graph would differ as we cannot use
1589 // conditional constant in fallback
1590
1591 // 1. RESTORE conditional constant dependency in fallback group;
1592 fb_graph = fusion_graph->copy();
1593 GRAPH_DEBUG("re-wiring fallback graph", *fb_graph);
1594
1595 for (const auto& offset : profiled_ivalue_indices) {
1596 auto val = fb_graph->inputs()[offset];
1597 auto uses = val->uses();
1598 // since we are updating use of val in the loop, we have to copy
1599 // val->uses() before hand.
1600 for (const auto& use : uses) {
1601 // re-wire inputs and remove conditional constant nodes;
1602 TORCH_INTERNAL_ASSERT(
1603 use.user->kind() == prim::Constant,
1604 "profile_ivalue at index: ",
1605 offset,
1606 " can only be used by conditional constant, instead got: ",
1607 use.user->kind().toDisplayString());
1608 use.user->output()->replaceAllUsesWith(val);
1609 use.user->destroy();
1610 }
1611 }
1612
1613 WithInsertPoint guard(false_block->return_node());
1614 const auto subgraph_outputs =
1615 insertGraph(*fusion->owningGraph(), *fb_graph, fusion->inputs());
1616 for (Value* output : subgraph_outputs) {
1617 false_block->registerOutput(output);
1618 }
1619 // types get copied to the fallback graph, so remove specializations before
1620 // replacing
1621 // TODO: this is not exposed here, I need to remove that before inserting
1622 // the graph
1623 // removeTensorTypeSpecializations(false_block);
1624 replaceBlockWithFallbackGraph(false_block, fusion->inputs());
1625
1626 // 2. REMOVE conditional constant dependency in fusion group
1627 size_t compensation = 0;
1628
1629 // get a constant true, which is used by `and` pattern later
1630 auto const_true = fusion->owningGraph()->insertConstant(IValue(true));
1631 const_true->node()->moveBefore(versioning_if);
1632
1633 for (const auto& original_offset : profiled_ivalue_indices) {
1634 size_t offset = original_offset - compensation;
1635
1636 // step a. handle fusion
1637 // remove inputs to fusion, and update check logic for fallback
1638 auto profiled_ival = fusion->input(offset)->node()->input();
1639 auto const_o = createConditionalConstant(fusion->input(offset)->node());
1640 TORCH_INTERNAL_ASSERT(
1641 const_o,
1642 "profile_ivalue node are expected to have profile information, at node: ",
1643 *fusion->input(offset)->node());
1644 const_o->node()->moveBefore(versioning_if);
1645 Value* ivalue_check = nullptr;
1646
1647 if (fusion->input(offset)->node()->hasAttribute(
1648 Symbol::attr("profiled_bool"))) {
1649 // aten::eq doesn't support comparison between two boolean
1650 auto xor_n = fusion->owningGraph()
1651 ->create(aten::__xor__, {profiled_ival, const_o}, 1)
1652 ->insertBefore(versioning_if);
1653 xor_n->output()->setType(BoolType::get());
1654 ivalue_check =
1655 fusion->owningGraph()
1656 ->create(aten::__xor__, {xor_n->output(), const_true}, 1)
1657 ->insertBefore(versioning_if)
1658 ->output();
1659 } else if (fusion->input(offset)->node()->hasAttribute(
1660 Symbol::attr("profiled_reduction_size"))) {
1661 // TODO(profile_size): check sizes here with special size comparison op
1662 // TORCH_INTERNAL_ASSERT(false, "not implemented yet");
1663 ivalue_check =
1664 fusion->owningGraph()
1665 ->create(
1666 c10::Symbol::fromQualString("prim::CudaFusionSizeEq"),
1667 {profiled_ival, const_o},
1668 1)
1669 ->insertBefore(versioning_if)
1670 ->output();
1671 } else if (fusion->input(offset)->node()->hasAttribute(
1672 Symbol::attr("profiled_view_size"))) {
1673 // TODO: Add support for dynamic split to view guard
1674
1675 // Path from profile-ivalue to prim::view_copy operation
1676 // profile-ivalue -> Constant -> CudaFusionGroup
1677 // Get argument position in CudaFusionGroup
1678 // Get argument in subgraph for CudaFusionGroup
1679 // CudaFusionGroup argument -> Constant List -> prim::view_copy
1680 auto subgraph_arg = fusion_graph->inputs()[offset];
1681 auto constant = subgraph_arg->uses().front().user->output();
1682
1683 TORCH_INTERNAL_ASSERT(!constant->uses().empty());
1684 auto view = constant->uses().front().user;
1685 TORCH_INTERNAL_ASSERT(
1686 view->kind() == prim::view_copy ||
1687 view->kind() == prim::reshape_copy);
1688
1689 ivalue_check = guardView(
1690 fusion,
1691 fusion_value_to_runtime_size,
1692 versioning_if,
1693 view,
1694 profiled_ival);
1695 } else if (fusion->input(offset)->node()->hasAttribute(
1696 Symbol::attr("profiled_ival"))) {
1697 ivalue_check =
1698 fusion->owningGraph()
1699 ->create(
1700 c10::Symbol::fromQualString("prim::CudaFusionIvalGuard"),
1701 {profiled_ival, const_o},
1702 1)
1703 ->insertBefore(versioning_if)
1704 ->output();
1705 } else {
1706 ivalue_check = fusion->owningGraph()
1707 ->create(aten::eq, {profiled_ival, const_o}, 1)
1708 ->insertBefore(versioning_if)
1709 ->output();
1710 }
1711 ivalue_check->setType(BoolType::get());
1712
1713 // aggregate flags;
1714 check_flags.emplace_back(ivalue_check);
1715
1716 // remove inputs to fusion;
1717 fusion->removeInput(offset);
1718
1719 // step b. remove the extra dependency inside fusion;
1720 for (const auto& use : fusion_graph->inputs()[offset]->uses()) {
1721 TORCH_INTERNAL_ASSERT(
1722 use.user->kind() == prim::Constant,
1723 "profile_ivalue at index: ",
1724 offset,
1725 " can only be used by conditional constant, instead got: ",
1726 use.user->kind().toDisplayString());
1727 use.user->removeAllInputs();
1728 }
1729 fusion_graph->eraseInput(offset);
1730 compensation++;
1731 }
1732 // update graph in fusion node
1733 fusion->g_(attr::Subgraph, fusion_graph);
1734 }
1735
1736 if (!check_flags.empty()) {
1737 // attaching output from CudaFusionGuard to profile ivalue checks
1738 check_flags.emplace_back(typecheck_result);
1739 auto graph = fusion->owningGraph();
1740 auto bool_list_node =
1741 graph->insertNode(graph->createList(BoolType::get(), check_flags));
1742 bool_list_node->moveBefore(versioning_if);
1743 Value* bool_list = bool_list_node->output();
1744 // new typecheck_result
1745 typecheck_result = graph->insert(aten::all, {bool_list});
1746 typecheck_result->node()->moveBefore(versioning_if);
1747 }
1748
1749 if (profiled_ivalue_indices.empty()) {
1750 WithInsertPoint guard(false_block->return_node());
1751 const auto subgraph_outputs =
1752 insertGraph(*fusion->owningGraph(), *fusion_graph, fusion->inputs());
1753 for (Value* output : subgraph_outputs) {
1754 false_block->registerOutput(output);
1755 }
1756 // types get copied to the fallback graph, so remove specializations before
1757 // replacing
1758 // TODO: this is not exposed here, I need to remove that before inserting
1759 // the graph
1760 // removeTensorTypeSpecializations(false_block);
1761 replaceBlockWithFallbackGraph(false_block, fusion->inputs());
1762 }
1763
1764 // wiring up if block
1765 versioning_if->addInput(typecheck_result);
1766
1767 // Fill in the true block. It has all inputs type-checked and its
1768 // body should be the fusion group node.
1769 fusion->moveBefore(true_block->return_node());
1770 for (Value* output : fusion->outputs()) {
1771 true_block->registerOutput(output);
1772 }
1773}
1774
1775void guardFusionGroups(
1776 Block* block,
1777 std::unordered_map<Value*, Value*>& fusion_value_to_runtime_size) {
1778 std::vector<Node*> fusions;
1779 for (Node* n : block->nodes()) {
1780 for (Block* b : n->blocks()) {
1781 guardFusionGroups(b, fusion_value_to_runtime_size);
1782 }
1783 if (n->kind() == prim::CudaFusionGroup) {
1784 fusions.push_back(n);
1785 }
1786 }
1787 for (Node* fusion : fusions) {
1788 // step 1: a. add prim::CudaFusionGuard and fallback logic
1789 // b. insert guard logic of profile_ivalue with if block
1790 // c. restore conditional constant to non-constant for fallback
1791 guardFusionGroup(fusion, fusion_value_to_runtime_size);
1792 }
1793}
1794
1795void dumpFusionGroups(std::shared_ptr<Graph>& g) {
1796 DepthFirstGraphNodeIterator it(g);
1797 Node* n = nullptr;
1798 GRAPH_DEBUG("Exporting all NVFuser fusions:");
1799 while ((n = it.next()) != nullptr) {
1800 if (n->kind() == prim::FallbackGraph) {
1801 GRAPH_EXPORT("", n->g(attr::Subgraph));
1802 }
1803 }
1804}
1805
1806// rewire const integer index & empty byte-typed reserve space tensor outputs,
1807// so `CudaFusionGroup` doesn't have to handle those
1808void alterBatchNormImplIndex(Node* node) {
1809 std::set<size_t> bn_index_out_indices;
1810 std::set<size_t> bn_buffer_out_indices;
1811
1812 auto subgraph = node->g(attr::Subgraph);
1813 for (const auto i : c10::irange(subgraph->outputs().size())) {
1814 auto val = subgraph->outputs()[i];
1815 if (val->node()->kind() == aten::_batch_norm_impl_index &&
1816 val->offset() == 4) {
1817 bn_index_out_indices.emplace(i);
1818 } else if (
1819 val->node()->kind() == aten::_batch_norm_impl_index &&
1820 val->offset() == 3) {
1821 bn_buffer_out_indices.emplace(i);
1822 }
1823 }
1824
1825 if (!bn_index_out_indices.empty()) {
1826 // we output index to 0 so backwards go through native_batch_norm, which is
1827 // what we support;
1828 auto const_1 = node->owningGraph()->insertConstant(IValue(0));
1829 const_1->node()->moveBefore(node);
1830 for (auto i : bn_index_out_indices) {
1831 node->outputs()[i]->replaceAllUsesWith(const_1);
1832 }
1833 }
1834
1835 if (!bn_buffer_out_indices.empty()) {
1836 auto graph = node->owningGraph();
1837 std::vector<int64_t> sizes{0}; // empty tensor with no size;
1838 // std::vector<int64_t> sizes; // empty tensor with no size;
1839 auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes));
1840 const_size_0->node()->moveBefore(node);
1841 auto const_0 = node->owningGraph()->insertConstant(IValue(0));
1842 const_0->node()->moveBefore(node);
1843 auto none_val = node->owningGraph()->insertConstant(IValue());
1844 none_val->node()->moveBefore(node);
1845 auto device =
1846 graph->insertNode(graph->create(prim::device, {node->inputs()[0]}, 1));
1847 device->moveBefore(node);
1848 device->output()->setType(DeviceObjType::get());
1849 auto empty_tensor = graph->insertNode(graph->create(
1850 aten::empty,
1851 {const_size_0, const_0, none_val, device->output(), none_val, none_val},
1852 1));
1853 empty_tensor->moveBefore(node);
1854 for (auto i : bn_buffer_out_indices) {
1855 node->outputs()[i]->replaceAllUsesWith(empty_tensor->output());
1856 }
1857 }
1858
1859 bn_index_out_indices.insert(
1860 bn_buffer_out_indices.begin(), bn_buffer_out_indices.end());
1861 for (auto iter = bn_index_out_indices.crbegin();
1862 iter != bn_index_out_indices.crend();
1863 ++iter) {
1864 subgraph->eraseOutput(*iter);
1865 node->eraseOutput(*iter);
1866 }
1867}
1868
1869// rewire empty byte-typed reserve space tensor input to an empty float-typed
1870// tensor, because `CudaFusionGroup` doesn't support byte-typed tensor, nor does
1871// it use reserve space.
1872void alterBatchNormImplIndexBackward(Node* node) {
1873 std::set<size_t> bn_buffer_in_indices;
1874
1875 auto subgraph = node->g(attr::Subgraph);
1876 for (auto n : subgraph->nodes()) {
1877 if (n->kind() == aten::_batch_norm_impl_index_backward) {
1878 // 11th inputs are `reserve`, which is not used by codegen kernel and its
1879 // type is not supported `Byte`. So we disconnect it here to avoid codegen
1880 // error
1881 auto byte_input = n->inputs()[11];
1882 // TODO: let's check the data type for buffer and skip if it's good
1883 // TODO: we can actually support it by adding an extra inputs to the
1884 // subgraph
1885 // TODO: assert on empty buffer
1886 TORCH_INTERNAL_ASSERT(
1887 byte_input->node() == subgraph->param_node(),
1888 "Assumption that reserve input to aten::_batch_norm_impl_index_backward comes from forward graph is broken");
1889 bn_buffer_in_indices.emplace(byte_input->offset());
1890 }
1891 }
1892
1893 if (!bn_buffer_in_indices.empty()) {
1894 auto graph = node->owningGraph();
1895 std::vector<int64_t> sizes{0}; // empty tensor with no size;
1896 // std::vector<int64_t> sizes{}; // empty tensor with no size;
1897 auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes));
1898 const_size_0->node()->moveBefore(node);
1899 auto const_0 = node->owningGraph()->insertConstant(IValue(6));
1900 const_0->node()->moveBefore(node);
1901 auto none_val = node->owningGraph()->insertConstant(IValue());
1902 none_val->node()->moveBefore(node);
1903 auto device =
1904 graph->insertNode(graph->create(prim::device, {node->inputs()[1]}, 1));
1905 device->moveBefore(node);
1906 device->output()->setType(DeviceObjType::get());
1907 auto empty_tensor = graph->insertNode(graph->create(
1908 aten::empty,
1909 {const_size_0, const_0, none_val, device->output(), none_val, none_val},
1910 1));
1911 empty_tensor->moveBefore(node);
1912
1913 for (const auto& item : bn_buffer_in_indices) {
1914 subgraph->inputs()[item]->setType(
1915 node->inputs()[item]->type()->cast<TensorType>()->withScalarType(
1916 at::ScalarType::Float));
1917 node->replaceInput(item, empty_tensor->output());
1918 }
1919 }
1920}
1921
1922void alterBatchNormImpls(Block* block) {
1923 std::vector<Node*> fusions;
1924 for (Node* n : block->nodes()) {
1925 for (Block* b : n->blocks()) {
1926 alterBatchNormImpls(b);
1927 }
1928 if (n->kind() == prim::CudaFusionGroup) {
1929 fusions.push_back(n);
1930 }
1931 }
1932 for (Node* fusion : fusions) {
1933 // remove index & reserve from outputs;
1934 alterBatchNormImplIndex(fusion);
1935 // remove reserve from inputs;
1936 alterBatchNormImplIndexBackward(fusion);
1937 }
1938}
1939
1940// We absorb `prim::dtype` node into CudaFusion structure. The structure below
1941//
1942// %1 = prim::CudaFusionGuard(...)
1943// %2, %3 = prim::If(...)
1944// block0():
1945// %4, %5 = prim::CudaFusionGroup(...)
1946// -> (%4, %5)
1947// block1():
1948// %6, %7 = prim::FallbackGraph(...)
1949// -> (%6, %7)
1950// %4 = prim::dtype(%3)
1951// ... (uses %2, %4, but never reference to %3 any more)
1952//
1953// is updated to:
1954//
1955// %1 = prim::CudaFusionGuard(...)
1956// %2, %3 = prim::If(...)
1957// block0():
1958// %4 = prim::CudaFusionGroup(...) # %5 is also removed from subgraph
1959// %8 = prim::Constant[value=...]()
1960// -> (%4, %8)
1961// block1():
1962// %6, %7 = prim::FallbackGraph(...)
1963// %9 = prim::dtype(%7)
1964// -> (%6, %9)
1965// # %4 = prim::dtype(%3) is removed. All reference to %4 is replaced with %3
1966// ... (uses %2, %4, but never reference to %3 any more)
1967void removeOutputUsedOnlyInDtype(Node* fusion_node) {
1968 auto fusion_block = fusion_node->owningBlock();
1969 TORCH_INTERNAL_ASSERT(
1970 fusion_block->owningNode() &&
1971 fusion_block->owningNode()->kind() == prim::If,
1972 "CudaFusionGroup should be inside `prim::CudaFusionGuard` / `prim::If`");
1973
1974 auto if_node = fusion_block->owningNode();
1975 auto fusion_node_graph = fusion_node->g(attr::Subgraph);
1976 auto fallback_block = if_node->blocks()[1];
1977
1978 bool updated = false;
1979 // Iterating in this order is crucial for correctness (i has to reflect the
1980 // current true index of outputs[i])!
1981 for (int64_t i = static_cast<int64_t>(if_node->outputs().size()) - 1; i >= 0;
1982 --i) {
1983 auto output = if_node->outputs()[i];
1984 // output only used in dtype, we eliminate the output and rely on
1985 // profiled/static scalar type inference to save on memory IO.
1986 if (usedOnlyInDtype(output)) {
1987 updated = true;
1988 {
1989 // update fusion_block to output profiled scalar type
1990 auto fusion_output = fusion_block->outputs()[i];
1991 auto tensor_type = fusion_output->type()->cast<TensorType>();
1992 TORCH_INTERNAL_ASSERT(
1993 tensor_type, "non tensor fed to dtype is not supported");
1994 auto scalar_type = tensor_type->scalarType();
1995 TORCH_INTERNAL_ASSERT(
1996 scalar_type.has_value(),
1997 "ScalarType should be static for Tensors in fusion for amp optimization");
1998 auto type_const =
1999 fusion_block->owningGraph()->insertConstant(IValue(scalar_type));
2000 type_const->setType(IntType::get());
2001 type_const->node()->moveBefore(fusion_block->return_node());
2002 fusion_block->replaceOutput(i, type_const);
2003
2004 // removing the dangling output tensor from CudaFusionGroup would
2005 // require tracing output i from block to output j in CudaFusionGroup.
2006 // We choose to instead do that later by simply checking uses
2007 }
2008
2009 {
2010 // update fallback_block to output dtype instead of tensor
2011 auto tensor_output = fallback_block->outputs()[i];
2012 auto dtype_node = fallback_block->owningGraph()->create(
2013 prim::dtype, tensor_output, 1);
2014 dtype_node->output()->setType(IntType::get());
2015 fallback_block->appendNode(dtype_node);
2016 fallback_block->replaceOutput(i, dtype_node->output());
2017 }
2018
2019 // we just shot-cut the `dtype` node since we are already outputing dtype
2020 auto uses = output->uses();
2021 for (Use u : uses) {
2022 AT_ASSERT(u.user->matches("prim::dtype(Tensor a) -> int"));
2023 u.user->output()->replaceAllUsesWith(output);
2024 u.user->destroy();
2025 }
2026 output->setType(IntType::get());
2027 }
2028 }
2029
2030 if (updated) {
2031 // Remove fusion node output with no uses;
2032 for (int64_t i = static_cast<int64_t>(fusion_node->outputs().size()) - 1;
2033 i >= 0;
2034 --i) {
2035 if (fusion_node->output(i)->uses().empty()) {
2036 GRAPH_UPDATE(
2037 "removing output: ", i, " from fusion node: ", *fusion_node);
2038 fusion_node->eraseOutput(i);
2039 fusion_node_graph->eraseOutput(i);
2040 }
2041 }
2042
2043 fusion_node->g_(attr::Subgraph, fusion_node_graph);
2044 }
2045}
2046
2047// For output tensors in fusion group that is only used by dtype node, with
2048// CudaFusionGuard, we can short-cut it with constant dtype directly instead to
2049// save IO memory bandwidth.
2050// The reason that we do it after we insert the guard, instead of doing it along
2051// during graph fusion/partitioning, is that we needed to handle the fallback
2052// differently, since fallback is not inside CudaFusionGuard, and hence doesn't
2053// have the dtype as a constant.
2054void removeOutputUsedOnlyInDtype(Block* block) {
2055 std::vector<Node*> fusions;
2056 for (Node* n : block->nodes()) {
2057 for (Block* b : n->blocks()) {
2058 removeOutputUsedOnlyInDtype(b);
2059 }
2060 if (n->kind() == prim::CudaFusionGroup) {
2061 fusions.push_back(n);
2062 }
2063 }
2064 for (Node* fusion : fusions) {
2065 // remove index & reserve from outputs;
2066 removeOutputUsedOnlyInDtype(fusion);
2067 }
2068}
2069
2070void RemoveProfileIValue(Node* profile_ivalue) {
2071 for (const auto& use : profile_ivalue->output()->uses()) {
2072 if (use.user->kind() == prim::Constant) {
2073 use.user->output()->replaceAllUsesWith(profile_ivalue->input());
2074 use.user->destroy();
2075 }
2076 }
2077 profile_ivalue->output()->replaceAllUsesWith(profile_ivalue->input());
2078 profile_ivalue->destroy();
2079}
2080
2081void ExtractProfileIValue(Node* profile_ivalue) {
2082 auto const_o = createConditionalConstant(profile_ivalue);
2083 if (const_o) {
2084 auto const_n = const_o->node();
2085 const_n->moveAfter(profile_ivalue);
2086 profile_ivalue->output()->replaceAllUsesAfterNodeWith(const_n, const_o);
2087 // special wiring, we add this input to constant simply in order to create
2088 // dependency, which we can trace and remove later;
2089 const_n->addInput(profile_ivalue->output());
2090 } else {
2091 // no profile value available, remove profile_ivalue node;
2092 RemoveProfileIValue(profile_ivalue);
2093 }
2094}
2095
2096// break `linear` layer into `matmul` and `add_optional`. This allows us to fuse
2097// the binary operation without supporting gemm.
2098// Note that we are not breaking `linear` layer without bias.
2099void decomposeLinearOps(Block* block) {
2100 std::vector<Node*> linear_nodes;
2101 for (Node* n : block->nodes()) {
2102 for (Block* b : n->blocks()) {
2103 decomposeLinearOps(b);
2104 }
2105 // only decompose `linear` layer with bias
2106 if (n->kind() == aten::linear &&
2107 !n->input(2)->type()->isSubtypeOf(
2108 static_cast<c10::TypePtr>(NoneType::get()))) {
2109 linear_nodes.push_back(n);
2110 }
2111 }
2112
2113 auto graph = block->owningGraph();
2114 for (Node* n : linear_nodes) {
2115 WithInsertPoint guard(n);
2116 auto weight_t = graph->insertNode(graph->create(aten::t, {n->input(1)}, 1));
2117 auto matmul = graph->insertNode(
2118 graph->create(aten::matmul, {n->input(0), weight_t->output()}, 1));
2119 auto input_tensor_type = n->input(0)->type()->cast<c10::TensorType>();
2120 if (!input_tensor_type) {
2121 TORCH_WARN_ONCE(
2122 "linear input 0 is required to be tensor for linear decompose");
2123 continue;
2124 }
2125 auto mat0_size = input_tensor_type->sizes().concrete_sizes();
2126 auto mat1_size =
2127 n->input(1)->type()->cast<c10::TensorType>()->sizes().concrete_sizes();
2128
2129 // TODO: Continuing here is not necessary when we can handle matmul, right
2130 // now we are splitting the linear between matmul & bias_add. Our fuser can
2131 // only take the second half and we would need the size information.
2132 if (!mat0_size.has_value() || !mat1_size.has_value()) {
2133 TORCH_WARN_ONCE(
2134 "concrete shape for linear input & weight are required to decompose into matmul + bias");
2135 continue;
2136 }
2137
2138 // only decompose for input with nDims >= 4. since lower rank linear eager
2139 // is already fused
2140 if (mat0_size->size() < 4) {
2141 continue;
2142 }
2143
2144 auto out_size = mat0_size.value();
2145 TORCH_INTERNAL_ASSERT(
2146 mat1_size->size() == 2 || mat1_size->size() == 1,
2147 "weight dimension for linear is expected to be 1 or 2, but got: ",
2148 mat1_size->size());
2149 if (mat1_size->size() == 2) {
2150 out_size[out_size.size() - 1] = mat1_size.value()[0];
2151 } else if (mat1_size->size() == 1) {
2152 out_size.pop_back();
2153 }
2154 matmul->output()->setType(input_tensor_type->withSizes(out_size));
2155
2156 // TODO: memory stride should be considered here, our inference above is not
2157 // safe.
2158 auto bias = graph->insertNode(
2159 graph->create(prim::add_optional, {matmul->output(0), n->input(2)}, 1));
2160 bias->output()->setType(matmul->output(0)->type());
2161
2162 n->output()->replaceAllUsesWith(bias->output());
2163 n->destroy();
2164 }
2165}
2166
2167// Replace 'operation' with 'operation_copy' to guard alias operations.
2168// Supports View, Reshape, Squeeze, and Unsqueeze
2169void replaceAliasOpsWithCopy(std::shared_ptr<Graph>& graph, Block* block) {
2170 static std::unordered_map<Symbol, Symbol> alias_to_copy_mapping(
2171 {{aten::expand, prim::expand_copy},
2172 {aten::expand_as, prim::expand_as_copy},
2173 {aten::permute, prim::permute_copy},
2174 {aten::transpose, prim::transpose_copy},
2175 {aten::t, prim::t_copy}});
2176 // TODO: revert disabled aten::view
2177 // ({{aten::view, prim::view_copy},
2178 // {aten::reshape, prim::reshape_copy},
2179 // {aten::squeeze, prim::squeeze_copy},
2180 // {aten::unsqueeze, prim::unsqueeze_copy},
2181 // {aten::flatten, prim::flatten_copy}});
2182
2183 std::vector<Node*> maybe_safe_alias_nodes;
2184 for (Node* n : block->nodes()) {
2185 for (Block* b : n->blocks()) {
2186 replaceAliasOpsWithCopy(graph, b);
2187 }
2188 if (alias_to_copy_mapping.find(n->kind()) != alias_to_copy_mapping.end()) {
2189 maybe_safe_alias_nodes.push_back(n);
2190 }
2191 }
2192
2193 auto alias_db = std::make_unique<AliasDb>(graph);
2194
2195 auto safeToChangeAliasToCopy = [&alias_db](Node* n) {
2196 return !alias_db->hasWriters(n->input(0)) &&
2197 !alias_db->hasWriters(n->output(0));
2198 };
2199
2200 auto replaceAliasWithCopy = [&graph, &alias_db](Node* n) {
2201 WithInsertPoint guard(n);
2202 auto copy_op = graph->insertNode(
2203 graph->create(alias_to_copy_mapping[n->kind()], n->inputs(), 1));
2204 copy_op->output()->setType(n->output(0)->type());
2205
2206 // adding newly created value into alias_db;
2207 alias_db->createValue(copy_op->output());
2208
2209 n->output()->replaceAllUsesWith(copy_op->output());
2210 n->destroy();
2211 };
2212
2213 for (Node* n : maybe_safe_alias_nodes) {
2214 if (!safeToChangeAliasToCopy(n)) {
2215 continue;
2216 }
2217 replaceAliasWithCopy(n);
2218 }
2219}
2220
2221// Revert all 'operation_copy' with 'operation' except in CudaFusionGroup
2222// e.g., Any non-fused alias operation including within the prim::FallbackGraph
2223// Supports View, Reshape, Squeeze, and Unsqueeze
2224void revertAliasCopyOps(std::shared_ptr<Graph>& graph, Block* block) {
2225 static std::unordered_map<Symbol, Symbol> copy_to_alias_mapping(
2226 {{prim::expand_copy, aten::expand},
2227 {prim::expand_as_copy, aten::expand_as},
2228 {prim::permute_copy, aten::permute},
2229 {prim::transpose_copy, aten::transpose},
2230 {prim::t_copy, aten::t}});
2231 // TODO: revert disabled aten::view
2232 // ({{prim::view_copy, aten::view},
2233 // {prim::flatten_copy, aten::flatten},
2234 // {prim::reshape_copy, aten::reshape},
2235 // {prim::squeeze_copy, aten::squeeze},
2236 // {prim::unsqueeze_copy, aten::unsqueeze}});
2237
2238 std::vector<Node*> alias_copy_ops;
2239 for (Node* n : block->nodes()) {
2240 // Allow alias copy ops in CudaFusionGroup
2241 if (n->kind() == prim::CudaFusionGroup) {
2242 continue;
2243 }
2244 // Revert alias copy ops within FallbackGraph
2245 if (n->kind() == prim::FallbackGraph) {
2246 auto subgraph = n->g(attr::Subgraph);
2247 revertAliasCopyOps(subgraph, subgraph->block());
2248 }
2249 for (Block* b : n->blocks()) {
2250 revertAliasCopyOps(graph, b);
2251 }
2252 // Revert any non-fused alias copy ops
2253 if (copy_to_alias_mapping.find(n->kind()) != copy_to_alias_mapping.end()) {
2254 alias_copy_ops.push_back(n);
2255 }
2256 }
2257
2258 auto replaceCopyWithAlias = [&graph](Node* n) {
2259 WithInsertPoint guard(n);
2260 auto alias_op = graph->insertNode(
2261 graph->create(copy_to_alias_mapping[n->kind()], n->inputs(), 1));
2262 alias_op->output()->setType(n->output(0)->type());
2263 n->output()->replaceAllUsesWith(alias_op->output());
2264 n->destroy();
2265 };
2266
2267 for (Node* n : alias_copy_ops) {
2268 replaceCopyWithAlias(n);
2269 }
2270}
2271
2272// break `conv2d` layer into `conv2d` and `add_optional`. This allows us to fuse
2273// the binary operation without supporting gemm.
2274// Note that we are not breaking `conv2d` layer without bias.
2275void decomposeConvOps(Block* block) {
2276 std::vector<Node*> conv_nodes;
2277 for (Node* n : block->nodes()) {
2278 for (Block* b : n->blocks()) {
2279 decomposeConvOps(b);
2280 }
2281 // TODO: expand this to convXd
2282 // only decompose `conv2d` layer with bias.
2283 if (n->kind() == aten::conv2d &&
2284 n->input(2)->type()->isSubtypeOf(TensorType::get())) {
2285 conv_nodes.push_back(n);
2286 }
2287 }
2288
2289 auto graph = block->owningGraph();
2290 for (Node* n : conv_nodes) {
2291 // TODO: only handling conv2d at this moment, expand this to convXd
2292 WithInsertPoint guard(n);
2293
2294 auto const_neg_1 = n->owningGraph()->insertConstant(IValue(-1));
2295 auto const_none = n->owningGraph()->insertConstant(IValue());
2296
2297 auto bias_tensor_type = n->input(2)->type()->cast<c10::TensorType>();
2298 auto bias_size_opt = bias_tensor_type->sizes().concrete_sizes();
2299 if (!bias_size_opt.has_value()) {
2300 TORCH_WARN_ONCE(
2301 "concrete shape for bias input is required to decompose into conv + bias");
2302 continue;
2303 }
2304 // bias shape (C)
2305 auto bias_size = bias_size_opt.value();
2306
2307 auto tmp = graph->insertNode(
2308 graph->create(aten::unsqueeze, {n->input(2), const_neg_1}, 1));
2309 // new shape (C, 1)
2310 bias_size.emplace_back(1);
2311 tmp->output()->setType(bias_tensor_type->withSizes(bias_size));
2312
2313 auto unsqueezed_bias = graph->insertNode(
2314 graph->create(aten::unsqueeze, {tmp->output(), const_neg_1}, 1));
2315 // new shape (C, 1, 1)
2316 bias_size.emplace_back(1);
2317 unsqueezed_bias->output()->setType(bias_tensor_type->withSizes(bias_size));
2318
2319 // replace bias input to none
2320 n->replaceInput(2, const_none);
2321
2322 // add bias as a new node
2323 auto bias_n = graph->insertNode(graph->create(
2324 prim::add_optional, {n->output(0), unsqueezed_bias->output()}, 1));
2325 bias_n->output()->setType(n->output(0)->type());
2326 // moving add_optional after conv2d since it uses its output.
2327 bias_n->moveAfter(n);
2328
2329 // replace later uses
2330 n->output(0)->replaceAllUsesAfterNodeWith(bias_n, bias_n->output());
2331 }
2332}
2333
2334bool removeInplaceOperations(const std::shared_ptr<Graph>& graph) {
2335 // TODO: we should probably get a list that's close to what our fuser handles
2336 static std::unordered_set<Symbol> inplace_ops = []() {
2337 std::unordered_set<Symbol> target_ops;
2338 for (const auto& iter : activation_type_promotion_mapping) {
2339 std::string name = std::string(iter.first.toQualString()) + "_";
2340 target_ops.insert(Symbol::fromQualString(name));
2341 }
2342
2343 target_ops.insert(Symbol::fromQualString("aten::add_"));
2344 target_ops.insert(Symbol::fromQualString("aten::mul_"));
2345 target_ops.insert(Symbol::fromQualString("aten::div_"));
2346 target_ops.insert(Symbol::fromQualString("aten::sub_"));
2347 return target_ops;
2348 }();
2349
2350 return RemoveTensorMutation(
2351 graph, [&](Node* node) { return inplace_ops.count(node->kind()) != 0; });
2352}
2353
2354// Recursively traverse blocks, gather all nodes with given symbol,
2355// and then apply mutator function.
2356void mutateNode(
2357 Block* block,
2358 Symbol symbol,
2359 const std::function<void(Node*)>& func) {
2360 // Recursively call mutateNode on blocks
2361 // Gather all nodes with given symbol
2362 std::vector<Node*> nodes;
2363 for (Node* n : block->nodes()) {
2364 for (Block* b : n->blocks()) {
2365 mutateNode(b, symbol, func);
2366 }
2367 if (n->kind() == symbol) {
2368 nodes.push_back(n);
2369 }
2370 }
2371
2372 // Apply mutator funcion to every node
2373 for (Node* n : nodes) {
2374 func(n);
2375 }
2376}
2377
2378// For the given CudaFusionGroup, separate nested views and remove any unused,
2379// intermediate views
2380void separateNestedViews(Node* cuda_fusion_group) {
2381 TORCH_INTERNAL_ASSERT(cuda_fusion_group->kind() == prim::CudaFusionGroup);
2382
2383 auto isView = [](Node* node) {
2384 static std::unordered_set<Symbol> alias_op_set(
2385 {prim::view_copy, prim::reshape_copy});
2386 return alias_op_set.find(node->kind()) != alias_op_set.end();
2387 };
2388
2389 // node -> input / output values
2390 auto isNestedView = [&isView](Node* node) {
2391 return isView(node) && isView(node->input(0)->node());
2392 };
2393
2394 auto subgraph = cuda_fusion_group->g(attr::Subgraph);
2395 for (auto node : subgraph->block()->nodes()) {
2396 if (isNestedView(node)) {
2397 // grandparent -> (view / reshape) parent -> (view / reshape) node
2398 auto parent_value = node->input(0);
2399 auto parent = parent_value->node();
2400
2401 auto grandparent_value = parent->input(0);
2402 C10_UNUSED auto grandparent = grandparent_value->node();
2403
2404 // Before: gp -> x -> n
2405 // After: gp -> x / gp -> n
2406 // Delete x if no more uses
2407 node->replaceInputWith(parent_value, grandparent_value);
2408 if (!parent->hasUses()) {
2409 parent->destroy();
2410 }
2411 }
2412 }
2413}
2414
2415} // anonymous namespace
2416
2417void CudaFuseGraph(std::shared_ptr<Graph>& graph) {
2418 FUSER_PERF_SCOPE("nvFuser::Manager::CudaFuseGraph");
2419 GRAPH_DUMP("Before Fusion: ", graph);
2420
2421 // TODO: extract & guard profile_ivalue; but how do we restore it???
2422 // I don't know how to store edge/node in attribute. so let's abuse data flow
2423 // dependency and add inputs to conditional constant generated by
2424 // aten::profile_ivalue
2425 mutateNode(graph->block(), prim::profile_ivalue, ExtractProfileIValue);
2426 GRAPH_DEBUG("insert conditional constant from profile_ivalue: ", *graph);
2427
2428 // TODO: we need to properly restore shape information after fusion.
2429 // shamelessly use tool from NNC.
2430 RemoveProfileNodesAndSpecializeTypes(graph);
2431 GRAPH_DEBUG("After Profiling Nodes Removed: ", *graph);
2432
2433 // replace inplace operation to functional version to expose fusion
2434 // opportunities
2435 removeInplaceOperations(graph);
2436 GRAPH_DEBUG("Remove inplace operations: ", *graph);
2437
2438 // TODO: separate passes into different file;
2439 if (isOptionEnabled(EnableOption::LinearDecomposition)) {
2440 // TODO: restore decomposition after fusion, in case we are decomposing
2441 // operation that can't be fused;
2442 decomposeLinearOps(graph->block());
2443 }
2444 GRAPH_DEBUG("After decompose Linear Ops by nvfuser: ", *graph);
2445
2446 if (isOptionEnabled(EnableOption::ConvDecomposition)) {
2447 decomposeConvOps(graph->block());
2448 }
2449 GRAPH_DEBUG("After decompose decompose Conv Ops by nvfuser: ", *graph);
2450
2451 replaceAliasOpsWithCopy(graph, graph->block());
2452 GRAPH_DEBUG("replace alias_op with alias_copy by nvfuser: ", *graph);
2453
2454 CudaGraphFuser cgf(graph->block(), graph);
2455 cgf.run();
2456 GRAPH_DEBUG("After Fusion: ", *graph);
2457
2458 // guard input types as well as conditional constants from
2459 // aten::profile_ivalue
2460 guardFusionGroups(graph->block(), cgf.fusion_value_to_runtime_shape_);
2461 GRAPH_DEBUG("After Guard Fusion: ", *graph);
2462
2463 // mutate `aten::_batch_norm_impl_index` and
2464 // `aten::_batch_norm_impl_index_backward` node in the fusion group to WAR
2465 // the lack of fusion support on integer output as well as byte-typed tensor.
2466 alterBatchNormImpls(graph->block());
2467 GRAPH_DEBUG("After _batch_norm_impl_index: ", *graph);
2468
2469 mutateNode(graph->block(), prim::profile_ivalue, RemoveProfileIValue);
2470
2471 GRAPH_DEBUG("Before remove missing profiling: ", *graph);
2472 removeFusionWithMissingProfilingInformation(graph->block());
2473 GRAPH_DEBUG("After remove missing profiling: ", *graph);
2474
2475 // optimization targeting AMP
2476 removeOutputUsedOnlyInDtype(graph->block());
2477 GRAPH_DEBUG("After removeOutputUsedOnlyInDtype: ", *graph);
2478
2479 mutateNode(graph->block(), prim::CudaFusionGroup, separateNestedViews);
2480 GRAPH_DEBUG(
2481 "separate nested and delete redundant views in CudaFusionGroup:", *graph);
2482
2483 revertAliasCopyOps(graph, graph->block());
2484 GRAPH_DEBUG("revert alias_copy ops by nvfuser: ", *graph);
2485
2486 dumpFusionGroups(graph);
2487
2488 // After FuseGraph some common subexpressions may come back
2489 EliminateCommonSubexpression(graph);
2490 // We might have emitted a fair amount of useless shape propagating code, so
2491 // remove it
2492 EliminateDeadCode(graph);
2493
2494 GRAPH_DEBUG("After ECS & Dead code removal: ", *graph);
2495 // Improve the quality of shape propagation code that was left
2496 PeepholeOptimizeShapeExpressions(graph->block());
2497 GRAPH_DEBUG("After PeepholeOptimizeShapeExpressions: ", *graph);
2498
2499 // TODO: we need to properly restore shape information after fusion.
2500 // shamelessly use tool from NNC.
2501 RemoveTensorTypeSpecializations(graph);
2502
2503 GRAPH_DUMP("Before Compilation: ", graph);
2504 // Compile CudaFusionGroup
2505 compileFusionRecursive(graph->block());
2506}
2507
2508} // namespace cuda
2509} // namespace fuser
2510} // namespace jit
2511} // namespace torch
2512