1#pragma once
2
3#include <ir_all_nodes.h>
4#include <ir_base_nodes.h>
5#include <parallel_type_bitmap.h>
6#include <type.h>
7#include <utils.h>
8
9#include <c10/macros/Export.h>
10#include <c10/util/Optional.h>
11
12#include <cstdint>
13#include <string>
14#include <unordered_map>
15#include <vector>
16
17namespace torch {
18namespace jit {
19namespace fuser {
20namespace cuda {
21
22class IrBuilderPasskey;
23
24// Abstract nodes
25class Val;
26class Expr;
27
28// Values
29class Bool;
30class Double;
31class Int;
32class NamedScalar;
33
34class IterDomain;
35class TensorDomain;
36class TensorView;
37
38// Expressions
39class UnaryOp;
40class BinaryOp;
41class TernaryOp;
42class RNGOp;
43class ReductionOp;
44class WelfordOp;
45class BroadcastOp;
46
47namespace kir {
48class Kernel;
49
50// Values
51class Predicate;
52class TensorIndex;
53
54// Expressions
55class Allocate;
56class BlockSync;
57class GridSync;
58class CpAsyncWait;
59class CpAsyncCommit;
60class InitMagicZero;
61class UpdateMagicZero;
62class ForLoop;
63class IfThenElse;
64class GridReduction;
65class GroupedGridReduction;
66class GridBroadcast;
67class GridWelford;
68class GroupedGridWelford;
69class AllocateFusedReduction;
70
71// Expr container
72class Scope;
73
74class TORCH_CUDA_CU_API Predicate final : public Val {
75 public:
76 explicit Predicate(
77 IrBuilderPasskey passkey,
78 PredicateType ptype,
79 const Expr* expr = nullptr,
80 Bool* thread_pred = nullptr);
81
82 explicit Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop);
83
84 explicit Predicate(IrBuilderPasskey passkey, Bool* value);
85
86 PredicateType predicate_type() const {
87 return ptype_;
88 }
89
90 const Expr* expr() const {
91 TORCH_INTERNAL_ASSERT(
92 ptype_ != PredicateType::Unswitch &&
93 ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual);
94 return expr_;
95 }
96
97 Bool* thread_pred() const {
98 TORCH_INTERNAL_ASSERT(
99 ptype_ == PredicateType::Inline ||
100 ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift ||
101 ptype_ == PredicateType::Padding ||
102 ptype_ == PredicateType::ReductionWrite);
103 return thread_pred_;
104 }
105
106 ForLoop* unrolled_loop() const {
107 TORCH_INTERNAL_ASSERT(ptype_ == PredicateType::Unswitch);
108 return unrolled_loop_;
109 }
110
111 bool hasValue() const {
112 return value_ != nullptr;
113 }
114
115 Bool* value() const {
116 TORCH_INTERNAL_ASSERT(
117 value_ != nullptr,
118 "The conditional expression for this Predicate is invalid.");
119 return value_;
120 }
121
122 void setValue(Bool* value) {
123 TORCH_INTERNAL_ASSERT(value != nullptr, "The Bool expression is invalid.");
124 value_ = value;
125 }
126
127 bool isConst() const final {
128 return hasValue() && value_->isConst();
129 }
130
131 private:
132 PredicateType ptype_ = PredicateType::Manual;
133
134 // For PredicateCompute::getInlinePredicate,
135 // ShiftPredicateInserter::getShiftPredicate and getPaddingPredicate
136 const Expr* expr_ = nullptr;
137
138 // For PredicateCompute::getInlinePredicate
139 Bool* thread_pred_ = nullptr;
140
141 // For ParallelType::Unswitch - UnswitchPredicate::get
142 ForLoop* unrolled_loop_ = nullptr;
143
144 // The Bool conditional value
145 // The value is nullptr until lower_predicate pass
146 Bool* value_ = nullptr;
147};
148
149class TORCH_CUDA_CU_API TensorIndex final : public Val {
150 public:
151 TensorIndex(
152 IrBuilderPasskey,
153 const TensorView* view,
154 std::vector<Val*> indices);
155
156 std::vector<Val*>::size_type nDims() const {
157 return indices_.size();
158 }
159
160 Val* index(int i) const;
161
162 const std::vector<Val*>& indices() const {
163 return indices_;
164 }
165
166 TensorView* view() const {
167 TORCH_INTERNAL_ASSERT(view_ != nullptr);
168 return const_cast<TensorView*>(view_); // NOLINT
169 }
170
171 private:
172 const TensorView* view_ = nullptr;
173 std::vector<Val*> indices_;
174};
175
176//! Allocate is a lower level Node that describes a buffer of memory that
177//! is required as an intermediate within a kernel. The extent is the expression
178//! of the size of the buffer that is generated from the TensorView that
179//! describes the output of an operation.
180class TORCH_CUDA_CU_API Allocate final : public Expr {
181 public:
182 //! Allocation of a multi-dimensional buffer
183 //!
184 //! param shape Size of each dimension
185 explicit Allocate(
186 IrBuilderPasskey passkey,
187 Val* buffer,
188 MemoryType memory_type,
189 std::vector<Val*> shape = {},
190 bool zero_init = false);
191
192 //! Allocation of a non-dimensional buffer
193 //!
194 //! param size Size of allocation
195 explicit Allocate(
196 IrBuilderPasskey passkey,
197 Val* buffer,
198 MemoryType memory_type,
199 Val* size,
200 bool zero_init = false);
201
202 Expr* shallowCopy() const override;
203
204 Val* buffer() const {
205 return buffer_;
206 }
207
208 MemoryType memoryType() const {
209 return memory_type_;
210 }
211
212 Val* size() const {
213 return size_;
214 }
215
216 const std::vector<Val*>& shape() const {
217 return shape_;
218 }
219
220 bool zeroInit() const {
221 return zero_init_;
222 }
223
224 const Allocate* alias() const {
225 return alias_;
226 }
227
228 void setAlias(const Allocate* alias) {
229 TORCH_INTERNAL_ASSERT(alias != this);
230 TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type_);
231 alias_ = alias;
232 }
233
234 private:
235 Val* buffer_ = nullptr;
236 MemoryType memory_type_ = MemoryType::Local;
237 //! Size of each dimension
238 std::vector<Val*> shape_;
239 bool zero_init_ = false;
240 //! Total size
241 Val* size_ = nullptr;
242
243 // This alias tracks the next Allocate node in a linked chain of aliases
244 // If the alias is nullptr, then the Allocate node uses memory in the kernel
245 const Allocate* alias_ = nullptr;
246};
247
248// Sync represents __syncthreads barrier for block level coordination.
249//
250// TODO(kir): change name to SyncThreads as we could have other barriers.
251//
252class TORCH_CUDA_CU_API BlockSync final : public Expr {
253 public:
254 explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false);
255
256 Expr* shallowCopy() const override;
257
258 bool isWarHazardSync() const {
259 return war_sync_;
260 }
261
262 private:
263 // TODO: war_sync_ is only used for testing/validation purposes.
264 bool war_sync_ = false;
265};
266
267// CpAsyncWait represents wait intrinsics for cp.async
268class TORCH_CUDA_CU_API CpAsyncWait final : public Expr {
269 public:
270 explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0);
271
272 Expr* shallowCopy() const override;
273
274 //! Returns the remaining number of stages that are not synchronized
275 //! after this op.
276 unsigned int keepStages() const {
277 return keep_stages_;
278 }
279
280 private:
281 //! Number of stage to leave un-sync'ed by this op.
282 unsigned int keep_stages_ = 0;
283};
284
285// CpAsyncCommit represents commit intrinsics for cp.async
286// A commit intrinsic communicates delimiter of transaction groups
287// to the async load hardware. Example usage see [Cicular buffer].
288class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr {
289 public:
290 explicit CpAsyncCommit(IrBuilderPasskey passkey);
291
292 Expr* shallowCopy() const override;
293};
294
295// Synchronize all blocks in device, implies cooperative group launch is
296// required.
297class TORCH_CUDA_CU_API GridSync final : public Expr {
298 public:
299 explicit GridSync(
300 IrBuilderPasskey passkey,
301 ParallelTypeBitmap sync_dims,
302 Val* sync_buffer);
303
304 Expr* shallowCopy() const override;
305
306 ParallelTypeBitmap syncDims() const {
307 return sync_dims_;
308 }
309
310 Val* syncBuffer() const {
311 return sync_buffer_;
312 }
313
314 private:
315 ParallelTypeBitmap sync_dims_;
316 Val* sync_buffer_ = nullptr;
317};
318
319// Simply prints "DEFINE_MAGIC_ZERO" in the code in accordance with magic_zero
320// in helpers.cu
321class TORCH_CUDA_CU_API InitMagicZero final : public Expr {
322 public:
323 explicit InitMagicZero(IrBuilderPasskey passkey);
324
325 Expr* shallowCopy() const override;
326};
327
328// Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero
329// in helpers.cu
330class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr {
331 public:
332 explicit UpdateMagicZero(IrBuilderPasskey passkey);
333
334 Expr* shallowCopy() const override;
335};
336
337// TODO(kir): promote to IR node
338class TORCH_CUDA_CU_API Scope {
339 public:
340 explicit Scope(Expr* owner) : owner_(owner) {}
341
342 const std::vector<Expr*>& exprs() const {
343 return exprs_;
344 }
345
346 bool empty() const {
347 return exprs_.empty();
348 }
349
350 auto size() const {
351 return exprs_.size();
352 }
353
354 auto& operator[](size_t i) {
355 return exprs_[i];
356 }
357
358 auto& operator[](size_t i) const {
359 return exprs_[i];
360 }
361
362 // Insert expr before expression at pos
363 void insert(size_t pos, Expr* expr);
364
365 // Insert expr before ref
366 void insert_before(Expr* ref, Expr* expr);
367
368 // Insert expr after ref
369 void insert_after(Expr* ref, Expr* expr);
370
371 void push_back(Expr* e) {
372 exprs_.push_back(e);
373 }
374
375 // Erase expr at pos
376 void erase(size_t pos);
377
378 // Erase expr ref
379 void erase(Expr* ref);
380
381 bool contains(Expr* expr) const;
382
383 void clear();
384
385 Expr* owner() const {
386 return owner_;
387 }
388
389 private:
390 // Insert expr before pos
391 void insert(std::vector<Expr*>::const_iterator pos, Expr* expr);
392
393 // Erase expr at pos
394 void erase(std::vector<Expr*>::const_iterator pos);
395
396 private:
397 std::vector<Expr*> exprs_;
398
399 //! Owner exprssion of this scope, e.g., IfThenElse
400 Expr* owner_ = nullptr;
401};
402
403//! ForLoop provides scoping around an int iterator from 0 to range. Exprs
404//! placed in its body are considered inside the scope of the for loop. In the
405//! future the implementation should look quite different so that we can do
406//! proper dependency annalysis like in Fusion.
407//!
408//! TODO(kir): this is not a real expression
409//!
410//! ForLoop may represent a part of an iteration domain representend
411//! by iter_domain_. In that case, the loop extent field, extent_, may
412//! be smaller than the extent of iter_domain_.
413class TORCH_CUDA_CU_API ForLoop final : public Expr {
414 public:
415 //! By default, start and stop are the same as those of iter_domain.
416 //! Step is one by default.
417 //!
418 //! TODO: cleaner way to set options?
419 ForLoop(
420 IrBuilderPasskey passkey,
421 IterDomain* iter_domain,
422 Val* index,
423 Val* start,
424 Val* stop,
425 Val* step,
426 bool vectorize,
427 Val* vectorize_shift,
428 bool unroll_required,
429 DoubleBufferLoopStage double_buffer_loop_stage);
430
431 ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain);
432
433 ForLoop(IrBuilderPasskey passkey, const ForLoop* other);
434
435 Expr* shallowCopy() const override;
436
437 Val* index() const {
438 return index_;
439 }
440
441 Val* start() const;
442
443 Val* stop() const;
444
445 Val* step() const;
446
447 Val* vectorize_shift() const {
448 return vectorize_shift_;
449 }
450
451 IterDomain* iter_domain() const {
452 return iter_domain_;
453 }
454
455 // TODO: Return pointer instead of reference to be more consistent
456 Scope& body() {
457 return body_;
458 }
459
460 const Scope& body() const {
461 return body_;
462 }
463
464 bool vectorize() const {
465 return vectorize_;
466 }
467
468 //! True if unrolled (i.e., "#pragma unroll" is attached)
469 bool isUnrolled() const;
470
471 //! True if unrolling is required
472 bool isUnrollRequired() const {
473 return unroll_required_;
474 }
475
476 //! Set unrolling required
477 void requireUnroll() {
478 unroll_required_ = true;
479 }
480
481 //! True if no actual for-loop is materialized
482 bool isTrivial() const;
483
484 //! Returns the stage of a double buffered iterdomain
485 //! that this for loop materializes.
486 auto doubleBufferLoopStage() const {
487 return double_buffer_loop_stage_;
488 }
489
490 private:
491 //! Returns if a loop could be unrolled.
492 bool isUnrollable() const;
493
494 private:
495 IterDomain* const iter_domain_ = nullptr;
496
497 Val* index_ = nullptr;
498 Val* start_ = nullptr;
499 Val* stop_ = nullptr;
500 Val* step_ = nullptr;
501
502 // vectorize is true when the for-loop contains a vectorize set
503 // the flag is used to omit the for-loop from the kernel
504 bool vectorize_ = false;
505 // [pre | vectorize | post] <= inner-most, merged root domain
506 // shift_ is applied to vectorize and post sections.
507 Val* vectorize_shift_ = nullptr;
508
509 //! True if unroll is required for avoiding stack allocation
510 bool unroll_required_ = false;
511
512 Scope body_;
513
514 //! Tracks if this for loop is implementing a stage of
515 //! a double buffered iterdomain.
516 DoubleBufferLoopStage double_buffer_loop_stage_ =
517 DoubleBufferLoopStage::NotApplicable;
518};
519
520//! IfThenElse provides scoping for an boolean operator. Exprs placed in its
521//! body are considered inside the scope of the if statement. In the future the
522//! implementation should look quite different so that we can do proper
523//! dependency annalysis like in Fusion.
524//!
525//! TODO(kir): this is not a real expression
526//!
527class TORCH_CUDA_CU_API IfThenElse final : public Expr {
528 public:
529 explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond);
530
531 Expr* shallowCopy() const override;
532
533 Scope& thenBody() {
534 return then_body_;
535 }
536 const Scope& thenBody() const {
537 return then_body_;
538 }
539
540 Scope& elseBody() {
541 return else_body_;
542 }
543
544 const Scope& elseBody() const {
545 return else_body_;
546 }
547
548 bool hasElse() const {
549 return !else_body_.empty();
550 }
551
552 private:
553 Scope then_body_;
554 Scope else_body_;
555};
556
557//! Grid reduction operation
558//!
559//! This node is used only after lowering a fusion to explicitly mark a grid
560//! reduction and the buffer allocation needed to do it.
561//!
562//! This node provides FusionExecutor the information it needs to allocate the
563//! reduction and sync buffers.
564class TORCH_CUDA_CU_API GridReduction final : public ReductionOp {
565 public:
566 GridReduction(
567 IrBuilderPasskey passkey,
568 BinaryOpType reduction_op_type,
569 Val* init,
570 Val* out,
571 Val* in,
572 Allocate* reduction_buffer,
573 Allocate* sync_buffer,
574 Val* entrance_index,
575 Val* entrances,
576 bool is_allreduce = false);
577
578 Expr* shallowCopy() const override;
579
580 Allocate* reduction_buffer() const {
581 return reduction_buffer_;
582 }
583
584 Allocate* sync_buffer() const {
585 return sync_buffer_;
586 }
587
588 // Which instance of entering this grid reduction is this iteration?
589 Val* entrance_index() const {
590 return entrance_index_;
591 }
592
593 // How many times will this grid reduction be entered
594 Val* entrances() const {
595 return entrances_;
596 }
597
598 const ParallelTypeBitmap& threadPredicate() const {
599 return thread_predicate_;
600 }
601
602 GridReduction* withThreadPredicate(
603 const ParallelTypeBitmap& thread_predicate) {
604 auto result = shallowCopy()->as<GridReduction>();
605 result->thread_predicate_ = thread_predicate;
606 return result;
607 }
608
609 private:
610 Allocate* reduction_buffer_ = nullptr;
611 Allocate* sync_buffer_ = nullptr;
612 // gridReduce has template flags for thread predicates. In order to
613 // use them, the thread predicate is held here separately from
614 // Expr::predicate_.
615 ParallelTypeBitmap thread_predicate_;
616 Val* entrance_index_ = nullptr;
617 Val* entrances_ = nullptr;
618};
619
620class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp {
621 public:
622 GroupedGridReduction(
623 IrBuilderPasskey passkey,
624 std::vector<BinaryOpType> reduction_op_type,
625 std::vector<Val*> init,
626 std::vector<Val*> out,
627 std::vector<Val*> in,
628 std::vector<Allocate*> reduction_buffers,
629 Allocate* sync_buffer,
630 Val* entrance_index,
631 Val* entrances,
632 Val* buffer_stride,
633 bool is_allreduce = false);
634
635 Expr* shallowCopy() const override;
636
637 const std::vector<Allocate*>& reduction_buffers() const {
638 return reduction_buffers_;
639 }
640
641 Allocate* reduction_buffer(size_t i) const {
642 return reduction_buffers_.at(i);
643 }
644
645 Allocate* sync_buffer() const {
646 return sync_buffer_;
647 }
648
649 // Which instance of entering this grid reduction is this iteration?
650 Val* entrance_index() const {
651 return entrance_index_;
652 }
653
654 // How many times will this grid reduction be entered
655 Val* entrances() const {
656 return entrances_;
657 }
658
659 Val* buffer_stride() const {
660 return buffer_stride_;
661 }
662
663 const ParallelTypeBitmap& threadPredicate() const {
664 return thread_predicate_;
665 }
666
667 GroupedGridReduction* withThreadPredicate(
668 const ParallelTypeBitmap& thread_predicate) {
669 auto result = shallowCopy()->as<GroupedGridReduction>();
670 result->thread_predicate_ = thread_predicate;
671 return result;
672 }
673
674 private:
675 std::vector<Allocate*> reduction_buffers_;
676 Allocate* sync_buffer_ = nullptr;
677 // gridReduce has template flags for thread predicates. In order to
678 // use them, the thread predicate is held here separately from
679 // Expr::predicate_.
680 ParallelTypeBitmap thread_predicate_;
681 Val* entrance_index_ = nullptr;
682 Val* entrances_ = nullptr;
683 // Stride of reduction buffers
684 Val* buffer_stride_ = nullptr;
685};
686
687//! Grid broadcast operation
688//!
689//! This node is used only after lowering a fusion to explicitly mark a grid
690//! broadcast and the buffer allocation needed to do it.
691//!
692//! This node provides FusionExecutor the information it needs to allocate the
693//! broadcast and sync buffers.
694class TORCH_CUDA_CU_API GridBroadcast final : public Expr {
695 public:
696 GridBroadcast(
697 IrBuilderPasskey passkey,
698 BroadcastOp* broadcast_op,
699 Allocate* broadcast_buffer,
700 Allocate* sync_buffer);
701
702 Expr* shallowCopy() const override;
703
704 BroadcastOp* broadcast_op() const {
705 return broadcast_op_;
706 }
707
708 Allocate* broadcast_buffer() const {
709 return broadcast_buffer_;
710 }
711
712 Allocate* sync_buffer() const {
713 return sync_buffer_;
714 }
715
716 private:
717 BroadcastOp* broadcast_op_ = nullptr;
718 Allocate* broadcast_buffer_ = nullptr;
719 Allocate* sync_buffer_ = nullptr;
720};
721
722//! Grid welford operation
723//!
724//! This node is used only after lowering a fusion to explicitly mark a grid
725//! reduction and the buffer allocation needed to do it.
726//!
727//! This node provides FusionExecutor the information it needs to allocate the
728//! reduction and sync buffers.
729//!
730//! TODO: Make this a subclass of WelfordOp
731class TORCH_CUDA_CU_API GridWelford final : public Expr {
732 public:
733 GridWelford(
734 IrBuilderPasskey passkey,
735 WelfordOp* welford_op,
736 Allocate* var_buffer,
737 Allocate* avg_buffer,
738 Allocate* n_buffer,
739 Allocate* sync_buffer,
740 Val* entrance_index,
741 Val* entrances);
742
743 Expr* shallowCopy() const override;
744
745 WelfordOp* welford_op() const {
746 return welford_op_;
747 }
748
749 Allocate* var_buffer() const {
750 return var_buffer_;
751 }
752
753 Allocate* avg_buffer() const {
754 return avg_buffer_;
755 }
756
757 Allocate* N_buffer() const {
758 return n_buffer_;
759 }
760
761 Allocate* sync_buffer() const {
762 return sync_buffer_;
763 }
764
765 // Which instance of entering this grid reduction is this iteration?
766 Val* entrance_index() const {
767 return entrance_index_;
768 }
769
770 // How many times will this grid reduction be entered
771 Val* entrances() const {
772 return entrances_;
773 }
774
775 const ParallelTypeBitmap& threadPredicate() const {
776 return thread_predicate_;
777 }
778
779 GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
780 auto result = shallowCopy()->as<GridWelford>();
781 result->thread_predicate_ = thread_predicate;
782 return result;
783 }
784
785 private:
786 WelfordOp* welford_op_ = nullptr;
787 Allocate* var_buffer_ = nullptr;
788 Allocate* avg_buffer_ = nullptr;
789 Allocate* n_buffer_ = nullptr;
790 Allocate* sync_buffer_ = nullptr;
791 Val* entrance_index_ = nullptr;
792 Val* entrances_ = nullptr;
793 // gridReduce has template flags for thread predicates. In order to
794 // use them, the thread predicate is held here separately from
795 // Expr::predicate_.
796 ParallelTypeBitmap thread_predicate_;
797};
798
799class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp {
800 public:
801 // input, output and init vals are vectors of triplets
802 GroupedGridWelford(
803 IrBuilderPasskey passkey,
804 std::vector<WelfordTriplet> output_vals,
805 std::vector<WelfordTriplet> input_vals,
806 std::vector<WelfordTriplet> init_vals,
807 std::array<std::vector<Allocate*>, 3> reduction_buffers,
808 Allocate* sync_buffer,
809 Val* entrance_index,
810 Val* entrances,
811 Val* buffer_stride,
812 bool is_allreduce = false);
813
814 Expr* shallowCopy() const override;
815
816 const std::array<std::vector<Allocate*>, 3>& reduction_buffers() const {
817 return reduction_buffers_;
818 }
819
820 Allocate* sync_buffer() const {
821 return sync_buffer_;
822 }
823
824 // Which instance of entering this grid reduction is this iteration?
825 Val* entrance_index() const {
826 return entrance_index_;
827 }
828
829 // How many times will this grid reduction be entered
830 Val* entrances() const {
831 return entrances_;
832 }
833
834 Val* buffer_stride() const {
835 return buffer_stride_;
836 }
837
838 const ParallelTypeBitmap& threadPredicate() const {
839 return thread_predicate_;
840 }
841
842 GroupedGridWelford* withThreadPredicate(
843 const ParallelTypeBitmap& thread_predicate) {
844 auto result = shallowCopy()->as<GroupedGridWelford>();
845 result->thread_predicate_ = thread_predicate;
846 return result;
847 }
848
849 private:
850 std::array<std::vector<Allocate*>, 3> reduction_buffers_;
851 Allocate* sync_buffer_ = nullptr;
852 // gridReduce has template flags for thread predicates. In order to
853 // use them, the thread predicate is held here separately from
854 // Expr::predicate_.
855 ParallelTypeBitmap thread_predicate_;
856 Val* entrance_index_ = nullptr;
857 Val* entrances_ = nullptr;
858 // Stride of reduction buffers
859 Val* buffer_stride_ = nullptr;
860};
861
862// Allocate an instance of the fused reduction class.
863class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr {
864 public:
865 explicit AllocateFusedReduction(
866 IrBuilderPasskey passkey,
867 GridReduction* grid_reduction);
868
869 explicit AllocateFusedReduction(
870 IrBuilderPasskey passkey,
871 GridWelford* grid_welford);
872
873 explicit AllocateFusedReduction(
874 IrBuilderPasskey passkey,
875 GroupedGridReduction* grouped_grid_reduction);
876
877 explicit AllocateFusedReduction(
878 IrBuilderPasskey passkey,
879 GroupedGridWelford* grouped_grid_welford);
880
881 Expr* shallowCopy() const override;
882
883 Expr* gridExpr() const {
884 return grid_expr_;
885 }
886
887 TensorIndex* out() const;
888
889 const ParallelTypeBitmap& threadPredicate() const;
890
891 private:
892 //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford
893 Expr* grid_expr_ = nullptr;
894};
895
896//! An IR node consisting of a pair of integers
897//! to facilitate definition of 2D swizzle operators.
898//! All swizzle 2D ops takes two inputs and outputs
899//! an integer pair.
900//! TODO:
901//! currently this IR node is only allowed as input
902//! to the new PairSelect node. In follow ups would
903//! possibly build out to support out of line
904//! definition of the pair alone.
905class TORCH_CUDA_CU_API IntPair : public Val {
906 public:
907 IntPair(IrBuilderPasskey passkey);
908};
909
910//! An IR node marking selection of first or second
911//! value from a pair of integers, e.g.:
912//! Pair(X,Y) -> X or Y.
913//! This IR node is used to facilitate generation
914//! of inline 2D swizzle math.
915class TORCH_CUDA_CU_API PairSelect : public Expr {
916 public:
917 //! Indicates which value from the input
918 //! integer pair to output.
919 enum class Selection { X = 0, Y };
920
921 PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection);
922
923 Expr* shallowCopy() const override;
924
925 Val* out() const {
926 return out_;
927 }
928
929 IntPair* in() const {
930 return in_;
931 }
932
933 auto selection() const {
934 return selection_;
935 }
936
937 private:
938 Val* const out_ = nullptr;
939 IntPair* const in_ = nullptr;
940 Selection selection_;
941};
942
943//! An integer IR node that will be generated
944//! using custom integer swizzle functions
945//! from the cuda runtime functions.
946//! Most supported swizzle functions require
947//! the sizes of each dimension defined so
948//! all operators will take the extents as inputs.
949class TORCH_CUDA_CU_API Swizzle2DInt : public Expr {
950 public:
951 Swizzle2DInt(
952 IrBuilderPasskey,
953 IntPair* out,
954 Val* in_x,
955 Val* in_y,
956 Val* extent_x,
957 Val* extent_y,
958 Swizzle2DType swizzle_type);
959
960 Expr* shallowCopy() const override;
961
962 IntPair* out() const {
963 return out_;
964 }
965
966 Val* inX() const {
967 return in_x_;
968 }
969
970 Val* inY() const {
971 return in_y_;
972 }
973
974 Val* extentX() const {
975 return extent_x_;
976 }
977
978 Val* extentY() const {
979 return extent_y_;
980 }
981
982 const auto& swizzleType() const {
983 return swizzle_type_;
984 }
985
986 private:
987 IntPair* const out_ = nullptr;
988
989 Val* const in_x_ = nullptr;
990 Val* const in_y_ = nullptr;
991 Val* const extent_x_ = nullptr;
992 Val* const extent_y_ = nullptr;
993 Swizzle2DType swizzle_type_;
994};
995
996} // namespace kir
997} // namespace cuda
998} // namespace fuser
999} // namespace jit
1000} // namespace torch
1001