1#include <arith.h>
2#include <fusion.h>
3#include <ir_builder.h>
4#include <ir_iostream.h>
5#include <ir_utils.h>
6#include <lower_utils.h>
7
8#include <set>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14namespace ir_utils {
15
16std::vector<int64_t> normalizeNew2Old(
17 const std::vector<int64_t>& new2old_in,
18 size_t ndims) {
19 TORCH_CHECK(
20 new2old_in.size() == ndims,
21 "There must be a transpose mapping for each dimension in domain");
22
23 // Canonicalize dimensions by wrapping each dim for the given ndims
24 std::vector<int64_t> new2old;
25 std::transform(
26 new2old_in.begin(),
27 new2old_in.end(),
28 std::inserter(new2old, new2old.begin()),
29 [ndims](int64_t entry) { return entry < 0 ? entry + ndims : entry; });
30
31 // Check if any adjusted values are < 0, or >= nDims, which are invalid
32 TORCH_CHECK(
33 std::none_of(
34 new2old.begin(),
35 new2old.end(),
36 [ndims](int64_t entry) {
37 return entry < 0 || (unsigned int)entry >= ndims;
38 }),
39 "New2Old axes are not within the number of dimensions of the provided domain.\t",
40 new2old);
41
42 // Going to use sets, to see if any duplicate values are in the map.
43 std::set<int64_t> old_pos_set;
44 std::transform(
45 new2old.begin(),
46 new2old.end(),
47 std::inserter(old_pos_set, old_pos_set.begin()),
48 [](int64_t entry) { return entry; });
49
50 // Error out if duplicate values are found.
51 TORCH_CHECK(
52 new2old.size() == ndims && old_pos_set.size() == new2old.size(),
53 "Duplicate entries in transformation map.");
54
55 // END VALIDATION CHECKS
56 return new2old;
57}
58
59std::vector<int> normalizeOld2New(
60 const std::unordered_map<int, int>& old2new_in,
61 size_t ndims) {
62 // adjust based on negative values (any negative values gets nDims added to
63 // it)
64 std::unordered_map<int, int> old2new;
65 std::transform(
66 old2new_in.begin(),
67 old2new_in.end(),
68 std::inserter(old2new, old2new.begin()),
69 [ndims](std::unordered_map<int, int>::value_type entry) {
70 return std::unordered_map<int, int>::value_type({
71 entry.first < 0 ? entry.first + ndims : entry.first,
72 entry.second < 0 ? entry.second + ndims : entry.second,
73 });
74 });
75
76 // Check if any adjusted values are < 0, or >= nDims, which are invalid
77
78 TORCH_CHECK(
79 std::none_of(
80 old2new.begin(),
81 old2new.end(),
82 [ndims](std::unordered_map<int, int>::value_type entry) {
83 return entry.first < 0 || (unsigned int)entry.first >= ndims ||
84 entry.second < 0 || (unsigned int)entry.second >= ndims;
85 }),
86 "Reorder axes are not within the number of dimensions of the provided domain.");
87
88 // Going to use sets, to see if any duplicate values are in the map.
89
90 std::set<int> old_pos_set;
91 std::transform(
92 old2new.begin(),
93 old2new.end(),
94 std::inserter(old_pos_set, old_pos_set.begin()),
95 [](std::unordered_map<int, int>::value_type entry) {
96 return entry.first;
97 });
98
99 std::set<int> new_pos_set;
100 std::transform(
101 old2new.begin(),
102 old2new.end(),
103 std::inserter(new_pos_set, new_pos_set.begin()),
104 [](std::unordered_map<int, int>::value_type entry) {
105 return entry.second;
106 });
107
108 // Error out if duplicate values are found.
109 TORCH_CHECK(
110 old_pos_set.size() == old2new.size() &&
111 new_pos_set.size() == old2new.size(),
112 "Duplicate entries in transformation map sent to TensorView reorder.");
113
114 // END VALIDATION CHECKS
115
116 std::vector<int> new2old(ndims, -1);
117
118 // Go through each old and new position, make sure they're within [0, ndims)
119 for (std::pair<int, int> elem : old2new) {
120 int old_pos = elem.first;
121 int new_pos = elem.second;
122 new2old[new_pos] = old_pos;
123 }
124
125 // old_positions that already have a new position
126 std::set<int> old_positions(new2old.begin(), new2old.end());
127 old_positions.erase(-1);
128
129 // All available new positions
130 std::set<int> all_positions;
131 for (decltype(ndims) i{0}; i < ndims; i++)
132 all_positions.insert(i);
133
134 // Check what positions haven't been specified.
135 std::set<int> positions_left;
136 std::set_difference(
137 all_positions.begin(),
138 all_positions.end(),
139 old_positions.begin(),
140 old_positions.end(),
141 std::inserter(positions_left, positions_left.end()));
142
143 // Fill in positions that weren't specified, in relative order,
144 // in empty spots in the set of new positions.
145 // new2old[new_position] = old_position
146 auto it = positions_left.begin(); // old positions left
147 std::transform(
148 new2old.begin(), new2old.end(), new2old.begin(), [&it](int i) -> int {
149 return i == -1 ? *it++ : i;
150 });
151
152 return new2old;
153}
154
155namespace ValReplacement {
156// Create New Expr given producer - [an input for the expression]
157// Creates a new Expr substituting current with producer
158struct SubstituteInExpr : public OptInDispatch {
159 public:
160 static Expr* subsitute(Expr* expr, Val* reference, Val* substitute) {
161 TORCH_INTERNAL_ASSERT(
162 expr != nullptr && reference != nullptr && substitute != nullptr,
163 "Nullptr arg found.");
164 SubstituteInExpr sie(reference, substitute);
165 sie.handle(expr);
166 TORCH_INTERNAL_ASSERT(
167 sie.expr_ != nullptr,
168 "Substitution failed of ",
169 reference,
170 " with ",
171 substitute);
172 return sie.expr_;
173 }
174
175 private:
176 explicit SubstituteInExpr(Val* reference, Val* substitute)
177 : reference_(reference), substitute_(substitute) {}
178
179 void handle(Expr* expr) final {
180 OptInDispatch::handle(expr);
181 }
182
183 void handle(FullOp* full_expr) final {
184 auto out = reference_->sameAs(full_expr->output(0)) ? substitute_
185 : full_expr->output(0);
186 expr_ = IrBuilder::create<FullOp>(
187 full_expr->container(),
188 out,
189 full_expr->getFillValue(),
190 full_expr->dtype());
191 }
192
193 void handle(ARangeOp* arange_expr) final {
194 auto start = reference_->sameAs(arange_expr->start())
195 ? substitute_
196 : arange_expr->start();
197 auto end = reference_->sameAs(arange_expr->end()) ? substitute_
198 : arange_expr->end();
199 auto step = reference_->sameAs(arange_expr->step()) ? substitute_
200 : arange_expr->step();
201 auto out = reference_->sameAs(arange_expr->output(0))
202 ? substitute_
203 : arange_expr->output(0);
204 expr_ = IrBuilder::create<ARangeOp>(
205 arange_expr->container(),
206 out,
207 start,
208 end,
209 step,
210 arange_expr->dtype(),
211 arange_expr->getLinearLogicalIndex());
212 }
213
214 void handle(EyeOp* eye_expr) final {
215 auto out = reference_->sameAs(eye_expr->output(0)) ? substitute_
216 : eye_expr->output(0);
217 expr_ = IrBuilder::create<EyeOp>(
218 eye_expr->container(),
219 out,
220 eye_expr->dtype(),
221 eye_expr->getIndex1(),
222 eye_expr->getIndex2());
223 }
224
225 void handle(UnaryOp* unary_expr) final {
226 auto in =
227 reference_->sameAs(unary_expr->in()) ? substitute_ : unary_expr->in();
228 auto out =
229 reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out();
230 expr_ = IrBuilder::create<UnaryOp>(
231 unary_expr->container(), unary_expr->getUnaryOpType(), out, in);
232 }
233
234 void handle(BinaryOp* binary_expr) final {
235 auto lhs = reference_->sameAs(binary_expr->lhs()) ? substitute_
236 : binary_expr->lhs();
237 auto rhs = reference_->sameAs(binary_expr->rhs()) ? substitute_
238 : binary_expr->rhs();
239 auto out = reference_->sameAs(binary_expr->out()) ? substitute_
240 : binary_expr->out();
241
242 expr_ = IrBuilder::create<BinaryOp>(
243 binary_expr->container(),
244 binary_expr->getBinaryOpType(),
245 out,
246 lhs,
247 rhs);
248 }
249
250 void handle(TernaryOp* ternary_expr) final {
251 auto in1 = reference_->sameAs(ternary_expr->in1()) ? substitute_
252 : ternary_expr->in1();
253 auto in2 = reference_->sameAs(ternary_expr->in2()) ? substitute_
254 : ternary_expr->in2();
255 auto in3 = reference_->sameAs(ternary_expr->in3()) ? substitute_
256 : ternary_expr->in3();
257 auto out = reference_->sameAs(ternary_expr->out()) ? substitute_
258 : ternary_expr->out();
259 expr_ = IrBuilder::create<TernaryOp>(
260 ternary_expr->container(),
261 ternary_expr->getTernaryOpType(),
262 out,
263 in1,
264 in2,
265 in3);
266 }
267
268 void handle(RNGOp* rng_expr) final {
269 std::vector<Val*> subsituted_params;
270 for (auto v : rng_expr->getParameters()) {
271 subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v);
272 }
273 auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_
274 : rng_expr->output(0);
275 expr_ = IrBuilder::create<RNGOp>(
276 rng_expr->container(),
277 rng_expr->getRNGOpType(),
278 out,
279 rng_expr->dtype(),
280 subsituted_params,
281 rng_expr->getRNGOffset(),
282 rng_expr->getPhiloxIndex());
283 }
284
285 void handle(ReductionOp* reduction_expr) final {
286 auto init = reference_->sameAs(reduction_expr->init())
287 ? substitute_
288 : reduction_expr->init();
289 auto out = reference_->sameAs(reduction_expr->out())
290 ? substitute_
291 : reduction_expr->out();
292 auto in = reference_->sameAs(reduction_expr->in()) ? substitute_
293 : reduction_expr->in();
294
295 expr_ = IrBuilder::create<ReductionOp>(
296 reduction_expr->container(),
297 reduction_expr->getReductionOpType(),
298 init,
299 out,
300 in);
301 }
302
303 void handle(GroupedReductionOp* grouped_reduction_expr) final {
304 std::vector<Val*> outputs;
305 std::transform(
306 grouped_reduction_expr->outputs().begin(),
307 grouped_reduction_expr->outputs().end(),
308 std::back_inserter(outputs),
309 [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; });
310
311 std::vector<Val*> inputs;
312 std::transform(
313 grouped_reduction_expr->inputs().begin(),
314 grouped_reduction_expr->inputs().end(),
315 std::back_inserter(inputs),
316 [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; });
317
318 std::vector<Val*> init_vals;
319 std::transform(
320 grouped_reduction_expr->initVals().begin(),
321 grouped_reduction_expr->initVals().end(),
322 std::back_inserter(init_vals),
323 [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; });
324
325 expr_ = IrBuilder::create<GroupedReductionOp>(
326 grouped_reduction_expr->container(),
327 grouped_reduction_expr->getReductionOpTypes(),
328 init_vals,
329 outputs,
330 inputs);
331 }
332
333 void handle(BroadcastOp* broadcast_expr) final {
334 auto out = reference_->sameAs(broadcast_expr->out())
335 ? substitute_
336 : broadcast_expr->out();
337 auto in = reference_->sameAs(broadcast_expr->in()) ? substitute_
338 : broadcast_expr->in();
339
340 expr_ = IrBuilder::create<BroadcastOp>(
341 broadcast_expr->container(),
342 out,
343 in,
344 broadcast_expr->getBroadcastDimFlags());
345 }
346
347 void handle(TransposeOp* transpose_expr) final {
348 TORCH_INTERNAL_ASSERT(
349 substitute_->isA<TensorView>(),
350 "All args to transpose must be tensor view, but received a non-TensorView for replacement: ",
351 substitute_);
352 auto out = reference_->sameAs(transpose_expr->out())
353 ? substitute_->as<TensorView>()
354 : transpose_expr->out();
355 auto in = reference_->sameAs(transpose_expr->in())
356 ? substitute_->as<TensorView>()
357 : transpose_expr->in();
358 expr_ = IrBuilder::create<TransposeOp>(
359 transpose_expr->container(), out, in, transpose_expr->new2old());
360 }
361
362 void handle(ExpandOp* expand_expr) final {
363 auto out = reference_->sameAs(expand_expr->out())
364 ? substitute_->as<TensorView>()
365 : expand_expr->out();
366 auto in = reference_->sameAs(expand_expr->in())
367 ? substitute_->as<TensorView>()
368 : expand_expr->in();
369
370 auto expanded_extents = expand_expr->expanded_extents();
371 if (substitute_->isA<Int>()) {
372 for (auto i : c10::irange(expanded_extents.size())) {
373 if (!expanded_extents[i]->sameAs(substitute_)) {
374 expanded_extents[i] = substitute_;
375 }
376 }
377 }
378 expr_ = IrBuilder::create<ExpandOp>(
379 expand_expr->container(), out, in, expanded_extents);
380 }
381
382 void handle(ShiftOp* shift_expr) final {
383 auto out =
384 reference_->sameAs(shift_expr->out()) ? substitute_ : shift_expr->out();
385 auto in =
386 reference_->sameAs(shift_expr->in()) ? substitute_ : shift_expr->in();
387
388 expr_ = IrBuilder::create<ShiftOp>(
389 shift_expr->container(),
390 out,
391 in,
392 shift_expr->offsets(),
393 shift_expr->padWidth());
394 }
395
396 void handle(GatherOp* gather_expr) final {
397 auto out = reference_->sameAs(gather_expr->out()) ? substitute_
398 : gather_expr->out();
399 auto in =
400 reference_->sameAs(gather_expr->in()) ? substitute_ : gather_expr->in();
401
402 expr_ = IrBuilder::create<GatherOp>(
403 gather_expr->container(),
404 out,
405 in,
406 gather_expr->windowShape(),
407 gather_expr->padWidth());
408 }
409
410 void handle(ViewAsScalar* expr) final {
411 TORCH_INTERNAL_ASSERT(
412 substitute_->isA<TensorView>(),
413 "All args to view must be TensorView, but received a non-TensorView for replacement: ",
414 substitute_);
415 auto in = reference_->sameAs(expr->in()) ? substitute_->as<TensorView>()
416 : expr->in();
417 auto out = reference_->sameAs(expr->out()) ? substitute_->as<TensorView>()
418 : expr->out();
419 expr_ = IrBuilder::create<ViewAsScalar>(
420 expr->container(), out, in, expr->vector_id(), expr->index());
421 }
422
423 void handle(ViewOp* view_expr) final {
424 TORCH_INTERNAL_ASSERT(
425 substitute_->isA<TensorView>(),
426 "All args to view must be TensorView, but received a non-TensorView for replacement: ",
427 substitute_);
428 auto in = reference_->sameAs(view_expr->in())
429 ? substitute_->as<TensorView>()
430 : view_expr->in();
431 auto out = reference_->sameAs(view_expr->out())
432 ? substitute_->as<TensorView>()
433 : view_expr->out();
434 expr_ = IrBuilder::create<ViewOp>(view_expr->container(), out, in);
435 }
436
437 void handle(WelfordOp* welford_expr) final {
438 auto out_avg = reference_->sameAs(welford_expr->outAvg())
439 ? substitute_->as<TensorView>()
440 : welford_expr->outAvg();
441 auto out_var = reference_->sameAs(welford_expr->outVar())
442 ? substitute_->as<TensorView>()
443 : welford_expr->outVar();
444 auto out_N = reference_->sameAs(welford_expr->outN())
445 ? substitute_->as<TensorView>()
446 : welford_expr->outN();
447 auto in_avg = reference_->sameAs(welford_expr->inAvg())
448 ? substitute_->as<TensorView>()
449 : welford_expr->inAvg();
450 auto in_var =
451 welford_expr->inVar() && reference_->sameAs(welford_expr->inVar())
452 ? substitute_->as<TensorView>()
453 : welford_expr->inVar();
454 auto in_N = reference_->sameAs(welford_expr->inN()) ? substitute_
455 : welford_expr->inN();
456 auto init_avg =
457 welford_expr->initAvg() && reference_->sameAs(welford_expr->initAvg())
458 ? substitute_->as<TensorView>()
459 : welford_expr->initAvg();
460 auto init_var =
461 welford_expr->initVar() && reference_->sameAs(welford_expr->initVar())
462 ? substitute_->as<TensorView>()
463 : welford_expr->initVar();
464 auto init_N =
465 welford_expr->initN() && reference_->sameAs(welford_expr->initN())
466 ? substitute_
467 : welford_expr->initN();
468 expr_ = IrBuilder::create<WelfordOp>(
469 welford_expr->container(),
470 out_avg,
471 out_var,
472 out_N,
473 in_avg,
474 in_var,
475 in_N,
476 init_avg,
477 init_var,
478 init_N,
479 welford_expr->isAllreduce());
480 }
481
482 void handle(LoadStoreOp* ldst_expr) final {
483 TORCH_INTERNAL_ASSERT(
484 substitute_->isA<TensorView>(),
485 "All args to view must be TensorView, but received a non-TensorView for replacement: ",
486 substitute_);
487 auto in = reference_->sameAs(ldst_expr->in())
488 ? substitute_->as<TensorView>()
489 : ldst_expr->in();
490 auto out = reference_->sameAs(ldst_expr->out())
491 ? substitute_->as<TensorView>()
492 : ldst_expr->out();
493 expr_ = IrBuilder::create<LoadStoreOp>(
494 ldst_expr->container(), ldst_expr->opType(), out, in);
495 }
496
497 void handle(MmaOp* mma_expr) final {
498 TORCH_INTERNAL_ASSERT(
499 substitute_->isA<TensorView>(),
500 "All args to MmaOp must be TensorView, but received a non-TensorView for replacement: ",
501 substitute_);
502 auto in_a = reference_->sameAs(mma_expr->inA())
503 ? substitute_->as<TensorView>()
504 : mma_expr->inA();
505 auto in_b = reference_->sameAs(mma_expr->inB())
506 ? substitute_->as<TensorView>()
507 : mma_expr->inB();
508 auto out = reference_->sameAs(mma_expr->out())
509 ? substitute_->as<TensorView>()
510 : mma_expr->out();
511 auto init = reference_->sameAs(mma_expr->init())
512 ? substitute_->as<TensorView>()
513 : mma_expr->init();
514 expr_ = IrBuilder::create<MmaOp>(
515 mma_expr->container(), out, in_a, in_b, init, mma_expr->options());
516 }
517
518 private:
519 Val* reference_ = nullptr;
520 Val* substitute_ = nullptr;
521 Expr* expr_ = nullptr;
522};
523
524} // namespace ValReplacement
525
526Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute) {
527 FusionGuard fg(expr->fusion());
528 return ValReplacement::SubstituteInExpr::subsitute(
529 expr, reference, substitute);
530}
531
532TensorView* rfactorHelper(
533 TensorView* reduction_tv,
534 const std::vector<int>& axes) {
535 TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr);
536 const bool has_multiple_tvs = reduction_tv->definition()->inputs().size() > 1;
537 if (!has_multiple_tvs) {
538 return reduction_tv->rFactor(axes);
539 }
540
541 std::vector<TensorView*> out_tvs;
542 std::transform(
543 reduction_tv->definition()->outputs().begin(),
544 reduction_tv->definition()->outputs().end(),
545 std::back_inserter(out_tvs),
546 [](Val* val) { return val->as<TensorView>(); });
547
548 auto rf_tvs = reduction_tv->rFactor(axes, out_tvs);
549
550 return rf_tvs.at(std::distance(
551 out_tvs.begin(),
552 std::find(out_tvs.begin(), out_tvs.end(), reduction_tv)));
553}
554
555namespace {
556
557template <typename T>
558std::vector<T*> uniqueEntries(const std::vector<T*>& tv_deuqe) {
559 std::vector<T*> unique_entries;
560 std::unordered_set<T*> inserted;
561 for (auto tv_entry : tv_deuqe) {
562 if (inserted.emplace(tv_entry).second) {
563 unique_entries.emplace_back(tv_entry);
564 }
565 }
566 return unique_entries;
567}
568
569} // namespace
570
571// Return immediate producers of val
572std::vector<Val*> producerValsOf(Val* val) {
573 if (val->definition() == nullptr) {
574 return {};
575 }
576 auto producer_vals = val->definition()->inputs();
577 return uniqueEntries<Val>({producer_vals.begin(), producer_vals.end()});
578}
579
580// Return immediate consumers of val
581std::vector<Val*> consumerValsOf(Val* val) {
582 std::vector<Val*> consumer_vals;
583 for (auto use_expr : val->uses()) {
584 auto outputs = use_expr->outputs();
585 consumer_vals.insert(consumer_vals.end(), outputs.begin(), outputs.end());
586 }
587 return uniqueEntries<Val>(consumer_vals);
588}
589
590// Return immediate siblings of val
591std::vector<Val*> siblingValsOf(Val* val) {
592 std::vector<Val*> sibling_vals;
593 auto def = val->definition();
594 if (def != nullptr) {
595 auto outs = def->outputs();
596 for (auto sibling_val : outs) {
597 if (sibling_val == val) {
598 continue;
599 }
600 sibling_vals.emplace_back(sibling_val);
601 }
602 }
603 return sibling_vals;
604}
605
606// Return immediate producers of val
607std::vector<Val*> producerValsOf(const std::vector<Val*>& vals) {
608 std::vector<Val*> all_producer_vals;
609 for (auto val : vals) {
610 auto producer_vals = producerValsOf(val);
611 all_producer_vals.insert(
612 all_producer_vals.end(), producer_vals.begin(), producer_vals.end());
613 }
614
615 return uniqueEntries<Val>(all_producer_vals);
616}
617
618// Return immediate consumers of val
619std::vector<Val*> consumerValsOf(const std::vector<Val*>& vals) {
620 std::vector<Val*> all_consumer_vals;
621 for (auto val : vals) {
622 auto consumer_vals = consumerValsOf(val);
623 all_consumer_vals.insert(
624 all_consumer_vals.end(), consumer_vals.begin(), consumer_vals.end());
625 }
626
627 return uniqueEntries<Val>(all_consumer_vals);
628}
629
630std::vector<TensorView*> producerTvsOf(TensorView* tv) {
631 auto producer_vals = producerValsOf(tv);
632 auto producer_tvs = ir_utils::filterByType<TensorView>(producer_vals);
633 return {producer_tvs.begin(), producer_tvs.end()};
634}
635
636std::vector<TensorView*> consumerTvsOf(TensorView* tv) {
637 auto consumer_vals = consumerValsOf(tv);
638 auto consumer_tvs = ir_utils::filterByType<TensorView>(consumer_vals);
639 return {consumer_tvs.begin(), consumer_tvs.end()};
640}
641
642std::vector<TensorView*> siblingTvsOf(TensorView* tv) {
643 auto sibling_vals = siblingValsOf(tv);
644 auto sibling_tvs = ir_utils::filterByType<TensorView>(sibling_vals);
645 return {sibling_tvs.begin(), sibling_tvs.end()};
646}
647
648std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs) {
649 std::vector<TensorView*> all_producer_tvs;
650 for (auto tv : tvs) {
651 auto producer_tvs = producerTvsOf(tv);
652 all_producer_tvs.insert(
653 all_producer_tvs.end(), producer_tvs.begin(), producer_tvs.end());
654 }
655
656 return uniqueEntries<TensorView>(all_producer_tvs);
657}
658
659std::vector<TensorView*> consumerTvsOf(const std::vector<TensorView*>& tvs) {
660 std::vector<TensorView*> all_consumer_tvs;
661 for (auto tv : tvs) {
662 auto consumer_tvs = consumerTvsOf(tv);
663 all_consumer_tvs.insert(
664 all_consumer_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
665 }
666
667 return uniqueEntries<TensorView>(all_consumer_tvs);
668}
669
670std::vector<TensorView*> inputTvsOf(TensorView* tv) {
671 return inputTvsOf(std::vector<TensorView*>{tv});
672}
673
674std::vector<TensorView*> outputTvsOf(TensorView* tv) {
675 return outputTvsOf(std::vector<TensorView*>{tv});
676}
677
678std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs) {
679 auto inp_vals = IterVisitor::getInputsTo({tvs.begin(), tvs.end()});
680 auto filtered = ir_utils::filterByType<TensorView>(inp_vals);
681 std::vector<TensorView*> inp_tvs(filtered.begin(), filtered.end());
682 return uniqueEntries<TensorView>(inp_tvs);
683}
684
685std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs) {
686 auto out_vals = DependencyCheck::getAllOutputsOf({tvs.begin(), tvs.end()});
687 auto filtered = ir_utils::filterByType<TensorView>(out_vals);
688 std::vector<TensorView*> out_tvs(filtered.begin(), filtered.end());
689 return uniqueEntries<TensorView>(out_tvs);
690}
691
692std::vector<TensorView*> allTvs(Fusion* fusion) {
693 auto used_vals = fusion->usedMathVals();
694 auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);
695
696 // This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing
697 // is having an input disconnected from outputs, and these iter domains are
698 // being checked in compute at maps in scheduling logic. This shouldn't hurt
699 // AFAICT.
700 auto tv_inputs = ir_utils::filterByType<TensorView>(fusion->inputs());
701
702 std::vector<TensorView*> all_tvs({used_tvs.begin(), used_tvs.end()});
703 // Sometimes inputs are not connected to outputs, however, we still include
704 // them when returning allTvs because they are registered as an input.
705 all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end());
706
707 // all_tvs has duplicates, to deduplicate it and return
708 return uniqueEntries<TensorView>(all_tvs);
709}
710
711std::vector<TensorView*> allTvsExcept(
712 Fusion* fusion,
713 const std::unordered_set<TensorView*>& except) {
714 auto all_tvs = allTvs(fusion);
715 std::vector<TensorView*> result;
716 for (auto tv : all_tvs) {
717 if (except.count(tv) == 0) {
718 result.emplace_back(tv);
719 }
720 }
721 return result;
722}
723
724std::vector<Expr*> getReductionOps(Fusion* fusion, bool ignore_trivial) {
725 std::vector<Expr*> red_ops;
726
727 auto isReduction = [&ignore_trivial](Val* out_val) {
728 if (out_val == nullptr || !out_val->isA<TensorView>()) {
729 return false;
730 }
731 auto out_tv = out_val->as<TensorView>();
732 return std::any_of(
733 out_tv->getRootDomain().begin(),
734 out_tv->getRootDomain().end(),
735 [&ignore_trivial](IterDomain* id) {
736 return id->isReduction() &&
737 !(ignore_trivial && id->isTrivialReduction());
738 });
739 };
740
741 for (auto expr : fusion->exprs()) {
742 bool is_reduction = false;
743 if (expr->isA<ReductionOp>()) {
744 is_reduction = isReduction(expr->as<ReductionOp>()->out());
745 } else if (expr->isA<GroupedReductionOp>()) {
746 is_reduction = std::any_of(
747 expr->as<GroupedReductionOp>()->outputs().begin(),
748 expr->as<GroupedReductionOp>()->outputs().end(),
749 isReduction);
750 } else if (expr->isA<WelfordOp>()) {
751 is_reduction = isReduction(expr->as<WelfordOp>()->outAvg());
752 }
753 if (is_reduction) {
754 red_ops.push_back(expr);
755 }
756 }
757
758 return red_ops;
759}
760
761namespace {
762
763class ValReplacementMutator : private OptOutMutator {
764 public:
765 ValReplacementMutator(
766 Fusion* fusion,
767 const std::unordered_map<Val*, Val*>& replacement_map)
768 : replacement_map_(replacement_map) {
769 FusionGuard fg(fusion);
770
771 // Welford makes this a little annoying since it holds a count which is
772 // typically not used by anything else. If we don't grab that count, then it
773 // would be a tensorview that doesn't get updated extents. Therefore, first
774 // grab all leaves towards outputs and grab stmts from there.
775 auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true);
776
777 // Some fusions, such as standalone rand_like, can have disconnected DAG, so
778 // we need some mechanism to make sure our replacement set is as complete as
779 // possible
780 // TODO: I think we need a more general mechanism to support disconnected
781 // DAG
782 std::vector<Val*> more;
783 for (auto v : fusion->inputs()) {
784 if (std::find(stmts.begin(), stmts.end(), v) == stmts.end()) {
785 more.emplace_back(v);
786 }
787 }
788 auto more_stmts = StmtSort::getStmts(fusion, more, true);
789 more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end());
790
791 for (auto stmt : more_stmts) {
792 mutate(stmt);
793 }
794 }
795
796 private:
797 using OptOutMutator::mutate;
798
799 void mutate(Val* val) final {
800 if (replacement_map_.find(val) == replacement_map_.end()) {
801 return OptOutMutator::mutate(val);
802 }
803 auto replaced_val = replacement_map_.at(val);
804 registerMutation(val, replaced_val);
805 }
806
807 std::vector<Val*> allLeafOuts(Fusion* fusion) {
808 auto exprs = StmtSort::getExprs(fusion, true);
809 std::unordered_set<Val*> inputs;
810 std::unordered_set<Val*> outputs;
811 std::vector<Val*> ordered_outputs;
812 for (auto expr : exprs) {
813 inputs.insert(expr->inputs().begin(), expr->inputs().end());
814 outputs.insert(expr->outputs().begin(), expr->outputs().end());
815 ordered_outputs.insert(
816 ordered_outputs.end(),
817 expr->outputs().begin(),
818 expr->outputs().end());
819 }
820 for (auto input : inputs) {
821 outputs.erase(input);
822 }
823
824 std::vector<Val*> ordered_leaf_outs;
825 for (auto out : ordered_outputs) {
826 if (outputs.find(out) != outputs.end()) {
827 ordered_leaf_outs.push_back(out);
828 }
829 }
830 return ordered_leaf_outs;
831 }
832
833 const std::unordered_map<Val*, Val*>& replacement_map_;
834};
835
836} // namespace
837
838void replaceValue(
839 Fusion* fusion,
840 const std::unordered_map<Val*, Val*>& replacement_map) {
841 ValReplacementMutator(fusion, replacement_map);
842}
843
844Val* getReductionInitValOf(TensorView* tv) {
845 auto def = tv->definition();
846 if (def == nullptr) {
847 return nullptr;
848 }
849
850 Val* init = nullptr;
851 if (auto rop = dynamic_cast<ReductionOp*>(def)) {
852 init = rop->init();
853 } else if (auto grop = dynamic_cast<GroupedReductionOp*>(def)) {
854 int output_idx = grop->getExprIndexOfOutput(tv);
855 init = grop->initVal(output_idx);
856 } else if (auto wop = dynamic_cast<WelfordOp*>(def)) {
857 return wop->getInitValOfOutput(tv);
858 } else if (auto gwop = dynamic_cast<GroupedWelfordOp*>(def)) {
859 init = gwop->getInitValOfOutput(tv);
860 } else if (auto mma = dynamic_cast<MmaOp*>(def)) {
861 init = mma->init();
862 }
863
864 return init;
865}
866
867// TODO: Should mma be in here? Should we return true if it's a trivial
868// reduction?
869bool isReductionOp(const Expr* expr) {
870 // Note that GridReduction inherits ReductionOp
871 return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
872 expr->isA<WelfordOp>() || expr->isA<GroupedWelfordOp>() ||
873 expr->isA<kir::GridWelford>() || expr->isA<kir::GroupedGridWelford>();
874}
875
876bool isReductionTvOp(const Expr* expr) {
877 return ir_utils::isTvOp(expr) && isReductionOp(expr);
878}
879
880std::vector<ViewOp*> getViewOps(Fusion* fusion) {
881 auto all_exprs = fusion->exprs();
882
883 auto all_view_ops = ir_utils::filterByType<ViewOp>(all_exprs);
884
885 std::vector<ViewOp*> view_ops;
886
887 std::copy_if(
888 all_view_ops.begin(),
889 all_view_ops.end(),
890 std::back_inserter(view_ops),
891 [](ViewOp* view) {
892 return std::any_of(
893 view->outputs().begin(), view->outputs().end(), [](Val* v) {
894 if (!v->isA<TensorView>()) {
895 return false;
896 }
897 return v->as<TensorView>()->hasRFactor();
898 });
899 });
900
901 return view_ops;
902}
903
904namespace {
905
906struct ReplaceValInIndexVal : public OptInDispatch {
907 public:
908 //! Apply replacements to index as specified in
909 //! replacement_map. index is assumed to consist only from Int and
910 //! NamedScalar
911 static Val* replace(
912 Val* index,
913 const std::unordered_map<Val*, Val*>& replacement_map) {
914 ReplaceValInIndexVal replace_index_val(replacement_map);
915 replace_index_val.handle(index);
916 // Return the original index if not replaced
917 if (replace_index_val.is_replaced_) {
918 return replace_index_val.last_visited_val_;
919 } else {
920 return index;
921 }
922 }
923
924 private:
925 ReplaceValInIndexVal(const std::unordered_map<Val*, Val*>& replacement_map)
926 : replacement_map_(replacement_map) {}
927
928 using OptOutDispatch::handle;
929
930 void handle(Val* val) override {
931 TORCH_INTERNAL_ASSERT(
932 val->isA<Int>() || val->isA<NamedScalar>() || val->isA<kir::IntPair>(),
933 "Invalid Val type: ",
934 val->toString());
935
936 // if val appears in the replacement map, stop traversing and set
937 // the current val with the replacement
938 auto it = replacement_map_.find(val);
939 if (it != replacement_map_.end()) {
940 last_visited_val_ = it->second;
941 is_replaced_ = true;
942 return;
943 }
944
945 // Recursively traverse its defining expr
946 auto def = val->definition();
947 if (def != nullptr) {
948 switch (def->etype()) {
949 case ExprType::UnaryOp:
950 case ExprType::BinaryOp:
951 case ExprType::Swizzle2DInt:
952 case ExprType::PairSelect:
953 handle(val->definition());
954 break;
955 default:
956 TORCH_INTERNAL_ASSERT(
957 false, "Unexpected definition: ", def->toString())
958 }
959 // last_visited_val_ is set in the expr handlers
960 } else {
961 last_visited_val_ = val;
962 }
963 }
964
965 // Clone expression after recurisvely replacing inputs
966 void handle(UnaryOp* uop) override {
967 handle(uop->in());
968 auto inp = last_visited_val_;
969 TORCH_INTERNAL_ASSERT(uop->out()->isA<Int>());
970 auto out = IrBuilder::create<Int>(c10::nullopt);
971 IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, inp);
972 last_visited_val_ = out;
973 }
974
975 // Clone expression after recurisvely replacing inputs
976 void handle(BinaryOp* bop) override {
977 handle(bop->lhs());
978 auto lhs = last_visited_val_;
979 handle(bop->rhs());
980 auto rhs = last_visited_val_;
981 TORCH_INTERNAL_ASSERT(bop->out()->isA<Int>());
982 auto out = IrBuilder::create<Int>(c10::nullopt);
983 IrBuilder::create<BinaryOp>(bop->getBinaryOpType(), out, lhs, rhs);
984 last_visited_val_ = out;
985 }
986
987 // Clone expression after recurisvely replacing inputs
988 void handle(kir::Swizzle2DInt* swizzle_2d) override {
989 handle(swizzle_2d->inX());
990 auto in_x = last_visited_val_;
991 handle(swizzle_2d->inY());
992 auto in_y = last_visited_val_;
993 auto out = IrBuilder::create<kir::IntPair>();
994
995 // Extents are assumed constant in swizzle so no need to
996 // duplicate their graphs.
997 IrBuilder::create<kir::Swizzle2DInt>(
998 out,
999 in_x,
1000 in_y,
1001 swizzle_2d->extentX(),
1002 swizzle_2d->extentY(),
1003 swizzle_2d->swizzleType());
1004 last_visited_val_ = out;
1005 }
1006
1007 void handle(kir::PairSelect* pair_select) override {
1008 handle(pair_select->in()->asVal());
1009 auto in = last_visited_val_;
1010 TORCH_INTERNAL_ASSERT(pair_select->out()->isA<Int>());
1011 auto out = IrBuilder::create<Int>(c10::nullopt);
1012 IrBuilder::create<kir::PairSelect>(
1013 out, in->as<kir::IntPair>(), pair_select->selection());
1014 last_visited_val_ = out;
1015 }
1016
1017 private:
1018 const std::unordered_map<Val*, Val*>& replacement_map_;
1019 Val* last_visited_val_ = nullptr;
1020 bool is_replaced_ = false;
1021};
1022
1023} // namespace
1024
1025Val* replaceValInIndexVal(
1026 Val* index,
1027 const std::unordered_map<Val*, Val*>& replacement_map) {
1028 return ReplaceValInIndexVal::replace(index, replacement_map);
1029}
1030
1031} // namespace ir_utils
1032} // namespace cuda
1033} // namespace fuser
1034} // namespace jit
1035} // namespace torch
1036