1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <fusion.h>
6#include <ir_base_nodes.h>
7#include <mma_type.h>
8#include <parallel_type_bitmap.h>
9
10//! Nodes in here should generally not be used by users. They should be behind
11//! the scenes and users shouldn't have to be aware of what they do to use the
12//! code generator
13//!
14//! \todo improve implementation bool IterDomain::sameAs(const IterDomain*)
15//! \todo Add testing of sameAs functions for these nodes
16//!
17
18namespace torch {
19namespace jit {
20namespace fuser {
21namespace cuda {
22
23class ViewTransform;
24class Scope;
25class IrCloner;
26struct AnalyzeViewResult;
27
28//! Returns true if both v1 and v2 are scalars, are the same type of scalars,
29//! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both
30//! vals are `Int` will dispatch to v1->as<Int>()->sameAs(v2.as<Int>())
31bool areEqualScalars(Val* v1, Val* v2);
32
33class TORCH_CUDA_CU_API FullOp : public Expr {
34 public:
35 FullOp(IrBuilderPasskey, Val* out, Val* fill_value, DataType dtype);
36
37 FullOp(const FullOp* src, IrCloner* ir_cloner);
38
39 Expr* shallowCopy() const override;
40
41 bool sameAs(const Statement* other) const override;
42
43 DataType dtype() const {
44 return dtype_;
45 }
46
47 Val* getFillValue() const {
48 return fill_value_;
49 }
50
51 private:
52 const DataType dtype_;
53 Val* fill_value_;
54};
55
56class TORCH_CUDA_CU_API ARangeOp : public Expr {
57 public:
58 ARangeOp(
59 IrBuilderPasskey,
60 Val* out,
61 Val* start,
62 Val* end,
63 Val* step,
64 DataType dtype,
65 Val* linear_index = nullptr);
66
67 ARangeOp(const ARangeOp* src, IrCloner* ir_cloner);
68
69 Expr* shallowCopy() const override;
70
71 bool sameAs(const Statement* other) const override;
72
73 DataType dtype() const {
74 return dtype_;
75 }
76
77 Val* start() const {
78 return start_;
79 }
80
81 Val* end() const {
82 return end_;
83 }
84
85 Val* step() const {
86 return step_;
87 }
88
89 Val* getLinearLogicalIndex() const {
90 return linear_index_;
91 }
92
93 void setLinearIndex(Val* index) {
94 linear_index_ = index;
95 }
96
97 private:
98 const DataType dtype_;
99 Val* start_;
100 Val* end_;
101 Val* step_;
102 Val* linear_index_ = nullptr;
103};
104
105// Tensor factory for generating identity matrices like
106//
107// [[1, 0, 0],
108// [0, 1, 0],
109// [0, 0, 1]]
110//
111// or
112//
113// [[1, 0, 0],
114// [0, 1, 0],
115// [0, 0, 1],
116// [0, 0, 0]]
117//
118// or
119//
120// [[1, 0, 0, 0],
121// [0, 1, 0, 0],
122// [0, 0, 1, 0]]
123class TORCH_CUDA_CU_API EyeOp : public Expr {
124 public:
125 EyeOp(
126 IrBuilderPasskey,
127 Val* out,
128 DataType dtype,
129 Val* index1 = nullptr,
130 Val* index2 = nullptr);
131
132 EyeOp(const EyeOp* src, IrCloner* ir_cloner);
133
134 Expr* shallowCopy() const override;
135
136 bool sameAs(const Statement* other) const override;
137
138 DataType dtype() const {
139 return dtype_;
140 }
141
142 Val* getIndex1() const {
143 return index1_;
144 }
145
146 void setIndex1(Val* index) {
147 index1_ = index;
148 }
149
150 Val* getIndex2() const {
151 return index2_;
152 }
153
154 void setIndex2(Val* index) {
155 index2_ = index;
156 }
157
158 private:
159 const DataType dtype_;
160 Val* index1_ = nullptr;
161 Val* index2_ = nullptr;
162};
163
164//! A specialization for Unary operations. Unary operations take in a single
165//! input and produce a single output. Examples include:
166//! 1) Casting operation i.e. float(a_val)
167//! 2) Negation i.e. val * -1
168//! 3) Reduction across a dimension i.e. val.sum(axis=2)
169//! 4) split/merge
170class TORCH_CUDA_CU_API UnaryOp : public Expr {
171 public:
172 UnaryOp(
173 IrBuilderPasskey,
174 UnaryOpType type,
175 Val* out,
176 Val* in,
177 int rng_offset = -1);
178
179 UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);
180
181 Expr* shallowCopy() const override;
182
183 Val* out() const {
184 return out_;
185 }
186 Val* in() const {
187 return in_;
188 }
189
190 UnaryOpType getUnaryOpType() const {
191 return unary_op_type_;
192 }
193
194 bool sameAs(const Statement* other) const override;
195
196 private:
197 const UnaryOpType unary_op_type_;
198 Val* const out_ = nullptr;
199 Val* const in_ = nullptr;
200};
201
202//! A specialization for Binary operations. Binary operations take in two inputs
203//! and produce a single output. Examples include:
204//! 1) Add/mul/div/mod/sub (A * B)
205//! 2) LT (A < B)
206class TORCH_CUDA_CU_API BinaryOp : public Expr {
207 public:
208 BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs);
209
210 BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);
211
212 Expr* shallowCopy() const override;
213
214 Val* out() const {
215 return out_;
216 }
217 Val* lhs() const {
218 return lhs_;
219 }
220 Val* rhs() const {
221 return rhs_;
222 }
223
224 BinaryOpType getBinaryOpType() const {
225 return binary_op_type_;
226 }
227
228 bool sameAs(const Statement* other) const override;
229
230 private:
231 const BinaryOpType binary_op_type_;
232 Val* const out_ = nullptr;
233 Val* const lhs_ = nullptr;
234 Val* const rhs_ = nullptr;
235};
236
237//! A specialization for random number generator (RNG) operations. RNG
238//! operations take in no tensor input and produce a single output.
239class TORCH_CUDA_CU_API RNGOp : public Expr {
240 public:
241 RNGOp(
242 IrBuilderPasskey,
243 RNGOpType type,
244 Val* out,
245 DataType dtype,
246 std::vector<Val*> parameters = {},
247 int rng_offset = 0,
248 Val* philox_index = nullptr);
249
250 RNGOp(const RNGOp* src, IrCloner* ir_cloner);
251
252 Expr* shallowCopy() const override;
253
254 RNGOpType getRNGOpType() const {
255 return rng_op_type_;
256 }
257
258 DataType dtype() const {
259 return dtype_;
260 }
261
262 int getRNGOffset() const {
263 return rng_offset_;
264 }
265
266 void setRNGOffset(int val) {
267 rng_offset_ = val;
268 }
269
270 const std::vector<Val*>& getParameters() const {
271 return parameters_;
272 }
273
274 const std::vector<Val*>& getShape() const {
275 return shape_;
276 }
277
278 Val* getPhiloxIndex() const {
279 return philox_index_;
280 }
281
282 void setPhiloxIndex(Val* index) {
283 philox_index_ = index;
284 }
285
286 bool sameAs(const Statement* other) const override;
287
288 private:
289 const RNGOpType rng_op_type_;
290 const DataType dtype_;
291 std::vector<Val*> parameters_;
292 std::vector<Val*> shape_;
293 int rng_offset_ = -1;
294 // The index used to feed philox's subsequence and component
295 Val* philox_index_ = nullptr;
296};
297
298//! Broadcast in to match out. is_broadcast_dims are relative to out. Where
299//! is_broadcast_dims.size() == out->nDims().
300class TORCH_CUDA_CU_API BroadcastOp : public Expr {
301 public:
302 //! \param out The output tensor
303 //! \param in The input tensor
304 //! \param is_broadcast_dims True when output dim is a new broadcast domain
305 BroadcastOp(
306 IrBuilderPasskey,
307 Val* out,
308 Val* in,
309 std::vector<bool> is_broadcast_dims);
310
311 BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);
312
313 Expr* shallowCopy() const override;
314
315 Val* out() const {
316 return out_;
317 }
318 Val* in() const {
319 return in_;
320 }
321
322 bool isBroadcastDim(size_t dim) const {
323 return is_broadcast_dims_.at(dim);
324 }
325
326 const std::vector<bool>& getBroadcastDimFlags() const {
327 return is_broadcast_dims_;
328 }
329
330 bool sameAs(const Statement* other) const override;
331
332 private:
333 Val* const out_ = nullptr;
334 Val* const in_ = nullptr;
335
336 //! The same list passed to the broadcast arithmetic op. Each
337 //! element corresponds to an IterDomain of the output tensor and is
338 //! true when the IterDomain is a new broadcast domain. Note
339 //! that the output tensor may have other broadcast domains whose
340 //! flags are false because the input tensor may already have
341 //! broadcast domains.
342 const std::vector<bool> is_broadcast_dims_;
343};
344
345//! Reduction operation. Out is first initialized to _init. Then
346//! reduction_op_type is used to update out as out = reductionOp(out, in).
347//! Output's axes marked as reduction will be reduced to produce an output
348//! tensor. The output tensors size will be the size of all
349//! non-reduction/non-broadcast dimensions.
350class TORCH_CUDA_CU_API ReductionOp : public Expr {
351 public:
352 ReductionOp(
353 IrBuilderPasskey,
354 BinaryOpType reduction_op_type,
355 Val* init,
356 Val* out,
357 Val* in,
358 bool is_allreduce = false,
359 ExprType expr_type = ExprType::ReductionOp);
360
361 ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);
362
363 Expr* shallowCopy() const override;
364
365 Val* out() const {
366 return out_;
367 }
368 Val* in() const {
369 return in_;
370 }
371 Val* init() const {
372 return init_;
373 }
374
375 BinaryOpType getReductionOpType() const {
376 return reduction_op_type_;
377 }
378
379 bool isAllreduce() const {
380 return is_allreduce_;
381 }
382
383 bool sameAs(const Statement* other) const override;
384
385 private:
386 const BinaryOpType reduction_op_type_;
387 Val* const init_ = nullptr;
388 Val* const out_ = nullptr;
389 Val* const in_ = nullptr;
390 //! True if broadcast is fused
391 bool is_allreduce_ = false;
392};
393
394//! Grouped reduction operation for horizontal fusions. It works like
395//! batched GEMMs in the sense that multiple independent reductions are
396//! performed together. The main benefit is when reducing tensors across thread
397//! blocks, a single grid sync can be done for all individual
398//! reductions. As grid sync is very expensive, this can be a
399//! significant performance impact.
400class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {
401 public:
402 GroupedReductionOp(
403 IrBuilderPasskey,
404 std::vector<BinaryOpType> reduction_op_type,
405 std::vector<Val*> init,
406 std::vector<Val*> out,
407 std::vector<Val*> in,
408 bool is_allreduce = false,
409 ExprType expr_type = ExprType::GroupedReductionOp);
410
411 GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner);
412
413 Expr* shallowCopy() const override;
414
415 //! Number of expressions grouped horizontally. It does not reflect
416 //! iteration grouping.
417 size_t numExprs() const {
418 return reduction_op_types_.size();
419 }
420
421 const std::vector<Val*>& initVals() const {
422 return init_vals_;
423 }
424
425 Val* initVal(size_t index) const {
426 return init_vals_.at(index);
427 }
428
429 const std::vector<BinaryOpType>& getReductionOpTypes() const {
430 return reduction_op_types_;
431 }
432
433 BinaryOpType getReductionOpType(size_t index) const {
434 return reduction_op_types_.at(index);
435 }
436
437 bool isAllreduce() const {
438 return is_allreduce_;
439 }
440
441 //! Return the index of the corresponding reduction expression for
442 //! a given output val.
443 int getExprIndexOfOutput(Val* output_val) const;
444
445 bool sameAs(const Statement* other) const override;
446
447 private:
448 //! Reduction ops of grouped reductions
449 const std::vector<BinaryOpType> reduction_op_types_;
450 //! Initial values of grouped reductions
451 const std::vector<Val*> init_vals_;
452 //! True if using the fused reduction kernel
453 bool is_allreduce_ = false;
454};
455
456//! Average, variance and N (count) vals for Welford
457class TORCH_CUDA_CU_API WelfordTriplet {
458 public:
459 //! Names of the Welford triplet vals
460 enum class ValName { Avg, Var, N };
461
462 WelfordTriplet() = default;
463
464 WelfordTriplet(Val* avg, Val* var, Val* N) : vals_({avg, var, N}) {}
465
466 Val* const& avg() const {
467 return get(ValName::Avg);
468 }
469
470 Val*& avg() {
471 return get(ValName::Avg);
472 }
473
474 TensorView* avgTv() const {
475 TORCH_INTERNAL_ASSERT(avg()->isA<TensorView>());
476 return avg()->as<TensorView>();
477 }
478
479 Val* const& var() const {
480 return get(ValName::Var);
481 }
482
483 Val*& var() {
484 return get(ValName::Var);
485 }
486
487 TensorView* varTv() const {
488 TORCH_INTERNAL_ASSERT(var()->isA<TensorView>());
489 return var()->as<TensorView>();
490 }
491
492 Val* const& N() const {
493 return get(ValName::N);
494 }
495
496 Val*& N() {
497 return get(ValName::N);
498 }
499
500 TensorView* NTv() const {
501 TORCH_INTERNAL_ASSERT(N()->isA<TensorView>());
502 return N()->as<TensorView>();
503 }
504
505 //! Get the i-th val. Ordering is defined by ValName.
506 Val* const& get(int i) const {
507 return vals_.at(i);
508 }
509
510 //! Get the i-th val. Ordering is defined by ValName.
511 Val*& get(int i) {
512 return vals_.at(i);
513 }
514
515 Val* const& get(ValName name) const {
516 return get(valNameToIndex(name));
517 }
518
519 Val*& get(ValName name) {
520 return get(valNameToIndex(name));
521 }
522
523 //! Get the name of a given val in this triplet. None is returned if
524 //! not found.
525 c10::optional<ValName> getNameOf(Val* val) const;
526
527 //! Return a new triplet with outputs produced by a function applied
528 //! to each of this triplet
529 template <typename Func>
530 WelfordTriplet transform(Func func) const {
531 return WelfordTriplet(func(avg()), func(var()), func(N()));
532 }
533
534 bool sameAs(const WelfordTriplet& other) const;
535
536 WelfordTriplet clone(IrCloner* ir_cloner) const;
537
538 //! Clone a vector of triplets
539 static std::vector<WelfordTriplet> clone(
540 const std::vector<WelfordTriplet>& src,
541 IrCloner* ir_cloner);
542
543 auto begin() {
544 return vals_.begin();
545 }
546
547 auto begin() const {
548 return vals_.begin();
549 }
550
551 auto end() {
552 return vals_.end();
553 }
554
555 auto end() const {
556 return vals_.end();
557 }
558
559 private:
560 //! Convert a given val name to an index
561 static int valNameToIndex(ValName name) {
562 return static_cast<int>(name);
563 }
564
565 //! Convert a given index to a name
566 static ValName indexToValName(int index) {
567 TORCH_INTERNAL_ASSERT(index >= 0 && index < 3, "Invalid index: ", index);
568 return static_cast<ValName>(index);
569 }
570
571 private:
572 //! Holds avg, var and N in this order
573 std::array<Val*, 3> vals_ = {{nullptr, nullptr, nullptr}};
574};
575
576//! Welford Scan operation.
577class TORCH_CUDA_CU_API WelfordOp : public Expr {
578 public:
579 WelfordOp(
580 IrBuilderPasskey,
581 const WelfordTriplet& output,
582 const WelfordTriplet& input,
583 const WelfordTriplet& init,
584 bool is_fused = false);
585
586 WelfordOp(
587 IrBuilderPasskey,
588 Val* out_avg,
589 Val* out_var,
590 Val* out_N,
591 Val* in_avg,
592 Val* in_var,
593 Val* in_N,
594 Val* init_avg,
595 Val* init_var,
596 Val* init_N,
597 bool is_fused = false);
598
599 WelfordOp(const WelfordOp* src, IrCloner* ir_cloner);
600
601 Expr* shallowCopy() const override;
602
603 Val* out() const {
604 return output().avg();
605 }
606
607 Val* in() const {
608 return input().avg();
609 }
610
611 bool sameAs(const Statement* const other) const override;
612
613 const WelfordTriplet& output() const {
614 return output_;
615 }
616
617 Val* outAvg() const {
618 return output().avg();
619 }
620
621 Val* outVar() const {
622 return output().var();
623 }
624
625 Val* outN() const {
626 return output().N();
627 }
628
629 const WelfordTriplet& input() const {
630 return input_;
631 }
632
633 Val* inAvg() const {
634 return input().avg();
635 }
636
637 Val* inVar() const {
638 return input().var();
639 }
640
641 Val* inN() const {
642 return input().N();
643 }
644
645 const WelfordTriplet& init() const {
646 return init_;
647 }
648
649 Val* initAvg() const {
650 return init().avg();
651 }
652
653 Val* initVar() const {
654 return init().var();
655 }
656
657 Val* initN() const {
658 return init().N();
659 }
660
661 bool singleValue() const {
662 return inN()->isOneInt();
663 }
664
665 bool hasInit() const {
666 return !initN()->isZeroInt();
667 }
668
669 bool isAllreduce() const {
670 return is_allreduce_;
671 }
672
673 std::vector<Val*> getInitVals() const;
674
675 //! Return the init val for an output val
676 Val* getInitValOfOutput(Val* output_val) const;
677
678 private:
679 const WelfordTriplet output_;
680 const WelfordTriplet input_;
681 const WelfordTriplet init_;
682 //! True if using the fused reduction kernel (not implemented yet)
683 bool is_allreduce_ = false;
684};
685
686class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr {
687 public:
688 GroupedWelfordOp(
689 IrBuilderPasskey,
690 std::vector<WelfordTriplet> output_vals,
691 std::vector<WelfordTriplet> input_vals,
692 std::vector<WelfordTriplet> init_vals,
693 bool is_allreduce = false,
694 ExprType expr_type = ExprType::GroupedWelfordOp);
695
696 GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner);
697
698 Expr* shallowCopy() const override;
699
700 //! Number of expressions grouped horizontally. It does not reflect
701 //! iteration grouping. As horizontal grouping is not supported,
702 //! this always returns 1.
703 size_t numExprs() const {
704 return 1;
705 }
706
707 Val* out(size_t index) const {
708 return outAvg(index);
709 }
710
711 Val* in(size_t index) const {
712 return inAvg(index);
713 }
714
715 bool sameAs(const Statement* const other) const override;
716
717 const std::vector<WelfordTriplet>& outputVals() const {
718 return output_vals_;
719 }
720
721 const std::vector<WelfordTriplet>& inputVals() const {
722 return input_vals_;
723 }
724
725 const std::vector<WelfordTriplet>& initVals() const {
726 return init_vals_;
727 }
728
729 Val* outAvg(size_t index) const {
730 return outputVals().at(index).avg();
731 }
732
733 Val* outVar(size_t index) const {
734 return outputVals().at(index).var();
735 }
736
737 Val* outN(size_t index) const {
738 return outputVals().at(index).N();
739 }
740
741 Val* inAvg(size_t index) const {
742 return inputVals().at(index).avg();
743 }
744
745 Val* inVar(size_t index) const {
746 return inputVals().at(index).var();
747 }
748
749 Val* inN(size_t index) const {
750 return inputVals().at(index).N();
751 }
752
753 Val* initAvg(size_t index) const {
754 return initVals().at(index).avg();
755 }
756
757 Val* initVar(size_t index) const {
758 return initVals().at(index).var();
759 }
760
761 Val* initN(size_t index) const {
762 return initVals().at(index).N();
763 }
764
765 //! Return the index of the corresponding welford expression for
766 //! a given output val
767 int getExprIndexOfOutput(Val* output_val) const;
768
769 //! Return the init val for an output val
770 Val* getInitValOfOutput(Val* output_val) const;
771
772 bool singleValue(size_t index) const {
773 return inN(index)->isOneInt();
774 }
775
776 bool hasInit(size_t index) const {
777 return !initN(index)->isZeroInt();
778 }
779
780 bool isAllreduce() const {
781 return is_allreduce_;
782 }
783
784 private:
785 const std::vector<WelfordTriplet> output_vals_;
786 const std::vector<WelfordTriplet> input_vals_;
787 const std::vector<WelfordTriplet> init_vals_;
788 //! True if using the fused reduction kernel
789 bool is_allreduce_ = false;
790};
791
792//! Fused Matmul operation
793class TORCH_CUDA_CU_API MmaOp : public Expr {
794 public:
795 // This is a temporary data structure to for the
796 // scheduling specific parameters that we still need
797 // to store on an mma node. Eventually will only be
798 // the mma macro type that will stay on the IR node
799 // after additional cleaning ups.
800 struct OptionsInMma {
801 MmaOptions::MacroType macro = MmaOptions::MacroType::NoMMA;
802 MmaOptions::MmaInputLayout operand_layout = MmaOptions::MmaInputLayout::TT;
803 int accumulator_stride = 0;
804
805 bool operator==(const OptionsInMma& other) const {
806 return macro == other.macro && operand_layout == other.operand_layout &&
807 accumulator_stride == other.accumulator_stride;
808 }
809 };
810
811 MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
812
813 MmaOp(
814 IrBuilderPasskey,
815 Val* out,
816 Val* in_a,
817 Val* in_b,
818 Val* init,
819 OptionsInMma options);
820
821 MmaOp(const MmaOp* src, IrCloner* ir_cloner);
822
823 Expr* shallowCopy() const override;
824
825 Val* out() const {
826 return out_;
827 }
828
829 Val* inA() const {
830 return in_a_;
831 }
832
833 Val* inB() const {
834 return in_b_;
835 }
836
837 Val* init() const {
838 return init_;
839 }
840
841 const auto& options() const {
842 TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this);
843 return options_.value();
844 }
845
846 bool sameAs(const Statement* const other) const override;
847
848 auto accStride() const {
849 TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this);
850 return options_->accumulator_stride;
851 }
852
853 void configureOptions(MmaOptions options) {
854 options_ = OptionsInMma();
855 TORCH_INTERNAL_ASSERT(
856 options.macro != MmaOptions::MacroType::NoMMA,
857 "Un-configured mma type from options.");
858 TORCH_INTERNAL_ASSERT(
859 options.accumulator_stride > 0, "Un-configured accumulator stride.");
860 options_->accumulator_stride = options.accumulator_stride;
861 options_->macro = options.macro;
862 options_->operand_layout = options.operand_layout;
863 }
864
865 private:
866 Val* const out_ = nullptr;
867 Val* const in_a_ = nullptr;
868 Val* const in_b_ = nullptr;
869 Val* const init_ = nullptr;
870 c10::optional<OptionsInMma> options_ = c10::nullopt;
871};
872
873class TORCH_CUDA_CU_API TransposeOp : public Expr {
874 public:
875 TransposeOp(
876 IrBuilderPasskey,
877 TensorView* out,
878 TensorView* in,
879 std::vector<int64_t> new2old);
880
881 TransposeOp(const TransposeOp* src, IrCloner* ir_cloner);
882
883 Expr* shallowCopy() const override;
884
885 TensorView* out() const {
886 return out_;
887 }
888
889 TensorView* in() const {
890 return in_;
891 }
892
893 const std::vector<int64_t>& new2old() const {
894 return new2old_;
895 }
896
897 std::vector<int64_t> old2new() const;
898
899 private:
900 TensorView* const out_ = nullptr;
901 TensorView* const in_ = nullptr;
902 const std::vector<int64_t> new2old_;
903};
904
905class TORCH_CUDA_CU_API ExpandOp : public Expr {
906 public:
907 ExpandOp(
908 IrBuilderPasskey,
909 TensorView* out,
910 TensorView* in,
911 std::vector<Val*> _expanded_extents);
912
913 ExpandOp(const ExpandOp* src, IrCloner* ir_cloner);
914
915 Expr* shallowCopy() const override;
916
917 TensorView* out() const {
918 return out_;
919 }
920
921 TensorView* in() const {
922 return in_;
923 }
924
925 const std::vector<Val*>& expanded_extents() const {
926 return expanded_extents_;
927 }
928
929 private:
930 TensorView* const out_ = nullptr;
931 TensorView* const in_ = nullptr;
932 std::vector<Val*> expanded_extents_;
933};
934
935class TORCH_CUDA_CU_API TernaryOp : public Expr {
936 public:
937 TernaryOp(
938 IrBuilderPasskey,
939 TernaryOpType type,
940 Val* out,
941 Val* in1,
942 Val* in2,
943 Val* in3);
944
945 TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);
946
947 Expr* shallowCopy() const override;
948
949 Val* out() const {
950 return out_;
951 }
952
953 Val* in1() const {
954 return in1_;
955 }
956 Val* in2() const {
957 return in2_;
958 }
959 Val* in3() const {
960 return in3_;
961 }
962
963 TernaryOpType getTernaryOpType() const {
964 return ternary_op_type_;
965 }
966
967 bool sameAs(const Statement* other) const override;
968
969 private:
970 const TernaryOpType ternary_op_type_;
971 Val* const out_ = nullptr;
972 Val* const in1_ = nullptr;
973 Val* const in2_ = nullptr;
974 Val* const in3_ = nullptr;
975};
976
977//! Shift
978class TORCH_CUDA_CU_API ShiftOp : public Expr {
979 public:
980 //! \param out
981 //! \param in
982 //! \param offsets
983 ShiftOp(
984 IrBuilderPasskey,
985 Val* out,
986 Val* in,
987 std::vector<int> offsets,
988 std::vector<int> pad_width);
989
990 ShiftOp(const ShiftOp* src, IrCloner* ir_cloner);
991
992 Expr* shallowCopy() const override;
993
994 Val* out() const {
995 return out_;
996 }
997 Val* in() const {
998 return in_;
999 }
1000
1001 int offset(size_t dim) const {
1002 return offsets_.at(dim);
1003 }
1004
1005 const std::vector<int>& offsets() const {
1006 return offsets_;
1007 }
1008
1009 const std::vector<int>& padWidth() const {
1010 return pad_width_;
1011 }
1012
1013 bool hasPadding() const {
1014 return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto p) {
1015 return p > 0;
1016 });
1017 }
1018
1019 bool sameAs(const Statement* other) const override;
1020
1021 private:
1022 Val* const out_ = nullptr;
1023 Val* const in_ = nullptr;
1024 //! Each of the root axes is shifted by the corresponding value of
1025 //! offsets_. The sign of each value indicates the direction of
1026 //! shifting.
1027 const std::vector<int> offsets_;
1028 const std::vector<int> pad_width_;
1029};
1030
1031//! Gather a window around each element.
1032class TORCH_CUDA_CU_API GatherOp : public Expr {
1033 public:
1034 GatherOp(
1035 IrBuilderPasskey,
1036 Val* out,
1037 Val* in,
1038 std::vector<int> window_shape,
1039 std::vector<std::vector<int>> pad_width);
1040
1041 GatherOp(const GatherOp* src, IrCloner* ir_cloner);
1042
1043 Expr* shallowCopy() const override;
1044
1045 Val* out() const {
1046 return out_;
1047 }
1048 Val* in() const {
1049 return in_;
1050 }
1051
1052 const auto& windowShape() const {
1053 return window_shape_;
1054 }
1055
1056 //! Returns the gather axis that corresponds to an input axis
1057 int gatherAxis(int axis) const;
1058
1059 const auto& padWidth() const {
1060 return pad_width_;
1061 }
1062
1063 bool hasPadding() const {
1064 return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto& p) {
1065 return p[0] > 0 || p[1] > 0;
1066 });
1067 }
1068
1069 bool sameAs(const Statement* other) const override;
1070
1071 private:
1072 Val* const out_ = nullptr;
1073 Val* const in_ = nullptr;
1074 //! Shape of a window gathered for each element.
1075 std::vector<int> window_shape_;
1076 //! The size of zero-padding of each axis.
1077 std::vector<std::vector<int>> pad_width_;
1078};
1079
1080class TORCH_CUDA_CU_API ViewAsScalar : public Expr {
1081 public:
1082 ViewAsScalar(
1083 IrBuilderPasskey,
1084 Val* out,
1085 Val* in,
1086 IterDomain* vector_id,
1087 Val* index = nullptr);
1088
1089 ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner);
1090
1091 Expr* shallowCopy() const override;
1092
1093 Val* out() const {
1094 return out_;
1095 }
1096
1097 Val* in() const {
1098 return in_;
1099 }
1100
1101 IterDomain* vector_id() const {
1102 return vector_id_;
1103 }
1104
1105 Val* index() const {
1106 return index_;
1107 }
1108
1109 private:
1110 Val* const out_ = nullptr;
1111 Val* const in_ = nullptr;
1112
1113 // The IterDomain of type VectorComponent newly appended to the output
1114 IterDomain* vector_id_ = nullptr;
1115
1116 // The index that vector_id_ is lowered into
1117 Val* index_ = nullptr;
1118};
1119
1120class TORCH_CUDA_CU_API ViewOp : public Expr {
1121 public:
1122 ViewOp(IrBuilderPasskey, TensorView* out, TensorView* in);
1123
1124 ViewOp(const ViewOp* src, IrCloner* ir_cloner);
1125
1126 Expr* shallowCopy() const override;
1127
1128 TensorView* out() const {
1129 return out_;
1130 }
1131
1132 TensorView* in() const {
1133 return in_;
1134 }
1135
1136 private:
1137 TensorView* const out_ = nullptr;
1138 TensorView* const in_ = nullptr;
1139};
1140
1141//! This operator explicitly models data movement between
1142//! state spaces on GPU. Currently the modeled state spaces include
1143//! global memory, shared memory and register.
1144//!
1145//! The main usage of this op is to facilitate generation of hardware
1146//! accelerated memory ops, i.e. ldmatrix, cp.async and more to come.
1147class TORCH_CUDA_CU_API LoadStoreOp : public Expr {
1148 public:
1149 LoadStoreOp(IrBuilderPasskey, LoadStoreOpType op_type, Val* out, Val* in);
1150
1151 LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner);
1152
1153 Expr* shallowCopy() const override;
1154
1155 Val* out() const {
1156 return out_;
1157 }
1158
1159 Val* in() const {
1160 return in_;
1161 }
1162
1163 LoadStoreOpType opType() const {
1164 return load_store_type_;
1165 }
1166
1167 private:
1168 LoadStoreOpType load_store_type_ = LoadStoreOpType::LdMatrix;
1169 Val* const out_ = nullptr;
1170 Val* const in_ = nullptr;
1171};
1172
1173// Convenience utility to initialize IterDomain's without having to sort through
1174// all the default values. Intended to be used with
1175// IterDomain::IterDomain(IrBuilderPasskey IterDomainBuildArgs)
1176class TORCH_CUDA_CU_API IterDomainBuilder {
1177 public:
1178 // Match legacy constructor
1179 IterDomainBuilder(Val* _start, Val* _extent);
1180
1181 // Grab all the parameters from id to set the IterDomainBuilder
1182 IterDomainBuilder(const IterDomain* id);
1183
1184 // Resets defaults for rfactor, is padded dim, padded to size, and is mma
1185 // swizzle which should only be set during scheduling.
1186 IterDomainBuilder& resetSchedulingParams();
1187
1188 // Resets is_rfactor_domain
1189 IterDomainBuilder& resetRfactor();
1190
1191 IterDomainBuilder& start(Val* _start);
1192 IterDomainBuilder& extent(Val* _extent);
1193 IterDomainBuilder& expanded_extent(Val* _expanded_extent);
1194 IterDomainBuilder& stop_offset(Val* _stop_offset);
1195 IterDomainBuilder& parallel_type(ParallelType _parallel_type);
1196 IterDomainBuilder& iter_type(IterType _iter_type);
1197 IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain);
1198 IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension);
1199 IterDomainBuilder& padded_to_size(c10::optional<int64_t> _padded_to_size);
1200 IterDomainBuilder& is_mma_swizzled(bool _is_mma_swizzled);
1201
1202 IterDomain* build() const;
1203
1204 // Must have start and extent at least
1205 IterDomainBuilder() = delete;
1206
1207 Val* start_ = nullptr;
1208 Val* extent_ = nullptr;
1209 Val* expanded_extent_ = nullptr;
1210 Val* stop_offset_ = nullptr;
1211 ParallelType parallel_type_ = ParallelType::Serial;
1212 IterType iter_type_ = IterType::Iteration;
1213
1214 // Only relevant at scheduling time or compile time.
1215 bool is_rfactor_domain_ = false;
1216 bool is_padded_dimension_ = false;
1217 c10::optional<int64_t> padded_to_size_ = c10::nullopt;
1218 bool is_mma_swizzled_ = false;
1219};
1220
1221// Friends for direct access to split
1222class TensorDomain;
1223class ReplayTransformations;
1224class IndexReferenceReplay;
1225//! Simply a representation of an annotated 1D iterable from start to extent.
1226//! TensorDomains which represent how to iterate over a tensor is made up of
1227//! IterDomains to form an ND iterable. We directly set parallization strategies
1228//! on IterDomains.
1229class TORCH_CUDA_CU_API IterDomain : public Val {
1230 public:
1231 IterDomain(IrBuilderPasskey, const IterDomainBuilder& args);
1232
1233 // Legacy constructor, TODO: should start moving to use IterDomainBuildArgs
1234 // constructor Same as the above but can set the offset of the stop point
1235 IterDomain(
1236 IrBuilderPasskey,
1237 Val* start,
1238 Val* extent,
1239 Val* expanded_extent,
1240 Val* stop_offset,
1241 ParallelType parallel_type,
1242 IterType iter_type,
1243 bool is_rfactor_domain,
1244 bool is_padded_dimension,
1245 c10::optional<int64_t> padded_to_size_,
1246 bool is_mma_swizzled);
1247
1248 IterDomain(const IterDomain* src, IrCloner* ir_cloner);
1249
1250 bool sameAs(const Statement* other) const override;
1251
1252 //! Returns a new IterDomain matching properties of this
1253 //!
1254 //! This does NOT copy the is_rfactor_domain flag.
1255 IterDomain* cloneWithoutRFactor() const;
1256
1257 //! Clone a vector domains
1258 static std::vector<IterDomain*> clone(
1259 const std::vector<IterDomain*>& domains);
1260
1261 static IterDomain* merge(IterDomain* outer, IterDomain* inner);
1262
1263 //! start_offset and stop_offset defines partial split. Only root
1264 //! domains are allowed to have non-zero start and stop offsets.
1265 static std::pair<IterDomain*, IterDomain*> split(
1266 IterDomain* in,
1267 Val* factor,
1268 bool inner_split,
1269 Val* start_offset = nullptr,
1270 Val* stop_offset = nullptr);
1271
1272 //! trim_out_of_bounds controls how the values outside start and stop
1273 //! positions are treated. The option is only valid with root
1274 //! domains as non-root domains do not have valid start and stop
1275 //! positions.
1276 //!
1277 //! \param trim_out_of_bounds Trims [0, start_] and [-stop_offset_, extent_]
1278 static std::pair<IterDomain*, IterDomain*> split(
1279 IterDomain* in,
1280 Val* factor,
1281 bool inner_split,
1282 bool trim_out_of_bounds);
1283
1284 bool isReduction() const {
1285 return getIterType() == IterType::Reduction;
1286 }
1287
1288 bool isRFactorProduct() const {
1289 return is_rfactor_domain_;
1290 }
1291
1292 bool isBroadcast() const {
1293 return getIterType() == IterType::Broadcast;
1294 }
1295
1296 bool isGather() const {
1297 return getIterType() == IterType::Gather;
1298 }
1299
1300 bool isStride() const {
1301 return getIterType() == IterType::Stride;
1302 }
1303
1304 bool isVectorComponent() const {
1305 return getIterType() == IterType::VectorComponent;
1306 }
1307
1308 bool isParallelized() const {
1309 return getParallelType() != ParallelType::Serial;
1310 }
1311
1312 //! Return if this iter domain is mapped to a grid dimension
1313 bool isBlockDim() const {
1314 return isParallelTypeBlockDim(getParallelType());
1315 }
1316
1317 //! Return if this iter domain is mapped to a block dimension
1318 bool isThreadDim() const {
1319 return isParallelTypeThreadDim(getParallelType());
1320 }
1321
1322 //! Return if this iter domain is either mapped to a block or grid dimension
1323 bool isThread() const {
1324 return (isBlockDim() || isThreadDim());
1325 }
1326
1327 void parallelize(ParallelType t);
1328
1329 ParallelType getParallelType() const {
1330 return parallel_type_;
1331 }
1332
1333 IterType getIterType() const {
1334 return iter_type_;
1335 }
1336
1337 Val* start() const {
1338 return start_;
1339 }
1340
1341 Val* stop() const;
1342
1343 Val* stopOffset() const;
1344
1345 Val* extent() const {
1346 TORCH_INTERNAL_ASSERT(extent_ != nullptr);
1347 return extent_;
1348 }
1349
1350 bool hasExpandedExtent() const {
1351 return expanded_extent_ != nullptr;
1352 }
1353
1354 // Returns the expanded extent of a strided broadcast entry.
1355 Val* expandedExtent() const {
1356 TORCH_INTERNAL_ASSERT(
1357 hasExpandedExtent(),
1358 "Requested expanded extent, but none found on this dimension.");
1359 return expanded_extent_;
1360 }
1361
1362 Val* getMaybeExpandedExtent() const {
1363 if (hasExpandedExtent()) {
1364 return expandedExtent();
1365 }
1366 return extent();
1367 }
1368
1369 //! Dimension padding interface:
1370 //! 2 modes are currently supported:
1371 //!
1372 //! - mode 1: if to_size is given as a positive number,
1373 //! the dimension will be padded to the size so that
1374 //! this iterdomain will be compile-time constant
1375 //! size and it is the scheduler's responsibility
1376 //! to ensure no input larger than the padded size
1377 //! will be observed
1378 //!
1379 //! - mode 2: if no to_size is given, this dimension
1380 //! is "dynamically" padded to next smallest multiple
1381 //! of a warp size, i.e. 17 padded to 32, 33 padded to 64
1382 //! based on the given input.
1383 void padToMultipleOfWarp(c10::optional<int64_t> maybe_to_size = {}) {
1384 // Currently only restricted to TIDx to generate warp reduce
1385 TORCH_CHECK(
1386 parallel_type_ == ParallelType::TIDx,
1387 "padToMultipleOfWarp : warp padding only supported on TIDx parallel dimension");
1388 is_padded_dimension_ = true;
1389 if (maybe_to_size.has_value()) {
1390 if (maybe_to_size.value() > 0) {
1391 padded_to_size_ = maybe_to_size.value();
1392 }
1393 }
1394 }
1395
1396 //! Indicates if this iterdomain had padding
1397 //! dynamical or statical
1398 bool hasPaddingToMultipleOfWarp() const {
1399 return is_padded_dimension_;
1400 }
1401
1402 //! Returns a concrete value if this iterdomain
1403 //! has been padded to a statical size.
1404 c10::optional<int64_t> getMaybeSizeAfterPadding() const {
1405 return padded_to_size_;
1406 }
1407
1408 //! True if range of iteration domain isn't across the full extent
1409 bool maybePartial() const;
1410
1411 //! Check if IterDomain is a broadcast axis with compile-time
1412 //! known extent. This is the case with all size-1 IterDomains on
1413 //! a TensorView's root domain when the TensorView is created.
1414 bool isImplicitBroadcast() const {
1415 return isBroadcast() && extent()->isOneInt();
1416 }
1417
1418 //! Check if IterDomain is a reduction axis with size of 1, i.e.
1419 //! a "squeeze" operator, or solely derived from such axes.
1420 bool isTrivialReduction() const;
1421
1422 //! Split for stride by a given factor. It effectively does an inner
1423 //! split by the factor and sets the inner domain as a Stride
1424 //! domain.
1425 std::pair<IterDomain*, IterDomain*> stridedSplit(int factor);
1426
1427 // TODO: Remove
1428 bool isSimple() const {
1429 return definition() == nullptr;
1430 }
1431
1432 //! Marks that this id represents a
1433 //! instruction loop, mma use only.
1434 //!
1435 //! An instruction loop can be considered a generalization of
1436 //! vectorization. It also represents a loop that's implemented
1437 //! by an instruction and should not be realized by codegen and
1438 //! cannot be inlined with.
1439 //! As an example, if a mma macro, call it mma_eg implements:
1440 //! for m in M
1441 //! for n in N
1442 //! for k in K
1443 //! C[m,n] += A[m,k]*B[k,n],
1444 //! But the generated code should simply be:
1445 //! mma_eg(C,A,B)
1446 //! without the 3 level loopnest, i.e. they're instruction loops.
1447 //!
1448 //! In the actual mma macros, the loopnests it implements is a
1449 //! transformed version of above to match the mma swizzle.
1450 //! So it's different implicit loopnest for different macros.
1451 //! WarpMmaSwizzler will label the instruction loops case-by-case.
1452 bool isMma() const {
1453 return parallel_type_ == ParallelType::Mma;
1454 }
1455
1456 //! Applies 2D swizzle on a rectangular tile defined by
1457 //! a pair of iterdomains.
1458 static std::pair<IterDomain*, IterDomain*> swizzle(
1459 Swizzle2DType swizzle_type,
1460 IterDomain* in_x,
1461 IterDomain* in_y,
1462 SwizzleMode swizzle_mode = SwizzleMode::Data);
1463
1464 bool isMmaSwizzled() const {
1465 return is_mma_swizzled_;
1466 }
1467
1468 //! Used by WarpMmaSwizzler, this is an utility for WarpMmaSwizzler
1469 //! to lock the thread swizzled iterdomains.
1470 //! Only true for the iterdomains produced by WarpMmaSwizzler.
1471 //! Mma ops require specific swizzle patterns
1472 //! and this label utility is to prevent any further transform on the
1473 //! iterdomains involved in the swizzle so that the pattern remain correct in
1474 //! generated code.
1475 //!
1476 //! Note:
1477 //! Used only through WarpMmaSwizzler only and mma validation relies on
1478 //! this
1479 //! flag being set on the correct iterdomains.
1480 void toMmaSwizzled() {
1481 is_mma_swizzled_ = true;
1482 }
1483
1484 protected:
1485 friend TensorDomain;
1486 friend ReplayTransformations;
1487 friend IndexReferenceReplay;
1488
1489 private:
1490 //! Valid range is defined as [start:-stop_offset]
1491 Val* const start_ = nullptr;
1492 Val* const extent_ = nullptr;
1493
1494 // Broadcast dimensions are assumed to be size 1 for the sake of code
1495 // generation. If a user though calls `expand` on a tensor that dimension is
1496 // still considered a broadcast dimension. However if we ever output that
1497 // dimension it should be a size dictated by the `expand` operation, and have
1498 // a stride of zero. Since this extent is important to track, but not
1499 // necessarily generate code for (still want loops on broadcast to be of size
1500 // 0), we simply store it separately from extent_. Having an expanded_extent_
1501 // is only allowed with broadcasted dimsneions. Only in this instance does it
1502 // make sense to have an expanded_extent_, because it's used when users are
1503 // expecting return tensors to have a physical domain. If a user simply
1504 // "broadcasts" an operation
1505 Val* const expanded_extent_ = nullptr;
1506
1507 //! Distance of stop from the end
1508 Val* const stop_offset_ = nullptr;
1509 ParallelType parallel_type_ = ParallelType::Serial;
1510 IterType iter_type_ = IterType::Iteration;
1511 bool is_rfactor_domain_ = false;
1512 bool is_padded_dimension_ = false;
1513 c10::optional<int64_t> padded_to_size_ = c10::nullopt;
1514
1515 // TODO: Remove only used in kernel IR because IterDomains don't maintain
1516 // definitions of split/merge.
1517 bool is_simple_ = true;
1518
1519 //! Tracks if this id represents a thread swizzled loop or
1520 //! models an implicit loop within instructions. Should not make
1521 //! any changes once an id is warp mapped.
1522 bool is_mma_swizzled_ = false;
1523};
1524
1525//! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
1526//! logical axis in its associated tensor. TensorDomain does not directly hold
1527//! the Tensor it is associated with, and in theory could be associated with
1528//! multiple tensors. TensorDomain's primary responsibility is to provide a
1529//! mechanism to access history of transformations that were used to generate
1530//! it. This is done through the normal interaction of Expr/Val in Fusion. i.e.
1531//! if we want to know the previous operation generating a particular
1532//! TensorDomain we can simply call:
1533//!
1534//! FusionGuard::getCurFusion()->definition(a_tensor_domain)
1535//!
1536//! which should give us an operation in the list [split, merge] or similar
1537//! operations that take in a TensorDomain, applies a transformation and outputs
1538//! a tensor domain.
1539class TORCH_CUDA_CU_API TensorDomain : public Val {
1540 public:
1541 explicit TensorDomain(
1542 IrBuilderPasskey,
1543 std::vector<IterDomain*> root_domain,
1544 std::vector<bool> contiguity = std::vector<bool>());
1545
1546 TensorDomain(
1547 IrBuilderPasskey,
1548 std::vector<IterDomain*> root_domain,
1549 std::vector<IterDomain*> domain,
1550 std::vector<bool> contiguity = std::vector<bool>());
1551
1552 TensorDomain(
1553 IrBuilderPasskey,
1554 std::vector<IterDomain*> root_domain,
1555 std::vector<IterDomain*> rfactor_domain,
1556 std::vector<IterDomain*> domain,
1557 std::vector<bool> contiguity = std::vector<bool>());
1558
1559 TensorDomain(const TensorDomain* src, IrCloner* ir_cloner);
1560
1561 bool operator==(const TensorDomain& other) const;
1562 bool operator!=(const TensorDomain& other) const {
1563 return !(*this == other);
1564 }
1565
1566 std::vector<IterDomain*>::size_type nDims() const {
1567 return domain_.size();
1568 }
1569
1570 bool sameAs(const Statement* other) const override;
1571
1572 static bool sameAs(
1573 const std::vector<IterDomain*>& lhs,
1574 const std::vector<IterDomain*>& rhs);
1575
1576 const std::vector<IterDomain*>& domain() const {
1577 return domain_;
1578 }
1579
1580 const std::vector<bool>& contiguity() const {
1581 return contiguity_;
1582 }
1583
1584 void setContiguity(const std::vector<bool>& contig);
1585
1586 std::string getContiguityString() const {
1587 std::stringstream ss;
1588 for (auto b : contiguity()) {
1589 ss << (b ? "t" : "f");
1590 }
1591 return ss.str();
1592 }
1593
1594 bool hasReduction() const;
1595 bool hasBlockReduction() const;
1596 bool hasGridReduction() const;
1597 bool hasBlockBroadcast() const;
1598 bool hasGridBroadcast() const;
1599 bool hasBroadcast() const;
1600 bool hasRFactor() const;
1601
1602 // Returns if rfactor domain only consists of id's of iter type.
1603 bool hasViewLikeRFactor() const;
1604
1605 bool hasVectorize() const;
1606
1607 c10::optional<unsigned int> getReductionAxis() const;
1608
1609 const std::vector<IterDomain*>& noReductions() const {
1610 return no_reduction_domain_;
1611 }
1612
1613 const std::vector<IterDomain*>& noBroadcasts() const {
1614 return no_bcast_domain_;
1615 }
1616
1617 const std::vector<IterDomain*>& getRootDomain() const {
1618 return root_domain_;
1619 };
1620
1621 const std::vector<IterDomain*>& getRFactorDomain() const {
1622 return rfactor_domain_;
1623 };
1624
1625 // If rfactor domain exists in domain() return it, otherwise return root
1626 // domain.
1627 const std::vector<IterDomain*>& getMaybeRFactorDomain() const {
1628 return hasRFactor() ? getRFactorDomain() : getRootDomain();
1629 }
1630
1631 void resetDomains() {
1632 no_reduction_domain_ = noReductions(domain_);
1633 no_bcast_domain_ = noBroadcasts(domain_);
1634 has_nontrivial_reduction_ = hasNontrivialReduction(domain_);
1635 }
1636
1637 // i here is int, as we want to accept negative value and ::size_type can be a
1638 // uint.
1639 IterDomain* axis(int i) const;
1640
1641 size_t posOf(IterDomain* id) const;
1642
1643 //! Returns a position of a root domain
1644 size_t rootPosOf(IterDomain* id) const;
1645
1646 // Split "axis" into 2 axes
1647 //! inner_split dictates if the factor section of the split should be inside
1648 //! the
1649 //! remainer or outside.
1650 //! e.g. split(0, 4, inner_split = true) will result in:
1651 //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
1652 //! e.g. split(0, 4, inner_split = false) will result in:
1653 //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
1654 void split(
1655 int axis_,
1656 Val* factor,
1657 bool inner_split,
1658 bool trim_out_of_bounds = false);
1659
1660 // Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting
1661 // axis is by default placed at original position axis_o
1662 void merge(int axis_o, int axis_i);
1663
1664 // Reorder axes according to map[old_pos] = new_pos
1665 void reorder(const std::unordered_map<int, int>& old2new);
1666
1667 //! Applies 2D swizzle on a rectangular tile defined by
1668 //! a pair of iterdomains contained in this domain.
1669 void swizzle(
1670 Swizzle2DType swizzle_type,
1671 int x,
1672 int y,
1673 SwizzleMode swizzle_mode = SwizzleMode::Data);
1674
1675 // Transform TensorView according to merge and split transformations
1676 TensorDomain* view(const AnalyzeViewResult& view_analysis);
1677
1678 TensorDomain* flatten(int64_t start_dim, int64_t end_dim);
1679
1680 static std::vector<IterDomain*> orderedAs(
1681 const std::vector<IterDomain*>& td,
1682 const std::unordered_map<int, int>& old2new);
1683
1684 static std::vector<IterDomain*> noReductions(const std::vector<IterDomain*>&);
1685 static std::vector<IterDomain*> noBroadcasts(const std::vector<IterDomain*>&);
1686
1687 static bool hasBroadcast(const std::vector<IterDomain*>&);
1688 static bool hasReduction(const std::vector<IterDomain*>&);
1689 static bool hasNontrivialReduction(const std::vector<IterDomain*>&);
1690
1691 // pair is in order where second is the consumer of first
1692 std::pair<TensorDomain*, TensorDomain*> rFactor(const std::vector<int>& axes);
1693
1694 private:
1695 const std::vector<IterDomain*> root_domain_;
1696 std::vector<IterDomain*> domain_;
1697 std::vector<IterDomain*> no_bcast_domain_;
1698 std::vector<IterDomain*> no_reduction_domain_;
1699 const std::vector<IterDomain*> rfactor_domain_;
1700 std::vector<bool> contiguity_;
1701 bool has_nontrivial_reduction_;
1702};
1703
1704//! Representation a split on an IterDomain by "factor"
1705//! inner_split dictates if the factor section of the split should be inside the
1706//! remainer or outside.
1707class TORCH_CUDA_CU_API Split : public Expr {
1708 public:
1709 // start_offset and stop_offset are used to express partial
1710 // split. Only the partial domain from start_offset to stop_offset
1711 // is split and the outer sub-regions are ignored. Note that both
1712 // start_offset and stop_offset are distance from the left end and
1713 // right ends, respectively.
1714 Split(
1715 IrBuilderPasskey,
1716 IterDomain* outer,
1717 IterDomain* inner,
1718 IterDomain* in,
1719 Val* factor,
1720 bool inner_split = true,
1721 Val* start_offset = nullptr,
1722 Val* stop_offset = nullptr);
1723
1724 Split(const Split* src, IrCloner* ir_cloner);
1725
1726 Expr* shallowCopy() const override;
1727
1728 IterDomain* outer() const {
1729 return outer_;
1730 }
1731 IterDomain* inner() const {
1732 return inner_;
1733 }
1734 IterDomain* in() const {
1735 return in_;
1736 }
1737 Val* factor() const {
1738 return factor_;
1739 }
1740
1741 bool innerSplit() const {
1742 return inner_split_;
1743 }
1744
1745 Val* startOffset() const {
1746 TORCH_INTERNAL_ASSERT(start_offset_ != nullptr);
1747 return start_offset_;
1748 }
1749
1750 Val* stopOffset() const {
1751 TORCH_INTERNAL_ASSERT(stop_offset_ != nullptr);
1752 return stop_offset_;
1753 }
1754
1755 //! Utility function to compute the split extent.
1756 static Val* extent(Val* in_extent, Val* start_offset, Val* stop_offset);
1757
1758 bool sameAs(const Statement* other) const override;
1759
1760 private:
1761 IterDomain* const outer_ = nullptr;
1762 IterDomain* const inner_ = nullptr;
1763 IterDomain* const in_ = nullptr;
1764 Val* const factor_ = nullptr;
1765 bool inner_split_ = true;
1766 //! Start position of the input domain. Non-zero means partial
1767 //! split. Elements until this offset are ignored.
1768 Val* const start_offset_ = nullptr;
1769 //! Offset from extent of the input domain. Non-zero means partial
1770 //! split. Elements after this offset are ignored.
1771 Val* const stop_offset_ = nullptr;
1772};
1773
1774//! Merge the IterDomains outer and inner into one domain, outer and inner
1775//! dictate which will be traversed first (inner). Both IterDomains must be of
1776//! the same iter or reduction type, as well as the same parallelization
1777//! strategy if there is one
1778class TORCH_CUDA_CU_API Merge : public Expr {
1779 public:
1780 Merge(
1781 IrBuilderPasskey,
1782 IterDomain* out,
1783 IterDomain* outer,
1784 IterDomain* inner);
1785
1786 Merge(const Merge* src, IrCloner* ir_cloner);
1787
1788 Expr* shallowCopy() const override;
1789
1790 IterDomain* out() const {
1791 return out_;
1792 }
1793 IterDomain* outer() const {
1794 return outer_;
1795 }
1796 IterDomain* inner() const {
1797 return inner_;
1798 }
1799
1800 bool sameAs(const Statement* other) const override;
1801
1802 private:
1803 IterDomain* const out_ = nullptr;
1804 IterDomain* const outer_ = nullptr;
1805 IterDomain* const inner_ = nullptr;
1806};
1807
1808//! Applies 2D swizzles on a rectangular tile defined by 2 iterdomains.
1809class TORCH_CUDA_CU_API Swizzle2D : public Expr {
1810 public:
1811 Swizzle2D(
1812 IrBuilderPasskey,
1813 IterDomain* out_x,
1814 IterDomain* out_y,
1815 IterDomain* in_x,
1816 IterDomain* in_y,
1817 Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle,
1818 SwizzleMode swizzle_mode = SwizzleMode::Data);
1819
1820 Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner);
1821
1822 Expr* shallowCopy() const override;
1823
1824 IterDomain* outX() const {
1825 return out_x_;
1826 }
1827
1828 IterDomain* outY() const {
1829 return out_y_;
1830 }
1831
1832 IterDomain* inX() const {
1833 return in_x_;
1834 }
1835
1836 IterDomain* inY() const {
1837 return in_y_;
1838 }
1839
1840 auto swizzleType() const {
1841 return swizzle_type_;
1842 }
1843
1844 auto swizzleMode() const {
1845 return swizzle_mode_;
1846 }
1847
1848 bool sameAs(const Statement* other) const override;
1849
1850 private:
1851 // Output iterdomain pair corresponding
1852 // to the original input iterdomain pair.
1853 IterDomain* const out_x_ = nullptr;
1854 IterDomain* const out_y_ = nullptr;
1855
1856 // Input iterdomain pair.
1857 IterDomain* const in_x_ = nullptr;
1858 IterDomain* const in_y_ = nullptr;
1859
1860 // The type of predefined 1-to-1 functions
1861 // used for swizzling math.
1862 Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle;
1863
1864 // Swizzle mode of this swizzle instance.
1865 // [Note on swizzle mode]
1866 // On the current implementations we support two modes of
1867 // swizzle math, namely, data mode and loop mode.
1868 // `Data` mode swizzling is a swizzle that will change the
1869 // data layout in shared memory, likely in global memory buffers
1870 // as well in the future. see also IndexSwizzle in index_compute.cpp.
1871 //
1872 // Most important use cases are transpose bank conflict removal, and mma
1873 // swizzled shared memory layout. Example illustrated in 1D case:
1874 //
1875 // for (int i = 0; i<I; i++){
1876 // # This is a `Data` mode swizzle.
1877 // Tshared [swizzled(i)] = Tin[i];
1878 // }
1879 // # Now Tshared holds swizzled data, i.e. the data layout of
1880 // Tshared does not map to Tin with affine relationships.
1881 //
1882 // for(int i=0;i<I;i++){
1883 // Tout = Tshared[swizzled(i)];
1884 // }
1885 //
1886 // `Loop` mode swizzling does not affect the data layout of any buffer
1887 // but only permutes the iteration order of serial or parallel loop.
1888 // This is useful when we want to designate non-affine mapping of thread
1889 // to data or we want to generate non-affine loops.
1890 // Exampe illustrated in 1D case:
1891 // for (int i = 0; i<I; i++){
1892 // # This is a `Loop` mode swizzle
1893 // Tshared [swizzled(i)] = Tin[swizzled(i)];
1894 // }
1895 // # Now Tshared holds normal data, i.e. it still has
1896 // the same data layout as if the swizzle wasn't there.
1897 //
1898 // # Consumers of Tshared does not need to know about the
1899 // loop swizzle at previous op if not inlined.
1900 // for(int i=0;i<I;i++){
1901 // Tout = Tshared[i];
1902 // }
1903 // TODO: Loop swizzles eventually will be piped through in all mappings
1904 // and replay of the fusion IR infrastructure.
1905 SwizzleMode swizzle_mode_ = SwizzleMode::Data;
1906};
1907
1908//! Integer value which has a special name
1909//!
1910//! These could be:
1911//! - threadIdx.x
1912//! - blockIdx.y
1913//! - blockDim.z
1914//! - T3.stride[2]
1915//!
1916class TORCH_CUDA_CU_API NamedScalar : public Val {
1917 public:
1918 NamedScalar(IrBuilderPasskey passkey, std::string name, DataType dtype);
1919
1920 NamedScalar(const NamedScalar* src, IrCloner* ir_cloner);
1921
1922 const std::string& name() const {
1923 return name_;
1924 }
1925
1926 bool sameAs(const Statement* other) const override;
1927
1928 //! Return the named scalar extent of a parallel dimension (e.g. blockDim.x)
1929 //! WARNING: Only works with Fusion container at the moment
1930 static NamedScalar* getParallelDim(ParallelType p_type);
1931
1932 //! Return the named scalar index of a parallel dimension (e.g. threadIdx.x)
1933 //! WARNING: Only works with Fusion container at the moment
1934 static NamedScalar* getParallelIndex(ParallelType p_type);
1935
1936 //! Return the parallel type of this NamedScalar if it is an extent of a
1937 //! parallel dimension
1938 c10::optional<ParallelType> getParallelDim() const;
1939
1940 //! Return the parallel type of this NamedScalar if it is an index of a
1941 //! parallel dimension
1942 c10::optional<ParallelType> getParallelIndex() const;
1943
1944 private:
1945 std::string name_;
1946};
1947
1948} // namespace cuda
1949} // namespace fuser
1950} // namespace jit
1951} // namespace torch
1952