1#pragma once
2#include <c10/macros/Export.h>
3#include <fusion.h>
4
5namespace torch {
6namespace jit {
7namespace fuser {
8namespace cuda {
9
10//! Utility data structure for recording gemm tiles
11struct GemmTile {
12 int m, n, k;
13 GemmTile(int m_, int n_, int k_) : m(m_), n(n_), k(k_) {}
14
15 bool operator==(const GemmTile& other) {
16 return m == other.m && n == other.n && k == other.k;
17 }
18
19 GemmTile operator/(const GemmTile& other) {
20 return GemmTile(m / other.m, n / other.n, k / other.k);
21 }
22
23 std::vector<int> toVector() {
24 return {m, n, k};
25 }
26};
27
28//! Utility data structure for recording gemm tiles
29struct TORCH_CUDA_CU_API MatMulTileOptions {
30 GemmTile cta_tile = GemmTile(128, 128, 32);
31 GemmTile warp_tile = GemmTile(64, 64, 32);
32 GemmTile instruction_tile = GemmTile(16, 8, 16);
33
34 MatMulTileOptions() = default;
35 MatMulTileOptions(
36 GemmTile cta_tile_,
37 GemmTile warp_tile_,
38 GemmTile instruction_tile_)
39 : cta_tile(cta_tile_),
40 warp_tile(warp_tile_),
41 instruction_tile(instruction_tile_) {}
42
43 bool operator==(const MatMulTileOptions& other) {
44 return cta_tile == other.cta_tile && warp_tile == other.warp_tile &&
45 instruction_tile == other.instruction_tile;
46 }
47};
48
49//! Information for configuring and lowering mma ops
50struct MmaOptions {
51 //! Type of mma instrinsic macro to use
52 //! This will translate to which mma intrinsic from runtime string
53 //! to be generated to implement the mma op. The current plan
54 //! is to have exactly one macro for each
55 //! (arch, datatype, operand layout) triple, though there
56 //! exists multiple possibilities for some cases, e.g. for Turing and fp16
57 //! one can use 16_8_8 or 16_8_16.
58 //! Will consider adding more choices that the scheduler can pick from
59 //! when our perf target becomes more fine grained, which is more likely in
60 //! latency bound kernels.
61 enum class MacroType {
62 NoMMA = 0,
63 Volta_16_16_4,
64 Ampere_16_8_16,
65 Ampere_16_16_16,
66 Turing_16_8_16,
67 Turing_16_16_16,
68 Ampere_16_8_8 // place holder for tf32
69 };
70
71 //! [Operand Layout Convention]
72 //! Operand layout, T=transposed/row_major, N=normal/col_major
73 //! We don't support calling NN mma directly since it implies
74 //! a fused transpose. User needs to swap the operands and use
75 //! TT mma to make the transpose explicit.
76 //! Ordered by position of K
77 //! NT : K,M x K,N -> K,M,N
78 //! TT : M,K X K,N -> M,K,N
79 //! TN : M,K X N,K -> M,N,K
80 enum class MmaInputLayout { NT = 0, TT, TN };
81
82 //! Utility to annotate which input of mma this option struct describes
83 enum class Operand { Accumulator = 0, A, B };
84
85 //! Utility to annotate which mma macro this config uses.
86 MacroType macro = MacroType::NoMMA;
87
88 //! Utility to annotate transposition of operands
89 MmaInputLayout operand_layout = MmaInputLayout::TT;
90
91 //! Utility to annotate which input of mma this option struct describes
92 Operand operand = Operand::A;
93
94 //! Accumulator register stride, will be removed when the swizzle op
95 //! is introduced and the output can be labeled with a transpose swizzle.
96 int accumulator_stride = 0;
97
98 bool operator==(const MmaOptions& other) const {
99 return macro == other.macro && operand_layout == other.operand_layout &&
100 operand == other.operand &&
101 accumulator_stride == other.accumulator_stride;
102 }
103
104 // The accumulator tensorview register supplied by the
105 // scheduler interface. Each mma builder is responsible
106 // for the parameters of one mma op, so the options struct
107 // would need a pointer to keep track of which mma op it
108 // is describing.
109 // Tracking mma expressions would not be stable as the expression
110 // can get deleted by mutate passes.
111 TensorView* accumulator_tv = nullptr;
112
113 //! Returns the mma op that this options parameter list
114 //! is describing. See comment on accumulator_tv.
115 MmaOp* mmaOp() const;
116};
117
118//! User interface for configuring the mma and mma related
119//! operators by specifying the mma instruction tile type
120//! input data layout, and the operand position of a tensor.
121class TORCH_CUDA_CU_API MmaBuilder {
122 public:
123 //! Initialized a mma builder, for the given mma instruction type.
124 //! TODO: the mma implementation is generic and should not have
125 //! strong dependency on the actual matmul tiling shapes. The
126 //! MatMulTileOptions provided in here is a WAR for mma format and
127 //! should be removed once there is support for labeling swizzles
128 //! on iterdomains.
129 MmaBuilder(MmaOptions::MacroType macro, MatMulTileOptions gemm_tile);
130
131 //! User configuration function:
132 //! Specifies the input matrix layout for the mma instruction.
133 //! see [Operand Layout Convention].
134 MmaBuilder& layout(MmaOptions::MmaInputLayout layout);
135
136 //! User configuration function:
137 //! Specifies which element in the mma op this builder is generating
138 //! parameters for, i.e. A or B. This is useful when generating
139 //! data swizzles for different elements of mma.
140 //! - Operand::Accumulator means the parameters describe accumulator in mma
141 //! op.
142 //! - This option is ignored when configuring the mma operator itself.
143 MmaBuilder& operand(MmaOptions::Operand a_or_b);
144
145 //! Generates the matching ldmatrix instruction type for the
146 //! specified mma option.
147 LoadStoreOpType ldMatrix() const;
148
149 //! Store the accumulator tv register reference in mma builder
150 //! to avoid automatic matching of which mma ops.
151 void accumulatorTv(TensorView* tv);
152
153 //! Fill in mma options in scheduling time.
154 //! Each mma op in Fusion IR must be configured once before lowering.
155 //! Mma options are configuration parameters used in lowering to mma
156 //! instrinsics, mainly the type of mma macro to use and input data layout
157 //! etc.
158 //!
159 //! TODO: This step will very likely be removed in a follow up PR. All of
160 //! the options configured here could actually be inferred from fusion IR
161 //! once we are feature complete.
162 void configureMma(TensorView* mma_output) const;
163
164 //! Export all the parameters with user's configurations applied.
165 MmaOptions build() const;
166
167 private:
168 MmaOptions option_;
169};
170
171//! GPU arch check for macro type
172bool isVolta(MmaOptions::MacroType macro);
173bool isTuring(MmaOptions::MacroType macro);
174bool isAmpere(MmaOptions::MacroType macro);
175
176//! Returns true if the given option describes a transposed operand
177bool isOperandTransposed(MmaOptions options);
178
179// Unpacked constants from macro type:
180// exact numbers are defined by each individual instruction.
181int getOutputRegisterSize(MmaOptions::MacroType macro);
182int getInputARegisterSize(MmaOptions::MacroType macro);
183int getInputBRegisterSize(MmaOptions::MacroType macro);
184
185// MMA stringify utils
186std::string toString(MmaOptions::MacroType macro);
187std::string toString(MmaOptions::MmaInputLayout input_layout);
188std::string toString(MmaOptions::MacroType mt);
189
190} // namespace cuda
191} // namespace fuser
192} // namespace jit
193} // namespace torch
194