1#include <fusion.h>
2#include <ir_all_nodes.h>
3#include <mma_type.h>
4
5namespace torch {
6namespace jit {
7namespace fuser {
8namespace cuda {
9
10MmaOp* MmaOptions::mmaOp() const {
11 TORCH_INTERNAL_ASSERT(
12 accumulator_tv != nullptr && accumulator_tv->definition() != nullptr,
13 "Invalid accumulator_tv.");
14 auto mma_op = dynamic_cast<MmaOp*>(accumulator_tv->definition());
15 TORCH_INTERNAL_ASSERT(
16 mma_op != nullptr, "accumulator tv not an output of mma op");
17 return mma_op;
18}
19
20MmaBuilder::MmaBuilder(
21 MmaOptions::MacroType macro,
22 MatMulTileOptions gemm_tile) {
23 option_.macro = macro;
24 // Calculate accumulator stride, will be removed once transpose swizzle ready
25 int outer_stride = gemm_tile.warp_tile.n / gemm_tile.instruction_tile.n;
26 switch (macro) {
27 // Numbers depend on actual output layout of mma instruction
28 case MmaOptions::MacroType::Volta_16_16_4:
29 option_.accumulator_stride = outer_stride * 4;
30 break;
31 case MmaOptions::MacroType::Turing_16_8_16:
32 case MmaOptions::MacroType::Ampere_16_8_16:
33 option_.accumulator_stride = outer_stride * 2;
34 break;
35 case MmaOptions::MacroType::Ampere_16_16_16:
36 case MmaOptions::MacroType::Turing_16_16_16:
37 option_.accumulator_stride = outer_stride * 4;
38 break;
39 default:
40 TORCH_CHECK(false, "unsupported macro");
41 break;
42 }
43}
44
45MmaBuilder& MmaBuilder::layout(MmaOptions::MmaInputLayout layout) {
46 option_.operand_layout = layout;
47 return *this;
48}
49
50MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {
51 option_.operand = a_or_b;
52 return *this;
53}
54
55// TODO: validate op config
56MmaOptions MmaBuilder::build() const {
57 TORCH_CHECK(
58 option_.accumulator_tv != nullptr,
59 "Please configure accumulator tv before using swizzle options.")
60 return option_;
61}
62
63void MmaBuilder::configureMma(TensorView* mma_output) const {
64 TORCH_CHECK(
65 mma_output->definition(),
66 "configureMma: invalid for input tensor ",
67 mma_output);
68 auto mma = dynamic_cast<MmaOp*>(mma_output->definition());
69 TORCH_CHECK(mma, "configureMma: invalid for non-mma output: ", mma_output);
70 mma->configureOptions(option_);
71}
72
73void MmaBuilder::accumulatorTv(TensorView* tv) {
74 TORCH_CHECK(
75 tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register");
76 TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv");
77 TORCH_CHECK(
78 tv->definition()->isA<MmaOp>(),
79 "Requires mma op output for reduction tv");
80 option_.accumulator_tv = tv;
81}
82
83namespace {
84
85// Utility to get ldmatrix direction a mma layout and operand
86LoadStoreOpType getLdMatrixType(MmaOptions options) {
87 bool transpose = false;
88 switch (options.macro) {
89 case MmaOptions::MacroType::Turing_16_8_16:
90 case MmaOptions::MacroType::Ampere_16_8_16:
91 case MmaOptions::MacroType::Ampere_16_16_16:
92 case MmaOptions::MacroType::Turing_16_16_16:
93 // Turing mma assumes TN as default
94 transpose = (options.operand == MmaOptions::Operand::A &&
95 !isOperandTransposed(options)) ||
96 (options.operand == MmaOptions::Operand::B &&
97 isOperandTransposed(options));
98 break;
99 default:
100 TORCH_INTERNAL_ASSERT(false, "unsupported op with ldmatrix");
101 break;
102 }
103 return transpose ? LoadStoreOpType::LdMatrixTranspose
104 : LoadStoreOpType::LdMatrix;
105}
106
107} // namespace
108
109LoadStoreOpType MmaBuilder::ldMatrix() const {
110 return getLdMatrixType(option_);
111}
112
113bool isVolta(MmaOptions::MacroType macro) {
114 return macro == MmaOptions::MacroType::Volta_16_16_4;
115}
116
117bool isTuring(MmaOptions::MacroType macro) {
118 return macro == MmaOptions::MacroType::Turing_16_8_16 ||
119 macro == MmaOptions::MacroType::Turing_16_16_16;
120}
121
122bool isAmpere(MmaOptions::MacroType macro) {
123 return macro == MmaOptions::MacroType::Ampere_16_8_16 ||
124 macro == MmaOptions::MacroType::Ampere_16_16_16;
125}
126
127int getOutputRegisterSize(MmaOptions::MacroType macro) {
128 switch (macro) {
129 case MmaOptions::MacroType::Volta_16_16_4:
130 case MmaOptions::MacroType::Ampere_16_16_16:
131 case MmaOptions::MacroType::Turing_16_16_16:
132 return 8;
133 break;
134 case MmaOptions::MacroType::Turing_16_8_16:
135 case MmaOptions::MacroType::Ampere_16_8_16:
136 return 4;
137 break;
138 default:
139 TORCH_INTERNAL_ASSERT(false, "unknown macro");
140 break;
141 }
142 return -1;
143}
144
145int getInputARegisterSize(MmaOptions::MacroType macro) {
146 switch (macro) {
147 case MmaOptions::MacroType::Volta_16_16_4:
148 return 4;
149 break;
150 case MmaOptions::MacroType::Turing_16_8_16:
151 case MmaOptions::MacroType::Turing_16_16_16:
152 case MmaOptions::MacroType::Ampere_16_8_16:
153 case MmaOptions::MacroType::Ampere_16_16_16:
154 return 8;
155 break;
156 default:
157 TORCH_INTERNAL_ASSERT(false, "unknown macro");
158 break;
159 }
160 return -1;
161}
162
163int getInputBRegisterSize(MmaOptions::MacroType macro) {
164 switch (macro) {
165 case MmaOptions::MacroType::Volta_16_16_4:
166 return 4;
167 break;
168 case MmaOptions::MacroType::Turing_16_8_16:
169 case MmaOptions::MacroType::Ampere_16_8_16:
170 return 4;
171 case MmaOptions::MacroType::Turing_16_16_16:
172 case MmaOptions::MacroType::Ampere_16_16_16:
173 return 8;
174 default:
175 TORCH_INTERNAL_ASSERT(false, "unknown macro");
176 break;
177 }
178 return -1;
179}
180
181bool isOperandTransposed(MmaOptions options) {
182 switch (options.operand) {
183 case MmaOptions::Operand::A:
184 return options.operand_layout == MmaOptions::MmaInputLayout::TT ||
185 options.operand_layout == MmaOptions::MmaInputLayout::TN;
186 case MmaOptions::Operand::B:
187 return options.operand_layout == MmaOptions::MmaInputLayout::TT ||
188 options.operand_layout == MmaOptions::MmaInputLayout::NT;
189 default:
190 TORCH_CHECK(false, "isOperandTransposed: please specify operand");
191 }
192 return false;
193}
194
195std::string toString(MmaOptions::MmaInputLayout input_layout) {
196 std::stringstream ss;
197 switch (input_layout) {
198 case MmaOptions::MmaInputLayout::TT:
199 ss << "TT";
200 break;
201 case MmaOptions::MmaInputLayout::TN:
202 ss << "TN";
203 break;
204 case MmaOptions::MmaInputLayout::NT:
205 ss << "NT";
206 break;
207 default:
208 TORCH_INTERNAL_ASSERT(false, "unsupported operand layout");
209 }
210 return ss.str();
211}
212
213std::string toString(MmaOptions::MacroType mt) {
214 std::stringstream ss;
215 switch (mt) {
216 case MmaOptions::MacroType::NoMMA:
217 ss << "NoOp";
218 break;
219 case MmaOptions::MacroType::Volta_16_16_4:
220 ss << "M16N16K4";
221 break;
222 case MmaOptions::MacroType::Turing_16_8_16:
223 case MmaOptions::MacroType::Ampere_16_8_16:
224 ss << "M16N8K16";
225 break;
226 case MmaOptions::MacroType::Turing_16_16_16:
227 case MmaOptions::MacroType::Ampere_16_16_16:
228 ss << "M16N16K16";
229 break;
230 default:
231 TORCH_INTERNAL_ASSERT(false, "undefined mma type");
232 break;
233 }
234 return ss.str();
235}
236
237} // namespace cuda
238} // namespace fuser
239} // namespace jit
240} // namespace torch
241