1 | #include <fusion.h> |
2 | #include <ir_all_nodes.h> |
3 | #include <mma_type.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | namespace fuser { |
8 | namespace cuda { |
9 | |
10 | MmaOp* MmaOptions::mmaOp() const { |
11 | TORCH_INTERNAL_ASSERT( |
12 | accumulator_tv != nullptr && accumulator_tv->definition() != nullptr, |
13 | "Invalid accumulator_tv." ); |
14 | auto mma_op = dynamic_cast<MmaOp*>(accumulator_tv->definition()); |
15 | TORCH_INTERNAL_ASSERT( |
16 | mma_op != nullptr, "accumulator tv not an output of mma op" ); |
17 | return mma_op; |
18 | } |
19 | |
20 | MmaBuilder::MmaBuilder( |
21 | MmaOptions::MacroType macro, |
22 | MatMulTileOptions gemm_tile) { |
23 | option_.macro = macro; |
24 | // Calculate accumulator stride, will be removed once transpose swizzle ready |
25 | int outer_stride = gemm_tile.warp_tile.n / gemm_tile.instruction_tile.n; |
26 | switch (macro) { |
27 | // Numbers depend on actual output layout of mma instruction |
28 | case MmaOptions::MacroType::Volta_16_16_4: |
29 | option_.accumulator_stride = outer_stride * 4; |
30 | break; |
31 | case MmaOptions::MacroType::Turing_16_8_16: |
32 | case MmaOptions::MacroType::Ampere_16_8_16: |
33 | option_.accumulator_stride = outer_stride * 2; |
34 | break; |
35 | case MmaOptions::MacroType::Ampere_16_16_16: |
36 | case MmaOptions::MacroType::Turing_16_16_16: |
37 | option_.accumulator_stride = outer_stride * 4; |
38 | break; |
39 | default: |
40 | TORCH_CHECK(false, "unsupported macro" ); |
41 | break; |
42 | } |
43 | } |
44 | |
45 | MmaBuilder& MmaBuilder::layout(MmaOptions::MmaInputLayout layout) { |
46 | option_.operand_layout = layout; |
47 | return *this; |
48 | } |
49 | |
50 | MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) { |
51 | option_.operand = a_or_b; |
52 | return *this; |
53 | } |
54 | |
55 | // TODO: validate op config |
56 | MmaOptions MmaBuilder::build() const { |
57 | TORCH_CHECK( |
58 | option_.accumulator_tv != nullptr, |
59 | "Please configure accumulator tv before using swizzle options." ) |
60 | return option_; |
61 | } |
62 | |
63 | void MmaBuilder::configureMma(TensorView* mma_output) const { |
64 | TORCH_CHECK( |
65 | mma_output->definition(), |
66 | "configureMma: invalid for input tensor " , |
67 | mma_output); |
68 | auto mma = dynamic_cast<MmaOp*>(mma_output->definition()); |
69 | TORCH_CHECK(mma, "configureMma: invalid for non-mma output: " , mma_output); |
70 | mma->configureOptions(option_); |
71 | } |
72 | |
73 | void MmaBuilder::accumulatorTv(TensorView* tv) { |
74 | TORCH_CHECK( |
75 | tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register" ); |
76 | TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv" ); |
77 | TORCH_CHECK( |
78 | tv->definition()->isA<MmaOp>(), |
79 | "Requires mma op output for reduction tv" ); |
80 | option_.accumulator_tv = tv; |
81 | } |
82 | |
83 | namespace { |
84 | |
85 | // Utility to get ldmatrix direction a mma layout and operand |
86 | LoadStoreOpType getLdMatrixType(MmaOptions options) { |
87 | bool transpose = false; |
88 | switch (options.macro) { |
89 | case MmaOptions::MacroType::Turing_16_8_16: |
90 | case MmaOptions::MacroType::Ampere_16_8_16: |
91 | case MmaOptions::MacroType::Ampere_16_16_16: |
92 | case MmaOptions::MacroType::Turing_16_16_16: |
93 | // Turing mma assumes TN as default |
94 | transpose = (options.operand == MmaOptions::Operand::A && |
95 | !isOperandTransposed(options)) || |
96 | (options.operand == MmaOptions::Operand::B && |
97 | isOperandTransposed(options)); |
98 | break; |
99 | default: |
100 | TORCH_INTERNAL_ASSERT(false, "unsupported op with ldmatrix" ); |
101 | break; |
102 | } |
103 | return transpose ? LoadStoreOpType::LdMatrixTranspose |
104 | : LoadStoreOpType::LdMatrix; |
105 | } |
106 | |
107 | } // namespace |
108 | |
109 | LoadStoreOpType MmaBuilder::ldMatrix() const { |
110 | return getLdMatrixType(option_); |
111 | } |
112 | |
113 | bool isVolta(MmaOptions::MacroType macro) { |
114 | return macro == MmaOptions::MacroType::Volta_16_16_4; |
115 | } |
116 | |
117 | bool isTuring(MmaOptions::MacroType macro) { |
118 | return macro == MmaOptions::MacroType::Turing_16_8_16 || |
119 | macro == MmaOptions::MacroType::Turing_16_16_16; |
120 | } |
121 | |
122 | bool isAmpere(MmaOptions::MacroType macro) { |
123 | return macro == MmaOptions::MacroType::Ampere_16_8_16 || |
124 | macro == MmaOptions::MacroType::Ampere_16_16_16; |
125 | } |
126 | |
127 | int getOutputRegisterSize(MmaOptions::MacroType macro) { |
128 | switch (macro) { |
129 | case MmaOptions::MacroType::Volta_16_16_4: |
130 | case MmaOptions::MacroType::Ampere_16_16_16: |
131 | case MmaOptions::MacroType::Turing_16_16_16: |
132 | return 8; |
133 | break; |
134 | case MmaOptions::MacroType::Turing_16_8_16: |
135 | case MmaOptions::MacroType::Ampere_16_8_16: |
136 | return 4; |
137 | break; |
138 | default: |
139 | TORCH_INTERNAL_ASSERT(false, "unknown macro" ); |
140 | break; |
141 | } |
142 | return -1; |
143 | } |
144 | |
145 | int getInputARegisterSize(MmaOptions::MacroType macro) { |
146 | switch (macro) { |
147 | case MmaOptions::MacroType::Volta_16_16_4: |
148 | return 4; |
149 | break; |
150 | case MmaOptions::MacroType::Turing_16_8_16: |
151 | case MmaOptions::MacroType::Turing_16_16_16: |
152 | case MmaOptions::MacroType::Ampere_16_8_16: |
153 | case MmaOptions::MacroType::Ampere_16_16_16: |
154 | return 8; |
155 | break; |
156 | default: |
157 | TORCH_INTERNAL_ASSERT(false, "unknown macro" ); |
158 | break; |
159 | } |
160 | return -1; |
161 | } |
162 | |
163 | int getInputBRegisterSize(MmaOptions::MacroType macro) { |
164 | switch (macro) { |
165 | case MmaOptions::MacroType::Volta_16_16_4: |
166 | return 4; |
167 | break; |
168 | case MmaOptions::MacroType::Turing_16_8_16: |
169 | case MmaOptions::MacroType::Ampere_16_8_16: |
170 | return 4; |
171 | case MmaOptions::MacroType::Turing_16_16_16: |
172 | case MmaOptions::MacroType::Ampere_16_16_16: |
173 | return 8; |
174 | default: |
175 | TORCH_INTERNAL_ASSERT(false, "unknown macro" ); |
176 | break; |
177 | } |
178 | return -1; |
179 | } |
180 | |
181 | bool isOperandTransposed(MmaOptions options) { |
182 | switch (options.operand) { |
183 | case MmaOptions::Operand::A: |
184 | return options.operand_layout == MmaOptions::MmaInputLayout::TT || |
185 | options.operand_layout == MmaOptions::MmaInputLayout::TN; |
186 | case MmaOptions::Operand::B: |
187 | return options.operand_layout == MmaOptions::MmaInputLayout::TT || |
188 | options.operand_layout == MmaOptions::MmaInputLayout::NT; |
189 | default: |
190 | TORCH_CHECK(false, "isOperandTransposed: please specify operand" ); |
191 | } |
192 | return false; |
193 | } |
194 | |
195 | std::string toString(MmaOptions::MmaInputLayout input_layout) { |
196 | std::stringstream ss; |
197 | switch (input_layout) { |
198 | case MmaOptions::MmaInputLayout::TT: |
199 | ss << "TT" ; |
200 | break; |
201 | case MmaOptions::MmaInputLayout::TN: |
202 | ss << "TN" ; |
203 | break; |
204 | case MmaOptions::MmaInputLayout::NT: |
205 | ss << "NT" ; |
206 | break; |
207 | default: |
208 | TORCH_INTERNAL_ASSERT(false, "unsupported operand layout" ); |
209 | } |
210 | return ss.str(); |
211 | } |
212 | |
213 | std::string toString(MmaOptions::MacroType mt) { |
214 | std::stringstream ss; |
215 | switch (mt) { |
216 | case MmaOptions::MacroType::NoMMA: |
217 | ss << "NoOp" ; |
218 | break; |
219 | case MmaOptions::MacroType::Volta_16_16_4: |
220 | ss << "M16N16K4" ; |
221 | break; |
222 | case MmaOptions::MacroType::Turing_16_8_16: |
223 | case MmaOptions::MacroType::Ampere_16_8_16: |
224 | ss << "M16N8K16" ; |
225 | break; |
226 | case MmaOptions::MacroType::Turing_16_16_16: |
227 | case MmaOptions::MacroType::Ampere_16_16_16: |
228 | ss << "M16N16K16" ; |
229 | break; |
230 | default: |
231 | TORCH_INTERNAL_ASSERT(false, "undefined mma type" ); |
232 | break; |
233 | } |
234 | return ss.str(); |
235 | } |
236 | |
237 | } // namespace cuda |
238 | } // namespace fuser |
239 | } // namespace jit |
240 | } // namespace torch |
241 | |