1 | #pragma once |
2 | #include <c10/macros/Export.h> |
3 | #include <fusion.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | namespace fuser { |
8 | namespace cuda { |
9 | |
10 | //! Utility data structure for recording gemm tiles |
11 | struct GemmTile { |
12 | int m, n, k; |
13 | GemmTile(int m_, int n_, int k_) : m(m_), n(n_), k(k_) {} |
14 | |
15 | bool operator==(const GemmTile& other) { |
16 | return m == other.m && n == other.n && k == other.k; |
17 | } |
18 | |
19 | GemmTile operator/(const GemmTile& other) { |
20 | return GemmTile(m / other.m, n / other.n, k / other.k); |
21 | } |
22 | |
23 | std::vector<int> toVector() { |
24 | return {m, n, k}; |
25 | } |
26 | }; |
27 | |
28 | //! Utility data structure for recording gemm tiles |
29 | struct TORCH_CUDA_CU_API MatMulTileOptions { |
30 | GemmTile cta_tile = GemmTile(128, 128, 32); |
31 | GemmTile warp_tile = GemmTile(64, 64, 32); |
32 | GemmTile instruction_tile = GemmTile(16, 8, 16); |
33 | |
34 | MatMulTileOptions() = default; |
35 | MatMulTileOptions( |
36 | GemmTile cta_tile_, |
37 | GemmTile warp_tile_, |
38 | GemmTile instruction_tile_) |
39 | : cta_tile(cta_tile_), |
40 | warp_tile(warp_tile_), |
41 | instruction_tile(instruction_tile_) {} |
42 | |
43 | bool operator==(const MatMulTileOptions& other) { |
44 | return cta_tile == other.cta_tile && warp_tile == other.warp_tile && |
45 | instruction_tile == other.instruction_tile; |
46 | } |
47 | }; |
48 | |
49 | //! Information for configuring and lowering mma ops |
50 | struct MmaOptions { |
51 | //! Type of mma instrinsic macro to use |
52 | //! This will translate to which mma intrinsic from runtime string |
53 | //! to be generated to implement the mma op. The current plan |
54 | //! is to have exactly one macro for each |
55 | //! (arch, datatype, operand layout) triple, though there |
56 | //! exists multiple possibilities for some cases, e.g. for Turing and fp16 |
57 | //! one can use 16_8_8 or 16_8_16. |
58 | //! Will consider adding more choices that the scheduler can pick from |
59 | //! when our perf target becomes more fine grained, which is more likely in |
60 | //! latency bound kernels. |
61 | enum class MacroType { |
62 | NoMMA = 0, |
63 | Volta_16_16_4, |
64 | Ampere_16_8_16, |
65 | Ampere_16_16_16, |
66 | Turing_16_8_16, |
67 | Turing_16_16_16, |
68 | Ampere_16_8_8 // place holder for tf32 |
69 | }; |
70 | |
71 | //! [Operand Layout Convention] |
72 | //! Operand layout, T=transposed/row_major, N=normal/col_major |
73 | //! We don't support calling NN mma directly since it implies |
74 | //! a fused transpose. User needs to swap the operands and use |
75 | //! TT mma to make the transpose explicit. |
76 | //! Ordered by position of K |
77 | //! NT : K,M x K,N -> K,M,N |
78 | //! TT : M,K X K,N -> M,K,N |
79 | //! TN : M,K X N,K -> M,N,K |
80 | enum class MmaInputLayout { NT = 0, TT, TN }; |
81 | |
82 | //! Utility to annotate which input of mma this option struct describes |
83 | enum class Operand { Accumulator = 0, A, B }; |
84 | |
85 | //! Utility to annotate which mma macro this config uses. |
86 | MacroType macro = MacroType::NoMMA; |
87 | |
88 | //! Utility to annotate transposition of operands |
89 | MmaInputLayout operand_layout = MmaInputLayout::TT; |
90 | |
91 | //! Utility to annotate which input of mma this option struct describes |
92 | Operand operand = Operand::A; |
93 | |
94 | //! Accumulator register stride, will be removed when the swizzle op |
95 | //! is introduced and the output can be labeled with a transpose swizzle. |
96 | int accumulator_stride = 0; |
97 | |
98 | bool operator==(const MmaOptions& other) const { |
99 | return macro == other.macro && operand_layout == other.operand_layout && |
100 | operand == other.operand && |
101 | accumulator_stride == other.accumulator_stride; |
102 | } |
103 | |
104 | // The accumulator tensorview register supplied by the |
105 | // scheduler interface. Each mma builder is responsible |
106 | // for the parameters of one mma op, so the options struct |
107 | // would need a pointer to keep track of which mma op it |
108 | // is describing. |
109 | // Tracking mma expressions would not be stable as the expression |
110 | // can get deleted by mutate passes. |
111 | TensorView* accumulator_tv = nullptr; |
112 | |
113 | //! Returns the mma op that this options parameter list |
114 | //! is describing. See comment on accumulator_tv. |
115 | MmaOp* mmaOp() const; |
116 | }; |
117 | |
118 | //! User interface for configuring the mma and mma related |
119 | //! operators by specifying the mma instruction tile type |
120 | //! input data layout, and the operand position of a tensor. |
121 | class TORCH_CUDA_CU_API MmaBuilder { |
122 | public: |
123 | //! Initialized a mma builder, for the given mma instruction type. |
124 | //! TODO: the mma implementation is generic and should not have |
125 | //! strong dependency on the actual matmul tiling shapes. The |
126 | //! MatMulTileOptions provided in here is a WAR for mma format and |
127 | //! should be removed once there is support for labeling swizzles |
128 | //! on iterdomains. |
129 | MmaBuilder(MmaOptions::MacroType macro, MatMulTileOptions gemm_tile); |
130 | |
131 | //! User configuration function: |
132 | //! Specifies the input matrix layout for the mma instruction. |
133 | //! see [Operand Layout Convention]. |
134 | MmaBuilder& layout(MmaOptions::MmaInputLayout layout); |
135 | |
136 | //! User configuration function: |
137 | //! Specifies which element in the mma op this builder is generating |
138 | //! parameters for, i.e. A or B. This is useful when generating |
139 | //! data swizzles for different elements of mma. |
140 | //! - Operand::Accumulator means the parameters describe accumulator in mma |
141 | //! op. |
142 | //! - This option is ignored when configuring the mma operator itself. |
143 | MmaBuilder& operand(MmaOptions::Operand a_or_b); |
144 | |
145 | //! Generates the matching ldmatrix instruction type for the |
146 | //! specified mma option. |
147 | LoadStoreOpType ldMatrix() const; |
148 | |
149 | //! Store the accumulator tv register reference in mma builder |
150 | //! to avoid automatic matching of which mma ops. |
151 | void accumulatorTv(TensorView* tv); |
152 | |
153 | //! Fill in mma options in scheduling time. |
154 | //! Each mma op in Fusion IR must be configured once before lowering. |
155 | //! Mma options are configuration parameters used in lowering to mma |
156 | //! instrinsics, mainly the type of mma macro to use and input data layout |
157 | //! etc. |
158 | //! |
159 | //! TODO: This step will very likely be removed in a follow up PR. All of |
160 | //! the options configured here could actually be inferred from fusion IR |
161 | //! once we are feature complete. |
162 | void configureMma(TensorView* mma_output) const; |
163 | |
164 | //! Export all the parameters with user's configurations applied. |
165 | MmaOptions build() const; |
166 | |
167 | private: |
168 | MmaOptions option_; |
169 | }; |
170 | |
171 | //! GPU arch check for macro type |
172 | bool isVolta(MmaOptions::MacroType macro); |
173 | bool isTuring(MmaOptions::MacroType macro); |
174 | bool isAmpere(MmaOptions::MacroType macro); |
175 | |
176 | //! Returns true if the given option describes a transposed operand |
177 | bool isOperandTransposed(MmaOptions options); |
178 | |
179 | // Unpacked constants from macro type: |
180 | // exact numbers are defined by each individual instruction. |
181 | int getOutputRegisterSize(MmaOptions::MacroType macro); |
182 | int getInputARegisterSize(MmaOptions::MacroType macro); |
183 | int getInputBRegisterSize(MmaOptions::MacroType macro); |
184 | |
185 | // MMA stringify utils |
186 | std::string toString(MmaOptions::MacroType macro); |
187 | std::string toString(MmaOptions::MmaInputLayout input_layout); |
188 | std::string toString(MmaOptions::MacroType mt); |
189 | |
190 | } // namespace cuda |
191 | } // namespace fuser |
192 | } // namespace jit |
193 | } // namespace torch |
194 | |