1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_cuda.cc |
22 | */ |
23 | |
24 | #include "codegen_cuda.h" |
25 | |
26 | #include <tvm/arith/analyzer.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/tir/index_map.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | |
31 | #include <cmath> |
32 | #include <string> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | #include "literal/cuda_half_t.h" |
37 | #include "ptx.h" |
38 | |
39 | namespace tvm { |
40 | namespace codegen { |
41 | |
42 | CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__" ; } |
43 | |
44 | void CodeGenCUDA::Init(bool output_ssa) { |
45 | CodeGenC::Init(output_ssa); |
46 | vid_global_barrier_state_ = name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state); |
47 | vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect" ); |
48 | ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); |
49 | } |
50 | |
51 | void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void" ; } |
52 | |
53 | class : public tir::StmtVisitor { |
54 | private: |
55 | void (const AttrStmtNode* op) final { |
56 | if (op->attr_key == tir::attr::thread_extent) { |
57 | IterVar iv = Downcast<IterVar>(op->node); |
58 | if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x" ) { |
59 | threadIdx_x_ext = op->value; |
60 | } |
61 | if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y" ) { |
62 | threadIdx_y_ext = op->value; |
63 | } |
64 | if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z" ) { |
65 | threadIdx_z_ext = op->value; |
66 | } |
67 | } |
68 | StmtVisitor::VisitStmt_(op); |
69 | } |
70 | |
71 | public: |
72 | PrimExpr = Integer(1); |
73 | PrimExpr = Integer(1); |
74 | PrimExpr = Integer(1); |
75 | }; |
76 | |
77 | void CodeGenCUDA::(const PrimFunc& f) { |
78 | ThreadIdxExtractor ; |
79 | extractor(f->body); |
80 | arith::Analyzer analyzer; |
81 | PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * |
82 | extractor.threadIdx_z_ext); |
83 | if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) { |
84 | if (threadIdx_ext_int->value == 1) { |
85 | // unable to extract the number of threads per block, hence directly return |
86 | return; |
87 | } |
88 | stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")" ; |
89 | } |
90 | } |
91 | |
92 | std::string CodeGenCUDA::Finish() { |
93 | if (enable_fp16_) { |
94 | decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n" ; |
95 | decl_stream << "#include <cuda_fp16.h>\n" ; |
96 | decl_stream << "__device__ half max" |
97 | << "(half a, half b)\n" |
98 | << "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n" ; |
99 | decl_stream << "__device__ half min(half a, half b)\n" |
100 | << "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n" ; |
101 | decl_stream << "#else\n" ; |
102 | decl_stream << _cuda_half_t_def; |
103 | decl_stream << "#endif\n\n" ; |
104 | decl_stream << _cuda_half_util; |
105 | } |
106 | |
107 | if (enable_bf16_) { |
108 | decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n" ; |
109 | decl_stream << "#include <cuda_bf16.h>\n" ; |
110 | decl_stream << "__device__ nv_bfloat16 max" |
111 | << "(nv_bfloat16 a, nv_bfloat16 b)\n" |
112 | << "{\n return __hgt(a, b) ? a : b;\n}\n" ; |
113 | decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n" |
114 | << "{\n return __hlt(a, b) ? a : b;\n}\n" ; |
115 | decl_stream << "#endif\n\n" ; |
116 | decl_stream << _cuda_bfloat16_util; |
117 | } |
118 | |
119 | if (enable_warp_shuffle_) { |
120 | decl_stream << _cuda_warp_intrinsic_util; |
121 | } |
122 | |
123 | if (enable_int8_) { |
124 | decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n" ; |
125 | decl_stream << "#include <sm_61_intrinsics.h>\n" ; |
126 | decl_stream << "#endif\n" ; |
127 | } |
128 | |
129 | if (need_math_constants_h_) { |
130 | decl_stream << "#include <math_constants.h>\n" ; |
131 | } |
132 | |
133 | if (need_mma_h_) { |
134 | decl_stream << "#include <mma.h>\n" ; |
135 | } |
136 | |
137 | decl_stream << "\n#ifdef _WIN32\n" ; |
138 | decl_stream << " using uint = unsigned int;\n" ; |
139 | decl_stream << " using uchar = unsigned char;\n" ; |
140 | decl_stream << " using ushort = unsigned short;\n" ; |
141 | decl_stream << " using int64_t = long long;\n" ; |
142 | decl_stream << " using uint64_t = unsigned long long;\n" ; |
143 | decl_stream << "#else\n" ; |
144 | decl_stream << " #define uint unsigned int\n" ; |
145 | decl_stream << " #define uchar unsigned char\n" ; |
146 | decl_stream << " #define ushort unsigned short\n" ; |
147 | decl_stream << " #define int64_t long long\n" ; |
148 | decl_stream << " #define uint64_t unsigned long long\n" ; |
149 | decl_stream << "#endif\n" ; |
150 | |
151 | return CodeGenC::Finish(); |
152 | } |
153 | |
154 | void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { |
155 | ICHECK(is_const_int(op->min, 0)); |
156 | if (op->kind == tir::ForKind::kUnrolled) { |
157 | PrintIndent(); |
158 | stream << "#pragma unroll\n" ; |
159 | } |
160 | CodeGenC::VisitStmt_(op); |
161 | } |
162 | |
163 | void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { |
164 | ICHECK(!var_idmap_.count(iv->var.get())); |
165 | var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); |
166 | } |
167 | |
168 | void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) |
169 | int lanes = t.lanes(); |
170 | if (t.is_handle()) { |
171 | ICHECK(t.is_scalar()) << "do not yet support vector types" ; |
172 | os << "void*" ; |
173 | return; |
174 | } |
175 | |
176 | if (t.is_void()) { |
177 | os << "void" ; |
178 | return; |
179 | } |
180 | |
181 | bool fail = false; |
182 | if (t.is_float()) { |
183 | switch (t.bits()) { |
184 | case 16: |
185 | enable_fp16_ = true; |
186 | if (t.is_scalar()) { |
187 | os << "half" ; |
188 | } else if (lanes <= 8) { |
189 | // Emit CUDA code to access fp16 vector elements. |
190 | // |
191 | // half4 is stored as uint2 |
192 | // |
193 | // h4.x is emitted as *(half2*)(&(u2.x)).x |
194 | // h4.y is emitted as *(half2*)(&(u2.x)).y |
195 | // h4.z is emitted as *(half2*)(&(u2.y)).x |
196 | // h4.w is emitted as *(half2*)(&(u2.y)).y |
197 | // |
198 | ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type" ; |
199 | os << "uint" << lanes / 2; |
200 | } else { |
201 | fail = true; |
202 | } |
203 | break; |
204 | case 32: |
205 | if (lanes <= 4) { |
206 | os << "float" ; |
207 | } else if (lanes <= 8) { |
208 | // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8. |
209 | // |
210 | // float8 is stored as ulonglong4 |
211 | // |
212 | // f8.v1 is emitted as *(float2*)(&(ul4.x)).x |
213 | // f8.v2 is emitted as *(float2*)(&(ul4.x)).y |
214 | // |
215 | ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4" ; |
216 | os << "ulonglong" << lanes / 2; |
217 | } else { |
218 | fail = true; |
219 | } |
220 | break; |
221 | case 64: |
222 | os << "double" ; |
223 | break; |
224 | default: |
225 | fail = true; |
226 | break; |
227 | } |
228 | if (!fail && (t.is_scalar() || t.bits() == 16)) return; |
229 | if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; |
230 | if (!fail && (lanes >= 2 && lanes <= 4)) { |
231 | os << lanes; |
232 | return; |
233 | } |
234 | } else if (t.is_bfloat16()) { |
235 | enable_bf16_ = true; |
236 | if (t.is_scalar()) { |
237 | os << "nv_bfloat16" ; |
238 | } else if (lanes <= 8) { |
239 | ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type" ; |
240 | os << "uint" << lanes / 2; |
241 | } else { |
242 | fail = true; |
243 | } |
244 | if (!fail) return; |
245 | } else if (t == DataType::Bool()) { |
246 | os << "bool" ; |
247 | return; |
248 | } else if (t.is_vector_bool()) { |
249 | // CUDA does not support bool vectors. |
250 | // Use ushort vectors to represent instead. |
251 | int n = t.lanes(); |
252 | if (n <= 4) { |
253 | os << "ushort" << n; |
254 | return; |
255 | } |
256 | } else if (t.is_uint() || t.is_int()) { |
257 | if (t.is_uint()) { |
258 | os << "u" ; |
259 | } |
260 | switch (t.bits()) { |
261 | case 1: { |
262 | if (t.is_scalar()) { |
263 | os << "int" ; |
264 | return; |
265 | } else if (t.lanes() == 8) { |
266 | os << "int8_t" ; |
267 | return; |
268 | } else if (t.lanes() == 16) { |
269 | os << "int16_t" ; |
270 | return; |
271 | } else if (t.lanes() == 32) { |
272 | os << "int" ; |
273 | return; |
274 | } else { |
275 | LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!" ; |
276 | } |
277 | } |
278 | case 4: { |
279 | if (t.is_scalar()) { |
280 | os << "int" ; |
281 | return; |
282 | } else if (t.lanes() == 4) { |
283 | os << "int16_t" ; |
284 | return; |
285 | } else if (t.lanes() == 8) { |
286 | // directly 8 4-bit int in integer. |
287 | os << "int" ; |
288 | return; |
289 | } else if (t.lanes() == 16) { |
290 | os << "int2" ; |
291 | return; |
292 | } else if (t.lanes() == 32) { |
293 | os << "int4" ; |
294 | return; |
295 | } else if (t.lanes() == 64) { |
296 | os << "int8" ; |
297 | return; |
298 | } else { |
299 | LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!" ; |
300 | } |
301 | } |
302 | case 8: { |
303 | if (t.lanes() == 4) { |
304 | // directly 4 8 bit int in integer. |
305 | enable_int8_ = true; |
306 | |
307 | // We use int for int8x4 instead of char4 because using char4 is |
308 | // likely to produce extra instructions to pack four int8 elements |
309 | // into 32-bit data. |
310 | os << "int" ; |
311 | return; |
312 | } else if (t.lanes() == 8) { |
313 | enable_int8_ = true; |
314 | os << "int2" ; |
315 | return; |
316 | } else if (t.lanes() == 16) { |
317 | enable_int8_ = true; |
318 | os << "int4" ; |
319 | return; |
320 | } else if (!t.is_uint() && t.is_scalar()) { |
321 | os << "signed char" ; |
322 | break; |
323 | } else { |
324 | os << "char" ; |
325 | break; |
326 | } |
327 | } |
328 | case 16: { |
329 | if (t.is_scalar()) { |
330 | os << "short" ; |
331 | } else if (t.lanes() <= 4) { |
332 | os << "short" << lanes; |
333 | } else if (t.lanes() <= 8) { |
334 | // Emit CUDA code to access int16 vector elements. |
335 | // |
336 | // short4 is stored as int2 |
337 | // |
338 | // s4.x is emitted as *(short2*)(&(i2.x)).x |
339 | // s4.y is emitted as *(short2*)(&(i2.x)).y |
340 | // s4.z is emitted as *(short2*)(&(i2.y)).x |
341 | // s4.w is emitted as *(short2*)(&(i2.y)).y |
342 | // |
343 | ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4" ; |
344 | os << "int" << t.lanes() / 2; |
345 | } else { |
346 | fail = true; |
347 | } |
348 | if (!fail) { |
349 | return; |
350 | } |
351 | break; |
352 | } |
353 | case 32: { |
354 | if (t.is_scalar()) { |
355 | os << "int" ; |
356 | } else if (t.lanes() <= 4) { |
357 | os << "int" << t.lanes(); |
358 | } else if (t.lanes() <= 8) { |
359 | // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. |
360 | // |
361 | // int8 is stored as longlong4 |
362 | // |
363 | // i8.v1 is emitted as *(int2*)(&(l4.x)).x |
364 | // i8.v2 is emitted as *(int2*)(&(l4.x)).y |
365 | // |
366 | ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4" ; |
367 | os << "longlong" << lanes / 2; |
368 | } else { |
369 | fail = true; |
370 | } |
371 | if (!fail) { |
372 | return; |
373 | } |
374 | break; |
375 | } |
376 | case 64: { |
377 | if (t.is_scalar()) { |
378 | os << "int64_t" ; |
379 | } else if (t.lanes() == 2) { |
380 | os << "longlong2" ; |
381 | } else if (t.lanes() == 3) { |
382 | os << "longlong3" ; |
383 | } else if (t.lanes() == 4) { |
384 | os << "longlong4" ; |
385 | } |
386 | return; |
387 | } |
388 | default: |
389 | fail = true; |
390 | break; |
391 | } |
392 | if (!fail && lanes == 1) { |
393 | return; |
394 | } |
395 | if (!fail && (lanes >= 2 && lanes <= 4)) { |
396 | os << lanes; |
397 | return; |
398 | } |
399 | } |
400 | LOG(FATAL) << "Cannot convert type " << t << " to CUDA type" ; |
401 | } |
402 | |
403 | void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, |
404 | std::ostream& os) { // NOLINT(*) |
405 | // Delcare the result. |
406 | std::string sret = name_supply_->FreshName("_" ); |
407 | this->PrintIndent(); |
408 | this->PrintType(t, stream); |
409 | stream << ' ' << sret << ";\n" ; |
410 | int ssa_scope = BeginScope(); |
411 | { |
412 | // Unpack into individual ops. |
413 | std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); |
414 | std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); |
415 | |
416 | for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { |
417 | std::ostringstream value_temp; |
418 | if (isalpha(op[0])) { |
419 | value_temp << op << "(" ; |
420 | PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); |
421 | value_temp << ", " ; |
422 | PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); |
423 | value_temp << ")" ; |
424 | } else { |
425 | value_temp << "(" ; |
426 | PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); |
427 | value_temp << op; |
428 | PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); |
429 | value_temp << ")" ; |
430 | } |
431 | PrintVecElemStore(sret, t, i, value_temp.str()); |
432 | } |
433 | } |
434 | EndScope(ssa_scope); |
435 | os << sret; |
436 | } |
437 | |
438 | void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, |
439 | std::ostream& os) { // NOLINT(*) |
440 | if (t.is_scalar()) { |
441 | os << vec; |
442 | return; |
443 | } |
444 | |
445 | static const char access[] = {'x', 'y', 'z', 'w'}; |
446 | ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); |
447 | if (t.bits() == 8 && (t.is_int() || t.is_uint())) { |
448 | std::string type_name = t.is_int() ? "char" : "unsigned char" ; |
449 | if (t.lanes() == 2 || t.lanes() == 3) { |
450 | os << vec << "." << access[i % t.lanes()]; |
451 | } else { |
452 | std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); |
453 | os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))" ; |
454 | } |
455 | } else if (t.is_float16()) { |
456 | os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
457 | } else if (t.is_bfloat16()) { |
458 | os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
459 | } else if (t.lanes() > 4 && t.lanes() <= 8) { |
460 | std::string type_name; |
461 | if (t.bits() == 16) { |
462 | if (t.is_int()) { |
463 | type_name = "short" ; |
464 | } else if (t.is_uint()) { |
465 | type_name = "ushort" ; |
466 | } |
467 | } else if (t.bits() == 32) { |
468 | if (t.is_int()) { |
469 | type_name = "int" ; |
470 | } else if (t.is_uint()) { |
471 | type_name = "uint" ; |
472 | } else if (t.is_float()) { |
473 | type_name = "float" ; |
474 | } |
475 | } |
476 | ICHECK(!type_name.empty()); |
477 | os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
478 | } else { |
479 | os << vec << "." << access[i]; |
480 | } |
481 | } |
482 | |
483 | void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, |
484 | const std::string& value) { |
485 | this->PrintIndent(); |
486 | static const char access[] = {'x', 'y', 'z', 'w'}; |
487 | ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); |
488 | if (t.bits() == 8 && (t.is_int() || t.is_uint())) { |
489 | if (t.lanes() == 2 || t.lanes() == 3) { |
490 | stream << vec << '.' << access[i % t.lanes()] << "=" |
491 | << "(" << value << ");\n" ; |
492 | } else { |
493 | std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); |
494 | stream << ac << "=" ; |
495 | // Do not read the first undef lane. |
496 | if (i != 0) { |
497 | stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |" ; |
498 | } |
499 | stream << "(" << value << " << " << i % 4 * 8 << ");\n" ; |
500 | } |
501 | } else if (t.is_float16()) { |
502 | stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " |
503 | << value << ";\n" ; |
504 | } else if (t.is_bfloat16()) { |
505 | stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] |
506 | << " = " << value << ";\n" ; |
507 | } else if (t.lanes() > 4 && t.lanes() <= 8) { |
508 | std::string type_name; |
509 | if (t.bits() == 16) { |
510 | if (t.is_int()) { |
511 | type_name = "short" ; |
512 | } else if (t.is_uint()) { |
513 | type_name = "ushort" ; |
514 | } |
515 | } else if (t.bits() == 32) { |
516 | if (t.is_int()) { |
517 | type_name = "int" ; |
518 | } else if (t.is_uint()) { |
519 | type_name = "uint" ; |
520 | } else if (t.is_float()) { |
521 | type_name = "float" ; |
522 | } |
523 | } |
524 | ICHECK(!type_name.empty()); |
525 | stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" |
526 | << access[i % 2] << " = " << value << ";\n" ; |
527 | } else { |
528 | stream << vec << "." << access[i] << " = " << value << ";\n" ; |
529 | } |
530 | } |
531 | |
532 | void CodeGenCUDA::PrintStorageSync(const CallNode* op) { |
533 | const std::string& sync = op->args[0].as<StringImmNode>()->value; |
534 | if (sync == "warp" ) { |
535 | // DO nothing. |
536 | } else if (sync == "shared" || sync == "shared.dyn" ) { |
537 | this->PrintIndent(); |
538 | this->stream << "__syncthreads();\n" ; |
539 | } else if (sync == "global" ) { |
540 | if (!need_global_barrier_) { |
541 | need_global_barrier_ = true; |
542 | this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_ |
543 | << ";\n" ; |
544 | } |
545 | // global synchronizer |
546 | std::string is_load = PrintExpr(op->args[1]); |
547 | std::string num_blocks = PrintExpr(op->args[2]); |
548 | this->PrintIndent(); |
549 | // In theory only threadfence is needed |
550 | // but we observed problems with only threadfence |
551 | this->stream << "__threadfence_system();\n" ; |
552 | this->PrintIndent(); |
553 | this->stream << "if (" << is_load << ") {\n" ; |
554 | int wb = this->BeginScope(); |
555 | this->PrintIndent(); |
556 | this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n" ; |
557 | this->PrintIndent(); |
558 | std::string ptr = name_supply_->FreshName("pf" ); |
559 | this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n" ; |
560 | this->PrintIndent(); |
561 | this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n" ; |
562 | this->PrintIndent(); |
563 | this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n" ; |
564 | this->EndScope(wb); |
565 | this->PrintIndent(); |
566 | this->stream << "}\n" ; |
567 | this->PrintIndent(); |
568 | this->stream << "__syncthreads();\n" ; |
569 | } |
570 | } |
571 | |
572 | void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) |
573 | ICHECK_NE(scope, "global" ) << "Cannot allocate global memory when targeting CUDA. You must pass " |
574 | "all global arrays as input instead" ; |
575 | if (scope == "shared" ) { |
576 | os << "__shared__ " ; |
577 | } else if (scope == "shared.dyn" ) { |
578 | os << "extern __shared__ " ; |
579 | } |
580 | } |
581 | |
582 | std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType target) { |
583 | if (from == target) return value; |
584 | std::ostringstream os; |
585 | os << "((" ; |
586 | this->PrintType(target, os); |
587 | os << ")" ; |
588 | if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { |
589 | os << "(" ; |
590 | if (target.is_uint()) { |
591 | os << "u" ; |
592 | } |
593 | os << "int)" ; |
594 | } |
595 | os << value << ")" ; |
596 | return os.str(); |
597 | } |
598 | |
599 | void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { |
600 | DataType from_ty = op->value.dtype(); |
601 | DataType target_ty = op->dtype; |
602 | ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); |
603 | |
604 | // Emit simple C-style type conversion. |
605 | if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); |
606 | |
607 | // We could emit make_float4 like calls, but the emitted code looks |
608 | // too compact to read. Emit this as vectorized unary ops. |
609 | std::string sret = name_supply_->FreshName("_" ); |
610 | this->PrintIndent(); |
611 | this->PrintType(target_ty, stream); |
612 | stream << ' ' << sret << ";\n" ; |
613 | { |
614 | std::string src = SSAGetID(PrintExpr(op->value), from_ty); |
615 | for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { |
616 | std::ostringstream val; |
617 | val << "(" ; |
618 | PrintType(target_ty.element_of(), val); |
619 | val << ")(" ; |
620 | PrintVecElemLoad(src, from_ty, i, val); |
621 | val << ")" ; |
622 | PrintVecElemStore(sret, target_ty, i, val.str()); |
623 | } |
624 | } |
625 | os << sret; |
626 | } |
627 | |
628 | void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, |
629 | bool skip_first_arg, std::ostream& os) { // NOLINT(*) |
630 | DataType ret_dtype = GetRuntimeDataType(ret_type); |
631 | if (ret_dtype.is_vector()) { |
632 | // |
633 | // Emit an unsupported vector call |
634 | // |
635 | // v = intrin_f((float4*)A[0], (float4*)B[0]) |
636 | // |
637 | // as |
638 | // |
639 | // float4 __ret; |
640 | // { |
641 | // float4 __arg0 = ((float4*)A)[0]; |
642 | // float4 __arg1 = ((float4*)B)[0]; |
643 | // __ret.x = intrin_f(__arg0.x, __arg1.x); |
644 | // __ret.y = intrin_f(__arg0.y, __arg1.y); |
645 | // __ret.z = intrin_f(__arg0.z, __arg1.z); |
646 | // __ret.w = intrin_f(__arg0.w, __arg1.w); |
647 | // } |
648 | // v = __ret; |
649 | // |
650 | // Declare the result vector. |
651 | std::string sret = name_supply_->FreshName("_" ); |
652 | this->PrintIndent(); |
653 | this->PrintType(ret_dtype, stream); |
654 | stream << ' ' << sret << ";\n" ; |
655 | { |
656 | // Load arguments. |
657 | std::vector<std::string> sargs; |
658 | size_t arg_begin = static_cast<size_t>(skip_first_arg); |
659 | for (size_t i = arg_begin; i < args.size(); ++i) { |
660 | std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); |
661 | sargs.push_back(std::move(val)); |
662 | } |
663 | |
664 | // Emit a scalar call for each lane. |
665 | for (int i = 0; i < ret_dtype.lanes(); ++i) { |
666 | std::ostringstream scall; |
667 | scall << global_symbol << "(" ; |
668 | for (size_t j = 0; j < sargs.size(); ++j) { |
669 | if (j > 0) scall << ", " ; |
670 | PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); |
671 | } |
672 | scall << ")" ; |
673 | PrintVecElemStore(sret, ret_dtype, i, scall.str()); |
674 | } |
675 | } |
676 | os << sret; |
677 | } else { |
678 | CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); |
679 | } |
680 | } |
681 | |
682 | void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { |
683 | if (auto* ptr_op = op->op.as<OpNode>()) { |
684 | Op call_op = GetRef<Op>(ptr_op); |
685 | // This is only for backward compatibility with __shfl_{up/down}. |
686 | // A macro will be used to replace *_sync calls to legacy ones. |
687 | if (op_need_warp_shuffle_.get(call_op, false)) { |
688 | enable_warp_shuffle_ = true; |
689 | } |
690 | } |
691 | |
692 | if (op->op.same_as(builtin::tvm_fill_fragment())) { |
693 | need_mma_h_ = true; |
694 | ICHECK_EQ(op->args.size(), 6U); |
695 | os << "nvcuda::wmma::fill_fragment(" ; |
696 | this->PrintExpr(op->args[0], os); |
697 | os << "[" ; |
698 | this->PrintExpr(op->args[4], os); |
699 | os << "], " ; |
700 | this->PrintExpr(op->args[5], os); |
701 | os << ")" ; |
702 | } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { |
703 | need_mma_h_ = true; |
704 | ICHECK_EQ(op->args.size(), 8U); |
705 | os << "nvcuda::wmma::load_matrix_sync(" ; |
706 | this->PrintExpr(op->args[0], os); |
707 | os << "[" ; |
708 | this->PrintExpr(op->args[4], os); |
709 | os << "], " ; |
710 | this->PrintExpr(op->args[5], os); |
711 | os << ", " ; |
712 | this->PrintExpr(op->args[6], os); |
713 | os << ")" ; |
714 | } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { |
715 | need_mma_h_ = true; |
716 | ICHECK_EQ(op->args.size(), 8U); |
717 | os << "nvcuda::wmma::store_matrix_sync(" ; |
718 | this->PrintExpr(op->args[5], os); |
719 | os << ", " ; |
720 | this->PrintExpr(op->args[0], os); |
721 | os << "[" ; |
722 | this->PrintExpr(op->args[4], os); |
723 | os << "], " ; |
724 | this->PrintExpr(op->args[6], os); |
725 | if (const StringImmNode* str = op->args[7].as<StringImmNode>()) { |
726 | os << ", nvcuda::wmma::mem_" << str->value; |
727 | } else { |
728 | LOG(FATAL) << "Invalid parameters" ; |
729 | } |
730 | os << ")" ; |
731 | } else if (op->op.same_as(builtin::tvm_mma_sync())) { |
732 | need_mma_h_ = true; |
733 | ICHECK_EQ(op->args.size(), 8U); |
734 | os << "nvcuda::wmma::mma_sync(" ; |
735 | for (int i = 0; i < 4; ++i) { |
736 | this->PrintExpr(op->args[i * 2], os); |
737 | os << "[" ; |
738 | this->PrintExpr(op->args[i * 2 + 1], os); |
739 | os << "]" << ((i < 3) ? ", " : ")" ); |
740 | } |
741 | } else if (op->op.same_as(builtin::tvm_bmma_sync())) { |
742 | need_mma_h_ = true; |
743 | ICHECK_EQ(op->args.size(), 8U); |
744 | os << "nvcuda::wmma::bmma_sync(" ; |
745 | for (int i = 0; i < 4; ++i) { |
746 | this->PrintExpr(op->args[i * 2], os); |
747 | os << "[" ; |
748 | this->PrintExpr(op->args[i * 2 + 1], os); |
749 | os << "]" << ((i < 3) ? ", " : ")" ); |
750 | } |
751 | } else if (op->op.same_as(builtin::ptx_mma())) { |
752 | // arg 0: shape: mXnXkX |
753 | // arg 1: A layout: row/col |
754 | // arg 2: B layout: row/col |
755 | // arg 3: A precision: fp16, fp64, ... |
756 | // arg 4: B precision: fp16, fp64, ... |
757 | // arg 5: C precision: fp32, fp64, ... |
758 | // arg 6: A multiplicand |
759 | // arg 7: A multiplicand index |
760 | // arg 8: B multiplicand |
761 | // arg 9: B multiplicand index |
762 | // arg 10: C accumulator |
763 | // arg 11: C accumulator index |
764 | // arg 12: saturate |
765 | // arg 13: (optional) 1-bit operator (xor or and) |
766 | ICHECK(op->args.size() == 13U || op->args.size() == 14U); |
767 | std::string shape = Downcast<StringImm>(op->args[0])->value; |
768 | std::string A_layout = Downcast<StringImm>(op->args[1])->value; |
769 | std::string B_layout = Downcast<StringImm>(op->args[2])->value; |
770 | std::string A_dtype = Downcast<StringImm>(op->args[3])->value; |
771 | std::string B_dtype = Downcast<StringImm>(op->args[4])->value; |
772 | std::string C_dtype = Downcast<StringImm>(op->args[5])->value; |
773 | std::string a_ref = this->PrintExpr(op->args[6]); |
774 | std::string a_bias = this->PrintExpr(op->args[7]); |
775 | std::string b_ref = this->PrintExpr(op->args[8]); |
776 | std::string b_bias = this->PrintExpr(op->args[9]); |
777 | std::string c_ref = this->PrintExpr(op->args[10]); |
778 | std::string c_bias = this->PrintExpr(op->args[11]); |
779 | bool saturate = Downcast<Bool>(op->args[12])->value; |
780 | std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "" ; |
781 | std::string asm_code = |
782 | PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, |
783 | b_bias, c_ref, c_bias, "" , "" , "" , bit_op, false, saturate); |
784 | |
785 | this->stream << asm_code; |
786 | } else if (op->op.same_as(builtin::ptx_mma_sp())) { |
787 | // arg 0: shape: mXnXkX |
788 | // arg 1: A layout: row/col |
789 | // arg 2: B layout: row/col |
790 | // arg 3: A precision: fp16, fp32, ... |
791 | // arg 4: B precision: fp16, fp32, ... |
792 | // arg 5: C precision: fp16, fp32, ... |
793 | // arg 6: A multiplicand pointer |
794 | // arg 7: A multiplicand index |
795 | // arg 8: B multiplicand pointer |
796 | // arg 9: B multiplicand index |
797 | // arg 10: C accumulator pointer |
798 | // arg 11: C accumulator index |
799 | // arg 12: metadata |
800 | // arg 13: metadata index |
801 | // arg 14: sparse_selector |
802 | // arg 15: saturate |
803 | ICHECK_EQ(op->args.size(), 16U); |
804 | std::string shape = Downcast<StringImm>(op->args[0])->value; |
805 | std::string A_layout = Downcast<StringImm>(op->args[1])->value; |
806 | std::string B_layout = Downcast<StringImm>(op->args[2])->value; |
807 | std::string A_dtype = Downcast<StringImm>(op->args[3])->value; |
808 | std::string B_dtype = Downcast<StringImm>(op->args[4])->value; |
809 | std::string C_dtype = Downcast<StringImm>(op->args[5])->value; |
810 | std::string a_ref = this->PrintExpr(op->args[6]); |
811 | std::string a_offset = this->PrintExpr(op->args[7]); |
812 | std::string b_ref = this->PrintExpr(op->args[8]); |
813 | std::string b_offset = this->PrintExpr(op->args[9]); |
814 | std::string c_ref = this->PrintExpr(op->args[10]); |
815 | std::string c_offset = this->PrintExpr(op->args[11]); |
816 | std::string metadata = this->PrintExpr(op->args[12]); |
817 | std::string metadata_offset = this->PrintExpr(op->args[13]); |
818 | std::string sparse_selector = this->PrintExpr(op->args[14]); |
819 | bool saturate = Downcast<Bool>(op->args[15])->value; |
820 | std::string asm_code = PrintMMAAssembly( |
821 | shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, |
822 | c_ref, c_offset, metadata, metadata_offset, sparse_selector, "" , true, saturate); |
823 | this->stream << asm_code; |
824 | } else if (op->op.same_as(builtin::ptx_ldmatrix())) { |
825 | // arg 0: whether the matrix is loaded in column major format or not. |
826 | // arg 1: number of matrices to load. |
827 | // arg 2: The data type in the matrix, .b16 is the only accepted data type. |
828 | // arg 3: pointer to local buffer. |
829 | // arg 4: The offset of the element to store in the local buffer. |
830 | // arg 5: pointer to the shared memory buffer to load. |
831 | // arg 6: The offset of the start element of the row to load in shared memory. |
832 | ICHECK_EQ(op->args.size(), 7U); |
833 | bool trans = Downcast<Bool>(op->args[0])->value; |
834 | int num = Downcast<Integer>(op->args[1])->value; |
835 | std::string type = Downcast<StringImm>(op->args[2])->value; |
836 | std::string local_ptr = this->PrintExpr(op->args[3]); |
837 | std::string local_elem_offset = this->PrintExpr(op->args[4]); |
838 | std::string smem_ptr = this->PrintExpr(op->args[5]); |
839 | if (trans && op->dtype.bits() == 8) { |
840 | // Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an |
841 | // int8 matrix. |
842 | std::string smem_stride = this->PrintExpr(op->args[6]); |
843 | ICHECK(num == 4); |
844 | os << "for (int i = 0; i < 16; ++i) {\n" ; |
845 | os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr |
846 | << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + |
847 | "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n" ; |
848 | os << "}\n" ; |
849 | } else { |
850 | std::string smem_elem_offset = this->PrintExpr(op->args[6]); |
851 | this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, |
852 | smem_ptr, smem_elem_offset); |
853 | } |
854 | } else if (op->op.same_as(builtin::mma_store())) { |
855 | int m = Downcast<Integer>(op->args[0])->value; |
856 | int n = Downcast<Integer>(op->args[1])->value; |
857 | std::string dst = this->PrintExpr(op->args[2]); |
858 | std::string src = this->PrintExpr(op->args[3]); |
859 | std::string src_offset = this->PrintExpr(op->args[4]); |
860 | PrimExpr stride = op->args[5]; |
861 | |
862 | ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now" ; |
863 | |
864 | // Each thread in a warp holds a certain number of elements of an MMA output. |
865 | // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements |
866 | // in its registers. So conceptually, a warp memory is organized as a 32x8 block. |
867 | // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. |
868 | |
869 | // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map |
870 | // to determine the output location for each 8 element. |
871 | |
872 | const auto* index_map_func = |
873 | runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout" ); |
874 | ICHECK(index_map_func); |
875 | |
876 | auto inverse_index_map = |
877 | IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}); |
878 | auto indices_16x16 = inverse_index_map->final_indices; |
879 | |
880 | // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine. |
881 | // FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them |
882 | // to the plain ones here. |
883 | class LowerFloorDivMod : public ExprMutator { |
884 | public: |
885 | PrimExpr VisitExpr_(const FloorDivNode* op) { |
886 | return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); |
887 | } |
888 | PrimExpr VisitExpr_(const FloorModNode* op) { |
889 | return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); |
890 | } |
891 | }; |
892 | |
893 | auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); |
894 | |
895 | var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x" ; |
896 | var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id" ; |
897 | |
898 | os << "for (int local_id = 0; local_id < 8; ++local_id) {\n" ; |
899 | os << dst << "[" + this->PrintExpr(dst_ind) + "]" |
900 | << " = " << src << "[" << src_offset << " + local_id];\n" ; |
901 | os << "}\n" ; |
902 | |
903 | } else if (op->op.same_as(builtin::mma_fill())) { |
904 | std::string num_elem = this->PrintExpr(op->args[0]); |
905 | std::string dst = this->PrintExpr(op->args[1]); |
906 | std::string dst_offset = this->PrintExpr(op->args[2]); |
907 | |
908 | os << "for (int i = 0; i < " << num_elem << "; ++i) {\n" ; |
909 | os << dst << "[" << dst_offset << " + i] = 0.0;" ; |
910 | os << "}\n" ; |
911 | } else if (op->op.same_as(builtin::ptx_cp_async())) { |
912 | std::string dst = this->PrintExpr(op->args[0]); |
913 | std::string dst_offset = this->PrintExpr(op->args[1]); |
914 | std::string src = this->PrintExpr(op->args[2]); |
915 | std::string src_offset = this->PrintExpr(op->args[3]); |
916 | std::string size = this->PrintExpr(op->args[4]); |
917 | this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); |
918 | } else if (op->op.same_as(builtin::ptx_commit_group())) { |
919 | this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n" ; |
920 | } else if (op->op.same_as(builtin::ptx_wait_group())) { |
921 | std::string N = this->PrintExpr(op->args[0]); |
922 | this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n" ; |
923 | } else { |
924 | CodeGenC::VisitExpr_(op, os); |
925 | } |
926 | } |
927 | |
928 | void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { |
929 | if (op->attr_key == tir::attr::fragment_shape) { |
930 | const VarNode* buffer = op->node.as<VarNode>(); |
931 | const StringImmNode* shape_str = op->value.as<StringImmNode>(); |
932 | fragment_shapes[buffer] = shape_str->value; |
933 | } else if (op->attr_key == tir::attr::fragment_layout) { |
934 | const VarNode* buffer = op->node.as<VarNode>(); |
935 | const StringImmNode* layout_str = op->value.as<StringImmNode>(); |
936 | fragment_layouts[buffer] = layout_str->value; |
937 | } else if (op->attr_key == tir::attr::async_commit_queue_scope) { |
938 | const IntImmNode* queue_id = op->value.as<IntImmNode>(); |
939 | ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0." ; |
940 | this->VisitStmt(op->body); |
941 | auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); |
942 | this->VisitExpr(commit_group, this->stream); |
943 | return; |
944 | } else if (op->attr_key == tir::attr::async_wait_queue_scope) { |
945 | auto wait_attrs = GetAsyncWaitAttributes(op); |
946 | auto queue_id = wait_attrs.first.as<IntImmNode>(); |
947 | ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0." ; |
948 | auto wait_cnt = wait_attrs.second; |
949 | auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); |
950 | this->VisitExpr(wait_group, this->stream); |
951 | auto inner = op->body.as<AttrStmtNode>(); |
952 | ICHECK(inner); |
953 | this->VisitStmt(inner->body); |
954 | return; |
955 | } |
956 | CodeGenC::VisitStmt_(op); |
957 | } |
958 | |
959 | void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { |
960 | ICHECK(!is_zero(op->condition)); |
961 | std::string vid = AllocVarID(op->buffer_var.get()); |
962 | |
963 | this->PrintIndent(); |
964 | std::string scope = GetPtrStorageScope(op->buffer_var); |
965 | const VarNode* buffer = op->buffer_var.as<VarNode>(); |
966 | if (scope.find("wmma." ) == 0) { |
967 | if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b" ) { |
968 | ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || |
969 | op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || |
970 | op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || |
971 | op->dtype == DataType::BFloat(16)) |
972 | << "Matrix_a and matrix_b only support half or char or unsigned char " |
973 | << "or uint4 or int4 or int1 type for now" ; |
974 | } else { |
975 | ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || |
976 | op->dtype == DataType::Int(32)) |
977 | << "Accumulator only support half, float and int type for now" ; |
978 | } |
979 | PrintWmmaScope(scope, op->dtype, buffer, stream); |
980 | } else { |
981 | PrintStorageScope(scope, stream); |
982 | PrintType(op->dtype, stream); |
983 | } |
984 | |
985 | if (scope == "shared.dyn" ) { |
986 | stream << ' ' << vid << "[];\n" ; |
987 | } else { |
988 | size_t constant_size = op->ConstantAllocationSize(); |
989 | ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now" ; |
990 | |
991 | if (scope.find("wmma." ) == 0) { |
992 | constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); |
993 | } |
994 | if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || |
995 | op->dtype == DataType::Int(1)) && |
996 | scope == "shared" ) { |
997 | constant_size = constant_size / (32 / op->dtype.bits()); |
998 | } |
999 | stream << ' ' << vid << '[' << constant_size << "];\n" ; |
1000 | } |
1001 | |
1002 | RegisterHandleType(op->buffer_var.get(), op->dtype); |
1003 | this->PrintStmt(op->body); |
1004 | } |
1005 | |
1006 | void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { |
1007 | if (is_const_int(op->value)) return; |
1008 | const CallNode* call = op->value.as<CallNode>(); |
1009 | if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { |
1010 | PrintIndent(); |
1011 | stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n" ; |
1012 | PrintIndent(); |
1013 | stream << "if (threadIdx.x == 0) {\n" ; |
1014 | PrintIndent(); |
1015 | stream << " " << vid_global_barrier_expect_ << " = 0;\n" ; |
1016 | PrintIndent(); |
1017 | stream << "}\n" ; |
1018 | } else { |
1019 | CodeGenC::VisitStmt_(op); |
1020 | } |
1021 | } |
1022 | |
1023 | void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { |
1024 | CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed." ; |
1025 | os << "(make_" ; |
1026 | PrintType(op->dtype, os); |
1027 | os << "(" ; |
1028 | for (int i = 0; i < op->lanes; i++) { |
1029 | os << "(" << PrintExpr(op->base) << ")" |
1030 | << "+(" << PrintExpr(op->stride) << "*" << i << ")" ; |
1031 | if (i != op->lanes - 1) os << ", " ; |
1032 | } |
1033 | os << "))" ; |
1034 | } |
1035 | |
1036 | void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) |
1037 | if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { |
1038 | // make_int8x4 |
1039 | const int64_t* p = as_const_int(op->value); |
1040 | ICHECK(p); |
1041 | int64_t v = *p & 0xFF; |
1042 | v = (v << 24) | (v << 16) | (v << 8) | v; |
1043 | if (op->dtype.is_uint()) { |
1044 | os << "(uint)" << v; |
1045 | } else { |
1046 | os << "(int)" << v; |
1047 | } |
1048 | return; |
1049 | } |
1050 | |
1051 | if (op->dtype.is_float16()) { |
1052 | std::string v = PrintExpr(op->value); |
1053 | os << "make_" ; |
1054 | PrintType(op->dtype, os); |
1055 | os << '('; |
1056 | for (int i = 0; i < op->lanes / 2; ++i) { |
1057 | if (i != 0) os << ", " ; |
1058 | os << "__pack_half2(" << v << ", " << v << ")" ; |
1059 | } |
1060 | os << ')'; |
1061 | return; |
1062 | } |
1063 | |
1064 | if (op->dtype.is_bfloat16()) { |
1065 | std::string v = PrintExpr(op->value); |
1066 | os << "make_" ; |
1067 | PrintType(op->dtype, os); |
1068 | os << '('; |
1069 | for (int i = 0; i < op->lanes / 2; ++i) { |
1070 | if (i != 0) os << ", " ; |
1071 | os << "__pack_nv_bfloat162(" << v << ", " << v << ")" ; |
1072 | } |
1073 | os << ')'; |
1074 | return; |
1075 | } |
1076 | |
1077 | if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { |
1078 | bool fail = false; |
1079 | const int64_t* p = as_const_int(op->value); |
1080 | ICHECK(p); |
1081 | int64_t v = *p & 0xF; |
1082 | |
1083 | if (op->lanes == 4) { |
1084 | v = (v << 12) | (v << 8) | (v << 4) | v; |
1085 | if (op->dtype.is_uint()) { |
1086 | os << "(uint16_t)" << v; |
1087 | } else { |
1088 | os << "(int16_t)" << v; |
1089 | } |
1090 | } else { |
1091 | v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; |
1092 | if (op->lanes == 8) { |
1093 | if (op->dtype.is_uint()) { |
1094 | os << "(uint)" << v; |
1095 | } else { |
1096 | os << "(int)" << v; |
1097 | } |
1098 | } else if (op->lanes == 16 || op->lanes == 32) { |
1099 | os << "make_" ; |
1100 | PrintType(op->dtype, os); |
1101 | os << '('; |
1102 | for (int i = 0; i < op->lanes / 8; ++i) { |
1103 | if (i != 0) os << ", " ; |
1104 | if (op->dtype.is_uint()) { |
1105 | os << "(uint)" << v; |
1106 | } else { |
1107 | os << "(int)" << v; |
1108 | } |
1109 | } |
1110 | os << ')'; |
1111 | } else { |
1112 | fail = true; |
1113 | } |
1114 | } |
1115 | |
1116 | if (!fail) { |
1117 | return; |
1118 | } |
1119 | } |
1120 | |
1121 | std::string v = PrintExpr(op->value); |
1122 | os << "make_" ; |
1123 | PrintType(op->dtype, os); |
1124 | os << '('; |
1125 | for (int i = 0; i < op->lanes; ++i) { |
1126 | if (i != 0) os << ", " ; |
1127 | os << v; |
1128 | } |
1129 | os << ')'; |
1130 | } |
1131 | |
1132 | void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) { |
1133 | std::vector<std::string> to_shuffle(op->vectors.size()); |
1134 | for (int i = 0, e = op->vectors.size(); i < e; ++i) { |
1135 | ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!" ; |
1136 | to_shuffle[i] = PrintExpr(op->vectors[i]); |
1137 | } |
1138 | os << "make_" ; |
1139 | PrintType(op->dtype, os); |
1140 | os << '('; |
1141 | for (int i = 0, e = op->indices.size(); i < e; ++i) { |
1142 | const int64_t* val = as_const_int(op->indices[i]); |
1143 | ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size()); |
1144 | if (i != 0) os << ", " ; |
1145 | os << to_shuffle[*val]; |
1146 | } |
1147 | os << ')'; |
1148 | } |
1149 | |
1150 | void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { |
1151 | // Non-vector cases. |
1152 | if (!op->dtype.is_vector()) { |
1153 | CodeGenC::VisitExpr_(op, os); |
1154 | return; |
1155 | } |
1156 | |
1157 | // Codegen vector condition case by serializing the select op. |
1158 | ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && |
1159 | op->dtype.lanes() == op->condition.dtype().lanes()); |
1160 | |
1161 | std::string r_var = name_supply_->FreshName("_" ); |
1162 | this->PrintIndent(); |
1163 | this->PrintType(op->dtype, stream); |
1164 | stream << ' ' << r_var << ";\n" ; |
1165 | { |
1166 | std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); |
1167 | std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); |
1168 | std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); |
1169 | |
1170 | // The condition is stored as an ushort vector. |
1171 | int lanes = op->dtype.lanes(); |
1172 | DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes); |
1173 | |
1174 | for (int i = 0; i < lanes; ++i) { |
1175 | std::ostringstream item; |
1176 | item << "(bool(" ; |
1177 | PrintVecElemLoad(c_var, memory_ty, i, item); |
1178 | item << ")?" ; |
1179 | PrintVecElemLoad(t_var, op->dtype, i, item); |
1180 | item << ':'; |
1181 | PrintVecElemLoad(f_var, op->dtype, i, item); |
1182 | item << ')'; |
1183 | PrintVecElemStore(r_var, op->dtype, i, item.str()); |
1184 | } |
1185 | } |
1186 | os << r_var; |
1187 | } |
1188 | |
1189 | inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) |
1190 | // Type code is kBFloat |
1191 | if (op->dtype.is_bfloat16()) { |
1192 | os << "__float2bfloat16_rn" ; |
1193 | os << '(' << std::scientific << op->value << 'f' << ')'; |
1194 | return; |
1195 | } |
1196 | // Type code is kFloat |
1197 | switch (op->dtype.bits()) { |
1198 | case 64: |
1199 | case 32: { |
1200 | std::ostringstream temp; |
1201 | if (std::isinf(op->value)) { |
1202 | if (op->value < 0) { |
1203 | temp << "-" ; |
1204 | } |
1205 | temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF" ); |
1206 | p->need_math_constants_h_ = true; |
1207 | } else if (std::isnan(op->value)) { |
1208 | temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN" ); |
1209 | p->need_math_constants_h_ = true; |
1210 | } else { |
1211 | temp << std::scientific << op->value; |
1212 | if (op->dtype.bits() == 32) temp << 'f'; |
1213 | } |
1214 | p->MarkConst(temp.str()); |
1215 | os << temp.str(); |
1216 | break; |
1217 | } |
1218 | case 16: { |
1219 | os << "__float2half_rn" << '('; |
1220 | FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); |
1221 | PrintConst(const_f32.get(), os, p); |
1222 | os << ')'; |
1223 | break; |
1224 | } |
1225 | default: |
1226 | LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n" ; |
1227 | } |
1228 | } |
1229 | |
1230 | void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) |
1231 | PrintConst(op, os, this); |
1232 | } |
1233 | |
1234 | void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, |
1235 | std::ostream& os) { |
1236 | std::stringstream type; |
1237 | PrintType(t, type); |
1238 | std::string shape_str = fragment_shapes.at(variable); |
1239 | if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { |
1240 | type.str(std::string()); |
1241 | if (t.is_int()) { |
1242 | if (t.bits() == 4) { |
1243 | type << "nvcuda::wmma::experimental::precision::s4" ; |
1244 | } else if (t.bits() == 1) { |
1245 | type << "nvcuda::wmma::experimental::precision::b1" ; |
1246 | } else { |
1247 | LOG(FATAL) << "Unhandled interger type for wmma fragment!" ; |
1248 | } |
1249 | } else if (t.is_uint()) { |
1250 | if (t.bits() == 4) { |
1251 | type << "nvcuda::wmma::experimental::precision::u4" ; |
1252 | } else { |
1253 | LOG(FATAL) << "Unhandled interger type for wmma fragment!" ; |
1254 | } |
1255 | } |
1256 | } |
1257 | if (scope == "wmma.matrix_a" ) { |
1258 | need_mma_h_ = true; |
1259 | std::string layout_str = fragment_layouts[variable]; |
1260 | ICHECK_NE(layout_str, "" ) << "Layout must be defined for matrix_a" ; |
1261 | os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str() |
1262 | << ", nvcuda::wmma::" << layout_str << ">" ; |
1263 | } else if (scope == "wmma.matrix_b" ) { |
1264 | need_mma_h_ = true; |
1265 | std::string layout_str = fragment_layouts[variable]; |
1266 | ICHECK_NE(layout_str, "" ) << "Layout must be defined for matrix_b" ; |
1267 | os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str() |
1268 | << ", nvcuda::wmma::" << layout_str << ">" ; |
1269 | } else if (scope == "wmma.accumulator" ) { |
1270 | need_mma_h_ = true; |
1271 | os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str() |
1272 | << ">" ; |
1273 | } |
1274 | } |
1275 | |
1276 | int stoi(const std::string& str) { |
1277 | try { |
1278 | return std::stoi(str); |
1279 | } catch (std::invalid_argument& e) { |
1280 | LOG(FATAL) << "Cannot convert \"" << str << "\" to int" ; |
1281 | throw; |
1282 | } |
1283 | } |
1284 | |
1285 | int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, |
1286 | int32_t size) { |
1287 | std::string shape_str = fragment_shapes.at(variable); |
1288 | size_t m, n, k; |
1289 | size_t last_pos = 0, pos = 0; |
1290 | pos = shape_str.find(", " , last_pos); |
1291 | m = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos)); |
1292 | last_pos = pos + 2; |
1293 | pos = shape_str.find(", " , last_pos); |
1294 | n = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos)); |
1295 | last_pos = pos + 2; |
1296 | k = tvm::codegen::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); |
1297 | if (scope == "wmma.matrix_a" ) { |
1298 | return size / m / k; |
1299 | } else if (scope == "wmma.matrix_b" ) { |
1300 | return size / n / k; |
1301 | } else if (scope == "wmma.accumulator" ) { |
1302 | return size / m / n; |
1303 | } |
1304 | return 0; |
1305 | } |
1306 | |
1307 | void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, |
1308 | std::ostream& os) { |
1309 | // Cast away volatile qualifier for fp16 types. That is, only loads and |
1310 | // stores are volatile. The loaded objects are not marked as volatile. |
1311 | // |
1312 | if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { |
1313 | os << "(" ; |
1314 | PrintType(op->dtype, os); |
1315 | os << ")(" << value << ")" ; |
1316 | } else { |
1317 | os << value; |
1318 | } |
1319 | } |
1320 | |
1321 | void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, |
1322 | std::ostream& os) { |
1323 | ICHECK_GT(t.lanes(), 1); |
1324 | if (t.bits() == 8 && (t.is_int() || t.is_uint())) { |
1325 | if (!(t.lanes() == 2 || t.lanes() == 3)) { |
1326 | if (i != 0) { |
1327 | os << "|" ; |
1328 | } |
1329 | os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))" ; |
1330 | return; |
1331 | } |
1332 | } |
1333 | |
1334 | if (t.is_float16()) { |
1335 | if (i == 0) { |
1336 | os << "make_" ; |
1337 | PrintType(t, os); |
1338 | os << '('; |
1339 | } |
1340 | if (i % 2 == 0) { |
1341 | os << "__pack_half2(" << value; |
1342 | } else { |
1343 | os << "," << value << ")" ; |
1344 | if (i != t.lanes() - 1) { |
1345 | os << "," ; |
1346 | } else { |
1347 | os << ")" ; |
1348 | } |
1349 | } |
1350 | return; |
1351 | } |
1352 | |
1353 | if (t.is_bfloat16()) { |
1354 | if (i == 0) { |
1355 | os << "make_" ; |
1356 | PrintType(t, os); |
1357 | os << '('; |
1358 | } |
1359 | if (i % 2 == 0) { |
1360 | os << "__pack_bfloat162(" << value; |
1361 | } else { |
1362 | os << "," << value << ")" ; |
1363 | if (i != t.lanes() - 1) { |
1364 | os << "," ; |
1365 | } else { |
1366 | os << ")" ; |
1367 | } |
1368 | } |
1369 | return; |
1370 | } |
1371 | |
1372 | if (i == 0) { |
1373 | os << "make_" ; |
1374 | PrintType(t, os); |
1375 | os << "(" ; |
1376 | } |
1377 | os << value; |
1378 | if (i != t.lanes() - 1) { |
1379 | os << "," ; |
1380 | } else { |
1381 | os << ")" ; |
1382 | } |
1383 | return; |
1384 | } |
1385 | |
1386 | } // namespace codegen |
1387 | } // namespace tvm |
1388 | |