1// Generated from "/code/pytorch/third_party/nvfuser/runtime/tensorcore.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* tensorcore_cu = R"(
7// Utility macro for this file
8#define DEVICE_INLINE __device__ inline
9
10// MMA instruction wrappers:
11// The wrappers are subroutines that implement matrix of size
12// A(M,K) X B(K,N) = C(M,N)
13// The naming of the wrappers follow similar naming conventions
14// as the mma instructions.
15// All the mma macros follow the namespace and naming like
16// Arch::M (M-dim) N (N-dim) K(K-dim) (Layout), eg.
17// Volta::M16N16K4TT,
18// with the dimensions describing the size of the sub-matrices being
19// multiplied by this wrapper.
20// see [Operand Layout Convention] in mma_type.h for details on the layout
21// notation.
22namespace Volta {
23
24namespace util {
25// MMA instruction wrappers (sm_70+):
26// The instruction wrappers below are quarter-warp macros, which currently
27// nvfuser doesn't explicitly model.
28// So they are currently only meant to be
29// used as building blocks in warp level mma macros
30
31// 8x8x4 mma instruction, per quarter warp (8 threads), fp32 accumulate
32// per thread register:
33// A[4] x B[4] -> C[8]
34DEVICE_INLINE void mmaM8n8k4tt(
35 Array<float, 8, 8>* C,
36 Array<__half, 4, 4>* A,
37 Array<__half, 4, 4>* B) {
38 unsigned const* _A = reinterpret_cast<unsigned const*>(A);
39 unsigned const* _B = reinterpret_cast<unsigned const*>(B);
40 unsigned* _C = reinterpret_cast<unsigned*>(C);
41
42 asm("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n"
43 : "=r"(_C[0]),
44 "=r"(_C[1]),
45 "=r"(_C[2]),
46 "=r"(_C[3]),
47 "=r"(_C[4]),
48 "=r"(_C[5]),
49 "=r"(_C[6]),
50 "=r"(_C[7])
51 : "r"(_A[0]),
52 "r"(_A[1]),
53 "r"(_B[0]),
54 "r"(_B[1]),
55 "r"(_C[0]),
56 "r"(_C[1]),
57 "r"(_C[2]),
58 "r"(_C[3]),
59 "r"(_C[4]),
60 "r"(_C[5]),
61 "r"(_C[6]),
62 "r"(_C[7]));
63}
64
65DEVICE_INLINE void mmaM8n8k4tn(
66 Array<float, 8, 8>* C,
67 Array<__half, 4, 4>* A,
68 Array<__half, 4, 4>* B) {
69 unsigned const* _A = reinterpret_cast<unsigned const*>(A);
70 unsigned const* _B = reinterpret_cast<unsigned const*>(B);
71 unsigned* _C = reinterpret_cast<unsigned*>(C);
72
73 asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n"
74 : "=r"(_C[0]),
75 "=r"(_C[1]),
76 "=r"(_C[2]),
77 "=r"(_C[3]),
78 "=r"(_C[4]),
79 "=r"(_C[5]),
80 "=r"(_C[6]),
81 "=r"(_C[7])
82 : "r"(_A[0]),
83 "r"(_A[1]),
84 "r"(_B[0]),
85 "r"(_B[1]),
86 "r"(_C[0]),
87 "r"(_C[1]),
88 "r"(_C[2]),
89 "r"(_C[3]),
90 "r"(_C[4]),
91 "r"(_C[5]),
92 "r"(_C[6]),
93 "r"(_C[7]));
94}
95
96DEVICE_INLINE void mmaM8n8k4nt(
97 Array<float, 8, 8>* C,
98 Array<__half, 4, 4>* A,
99 Array<__half, 4, 4>* B) {
100 unsigned const* _A = reinterpret_cast<unsigned const*>(A);
101 unsigned const* _B = reinterpret_cast<unsigned const*>(B);
102 unsigned* _C = reinterpret_cast<unsigned*>(C);
103
104 asm("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n"
105 : "=r"(_C[0]),
106 "=r"(_C[1]),
107 "=r"(_C[2]),
108 "=r"(_C[3]),
109 "=r"(_C[4]),
110 "=r"(_C[5]),
111 "=r"(_C[6]),
112 "=r"(_C[7])
113 : "r"(_A[0]),
114 "r"(_A[1]),
115 "r"(_B[0]),
116 "r"(_B[1]),
117 "r"(_C[0]),
118 "r"(_C[1]),
119 "r"(_C[2]),
120 "r"(_C[3]),
121 "r"(_C[4]),
122 "r"(_C[5]),
123 "r"(_C[6]),
124 "r"(_C[7]));
125}
126
127// TODO: in a follow up,
128// lift this part onto iterdomain ops, once the
129// swizzle ops are ready.
130template <int acc_stride>
131DEVICE_INLINE Array<float, 8, 8> accToMma(float* _C) {
132 float C_data[8] = {
133 _C[0],
134 _C[1],
135 _C[acc_stride],
136 _C[acc_stride + 1],
137 _C[2],
138 _C[3],
139 _C[acc_stride + 2],
140 _C[acc_stride + 3],
141 };
142
143 return *reinterpret_cast<Array<float, 8, 8>*>(&C_data[0]);
144}
145
146template <int acc_stride>
147DEVICE_INLINE void mmaToAcc(float* _C, Array<float, 8, 8>& C) {
148 float* C_data = reinterpret_cast<float*>(&C);
149 _C[0] = C_data[0];
150 _C[1] = C_data[1];
151 _C[acc_stride] = C_data[2];
152 _C[acc_stride + 1] = C_data[3];
153 _C[2] = C_data[4];
154 _C[3] = C_data[5];
155 _C[acc_stride + 2] = C_data[6];
156 _C[acc_stride + 3] = C_data[7];
157}
158
159// Should be able to lift this with transpose op as well.
160template <int acc_stride>
161DEVICE_INLINE void initM16N16K4(Array<float, 8, 8>& accumulator) {
162 float* _C = reinterpret_cast<float*>(&accumulator);
163 float zeros[8] = {0, 0, 0, 0, 0, 0, 0, 0};
164 mmaToAcc<acc_stride>(_C, *reinterpret_cast<Array<float, 8, 8>*>(&zeros[0]));
165}
166
167} // namespace util
168
169template <int acc_stride>
170DEVICE_INLINE void M16N16K4TT(
171 Array<float, 8, 8>* C,
172 Array<__half, 4, 4>* A,
173 Array<__half, 4, 4>* B) {
174 float* _C = reinterpret_cast<float*>(C);
175 Array<float, 8, 8> C_data = util::accToMma<acc_stride>(_C);
176 util::mmaM8n8k4tt(&C_data, A, B);
177 util::mmaToAcc<acc_stride>(_C, C_data);
178}
179
180template <int acc_stride>
181DEVICE_INLINE void M16N16K4TN(
182 Array<float, 8, 8>* C,
183 Array<__half, 4, 4>* A,
184 Array<__half, 4, 4>* B) {
185 float* _C = reinterpret_cast<float*>(C);
186 Array<float, 8, 8> C_data = util::accToMma<acc_stride>(_C);
187 util::mmaM8n8k4tn(&C_data, A, B);
188 util::mmaToAcc<acc_stride>(_C, C_data);
189}
190
191template <int acc_stride>
192DEVICE_INLINE void M16N16K4NT(
193 Array<float, 8, 8>* C,
194 Array<__half, 4, 4>* A,
195 Array<__half, 4, 4>* B) {
196 float* _C = reinterpret_cast<float*>(C);
197 Array<float, 8, 8> C_data = util::accToMma<acc_stride>(_C);
198 util::mmaM8n8k4nt(&C_data, A, B);
199 util::mmaToAcc<acc_stride>(_C, C_data);
200}
201
202// Same initialization for now, will be different in interleaved
203// macros
204template <int acc_stride>
205DEVICE_INLINE void initM16N16K4TT(Array<float, 8, 8>* accumulator) {
206 util::initM16N16K4<acc_stride>(*accumulator);
207}
208
209template <int acc_stride>
210DEVICE_INLINE void initM16N16K4TN(Array<float, 8, 8>* accumulator) {
211 util::initM16N16K4<acc_stride>(*accumulator);
212}
213
214template <int acc_stride>
215DEVICE_INLINE void initM16N16K4NT(Array<float, 8, 8>* accumulator) {
216 util::initM16N16K4<acc_stride>(*accumulator);
217}
218
219} // namespace Volta
220
221#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
222
223namespace Turing {
224
225namespace util {
226// MMA instruction wrappers (sm_75+):
227DEVICE_INLINE void m16n8k16TN(
228 Array<float, 4, 4>* C,
229 Array<__half, 8, 8>* A,
230 Array<__half, 4, 4>* B) {
231 unsigned const* _A = reinterpret_cast<unsigned const*>(A);
232 unsigned const* _B = reinterpret_cast<unsigned const*>(B);
233 unsigned* _C = reinterpret_cast<unsigned*>(C);
234 const unsigned* _D = reinterpret_cast<const unsigned*>(C);
235
236 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
237 : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3])
238 : "r"(_A[0]),
239 "r"(_A[1]),
240 "r"(_B[0]),
241 "r"(_D[0]),
242 "r"(_D[1]),
243 "r"(_D[2]),
244 "r"(_D[3]));
245 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
246 : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3])
247 : "r"(_A[2]),
248 "r"(_A[3]),
249 "r"(_B[1]),
250 "r"(_D[0]),
251 "r"(_D[1]),
252 "r"(_D[2]),
253 "r"(_D[3]));
254}
255
256} // namespace util
257
258template <int acc_stride>
259DEVICE_INLINE void initM16N8K16TN(Array<float, 4, 4>* accumulator) {
260 float* _C = reinterpret_cast<float*>(accumulator);
261 _C[0] = 0;
262 _C[1] = 0;
263 _C[acc_stride] = 0;
264 _C[acc_stride + 1] = 0;
265}
266
267template <int acc_stride = 2>
268DEVICE_INLINE void M16N8K16TN(
269 Array<float, 4, 4>* C,
270 Array<__half, 8, 8>* A,
271 Array<__half, 4, 4>* B) {
272 // TODO: in a follow up,
273 // lift this fused swizzle onto iterdomain
274 float* _C = reinterpret_cast<float*>(C);
275 float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]};
276
277 util::m16n8k16TN(reinterpret_cast<Array<float, 4, 4>*>(&C_data[0]), A, B);
278
279 _C[0] = C_data[0];
280 _C[1] = C_data[1];
281 _C[acc_stride] = C_data[2];
282 _C[acc_stride + 1] = C_data[3];
283}
284
285template <int acc_stride>
286DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
287 float* _C = reinterpret_cast<float*>(accumulator);
288 initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
289 initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
290}
291
292template <int acc_stride = 2>
293DEVICE_INLINE void M16N16K16TN(
294 Array<float, 8, 8>* C,
295 Array<__half, 8, 8>* A,
296 Array<__half, 8, 8>* B) {
297 float* _C = reinterpret_cast<float*>(C);
298 __half* _B = reinterpret_cast<__half*>(B);
299 M16N8K16TN<acc_stride>(
300 reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
301 A,
302 reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
303 M16N8K16TN<acc_stride>(
304 reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
305 A,
306 reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
307}
308
309} // namespace Turing
310
311#endif // Arch 75
312
313#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
314
315namespace Ampere {
316
317namespace util {
318// MMA instruction wrappers (sm_75+):
319DEVICE_INLINE void m16n8k16TN(
320 Array<float, 4, 4>* C,
321 Array<__half, 8, 8>* A,
322 Array<__half, 4, 4>* B) {
323 unsigned const* _A = reinterpret_cast<unsigned const*>(A);
324 unsigned const* _B = reinterpret_cast<unsigned const*>(B);
325 unsigned* _C = reinterpret_cast<unsigned*>(C);
326 const unsigned* _D = reinterpret_cast<const unsigned*>(C);
327
328 asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
329 : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3])
330 : "r"(_A[0]),
331 "r"(_A[1]),
332 "r"(_A[2]),
333 "r"(_A[3]),
334 "r"(_B[0]),
335 "r"(_B[1]),
336 "r"(_D[0]),
337 "r"(_D[1]),
338 "r"(_D[2]),
339 "r"(_D[3]));
340}
341
342} // namespace util
343
344template <int acc_stride>
345DEVICE_INLINE void initM16N8K16TN(Array<float, 4, 4>* accumulator) {
346 float* _C = reinterpret_cast<float*>(accumulator);
347 _C[0] = 0;
348 _C[1] = 0;
349 _C[acc_stride] = 0;
350 _C[acc_stride + 1] = 0;
351}
352
353template <int acc_stride = 2>
354DEVICE_INLINE void M16N8K16TN(
355 Array<float, 4, 4>* C,
356 Array<__half, 8, 8>* A,
357 Array<__half, 4, 4>* B) {
358 // TODO: in a follow up,
359 // lift this fused swizzle onto iterdomain
360 float* _C = reinterpret_cast<float*>(C);
361 float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]};
362
363 util::m16n8k16TN(reinterpret_cast<Array<float, 4, 4>*>(&C_data[0]), A, B);
364
365 _C[0] = C_data[0];
366 _C[1] = C_data[1];
367 _C[acc_stride] = C_data[2];
368 _C[acc_stride + 1] = C_data[3];
369}
370
371template <int acc_stride>
372DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
373 float* _C = reinterpret_cast<float*>(accumulator);
374 initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
375 initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
376}
377
378template <int acc_stride = 2>
379DEVICE_INLINE void M16N16K16TN(
380 Array<float, 8, 8>* C,
381 Array<__half, 8, 8>* A,
382 Array<__half, 8, 8>* B) {
383 float* _C = reinterpret_cast<float*>(C);
384 __half* _B = reinterpret_cast<__half*>(B);
385 M16N8K16TN<acc_stride>(
386 reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
387 A,
388 reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
389 M16N8K16TN<acc_stride>(
390 reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
391 A,
392 reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
393}
394
395} // namespace Ampere
396
397#endif // Arch 80
398
399#undef DEVICE_INLINE
400)";
401
402} // namespace nvfuser_resources
403