1#include <arith.h>
2#include <codegen.h>
3#include <disjoint_set.h>
4#include <fusion.h>
5#include <fusion_segmenter.h>
6#include <instrumentation.h>
7#include <ir_all_nodes.h>
8#include <ir_cloner.h>
9#include <ir_printer.h>
10#include <ir_utils.h>
11#include <iter_visitor.h>
12#include <kernel.h>
13#include <lower2device.h>
14#include <lower_bank_conflict.h>
15
16namespace torch {
17namespace jit {
18namespace fuser {
19namespace cuda {
20
21static thread_local Fusion* ACTIVE_FUSION = nullptr; // NOLINT
22
23FusionGuard::FusionGuard(Fusion* fusion) {
24 prev_fusion = ACTIVE_FUSION;
25 ACTIVE_FUSION = fusion;
26}
27
28FusionGuard::~FusionGuard() {
29 ACTIVE_FUSION = prev_fusion;
30}
31
32Fusion* FusionGuard::getCurFusion() {
33 return ACTIVE_FUSION;
34}
35void FusionGuard::setCurFusion(Fusion* fusion) {
36 ACTIVE_FUSION = fusion;
37}
38
39void swap(Fusion& a, Fusion& b) noexcept {
40 FUSER_PERF_SCOPE("Fusion swap");
41
42 using std::swap;
43
44 swap(static_cast<IrContainer&>(a), static_cast<IrContainer&>(b));
45
46 swap(a.inputs_, b.inputs_);
47 swap(a.outputs_, b.outputs_);
48
49 swap(a.io_alias_, b.io_alias_);
50 swap(a.permuted_input_map_, b.permuted_input_map_);
51 swap(a.permuted_output_map_, b.permuted_output_map_);
52}
53
54std::unique_ptr<SegmentedFusion> Fusion::segment(
55 const KernelArgumentHolder& args) {
56 FUSER_PERF_SCOPE("Segment Fusion");
57 return SegmentCandidateFinder::segment(this, args);
58}
59
60IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
61 to->clear();
62 auto ir_cloner = IrContainer::copy(from, to);
63
64 for (auto val : from->vals_) {
65 ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
66 ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
67 }
68
69 to->inputs_ = ir_cloner.clone(from->inputs_);
70 to->outputs_ = ir_cloner.clone(from->outputs_);
71 for (auto inp : to->inputs_) {
72 inp->setIsFusionInput(true);
73 }
74 for (auto out : to->outputs_) {
75 out->setIsFusionOutput(true);
76 }
77
78 // TODO: put this into ir_cloner instead
79 for (const auto& entry : from->io_alias_) {
80 Val* copied_output = ir_cloner.clone(entry.first);
81 Val* copied_input = ir_cloner.clone(entry.second);
82 to->io_alias_[copied_output] = copied_input;
83 }
84
85 to->permuted_input_map_ = from->permuted_input_map_;
86 to->permuted_output_map_ = from->permuted_output_map_;
87
88 to->all_tv_uses_valid_ = from->all_tv_uses_valid_;
89 // This should never be true on copy, but copying for completeness.
90 to->is_during_update_uses_ = from->is_during_update_uses_;
91
92 return ir_cloner;
93}
94
95// Clang tidy complains when using default constructor for IrContainer instead
96// of copy constructor. Fusion::copy has a call to IrContainer::copy, so it's
97// redundant to use the IrContainer copy constructor, but it is harmless since
98// Fusion::copy starts by calling clear().
99Fusion::Fusion(const Fusion& other) : IrContainer(other) {
100 FUSER_PERF_SCOPE("Fusion copy");
101 Fusion::copy(&other, this);
102}
103
104Fusion::Fusion(Fusion&& other) noexcept {
105 FUSER_PERF_SCOPE("Fusion move");
106 swap(*this, other);
107}
108
109Fusion& Fusion::operator=(const Fusion& other) {
110 FUSER_PERF_SCOPE("Fusion copy assign");
111 Fusion copy(other);
112 clear();
113 swap(*this, copy);
114 return *this;
115}
116
117Fusion& Fusion::operator=(Fusion&& other) noexcept {
118 FUSER_PERF_SCOPE("Fusion move assign");
119 clear();
120 swap(*this, other);
121 return *this;
122}
123
124Fusion::~Fusion() {
125 clear();
126}
127
128void Fusion::clear() noexcept {
129 FUSER_PERF_SCOPE("Fusion clear");
130
131 IrContainer::clear();
132
133 inputs_.clear();
134 outputs_.clear();
135
136 io_alias_.clear();
137
138 permuted_input_map_.clear();
139 permuted_output_map_.clear();
140
141 all_tv_uses_valid_ = false;
142 is_during_update_uses_ = false;
143}
144
145void Fusion::removeExpr(Expr* expr) {
146 assertInContainer(expr, "Cannot remove expr ");
147 // If we hit this error too frequently, we could lighten the restrictions so
148 // that removing something that doesn't exist simply does nothing. For now,
149 // we're going with the strictest model which errors.
150
151 for (auto out : expr->outputs()) {
152 out->setDefinition(nullptr);
153 }
154
155 for (auto inp : expr->inputs()) {
156 auto uses_copy = inp->uses();
157 auto it = std::find(uses_copy.begin(), uses_copy.end(), expr);
158 if (it != uses_copy.end()) {
159 uses_copy.erase(it);
160 inp->setUses(uses_copy);
161 }
162 }
163
164 IrContainer::removeExpr(expr);
165}
166
167void Fusion::removeVal(Val* val) {
168 assertInContainer(val, "Cannot remove val ");
169
170 TORCH_CHECK(
171 !val->isFusionInput(),
172 "Cannot remove val as it is an input of the fusion.");
173 TORCH_CHECK(
174 !val->isFusionOutput(),
175 "Cannot remove val as it is an output of the fusion.");
176
177 Expr* orig = val->definition();
178 if (orig != nullptr)
179 removeExpr(val->definition());
180
181 for (Expr* use : unordered_uses(val)) {
182 removeExpr(use);
183 }
184 IrContainer::removeVal(val);
185}
186
187void Fusion::addInput(Val* input) {
188 assertInContainer(input, "Cannot register input ");
189
190 TORCH_INTERNAL_ASSERT(
191 input->getDataType() != DataType::Index,
192 "Data type Index is a local compile time data type only, it cannot be used as an input in case it was generated from another kernel.");
193
194 if (input->getValType().value() == ValType::TensorView) {
195 auto tv = input->as<TensorView>();
196 tv->setMemoryType(MemoryType::Global);
197 } else if (input->getValType().value() == ValType::Scalar) {
198 TORCH_CHECK(
199 !input->isConst(),
200 "Immediate scalar value cannot be added as an input. It is not necessary to pass it as an input.");
201 }
202
203 inputs_.push_back(input);
204 input->setIsFusionInput(true);
205
206 all_tv_uses_valid_ = false;
207}
208
209void Fusion::addOutput(Val* output) {
210 // We currently don't support explicitly outputing aliased inputs. This is
211 // because they are already marked as output for in-place update. It's tricky
212 // to allow marking them explicitly as real output, since that requires us to
213 // register/identify output not only by `Val*` pointer, but also by indices;
214 // it also requires us to magically arrange `outputs_` entries in proper order
215 // ^^^ this doesn't look intuitive on `outputs_` in fusion.
216 // I think we can solve this by marking addOutput on io_alias_ keys after
217 // fusion is fully defined. Tracking this in #1488
218 // Apparently we can't do this neither at the time. I think segmentation
219 // unfortunately would call addOutput after we marked io_alias_ map.
220 // TORCH_CHECK(io_alias_.count(output) == 0,
221 // "can't register aliased output as real output");
222
223 assertInContainer(output, "Cannot register output ");
224 if (output->getValType().value() == ValType::TensorView) {
225 auto tv = output->as<TensorView>();
226 tv->setMemoryType(MemoryType::Global);
227 }
228 outputs_.push_back(output);
229 output->setIsFusionOutput(true);
230
231 all_tv_uses_valid_ = false;
232}
233
234void Fusion::removeInput(Val* input) {
235 auto find_input = std::find(inputs_.begin(), inputs_.end(), input);
236 if (find_input != inputs_.end()) {
237 inputs_.erase(find_input);
238 }
239 input->setIsFusionInput(false);
240 all_tv_uses_valid_ = false;
241}
242
243void Fusion::removeOutput(Val* output) {
244 auto find_output = std::find(outputs_.begin(), outputs_.end(), output);
245 if (find_output != outputs_.end()) {
246 outputs_.erase(find_output);
247 }
248 output->setIsFusionOutput(false);
249 all_tv_uses_valid_ = false;
250}
251
252void Fusion::replaceOutput(Val* output, Val* replacement) {
253 auto find_output = std::find(outputs_.begin(), outputs_.end(), output);
254 TORCH_CHECK(find_output != outputs_.end(), "Unable to find output in Fusion");
255
256 if (find_output != outputs_.end()) {
257 std::replace_if(
258 outputs_.begin(),
259 outputs_.end(),
260 [&output](Val* v) { return v == output; },
261 replacement);
262
263 if (replacement->getValType().value() == ValType::TensorView) {
264 replacement->setIsFusionOutput(true);
265 replacement->as<TensorView>()->setMemoryType(MemoryType::Global);
266 }
267 if (output->getValType().value() == ValType::TensorView) {
268 output->setIsFusionOutput(false);
269 output->as<TensorView>()->setMemoryType(MemoryType::Local);
270 }
271 resetTvUses();
272 }
273
274 // Temporary WAR for issue #1112
275 // (https://github.com/csarofeen/pytorch/issues/1112)
276 if (io_alias_.count(output) != 0) {
277 auto input = io_alias_[output];
278 io_alias_.erase(output);
279 io_alias_[replacement] = input;
280 }
281}
282
283std::vector<Expr*> Fusion::exprs() {
284 return StmtSort::getExprs(this);
285}
286
287std::vector<Val*> Fusion::inputsOf(Val* val) {
288 return InputsOf::output(this, val);
289}
290
291void Fusion::validateInputs() {
292 std::unordered_set<Val*> all_inputs;
293 for (Val* out : outputs()) {
294 for (Val* input : inputsOf(out)) {
295 all_inputs.insert(input);
296 }
297 }
298
299 std::unordered_set<Val*> input_dims;
300 auto inp_tvs = ir_utils::filterByType<TensorView>(inputs());
301 for (auto tv : inp_tvs) {
302 for (auto id : tv->getMaybeRFactorDomain()) {
303 input_dims.emplace(id->extent());
304 }
305 }
306 for (Val* input : all_inputs) {
307 if (!input->isConstScalar()) {
308 TORCH_CHECK(
309 input->isFusionInput() ||
310 // TODO: Switch:
311 inContainer(input),
312 // to: input_dims.find(input) != input_dims.end(),
313 // https://github.com/csarofeen/pytorch/issues/1365
314 "Could not figure out how ",
315 input->toString(),
316 " is generated, however it was not specified as an input.");
317 }
318 }
319}
320
321void Fusion::print() {
322 FUSER_PERF_SCOPE("Fusion::print");
323
324 FusionGuard fg(this);
325 std::cout << "\n%kernel {\n";
326 IrMathPrinter op_exprs(std::cout);
327 op_exprs.handle(this);
328 std::cout << "\nTransformPrinter : \n";
329 IrTransformPrinter t_exprs(std::cout);
330 t_exprs.handle(this);
331 std::cout << "}\n\n";
332}
333
334void Fusion::printKernel(DataType index_type) {
335 FUSER_PERF_SCOPE("Fusion::printKernel");
336 TORCH_INTERNAL_ASSERT(
337 !this->isA<kir::Kernel>(),
338 "Cannot \"print kernel\" of a kernel container. ",
339 "This would require lowering during lowering.");
340 std::cout << codegen::generateCudaKernel(GpuLower(this, index_type).kernel());
341}
342
343std::unordered_map<std::string, std::pair<int, int>> Fusion::bankConflictInfo(
344 DataType index_type) {
345 GpuLower lower(this, index_type);
346 auto kernel = lower.kernel();
347 auto info = getBankConflictInfo(kernel);
348 // The container of exprs goes out of scope, so we return a map of string here
349 std::unordered_map<std::string, std::pair<int, int>> result;
350 result.reserve(info.size());
351 for (auto i : info) {
352 result[i.first->toString()] = i.second;
353 }
354 return result;
355}
356
357void Fusion::printMath(bool from_outputs_only) {
358 FUSER_PERF_SCOPE("Fusion::printMath");
359
360 FusionGuard fg(this);
361 auto exprs_for_print = exprs();
362 std::cout << "Inputs:" << std::endl;
363 for (auto inp : inputs()) {
364 std::cout << " " << inp << ", " << inp->getDataType().value() << std::endl;
365 }
366
367 std::cout << "Outputs:" << std::endl;
368 for (auto out : outputs()) {
369 std::cout << " " << out << ", " << out->getDataType().value() << std::endl;
370 }
371
372 // If we want everything in the fusion, grab all values without uses to
373 // traverse from.
374 if (!from_outputs_only) {
375 std::vector<Val*> leaf_vals;
376 for (auto val : deterministic_vals()) {
377 if (val->uses().empty()) {
378 leaf_vals.push_back(val);
379 }
380 }
381 exprs_for_print = StmtSort::getExprs(this, leaf_vals);
382 }
383
384 std::cout << "\n%kernel_math {\n";
385 for (auto expr : exprs_for_print) {
386 std::cout << expr;
387 }
388 std::cout << "}\n\n";
389}
390
391std::vector<Val*> Fusion::inputsAndCreated() {
392 auto result = inputs_;
393 for (auto expr : exprs()) {
394 auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
395 if (tv_inputs.empty()) {
396 for (auto v : expr->outputs()) {
397 result.emplace_back(v);
398 }
399 }
400 }
401 return result;
402}
403
404void Fusion::printTransforms() {
405 FUSER_PERF_SCOPE("Fusion::printTransforms");
406
407 FusionGuard fg(this);
408 IrTransformPrinter t_exprs(std::cout);
409 t_exprs.handle(this);
410}
411
412void Fusion::registerVal(Val* val) {
413 if (inContainer(val)) {
414 return;
415 }
416
417 if (val->fusion()) {
418 TORCH_CHECK(
419 val->fusion() == this, val, " was not found in the active fusion.");
420 }
421
422 IrContainer::registerVal(val);
423}
424
425void Fusion::registerExpr(Expr* expr) {
426 if (inContainer(expr)) {
427 return;
428 }
429
430 if (expr->fusion()) {
431 TORCH_CHECK(
432 expr->fusion() == this, expr, " was not found in the active fusion.");
433 }
434
435 IrContainer::registerExpr(expr);
436
437 bool has_tv = false;
438
439 for (Val* input : expr->inputs()) {
440 has_tv = has_tv || input->isA<TensorView>();
441 assertInContainer(input, "Input to expr is invalid, ");
442 auto uses_copy = input->uses();
443 if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
444 uses_copy.end()) {
445 uses_copy.push_back(expr);
446 input->setUses(uses_copy);
447 }
448 }
449
450 // Kernel is the only container type that is non-ssa. This is mainly (maybe
451 // only) because of initialization expressions which would overwrite tensor
452 // view definitions.
453 bool is_ssa = !this->isA<kir::Kernel>();
454
455 for (Val* output : expr->outputs()) {
456 has_tv = has_tv || output->isA<TensorView>();
457 assertInContainer(output, "Output to expr is invalid, ");
458 if (output->definition() != nullptr && is_ssa) {
459 removeExpr(output->definition());
460 }
461 if (is_ssa || (!is_ssa && output->definition() == nullptr)) {
462 output->setDefinition(expr);
463 }
464 }
465
466 if (has_tv) {
467 resetTvUses();
468 }
469}
470
471void Fusion::resetTvUses() {
472 FUSER_PERF_SCOPE("Fusion::resetTvUses");
473 is_during_update_uses_ = true;
474
475 // getExprs only uses definition, so even if we've modified uses already to
476 // remove dead exprs, this could reinsert them. getExprs is also boundeds by
477 // inputs as registered inputs will return nullptr as their definition.
478 const auto all_tvs = ir_utils::filterByType<TensorView>(vals_);
479 const auto used_exprs = StmtSort::getExprs(this);
480
481 for (auto tv : all_tvs) {
482 tv->setUses({});
483 }
484
485 // Same as in register expr
486 for (auto expr : used_exprs) {
487 for (Val* input : expr->inputs()) {
488 auto uses_copy = input->uses();
489 if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
490 uses_copy.end()) {
491 uses_copy.push_back(expr);
492 input->setUses(uses_copy);
493 }
494 }
495 }
496
497 all_tv_uses_valid_ = true;
498 is_during_update_uses_ = false;
499}
500
501std::vector<Val*> Fusion::usedMathVals() {
502 // Note that using fusion->inputs() as the argument for the first
503 // parameter of getAllValsBetween does not grab all used vals as
504 // there can be vals that are created inside a fusion without using
505 // anything from inputs. See, for example, tv0 in the
506 // FusionOuterSplit test.
507 const auto inputs = InputsOf::outputs(this, outputs());
508 auto used_math_vals = DependencyCheck::getAllValsBetween(
509 {inputs.begin(), inputs.end()}, outputs());
510 // When an expre has multiple outputs and only some of them are
511 // used, the rest aren't included in used_math_vals as they are not
512 // used. However, we want them to be included as they must show up
513 // in the fusion.
514 std::vector<Val*> vals_to_add;
515 std::unordered_set<Val*> added_vals;
516
517 for (auto val : used_math_vals) {
518 auto def = val->definition();
519 if (def == nullptr || def->outputs().size() < 2) {
520 continue;
521 }
522 for (auto out : def->outputs()) {
523 if (std::find(used_math_vals.begin(), used_math_vals.end(), out) ==
524 used_math_vals.end()) {
525 if (!added_vals.count(out)) {
526 vals_to_add.push_back(out);
527 added_vals.insert(out);
528 }
529 }
530 }
531 }
532
533 used_math_vals.insert(
534 used_math_vals.end(), vals_to_add.begin(), vals_to_add.end());
535
536 return used_math_vals;
537}
538
539std::vector<Val*> Fusion::terminatingMathVals() {
540 VectorOfUniqueEntries<Val*> result;
541 auto used_vals = usedMathVals();
542 for (auto v : used_vals) {
543 // Locate the vals that are not expr outputs but have valid definitions.
544 if (unordered_uses(v).empty() && v->definition() != nullptr) {
545 result.pushBack(v);
546 }
547 }
548 return result.vector();
549}
550
551std::unordered_set<Expr*> Fusion::unordered_uses(const Val* val) const {
552 return std::unordered_set<Expr*>(val->uses().begin(), val->uses().end());
553}
554
555Expr* Fusion::definition(const Val* val) const {
556 assertInContainer(val, "Cannot detect the definition of val, ");
557 return val->definition();
558}
559
560// Indicate to kernel to set itself up to generate random numbers
561bool Fusion::isStochastic() {
562 for (auto expr : exprs()) {
563 if (expr->getExprType() == ExprType::RNGOp) {
564 return true;
565 }
566 }
567 return false;
568}
569
570std::vector<Val*> Fusion::getTerminatingOutputs() const {
571 FUSER_PERF_SCOPE("getTerminatingOutputs");
572
573 auto is_reachable_to_output = [](Val* val) {
574 // traverse to consumers of val and see if there is an output
575 std::deque<Val*> consumers;
576 for (auto use : val->uses()) {
577 for (auto consumer : use->outputs()) {
578 consumers.push_back(consumer);
579 }
580 }
581 while (!consumers.empty()) {
582 auto consumer = consumers.back();
583 consumers.pop_back();
584 if (consumer->isFusionOutput()) {
585 return true;
586 }
587 // consumer is not an output; proceed to its consumers
588 for (auto use : consumer->uses()) {
589 for (auto consumer_of_consumer : use->outputs()) {
590 consumers.push_back(consumer_of_consumer);
591 }
592 }
593 }
594 return false;
595 };
596
597 std::vector<Val*> terminating_outputs;
598
599 for (auto out : outputs()) {
600 // If there is another output reachable from this output, it's not
601 // terminating.
602 if (is_reachable_to_output(out)) {
603 continue;
604 }
605 terminating_outputs.push_back(out);
606 }
607
608 return terminating_outputs;
609}
610
611bool Fusion::isAliasCompatible(Val* left, Val* right) {
612 // Nullptr check
613 if (left == nullptr || right == nullptr) {
614 return false;
615 }
616
617 // DataType check
618 if (!left->getDataType().has_value() || !right->getDataType().has_value() ||
619 left->getDataType().value() != right->getDataType().value()) {
620 return false;
621 }
622
623 // ValType check
624 if (!left->getValType().has_value() || !right->getValType().has_value() ||
625 left->getValType().value() != right->getValType().value()) {
626 return false;
627 }
628
629 // Check same number of dimensions if both values are TensorViews
630 if (ir_utils::isTV(left) && ir_utils::isTV(right)) {
631 return left->as<TensorView>()->nDims() == right->as<TensorView>()->nDims();
632 }
633 return false;
634}
635
636void Fusion::aliasOutputToInput(Val* output, Val* input) {
637 // Because we could cast output when input is cast.
638 TORCH_INTERNAL_ASSERT(
639 !output->isFusionOutput(),
640 "Do NOT add aliased output to fusion output outside of `aliasOutputToInput");
641
642 if (!input->isFusionInput()) {
643 auto input_expr = input->definition();
644 // TORCH_INTERNAL_ASSERT(input_def.etype() == ExprType::UnaryOp, "expected
645 // unary op for aliased input");
646 TORCH_INTERNAL_ASSERT(
647 input_expr->isA<UnaryOp>(), "expected unary op for aliased input");
648 auto input_uop = input_expr->as<UnaryOp>();
649 TORCH_INTERNAL_ASSERT(
650 input_uop->getUnaryOpType() == UnaryOpType::Cast,
651 "expected aliased input to be output of cast op");
652 input = input_uop->in();
653 }
654 TORCH_INTERNAL_ASSERT(
655 input->getDataType().has_value() && output->getDataType().has_value(),
656 "requires DataType to be available for aliased output to input");
657
658 if (input->getDataType().value() != output->getDataType().value()) {
659 output = castOp(input->getDataType().value(), output);
660 }
661 // TODO: output should be marked at the end of fusion definition #1488
662 addOutput(output);
663
664 TORCH_INTERNAL_ASSERT(
665 isAliasCompatible(input, output),
666 "The input and output values are not alias-compatible.");
667 io_alias_[output] = input;
668}
669
670Val* Fusion::getOutputAlias(Val* output) {
671 auto search = io_alias_.find(output);
672 if (search != io_alias_.end()) {
673 return search->second;
674 }
675 return nullptr;
676}
677
678std::unordered_set<int> Fusion::getOutputAliasIndices() const {
679 if (io_alias_.empty()) {
680 return {};
681 }
682
683 std::unordered_set<int> alias_indices;
684
685 for (const auto i : c10::irange(outputs_.size())) {
686 if (io_alias_.count(outputs_[i]) != 0) {
687 alias_indices.insert(i);
688 }
689 }
690 return alias_indices;
691}
692
693std::vector<std::pair<int, int>> Fusion::getInputAliasIndices() const {
694 if (io_alias_.empty()) {
695 return {};
696 }
697
698 std::vector<std::pair<int, int>> alias_indices;
699 for (const auto i : c10::irange(outputs_.size())) {
700 if (io_alias_.count(outputs_[i]) != 0) {
701 bool found = false;
702 for (const auto j : c10::irange(inputs_.size())) {
703 if (io_alias_.at(outputs_[i]) == inputs_[j]) {
704 alias_indices.emplace_back(i, j);
705 found = true;
706 break;
707 }
708 }
709 TORCH_INTERNAL_ASSERT(
710 found,
711 "io_alias_ mapping failure, alias output is not present in inputs");
712 }
713 }
714 // can't assert here, we could have segmented fusion where not all alias
715 // outputs are present
716
717 return alias_indices;
718}
719
720} // namespace cuda
721} // namespace fuser
722} // namespace jit
723} // namespace torch
724