1#include <ir_builder.h>
2#include <kernel.h>
3#include <kernel_expr_evaluator.h>
4#include <kernel_ir.h>
5#include <lower2device.h>
6#include <lower_utils.h>
7#include <type.h>
8
9#include <iostream>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15namespace kir {
16
17Predicate::Predicate(
18 IrBuilderPasskey passkey,
19 PredicateType ptype,
20 const Expr* expr,
21 Bool* thread_pred)
22 : Val(passkey, ValType::Predicate, DataType::Bool),
23 ptype_(ptype),
24 expr_(expr),
25 thread_pred_(thread_pred) {
26 TORCH_INTERNAL_ASSERT(
27 passkey.ir_container_->isA<kir::Kernel>(),
28 "IR type only valid for Kernel container.");
29 TORCH_INTERNAL_ASSERT(
30 ptype != PredicateType::Unswitch && ptype != PredicateType::Manual);
31}
32
33Predicate::Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop)
34 : Val(passkey, ValType::Predicate, DataType::Bool),
35 ptype_(PredicateType::Unswitch),
36 unrolled_loop_(unrolled_loop) {
37 TORCH_INTERNAL_ASSERT(
38 passkey.ir_container_->isA<kir::Kernel>(),
39 "IR type only valid for Kernel container.");
40 TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr);
41}
42
43Predicate::Predicate(IrBuilderPasskey passkey, Bool* value)
44 : Val(passkey, ValType::Predicate, DataType::Bool),
45 ptype_(PredicateType::Manual),
46 value_(value) {
47 TORCH_INTERNAL_ASSERT(
48 passkey.ir_container_->isA<kir::Kernel>(),
49 "IR type only valid for Kernel container.");
50 TORCH_INTERNAL_ASSERT(value != nullptr);
51}
52
53TensorIndex::TensorIndex(
54 IrBuilderPasskey passkey,
55 const TensorView* view,
56 std::vector<Val*> indices)
57 : Val(passkey, ValType::TensorIndex, view->getDataType().value()),
58 view_(view),
59 indices_(indices) {
60 TORCH_INTERNAL_ASSERT(
61 passkey.ir_container_->isA<kir::Kernel>(),
62 "IR type only valid for Kernel container.");
63 TORCH_INTERNAL_ASSERT(
64 std::all_of(
65 indices.begin(),
66 indices.end(),
67 [](Val* v) { return v->dtype() == DataType::Int; }),
68 "Cannot index with a value other than an int.");
69 indices_.erase(
70 std::remove_if(
71 indices_.begin(),
72 indices_.end(),
73 [](Val* index) { return index->isZeroInt(); }),
74 indices_.end());
75 // If indices becomes empty, just put one ZeroInt
76 if (indices_.empty()) {
77 indices_.push_back(FusionGuard::getCurFusion()->zeroVal());
78 }
79}
80
81Val* TensorIndex::index(int i) const {
82 TORCH_INTERNAL_ASSERT(
83 nDims() > 0, "Tried to get an index of a 0-dim TensorIndex");
84 if (i < 0)
85 i += nDims();
86 TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims()));
87 return indices_[i];
88}
89
90BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync)
91 : Expr(passkey, ExprType::BlockSync), war_sync_(war_sync) {
92 TORCH_INTERNAL_ASSERT(
93 passkey.ir_container_->isA<kir::Kernel>(),
94 "IR type only valid for Kernel container.");
95}
96
97Expr* BlockSync::shallowCopy() const {
98 auto result = IrBuilder::create<BlockSync>(war_sync_);
99 result->copyPredicatesFrom(this);
100 return result;
101}
102
103GridSync::GridSync(
104 IrBuilderPasskey passkey,
105 ParallelTypeBitmap sync_dims,
106 Val* sync_buffer)
107 : Expr(passkey, ExprType::GridSync),
108 sync_dims_(sync_dims),
109 sync_buffer_(sync_buffer) {}
110
111Expr* GridSync::shallowCopy() const {
112 auto result = IrBuilder::create<GridSync>(sync_dims_, sync_buffer_);
113 result->copyPredicatesFrom(this);
114 return result;
115}
116
117CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages)
118 : Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) {
119 TORCH_INTERNAL_ASSERT(
120 passkey.ir_container_->isA<kir::Kernel>(),
121 "IR type only valid for Kernel container.");
122}
123
124Expr* CpAsyncWait::shallowCopy() const {
125 auto result = IrBuilder::create<CpAsyncWait>(keep_stages_);
126 result->copyPredicatesFrom(this);
127 return result;
128}
129
130CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey)
131 : Expr(passkey, ExprType::CpAsyncCommit) {
132 TORCH_INTERNAL_ASSERT(
133 passkey.ir_container_->isA<kir::Kernel>(),
134 "IR type only valid for Kernel container.");
135}
136
137Expr* CpAsyncCommit::shallowCopy() const {
138 auto result = IrBuilder::create<CpAsyncCommit>();
139 result->copyPredicatesFrom(this);
140 return result;
141}
142
143InitMagicZero::InitMagicZero(IrBuilderPasskey passkey)
144 : Expr(passkey, ExprType::InitMagicZero) {
145 TORCH_INTERNAL_ASSERT(
146 passkey.ir_container_->isA<kir::Kernel>(),
147 "IR type only valid for Kernel container.");
148}
149
150Expr* InitMagicZero::shallowCopy() const {
151 auto result = IrBuilder::create<InitMagicZero>();
152 result->copyPredicatesFrom(this);
153 return result;
154}
155
156UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey)
157 : Expr(passkey, ExprType::UpdateMagicZero) {
158 TORCH_INTERNAL_ASSERT(
159 passkey.ir_container_->isA<kir::Kernel>(),
160 "IR type only valid for Kernel container.");
161}
162
163Expr* UpdateMagicZero::shallowCopy() const {
164 auto result = IrBuilder::create<UpdateMagicZero>();
165 result->copyPredicatesFrom(this);
166 return result;
167}
168
169namespace {
170
171bool isIntegralScalar(const Val* val) {
172 return val->isScalar() && val->getDataType().has_value() &&
173 isIntegralType(val->getDataType().value());
174}
175
176} // namespace
177
178IntPair::IntPair(IrBuilderPasskey passkey)
179 : Val(passkey, ValType::IntPair, DataType::Index) {}
180
181PairSelect::PairSelect(
182 IrBuilderPasskey passkey,
183 Val* out,
184 IntPair* in,
185 PairSelect::Selection selection)
186 : Expr(passkey, ExprType::PairSelect),
187 out_{out},
188 in_{in},
189 selection_(selection) {
190 addOutput(out);
191 addInput(in);
192 TORCH_INTERNAL_ASSERT(isIntegralScalar(out), "Integer only for this op");
193}
194
195Expr* PairSelect::shallowCopy() const {
196 auto result = IrBuilder::create<PairSelect>(out_, in_, selection_);
197 result->copyPredicatesFrom(this);
198 return result;
199}
200
201Swizzle2DInt::Swizzle2DInt(
202 IrBuilderPasskey passkey,
203 IntPair* out,
204 Val* in_x,
205 Val* in_y,
206 Val* extent_x,
207 Val* extent_y,
208 Swizzle2DType swizzle_type)
209 : Expr(passkey, ExprType::Swizzle2DInt),
210 out_{out},
211 in_x_{in_x},
212 in_y_{in_y},
213 extent_x_(extent_x),
214 extent_y_(extent_y),
215 swizzle_type_(swizzle_type) {
216 TORCH_INTERNAL_ASSERT(isIntegralScalar(in_x), "Integer only for this op");
217 TORCH_INTERNAL_ASSERT(isIntegralScalar(in_y), "Integer only for this op");
218
219 addOutput(out);
220 addInput(in_x);
221 addInput(in_y);
222 addInput(extent_x);
223 addInput(extent_y);
224}
225
226Expr* Swizzle2DInt::shallowCopy() const {
227 auto result = IrBuilder::create<Swizzle2DInt>(
228 out_, in_x_, in_y_, extent_x_, extent_y_, swizzle_type_);
229 result->copyPredicatesFrom(this);
230 return result;
231}
232
233void Scope::insert(std::vector<Expr*>::const_iterator pos, Expr* expr) {
234 exprs_.insert(pos, expr);
235}
236
237void Scope::insert_before(Expr* ref, Expr* expr) {
238 const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
239 TORCH_INTERNAL_ASSERT(
240 it != exprs_.end(),
241 "Tried to insert ",
242 expr,
243 " before the reference: ",
244 ref,
245 " however the reference was not found in this scope.");
246 insert(it, expr);
247}
248
249void Scope::insert_after(Expr* ref, Expr* expr) {
250 const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
251 TORCH_INTERNAL_ASSERT(
252 it != exprs_.end(),
253 "Tried to insert ",
254 expr,
255 " after the reference: ",
256 ref,
257 " however the reference was not found in this scope.");
258 insert(it + 1, expr);
259}
260
261void Scope::insert(size_t pos, Expr* expr) {
262 const auto it = exprs_.begin() + pos;
263 insert(it, expr);
264}
265
266void Scope::erase(std::vector<Expr*>::const_iterator pos) {
267 // Remove the scope of the expr if this is the scope
268 C10_UNUSED auto expr = *pos;
269 exprs_.erase(pos);
270}
271
272void Scope::erase(Expr* ref) {
273 const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
274 if (it != exprs_.end()) {
275 erase(it);
276 }
277}
278
279void Scope::erase(size_t pos) {
280 TORCH_INTERNAL_ASSERT(pos < size());
281 erase(exprs_.begin() + pos);
282}
283
284bool Scope::contains(Expr* expr) const {
285 const auto it = std::find(exprs_.begin(), exprs_.end(), expr);
286 return it != exprs_.end();
287}
288
289void Scope::clear() {
290 exprs_.clear();
291}
292
293ForLoop::ForLoop(
294 IrBuilderPasskey passkey,
295 IterDomain* iter_domain,
296 Val* index,
297 Val* start,
298 Val* stop,
299 Val* step,
300 bool vectorize,
301 Val* vectorize_shift,
302 bool unroll_required,
303 DoubleBufferLoopStage double_buffer_loop_stage)
304 : Expr(passkey, ExprType::ForLoop),
305 iter_domain_{iter_domain},
306 index_(index),
307 start_(start),
308 stop_(stop),
309 step_(step),
310 vectorize_(vectorize),
311 vectorize_shift_(vectorize_shift),
312 unroll_required_(unroll_required),
313 body_(this),
314 double_buffer_loop_stage_(double_buffer_loop_stage) {
315 TORCH_INTERNAL_ASSERT(
316 passkey.ir_container_->isA<kir::Kernel>(),
317 "IR type only valid for Kernel container.");
318 TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int);
319 addInput(index);
320 addInput(iter_domain);
321 if (start_ == nullptr && iter_domain->isThread()) {
322 start_ = NamedScalar::getParallelIndex(iter_domain->getParallelType());
323 }
324 if (step_ == nullptr) {
325 if (iter_domain->isThread()) {
326 step_ = NamedScalar::getParallelDim(iter_domain->getParallelType());
327 } else {
328 step_ = FusionGuard::getCurFusion()->oneVal();
329 }
330 }
331}
332
333ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain)
334 : ForLoop(
335 passkey,
336 iter_domain,
337 GpuLower::current()->caMap()->getIndexVariable(iter_domain),
338 nullptr,
339 nullptr,
340 nullptr,
341 !iter_domain->isBroadcast() &&
342 isParallelTypeVectorize(iter_domain->getParallelType()),
343 nullptr,
344 false,
345 DoubleBufferLoopStage::NotApplicable) {
346 TORCH_INTERNAL_ASSERT(
347 passkey.ir_container_->isA<kir::Kernel>(),
348 "IR type only valid for Kernel container.");
349}
350
351ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other)
352 : ForLoop(
353 passkey,
354 other->iter_domain(),
355 other->index(),
356 other->start(),
357 other->stop(),
358 other->step(),
359 other->vectorize(),
360 other->vectorize_shift(),
361 other->isUnrollRequired(),
362 other->doubleBufferLoopStage()) {
363 TORCH_INTERNAL_ASSERT(
364 passkey.ir_container_->isA<kir::Kernel>(),
365 "IR type only valid for Kernel container.");
366}
367
368Expr* ForLoop::shallowCopy() const {
369 auto result = IrBuilder::create<ForLoop>(
370 iter_domain_,
371 index_,
372 start_,
373 stop_,
374 step_,
375 vectorize_,
376 vectorize_shift_,
377 unroll_required_,
378 double_buffer_loop_stage_);
379 result->body_ = body_;
380 result->copyPredicatesFrom(this);
381 return result;
382}
383
384bool ForLoop::isUnrollable() const {
385 // Start and stop must be constant, must not be a broadcast
386 // dimension, cannot be bound to a parallel dimension, must not be
387 // vectorized.
388 return start()->isConstScalar() && stop()->isConstScalar() &&
389 !iter_domain()->isThread() && !iter_domain()->isBroadcast() &&
390 !vectorize();
391}
392
393bool ForLoop::isUnrolled() const {
394 if (isUnrollRequired() && !isUnrollable()) {
395 TORCH_WARN(
396 "Unroll required but not possible. Register allocation disabled. Loop index: ",
397 index_->toString());
398 return false;
399 }
400
401 // Size-one loop will not be materialized as a loop, so return false
402 if (start()->isZeroInt() && stop()->isOneInt()) {
403 return false;
404 }
405
406 // Unroll if required.
407 if (isUnrollRequired()) {
408 return true;
409 }
410
411 // Don't unroll if not possible
412 if (!isUnrollable()) {
413 return false;
414 }
415
416 // Unrolling is technically possible but avoided
417 if (iter_domain()->getParallelType() == ParallelType::Unswitch) {
418 // Use ParallelType::Unroll if unrolling is desired. Note that
419 // unswitched size-one loops are not unrolled as they are not
420 // materialized as actual for-loops.
421 return false;
422 }
423
424 return true;
425}
426
427Val* ForLoop::start() const {
428 if (start_ != nullptr) {
429 return start_;
430 } else {
431 // clang-tidy complains without this
432 TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr);
433 return iter_domain_->start();
434 }
435}
436
437Val* ForLoop::stop() const {
438 if (stop_ != nullptr) {
439 return stop_;
440 } else {
441 // clang-tidy complains without this
442 TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr);
443 return iter_domain_->extent();
444 }
445}
446
447Val* ForLoop::step() const {
448 TORCH_INTERNAL_ASSERT(step_ != nullptr);
449 return step_;
450}
451
452bool ForLoop::isTrivial() const {
453 // These loops are not materialized
454 if (vectorize() || iter_domain()->isBroadcast() ||
455 iter_domain()->isStride() || iter_domain()->isMma()) {
456 return true;
457 }
458
459 // By default, a parallelized loop would look like:
460 //
461 // for (int x = threadIdx.x; x < stop; x += blockDim.x) {
462 // do_some_comp(x);
463 // }
464 //
465 // When stop is guaranteed to be smaller or equal to the number of
466 // threads, the for-loop is not necessary. In the above case, we
467 // would just generate the loop body without the for clause but
468 // references to the loop index replaced by the loop start value.
469 //
470 // When the loop end is the same as the IterDomain extent, the
471 // assumption can be safely made. This is more conservative than
472 // necessary since the loop stop value just needs to be <= the
473 // IterDomain extent. However, at this point, this conservative
474 // analysis seems sufficient.
475 if (stop() == iter_domain()->extent() && iter_domain()->isThread()) {
476 return true;
477 }
478
479 // Extent-1 loop: for (int i = 0; i < 1; ++i) {
480 if (start()->isZeroInt() && stop()->isOneInt() && step()->isOneInt()) {
481 return true;
482 }
483
484 // Another extent-1 loop: for (int i = N - 1; i < N; ++i) {
485 if (start()->definition() != nullptr &&
486 start()->definition()->isA<BinaryOp>() &&
487 start()->definition()->as<BinaryOp>()->getBinaryOpType() ==
488 BinaryOpType::Sub &&
489 start()->definition()->as<BinaryOp>()->lhs() == stop() &&
490 start()->definition()->as<BinaryOp>()->rhs()->isOneInt()) {
491 return true;
492 }
493
494 return false;
495}
496
497IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond)
498 : Expr(passkey, ExprType::IfThenElse), then_body_(this), else_body_(this) {
499 setPredicate(cond);
500 addInput(cond);
501}
502
503Expr* IfThenElse::shallowCopy() const {
504 auto result = IrBuilder::create<IfThenElse>(predicate());
505 result->then_body_ = then_body_;
506 result->else_body_ = else_body_;
507 result->setWritePredicate(writePredicate());
508 return result;
509}
510
511Allocate::Allocate(
512 IrBuilderPasskey passkey,
513 Val* buffer,
514 MemoryType memory_type,
515 std::vector<Val*> shape,
516 bool zero_init)
517 : Expr(passkey, ExprType::Allocate),
518 buffer_(buffer),
519 memory_type_(memory_type),
520 shape_(std::move(shape)),
521 zero_init_(zero_init) {
522 TORCH_INTERNAL_ASSERT(
523 passkey.ir_container_->isA<kir::Kernel>(),
524 "IR type only valid for Kernel container.");
525 if (!shape_.empty()) {
526 TORCH_INTERNAL_ASSERT(
527 (shape_.size() == 1 && shape_[0]->isOneInt()) ||
528 buffer_->isA<TensorView>());
529 } else {
530 TORCH_INTERNAL_ASSERT(buffer_->isA<TensorView>());
531 TORCH_INTERNAL_ASSERT(
532 buffer_->as<TensorView>()->getMemoryType() == memory_type_);
533 const auto domain = buffer_->as<TensorView>()->domain();
534 for (auto axis : domain->noReductions()) {
535 shape_.push_back(axis->extent());
536 }
537 }
538
539 for (auto s : shape_) {
540 if (size_ == nullptr) {
541 size_ = s;
542 } else {
543 size_ = IrBuilder::mulExpr(size_, s);
544 }
545 }
546
547 if (size_ == nullptr) {
548 size_ = FusionGuard::getCurFusion()->oneVal();
549 }
550
551 addInput(size_);
552}
553
554Allocate::Allocate(
555 IrBuilderPasskey passkey,
556 Val* buffer,
557 MemoryType memory_type,
558 Val* size,
559 bool zero_init)
560 : Allocate(
561 passkey,
562 buffer,
563 memory_type,
564 size == nullptr ? std::vector<Val*>{} : std::vector<Val*>{size},
565 zero_init) {
566 TORCH_INTERNAL_ASSERT(
567 passkey.ir_container_->isA<kir::Kernel>(),
568 "IR type only valid for Kernel container.");
569}
570
571Expr* Allocate::shallowCopy() const {
572 auto result =
573 IrBuilder::create<Allocate>(buffer_, memory_type_, shape_, zero_init_);
574 result->copyPredicatesFrom(this);
575 return result;
576}
577
578GridReduction::GridReduction(
579 IrBuilderPasskey passkey,
580 BinaryOpType reduction_op_type,
581 Val* init,
582 Val* out,
583 Val* in,
584 Allocate* reduction_buffer,
585 Allocate* sync_buffer,
586 Val* entrance_index,
587 Val* entrances,
588 bool is_allreduce)
589 : ReductionOp(
590 passkey,
591 reduction_op_type,
592 init,
593 out,
594 in,
595 is_allreduce,
596 ExprType::GridReduction),
597 reduction_buffer_(reduction_buffer),
598 sync_buffer_(sync_buffer),
599 entrance_index_(entrance_index),
600 entrances_(entrances) {
601 TORCH_INTERNAL_ASSERT(
602 passkey.ir_container_->isA<kir::Kernel>(),
603 "IR type only valid for Kernel container.");
604}
605
606Expr* GridReduction::shallowCopy() const {
607 auto result = IrBuilder::create<GridReduction>(
608 getReductionOpType(),
609 init(),
610 out(),
611 in(),
612 reduction_buffer_,
613 sync_buffer_,
614 entrance_index_,
615 entrances_,
616 isAllreduce());
617 result->copyPredicatesFrom(this);
618 result->thread_predicate_ = thread_predicate_;
619 return result;
620}
621
622GroupedGridReduction::GroupedGridReduction(
623 IrBuilderPasskey passkey,
624 std::vector<BinaryOpType> reduction_op_types,
625 std::vector<Val*> init_vals,
626 std::vector<Val*> outputs,
627 std::vector<Val*> inputs,
628 std::vector<Allocate*> reduction_buffers,
629 Allocate* sync_buffer,
630 Val* entrance_index,
631 Val* entrances,
632 Val* buffer_stride,
633 bool is_allreduce)
634 : GroupedReductionOp(
635 passkey,
636 std::move(reduction_op_types),
637 std::move(init_vals),
638 std::move(outputs),
639 std::move(inputs),
640 is_allreduce,
641 ExprType::GroupedGridReduction),
642 reduction_buffers_(std::move(reduction_buffers)),
643 sync_buffer_(sync_buffer),
644 entrance_index_(entrance_index),
645 entrances_(entrances),
646 buffer_stride_(buffer_stride) {
647 TORCH_INTERNAL_ASSERT(
648 passkey.ir_container_->isA<kir::Kernel>(),
649 "IR type only valid for Kernel container.");
650}
651
652Expr* GroupedGridReduction::shallowCopy() const {
653 auto result = IrBuilder::create<GroupedGridReduction>(
654 getReductionOpTypes(),
655 initVals(),
656 outputs(),
657 inputs(),
658 reduction_buffers_,
659 sync_buffer_,
660 entrance_index_,
661 entrances_,
662 buffer_stride_,
663 isAllreduce());
664 result->copyPredicatesFrom(this);
665 result->thread_predicate_ = thread_predicate_;
666 return result;
667}
668
669GridBroadcast::GridBroadcast(
670 IrBuilderPasskey passkey,
671 BroadcastOp* broadcast_op,
672 Allocate* broadcast_buffer,
673 Allocate* sync_buffer)
674 : Expr(passkey, ExprType::GridBroadcast),
675 broadcast_op_(broadcast_op),
676 broadcast_buffer_(broadcast_buffer),
677 sync_buffer_(sync_buffer) {
678 TORCH_INTERNAL_ASSERT(
679 passkey.ir_container_->isA<kir::Kernel>(),
680 "IR type only valid for Kernel container.");
681}
682
683Expr* GridBroadcast::shallowCopy() const {
684 auto result = IrBuilder::create<GridBroadcast>(
685 broadcast_op_, broadcast_buffer_, sync_buffer_);
686 result->copyPredicatesFrom(this);
687 return result;
688}
689
690GridWelford::GridWelford(
691 IrBuilderPasskey passkey,
692 WelfordOp* welford_op,
693 Allocate* var_buffer,
694 Allocate* avg_buffer,
695 Allocate* n_buffer,
696 Allocate* sync_buffer,
697 Val* entrance_index,
698 Val* entrances)
699 : Expr(passkey, ExprType::GridWelford),
700 welford_op_(welford_op),
701 var_buffer_(var_buffer),
702 avg_buffer_(avg_buffer),
703 n_buffer_(n_buffer),
704 sync_buffer_(sync_buffer),
705 entrance_index_(entrance_index),
706 entrances_(entrances) {
707 TORCH_INTERNAL_ASSERT(
708 passkey.ir_container_->isA<kir::Kernel>(),
709 "IR type only valid for Kernel container.");
710}
711
712Expr* GridWelford::shallowCopy() const {
713 auto result = IrBuilder::create<GridWelford>(
714 welford_op_,
715 var_buffer_,
716 avg_buffer_,
717 n_buffer_,
718 sync_buffer_,
719 entrance_index_,
720 entrances_);
721 result->copyPredicatesFrom(this);
722 result->thread_predicate_ = thread_predicate_;
723 return result;
724}
725
726GroupedGridWelford::GroupedGridWelford(
727 IrBuilderPasskey passkey,
728 std::vector<WelfordTriplet> output_vals,
729 std::vector<WelfordTriplet> input_vals,
730 std::vector<WelfordTriplet> init_vals,
731 std::array<std::vector<Allocate*>, 3> reduction_buffers,
732 Allocate* sync_buffer,
733 Val* entrance_index,
734 Val* entrances,
735 Val* buffer_stride,
736 bool is_allreduce)
737 : GroupedWelfordOp(
738 passkey,
739 std::move(output_vals),
740 std::move(input_vals),
741 std::move(init_vals),
742 is_allreduce,
743 ExprType::GroupedGridWelford),
744 reduction_buffers_(std::move(reduction_buffers)),
745 sync_buffer_(sync_buffer),
746 entrance_index_(entrance_index),
747 entrances_(entrances),
748 buffer_stride_(buffer_stride) {
749 TORCH_INTERNAL_ASSERT(
750 passkey.ir_container_->isA<kir::Kernel>(),
751 "IR type only valid for Kernel container.");
752}
753
754Expr* GroupedGridWelford::shallowCopy() const {
755 auto result = IrBuilder::create<GroupedGridWelford>(
756 outputVals(),
757 inputVals(),
758 initVals(),
759 reduction_buffers_,
760 sync_buffer_,
761 entrance_index_,
762 entrances_,
763 buffer_stride_,
764 isAllreduce());
765 result->copyPredicatesFrom(this);
766 result->thread_predicate_ = thread_predicate_;
767 return result;
768}
769
770AllocateFusedReduction::AllocateFusedReduction(
771 IrBuilderPasskey passkey,
772 GridReduction* grid_reduction)
773 : Expr(passkey, ExprType::AllocateFusedReduction),
774 grid_expr_(grid_reduction) {
775 TORCH_INTERNAL_ASSERT(
776 passkey.ir_container_->isA<kir::Kernel>(),
777 "IR type only valid for Kernel container.");
778}
779
780AllocateFusedReduction::AllocateFusedReduction(
781 IrBuilderPasskey passkey,
782 GridWelford* grid_welford)
783 : Expr(passkey, ExprType::AllocateFusedReduction),
784 grid_expr_(grid_welford) {
785 TORCH_INTERNAL_ASSERT(
786 passkey.ir_container_->isA<kir::Kernel>(),
787 "IR type only valid for Kernel container.");
788}
789
790AllocateFusedReduction::AllocateFusedReduction(
791 IrBuilderPasskey passkey,
792 GroupedGridReduction* grouped_grid_reduction)
793 : Expr(passkey, ExprType::AllocateFusedReduction),
794 grid_expr_(grouped_grid_reduction) {
795 TORCH_INTERNAL_ASSERT(
796 passkey.ir_container_->isA<kir::Kernel>(),
797 "IR type only valid for Kernel container.");
798}
799
800AllocateFusedReduction::AllocateFusedReduction(
801 IrBuilderPasskey passkey,
802 GroupedGridWelford* grouped_grid_welford)
803 : Expr(passkey, ExprType::AllocateFusedReduction),
804 grid_expr_(grouped_grid_welford) {
805 TORCH_INTERNAL_ASSERT(
806 passkey.ir_container_->isA<kir::Kernel>(),
807 "IR type only valid for Kernel container.");
808}
809
810Expr* AllocateFusedReduction::shallowCopy() const {
811 if (grid_expr_->isA<GridReduction>()) {
812 auto result = IrBuilder::create<AllocateFusedReduction>(
813 grid_expr_->as<GridReduction>());
814 result->setPredicate(predicate());
815 result->setWritePredicate(writePredicate());
816 return result;
817 } else if (grid_expr_->isA<GridWelford>()) {
818 auto result = IrBuilder::create<AllocateFusedReduction>(
819 grid_expr_->as<GridWelford>());
820 result->setPredicate(predicate());
821 result->setWritePredicate(writePredicate());
822 return result;
823 } else if (grid_expr_->isA<GroupedGridReduction>()) {
824 auto result = IrBuilder::create<AllocateFusedReduction>(
825 grid_expr_->as<GroupedGridReduction>());
826 result->setPredicate(predicate());
827 result->setWritePredicate(writePredicate());
828 return result;
829 } else if (grid_expr_->isA<GroupedGridWelford>()) {
830 auto result = IrBuilder::create<AllocateFusedReduction>(
831 grid_expr_->as<GroupedGridWelford>());
832 result->setPredicate(predicate());
833 result->setWritePredicate(writePredicate());
834 return result;
835 }
836 TORCH_INTERNAL_ASSERT(
837 false, "Unknown reduction type in AllocateFusedReduction::shallowCopy");
838}
839
840TensorIndex* AllocateFusedReduction::out() const {
841 TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr);
842 if (grid_expr_->isA<GridReduction>() ||
843 grid_expr_->isA<GroupedGridReduction>()) {
844 return grid_expr_->outputs().at(0)->as<kir::TensorIndex>();
845 } else if (auto grid_welford = dynamic_cast<GridWelford*>(grid_expr_)) {
846 return grid_welford->welford_op()->out()->as<kir::TensorIndex>();
847 } else if (
848 auto grouped_grid_welford =
849 dynamic_cast<GroupedGridWelford*>(grid_expr_)) {
850 return grouped_grid_welford->out(0)->as<kir::TensorIndex>();
851 } else {
852 TORCH_INTERNAL_ASSERT(
853 false, "Invalid grid expression: ", grid_expr_->toString());
854 }
855}
856
857const ParallelTypeBitmap& AllocateFusedReduction::threadPredicate() const {
858 TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr);
859 if (auto grid_reduction = dynamic_cast<GridReduction*>(grid_expr_)) {
860 return grid_reduction->threadPredicate();
861 } else if (auto grid_welford = dynamic_cast<GridWelford*>(grid_expr_)) {
862 return grid_welford->threadPredicate();
863 } else if (
864 auto grouped_grid_reduction =
865 dynamic_cast<GroupedGridReduction*>(grid_expr_)) {
866 return grouped_grid_reduction->threadPredicate();
867 } else if (
868 auto grouped_grid_welford =
869 dynamic_cast<GroupedGridWelford*>(grid_expr_)) {
870 return grouped_grid_welford->threadPredicate();
871 } else {
872 TORCH_INTERNAL_ASSERT(
873 false, "Invalid grid expression: ", grid_expr_->toString());
874 }
875}
876
877} // namespace kir
878} // namespace cuda
879} // namespace fuser
880} // namespace jit
881} // namespace torch
882