1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/tensorcore.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* tensorcore_cu = R"( |
7 | // Utility macro for this file |
8 | #define DEVICE_INLINE __device__ inline |
9 | |
10 | // MMA instruction wrappers: |
11 | // The wrappers are subroutines that implement matrix of size |
12 | // A(M,K) X B(K,N) = C(M,N) |
13 | // The naming of the wrappers follow similar naming conventions |
14 | // as the mma instructions. |
15 | // All the mma macros follow the namespace and naming like |
16 | // Arch::M (M-dim) N (N-dim) K(K-dim) (Layout), eg. |
17 | // Volta::M16N16K4TT, |
18 | // with the dimensions describing the size of the sub-matrices being |
19 | // multiplied by this wrapper. |
20 | // see [Operand Layout Convention] in mma_type.h for details on the layout |
21 | // notation. |
22 | namespace Volta { |
23 | |
24 | namespace util { |
25 | // MMA instruction wrappers (sm_70+): |
26 | // The instruction wrappers below are quarter-warp macros, which currently |
27 | // nvfuser doesn't explicitly model. |
28 | // So they are currently only meant to be |
29 | // used as building blocks in warp level mma macros |
30 | |
31 | // 8x8x4 mma instruction, per quarter warp (8 threads), fp32 accumulate |
32 | // per thread register: |
33 | // A[4] x B[4] -> C[8] |
34 | DEVICE_INLINE void mmaM8n8k4tt( |
35 | Array<float, 8, 8>* C, |
36 | Array<__half, 4, 4>* A, |
37 | Array<__half, 4, 4>* B) { |
38 | unsigned const* _A = reinterpret_cast<unsigned const*>(A); |
39 | unsigned const* _B = reinterpret_cast<unsigned const*>(B); |
40 | unsigned* _C = reinterpret_cast<unsigned*>(C); |
41 | |
42 | asm("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" |
43 | : "=r"(_C[0]), |
44 | "=r"(_C[1]), |
45 | "=r"(_C[2]), |
46 | "=r"(_C[3]), |
47 | "=r"(_C[4]), |
48 | "=r"(_C[5]), |
49 | "=r"(_C[6]), |
50 | "=r"(_C[7]) |
51 | : "r"(_A[0]), |
52 | "r"(_A[1]), |
53 | "r"(_B[0]), |
54 | "r"(_B[1]), |
55 | "r"(_C[0]), |
56 | "r"(_C[1]), |
57 | "r"(_C[2]), |
58 | "r"(_C[3]), |
59 | "r"(_C[4]), |
60 | "r"(_C[5]), |
61 | "r"(_C[6]), |
62 | "r"(_C[7])); |
63 | } |
64 | |
65 | DEVICE_INLINE void mmaM8n8k4tn( |
66 | Array<float, 8, 8>* C, |
67 | Array<__half, 4, 4>* A, |
68 | Array<__half, 4, 4>* B) { |
69 | unsigned const* _A = reinterpret_cast<unsigned const*>(A); |
70 | unsigned const* _B = reinterpret_cast<unsigned const*>(B); |
71 | unsigned* _C = reinterpret_cast<unsigned*>(C); |
72 | |
73 | asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" |
74 | : "=r"(_C[0]), |
75 | "=r"(_C[1]), |
76 | "=r"(_C[2]), |
77 | "=r"(_C[3]), |
78 | "=r"(_C[4]), |
79 | "=r"(_C[5]), |
80 | "=r"(_C[6]), |
81 | "=r"(_C[7]) |
82 | : "r"(_A[0]), |
83 | "r"(_A[1]), |
84 | "r"(_B[0]), |
85 | "r"(_B[1]), |
86 | "r"(_C[0]), |
87 | "r"(_C[1]), |
88 | "r"(_C[2]), |
89 | "r"(_C[3]), |
90 | "r"(_C[4]), |
91 | "r"(_C[5]), |
92 | "r"(_C[6]), |
93 | "r"(_C[7])); |
94 | } |
95 | |
96 | DEVICE_INLINE void mmaM8n8k4nt( |
97 | Array<float, 8, 8>* C, |
98 | Array<__half, 4, 4>* A, |
99 | Array<__half, 4, 4>* B) { |
100 | unsigned const* _A = reinterpret_cast<unsigned const*>(A); |
101 | unsigned const* _B = reinterpret_cast<unsigned const*>(B); |
102 | unsigned* _C = reinterpret_cast<unsigned*>(C); |
103 | |
104 | asm("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" |
105 | : "=r"(_C[0]), |
106 | "=r"(_C[1]), |
107 | "=r"(_C[2]), |
108 | "=r"(_C[3]), |
109 | "=r"(_C[4]), |
110 | "=r"(_C[5]), |
111 | "=r"(_C[6]), |
112 | "=r"(_C[7]) |
113 | : "r"(_A[0]), |
114 | "r"(_A[1]), |
115 | "r"(_B[0]), |
116 | "r"(_B[1]), |
117 | "r"(_C[0]), |
118 | "r"(_C[1]), |
119 | "r"(_C[2]), |
120 | "r"(_C[3]), |
121 | "r"(_C[4]), |
122 | "r"(_C[5]), |
123 | "r"(_C[6]), |
124 | "r"(_C[7])); |
125 | } |
126 | |
127 | // TODO: in a follow up, |
128 | // lift this part onto iterdomain ops, once the |
129 | // swizzle ops are ready. |
130 | template <int acc_stride> |
131 | DEVICE_INLINE Array<float, 8, 8> accToMma(float* _C) { |
132 | float C_data[8] = { |
133 | _C[0], |
134 | _C[1], |
135 | _C[acc_stride], |
136 | _C[acc_stride + 1], |
137 | _C[2], |
138 | _C[3], |
139 | _C[acc_stride + 2], |
140 | _C[acc_stride + 3], |
141 | }; |
142 | |
143 | return *reinterpret_cast<Array<float, 8, 8>*>(&C_data[0]); |
144 | } |
145 | |
146 | template <int acc_stride> |
147 | DEVICE_INLINE void mmaToAcc(float* _C, Array<float, 8, 8>& C) { |
148 | float* C_data = reinterpret_cast<float*>(&C); |
149 | _C[0] = C_data[0]; |
150 | _C[1] = C_data[1]; |
151 | _C[acc_stride] = C_data[2]; |
152 | _C[acc_stride + 1] = C_data[3]; |
153 | _C[2] = C_data[4]; |
154 | _C[3] = C_data[5]; |
155 | _C[acc_stride + 2] = C_data[6]; |
156 | _C[acc_stride + 3] = C_data[7]; |
157 | } |
158 | |
159 | // Should be able to lift this with transpose op as well. |
160 | template <int acc_stride> |
161 | DEVICE_INLINE void initM16N16K4(Array<float, 8, 8>& accumulator) { |
162 | float* _C = reinterpret_cast<float*>(&accumulator); |
163 | float zeros[8] = {0, 0, 0, 0, 0, 0, 0, 0}; |
164 | mmaToAcc<acc_stride>(_C, *reinterpret_cast<Array<float, 8, 8>*>(&zeros[0])); |
165 | } |
166 | |
167 | } // namespace util |
168 | |
169 | template <int acc_stride> |
170 | DEVICE_INLINE void M16N16K4TT( |
171 | Array<float, 8, 8>* C, |
172 | Array<__half, 4, 4>* A, |
173 | Array<__half, 4, 4>* B) { |
174 | float* _C = reinterpret_cast<float*>(C); |
175 | Array<float, 8, 8> C_data = util::accToMma<acc_stride>(_C); |
176 | util::mmaM8n8k4tt(&C_data, A, B); |
177 | util::mmaToAcc<acc_stride>(_C, C_data); |
178 | } |
179 | |
180 | template <int acc_stride> |
181 | DEVICE_INLINE void M16N16K4TN( |
182 | Array<float, 8, 8>* C, |
183 | Array<__half, 4, 4>* A, |
184 | Array<__half, 4, 4>* B) { |
185 | float* _C = reinterpret_cast<float*>(C); |
186 | Array<float, 8, 8> C_data = util::accToMma<acc_stride>(_C); |
187 | util::mmaM8n8k4tn(&C_data, A, B); |
188 | util::mmaToAcc<acc_stride>(_C, C_data); |
189 | } |
190 | |
191 | template <int acc_stride> |
192 | DEVICE_INLINE void M16N16K4NT( |
193 | Array<float, 8, 8>* C, |
194 | Array<__half, 4, 4>* A, |
195 | Array<__half, 4, 4>* B) { |
196 | float* _C = reinterpret_cast<float*>(C); |
197 | Array<float, 8, 8> C_data = util::accToMma<acc_stride>(_C); |
198 | util::mmaM8n8k4nt(&C_data, A, B); |
199 | util::mmaToAcc<acc_stride>(_C, C_data); |
200 | } |
201 | |
202 | // Same initialization for now, will be different in interleaved |
203 | // macros |
204 | template <int acc_stride> |
205 | DEVICE_INLINE void initM16N16K4TT(Array<float, 8, 8>* accumulator) { |
206 | util::initM16N16K4<acc_stride>(*accumulator); |
207 | } |
208 | |
209 | template <int acc_stride> |
210 | DEVICE_INLINE void initM16N16K4TN(Array<float, 8, 8>* accumulator) { |
211 | util::initM16N16K4<acc_stride>(*accumulator); |
212 | } |
213 | |
214 | template <int acc_stride> |
215 | DEVICE_INLINE void initM16N16K4NT(Array<float, 8, 8>* accumulator) { |
216 | util::initM16N16K4<acc_stride>(*accumulator); |
217 | } |
218 | |
219 | } // namespace Volta |
220 | |
221 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) |
222 | |
223 | namespace Turing { |
224 | |
225 | namespace util { |
226 | // MMA instruction wrappers (sm_75+): |
227 | DEVICE_INLINE void m16n8k16TN( |
228 | Array<float, 4, 4>* C, |
229 | Array<__half, 8, 8>* A, |
230 | Array<__half, 4, 4>* B) { |
231 | unsigned const* _A = reinterpret_cast<unsigned const*>(A); |
232 | unsigned const* _B = reinterpret_cast<unsigned const*>(B); |
233 | unsigned* _C = reinterpret_cast<unsigned*>(C); |
234 | const unsigned* _D = reinterpret_cast<const unsigned*>(C); |
235 | |
236 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" |
237 | : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) |
238 | : "r"(_A[0]), |
239 | "r"(_A[1]), |
240 | "r"(_B[0]), |
241 | "r"(_D[0]), |
242 | "r"(_D[1]), |
243 | "r"(_D[2]), |
244 | "r"(_D[3])); |
245 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" |
246 | : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) |
247 | : "r"(_A[2]), |
248 | "r"(_A[3]), |
249 | "r"(_B[1]), |
250 | "r"(_D[0]), |
251 | "r"(_D[1]), |
252 | "r"(_D[2]), |
253 | "r"(_D[3])); |
254 | } |
255 | |
256 | } // namespace util |
257 | |
258 | template <int acc_stride> |
259 | DEVICE_INLINE void initM16N8K16TN(Array<float, 4, 4>* accumulator) { |
260 | float* _C = reinterpret_cast<float*>(accumulator); |
261 | _C[0] = 0; |
262 | _C[1] = 0; |
263 | _C[acc_stride] = 0; |
264 | _C[acc_stride + 1] = 0; |
265 | } |
266 | |
267 | template <int acc_stride = 2> |
268 | DEVICE_INLINE void M16N8K16TN( |
269 | Array<float, 4, 4>* C, |
270 | Array<__half, 8, 8>* A, |
271 | Array<__half, 4, 4>* B) { |
272 | // TODO: in a follow up, |
273 | // lift this fused swizzle onto iterdomain |
274 | float* _C = reinterpret_cast<float*>(C); |
275 | float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]}; |
276 | |
277 | util::m16n8k16TN(reinterpret_cast<Array<float, 4, 4>*>(&C_data[0]), A, B); |
278 | |
279 | _C[0] = C_data[0]; |
280 | _C[1] = C_data[1]; |
281 | _C[acc_stride] = C_data[2]; |
282 | _C[acc_stride + 1] = C_data[3]; |
283 | } |
284 | |
285 | template <int acc_stride> |
286 | DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) { |
287 | float* _C = reinterpret_cast<float*>(accumulator); |
288 | initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0])); |
289 | initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2])); |
290 | } |
291 | |
292 | template <int acc_stride = 2> |
293 | DEVICE_INLINE void M16N16K16TN( |
294 | Array<float, 8, 8>* C, |
295 | Array<__half, 8, 8>* A, |
296 | Array<__half, 8, 8>* B) { |
297 | float* _C = reinterpret_cast<float*>(C); |
298 | __half* _B = reinterpret_cast<__half*>(B); |
299 | M16N8K16TN<acc_stride>( |
300 | reinterpret_cast<Array<float, 4, 4>*>(&_C[0]), |
301 | A, |
302 | reinterpret_cast<Array<__half, 4, 4>*>(&_B[0])); |
303 | M16N8K16TN<acc_stride>( |
304 | reinterpret_cast<Array<float, 4, 4>*>(&_C[2]), |
305 | A, |
306 | reinterpret_cast<Array<__half, 4, 4>*>(&_B[4])); |
307 | } |
308 | |
309 | } // namespace Turing |
310 | |
311 | #endif // Arch 75 |
312 | |
313 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) |
314 | |
315 | namespace Ampere { |
316 | |
317 | namespace util { |
318 | // MMA instruction wrappers (sm_75+): |
319 | DEVICE_INLINE void m16n8k16TN( |
320 | Array<float, 4, 4>* C, |
321 | Array<__half, 8, 8>* A, |
322 | Array<__half, 4, 4>* B) { |
323 | unsigned const* _A = reinterpret_cast<unsigned const*>(A); |
324 | unsigned const* _B = reinterpret_cast<unsigned const*>(B); |
325 | unsigned* _C = reinterpret_cast<unsigned*>(C); |
326 | const unsigned* _D = reinterpret_cast<const unsigned*>(C); |
327 | |
328 | asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" |
329 | : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) |
330 | : "r"(_A[0]), |
331 | "r"(_A[1]), |
332 | "r"(_A[2]), |
333 | "r"(_A[3]), |
334 | "r"(_B[0]), |
335 | "r"(_B[1]), |
336 | "r"(_D[0]), |
337 | "r"(_D[1]), |
338 | "r"(_D[2]), |
339 | "r"(_D[3])); |
340 | } |
341 | |
342 | } // namespace util |
343 | |
344 | template <int acc_stride> |
345 | DEVICE_INLINE void initM16N8K16TN(Array<float, 4, 4>* accumulator) { |
346 | float* _C = reinterpret_cast<float*>(accumulator); |
347 | _C[0] = 0; |
348 | _C[1] = 0; |
349 | _C[acc_stride] = 0; |
350 | _C[acc_stride + 1] = 0; |
351 | } |
352 | |
353 | template <int acc_stride = 2> |
354 | DEVICE_INLINE void M16N8K16TN( |
355 | Array<float, 4, 4>* C, |
356 | Array<__half, 8, 8>* A, |
357 | Array<__half, 4, 4>* B) { |
358 | // TODO: in a follow up, |
359 | // lift this fused swizzle onto iterdomain |
360 | float* _C = reinterpret_cast<float*>(C); |
361 | float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]}; |
362 | |
363 | util::m16n8k16TN(reinterpret_cast<Array<float, 4, 4>*>(&C_data[0]), A, B); |
364 | |
365 | _C[0] = C_data[0]; |
366 | _C[1] = C_data[1]; |
367 | _C[acc_stride] = C_data[2]; |
368 | _C[acc_stride + 1] = C_data[3]; |
369 | } |
370 | |
371 | template <int acc_stride> |
372 | DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) { |
373 | float* _C = reinterpret_cast<float*>(accumulator); |
374 | initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0])); |
375 | initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2])); |
376 | } |
377 | |
378 | template <int acc_stride = 2> |
379 | DEVICE_INLINE void M16N16K16TN( |
380 | Array<float, 8, 8>* C, |
381 | Array<__half, 8, 8>* A, |
382 | Array<__half, 8, 8>* B) { |
383 | float* _C = reinterpret_cast<float*>(C); |
384 | __half* _B = reinterpret_cast<__half*>(B); |
385 | M16N8K16TN<acc_stride>( |
386 | reinterpret_cast<Array<float, 4, 4>*>(&_C[0]), |
387 | A, |
388 | reinterpret_cast<Array<__half, 4, 4>*>(&_B[0])); |
389 | M16N8K16TN<acc_stride>( |
390 | reinterpret_cast<Array<float, 4, 4>*>(&_C[2]), |
391 | A, |
392 | reinterpret_cast<Array<__half, 4, 4>*>(&_B[4])); |
393 | } |
394 | |
395 | } // namespace Ampere |
396 | |
397 | #endif // Arch 80 |
398 | |
399 | #undef DEVICE_INLINE |
400 | )" ; |
401 | |
402 | } // namespace nvfuser_resources |
403 | |