1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_cuda.cc
22 */
23
24#include "codegen_cuda.h"
25
26#include <tvm/arith/analyzer.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/tir/index_map.h>
29#include <tvm/tir/stmt_functor.h>
30
31#include <cmath>
32#include <string>
33#include <utility>
34#include <vector>
35
36#include "literal/cuda_half_t.h"
37#include "ptx.h"
38
39namespace tvm {
40namespace codegen {
41
42CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
43
44void CodeGenCUDA::Init(bool output_ssa) {
45 CodeGenC::Init(output_ssa);
46 vid_global_barrier_state_ = name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
47 vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
48 ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
49}
50
51void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; }
52
53class ThreadIdxExtractor : public tir::StmtVisitor {
54 private:
55 void VisitStmt_(const AttrStmtNode* op) final {
56 if (op->attr_key == tir::attr::thread_extent) {
57 IterVar iv = Downcast<IterVar>(op->node);
58 if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") {
59 threadIdx_x_ext = op->value;
60 }
61 if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") {
62 threadIdx_y_ext = op->value;
63 }
64 if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") {
65 threadIdx_z_ext = op->value;
66 }
67 }
68 StmtVisitor::VisitStmt_(op);
69 }
70
71 public:
72 PrimExpr threadIdx_x_ext = Integer(1);
73 PrimExpr threadIdx_y_ext = Integer(1);
74 PrimExpr threadIdx_z_ext = Integer(1);
75};
76
77void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
78 ThreadIdxExtractor extractor;
79 extractor(f->body);
80 arith::Analyzer analyzer;
81 PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
82 extractor.threadIdx_z_ext);
83 if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
84 if (threadIdx_ext_int->value == 1) {
85 // unable to extract the number of threads per block, hence directly return
86 return;
87 }
88 stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
89 }
90}
91
92std::string CodeGenCUDA::Finish() {
93 if (enable_fp16_) {
94 decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
95 decl_stream << "#include <cuda_fp16.h>\n";
96 decl_stream << "__device__ half max"
97 << "(half a, half b)\n"
98 << "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
99 decl_stream << "__device__ half min(half a, half b)\n"
100 << "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
101 decl_stream << "#else\n";
102 decl_stream << _cuda_half_t_def;
103 decl_stream << "#endif\n\n";
104 decl_stream << _cuda_half_util;
105 }
106
107 if (enable_bf16_) {
108 decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
109 decl_stream << "#include <cuda_bf16.h>\n";
110 decl_stream << "__device__ nv_bfloat16 max"
111 << "(nv_bfloat16 a, nv_bfloat16 b)\n"
112 << "{\n return __hgt(a, b) ? a : b;\n}\n";
113 decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n"
114 << "{\n return __hlt(a, b) ? a : b;\n}\n";
115 decl_stream << "#endif\n\n";
116 decl_stream << _cuda_bfloat16_util;
117 }
118
119 if (enable_warp_shuffle_) {
120 decl_stream << _cuda_warp_intrinsic_util;
121 }
122
123 if (enable_int8_) {
124 decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
125 decl_stream << "#include <sm_61_intrinsics.h>\n";
126 decl_stream << "#endif\n";
127 }
128
129 if (need_math_constants_h_) {
130 decl_stream << "#include <math_constants.h>\n";
131 }
132
133 if (need_mma_h_) {
134 decl_stream << "#include <mma.h>\n";
135 }
136
137 decl_stream << "\n#ifdef _WIN32\n";
138 decl_stream << " using uint = unsigned int;\n";
139 decl_stream << " using uchar = unsigned char;\n";
140 decl_stream << " using ushort = unsigned short;\n";
141 decl_stream << " using int64_t = long long;\n";
142 decl_stream << " using uint64_t = unsigned long long;\n";
143 decl_stream << "#else\n";
144 decl_stream << " #define uint unsigned int\n";
145 decl_stream << " #define uchar unsigned char\n";
146 decl_stream << " #define ushort unsigned short\n";
147 decl_stream << " #define int64_t long long\n";
148 decl_stream << " #define uint64_t unsigned long long\n";
149 decl_stream << "#endif\n";
150
151 return CodeGenC::Finish();
152}
153
154void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
155 ICHECK(is_const_int(op->min, 0));
156 if (op->kind == tir::ForKind::kUnrolled) {
157 PrintIndent();
158 stream << "#pragma unroll\n";
159 }
160 CodeGenC::VisitStmt_(op);
161}
162
163void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
164 ICHECK(!var_idmap_.count(iv->var.get()));
165 var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
166}
167
168void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
169 int lanes = t.lanes();
170 if (t.is_handle()) {
171 ICHECK(t.is_scalar()) << "do not yet support vector types";
172 os << "void*";
173 return;
174 }
175
176 if (t.is_void()) {
177 os << "void";
178 return;
179 }
180
181 bool fail = false;
182 if (t.is_float()) {
183 switch (t.bits()) {
184 case 16:
185 enable_fp16_ = true;
186 if (t.is_scalar()) {
187 os << "half";
188 } else if (lanes <= 8) {
189 // Emit CUDA code to access fp16 vector elements.
190 //
191 // half4 is stored as uint2
192 //
193 // h4.x is emitted as *(half2*)(&(u2.x)).x
194 // h4.y is emitted as *(half2*)(&(u2.x)).y
195 // h4.z is emitted as *(half2*)(&(u2.y)).x
196 // h4.w is emitted as *(half2*)(&(u2.y)).y
197 //
198 ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
199 os << "uint" << lanes / 2;
200 } else {
201 fail = true;
202 }
203 break;
204 case 32:
205 if (lanes <= 4) {
206 os << "float";
207 } else if (lanes <= 8) {
208 // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
209 //
210 // float8 is stored as ulonglong4
211 //
212 // f8.v1 is emitted as *(float2*)(&(ul4.x)).x
213 // f8.v2 is emitted as *(float2*)(&(ul4.x)).y
214 //
215 ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
216 os << "ulonglong" << lanes / 2;
217 } else {
218 fail = true;
219 }
220 break;
221 case 64:
222 os << "double";
223 break;
224 default:
225 fail = true;
226 break;
227 }
228 if (!fail && (t.is_scalar() || t.bits() == 16)) return;
229 if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return;
230 if (!fail && (lanes >= 2 && lanes <= 4)) {
231 os << lanes;
232 return;
233 }
234 } else if (t.is_bfloat16()) {
235 enable_bf16_ = true;
236 if (t.is_scalar()) {
237 os << "nv_bfloat16";
238 } else if (lanes <= 8) {
239 ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
240 os << "uint" << lanes / 2;
241 } else {
242 fail = true;
243 }
244 if (!fail) return;
245 } else if (t == DataType::Bool()) {
246 os << "bool";
247 return;
248 } else if (t.is_vector_bool()) {
249 // CUDA does not support bool vectors.
250 // Use ushort vectors to represent instead.
251 int n = t.lanes();
252 if (n <= 4) {
253 os << "ushort" << n;
254 return;
255 }
256 } else if (t.is_uint() || t.is_int()) {
257 if (t.is_uint()) {
258 os << "u";
259 }
260 switch (t.bits()) {
261 case 1: {
262 if (t.is_scalar()) {
263 os << "int";
264 return;
265 } else if (t.lanes() == 8) {
266 os << "int8_t";
267 return;
268 } else if (t.lanes() == 16) {
269 os << "int16_t";
270 return;
271 } else if (t.lanes() == 32) {
272 os << "int";
273 return;
274 } else {
275 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
276 }
277 }
278 case 4: {
279 if (t.is_scalar()) {
280 os << "int";
281 return;
282 } else if (t.lanes() == 4) {
283 os << "int16_t";
284 return;
285 } else if (t.lanes() == 8) {
286 // directly 8 4-bit int in integer.
287 os << "int";
288 return;
289 } else if (t.lanes() == 16) {
290 os << "int2";
291 return;
292 } else if (t.lanes() == 32) {
293 os << "int4";
294 return;
295 } else if (t.lanes() == 64) {
296 os << "int8";
297 return;
298 } else {
299 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
300 }
301 }
302 case 8: {
303 if (t.lanes() == 4) {
304 // directly 4 8 bit int in integer.
305 enable_int8_ = true;
306
307 // We use int for int8x4 instead of char4 because using char4 is
308 // likely to produce extra instructions to pack four int8 elements
309 // into 32-bit data.
310 os << "int";
311 return;
312 } else if (t.lanes() == 8) {
313 enable_int8_ = true;
314 os << "int2";
315 return;
316 } else if (t.lanes() == 16) {
317 enable_int8_ = true;
318 os << "int4";
319 return;
320 } else if (!t.is_uint() && t.is_scalar()) {
321 os << "signed char";
322 break;
323 } else {
324 os << "char";
325 break;
326 }
327 }
328 case 16: {
329 if (t.is_scalar()) {
330 os << "short";
331 } else if (t.lanes() <= 4) {
332 os << "short" << lanes;
333 } else if (t.lanes() <= 8) {
334 // Emit CUDA code to access int16 vector elements.
335 //
336 // short4 is stored as int2
337 //
338 // s4.x is emitted as *(short2*)(&(i2.x)).x
339 // s4.y is emitted as *(short2*)(&(i2.x)).y
340 // s4.z is emitted as *(short2*)(&(i2.y)).x
341 // s4.w is emitted as *(short2*)(&(i2.y)).y
342 //
343 ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4";
344 os << "int" << t.lanes() / 2;
345 } else {
346 fail = true;
347 }
348 if (!fail) {
349 return;
350 }
351 break;
352 }
353 case 32: {
354 if (t.is_scalar()) {
355 os << "int";
356 } else if (t.lanes() <= 4) {
357 os << "int" << t.lanes();
358 } else if (t.lanes() <= 8) {
359 // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
360 //
361 // int8 is stored as longlong4
362 //
363 // i8.v1 is emitted as *(int2*)(&(l4.x)).x
364 // i8.v2 is emitted as *(int2*)(&(l4.x)).y
365 //
366 ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
367 os << "longlong" << lanes / 2;
368 } else {
369 fail = true;
370 }
371 if (!fail) {
372 return;
373 }
374 break;
375 }
376 case 64: {
377 if (t.is_scalar()) {
378 os << "int64_t";
379 } else if (t.lanes() == 2) {
380 os << "longlong2";
381 } else if (t.lanes() == 3) {
382 os << "longlong3";
383 } else if (t.lanes() == 4) {
384 os << "longlong4";
385 }
386 return;
387 }
388 default:
389 fail = true;
390 break;
391 }
392 if (!fail && lanes == 1) {
393 return;
394 }
395 if (!fail && (lanes >= 2 && lanes <= 4)) {
396 os << lanes;
397 return;
398 }
399 }
400 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
401}
402
403void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
404 std::ostream& os) { // NOLINT(*)
405 // Delcare the result.
406 std::string sret = name_supply_->FreshName("_");
407 this->PrintIndent();
408 this->PrintType(t, stream);
409 stream << ' ' << sret << ";\n";
410 int ssa_scope = BeginScope();
411 {
412 // Unpack into individual ops.
413 std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
414 std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
415
416 for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
417 std::ostringstream value_temp;
418 if (isalpha(op[0])) {
419 value_temp << op << "(";
420 PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
421 value_temp << ", ";
422 PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
423 value_temp << ")";
424 } else {
425 value_temp << "(";
426 PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
427 value_temp << op;
428 PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
429 value_temp << ")";
430 }
431 PrintVecElemStore(sret, t, i, value_temp.str());
432 }
433 }
434 EndScope(ssa_scope);
435 os << sret;
436}
437
438void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
439 std::ostream& os) { // NOLINT(*)
440 if (t.is_scalar()) {
441 os << vec;
442 return;
443 }
444
445 static const char access[] = {'x', 'y', 'z', 'w'};
446 ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
447 if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
448 std::string type_name = t.is_int() ? "char" : "unsigned char";
449 if (t.lanes() == 2 || t.lanes() == 3) {
450 os << vec << "." << access[i % t.lanes()];
451 } else {
452 std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
453 os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
454 }
455 } else if (t.is_float16()) {
456 os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
457 } else if (t.is_bfloat16()) {
458 os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
459 } else if (t.lanes() > 4 && t.lanes() <= 8) {
460 std::string type_name;
461 if (t.bits() == 16) {
462 if (t.is_int()) {
463 type_name = "short";
464 } else if (t.is_uint()) {
465 type_name = "ushort";
466 }
467 } else if (t.bits() == 32) {
468 if (t.is_int()) {
469 type_name = "int";
470 } else if (t.is_uint()) {
471 type_name = "uint";
472 } else if (t.is_float()) {
473 type_name = "float";
474 }
475 }
476 ICHECK(!type_name.empty());
477 os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
478 } else {
479 os << vec << "." << access[i];
480 }
481}
482
483void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
484 const std::string& value) {
485 this->PrintIndent();
486 static const char access[] = {'x', 'y', 'z', 'w'};
487 ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
488 if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
489 if (t.lanes() == 2 || t.lanes() == 3) {
490 stream << vec << '.' << access[i % t.lanes()] << "="
491 << "(" << value << ");\n";
492 } else {
493 std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
494 stream << ac << "=";
495 // Do not read the first undef lane.
496 if (i != 0) {
497 stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
498 }
499 stream << "(" << value << " << " << i % 4 * 8 << ");\n";
500 }
501 } else if (t.is_float16()) {
502 stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
503 << value << ";\n";
504 } else if (t.is_bfloat16()) {
505 stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
506 << " = " << value << ";\n";
507 } else if (t.lanes() > 4 && t.lanes() <= 8) {
508 std::string type_name;
509 if (t.bits() == 16) {
510 if (t.is_int()) {
511 type_name = "short";
512 } else if (t.is_uint()) {
513 type_name = "ushort";
514 }
515 } else if (t.bits() == 32) {
516 if (t.is_int()) {
517 type_name = "int";
518 } else if (t.is_uint()) {
519 type_name = "uint";
520 } else if (t.is_float()) {
521 type_name = "float";
522 }
523 }
524 ICHECK(!type_name.empty());
525 stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->"
526 << access[i % 2] << " = " << value << ";\n";
527 } else {
528 stream << vec << "." << access[i] << " = " << value << ";\n";
529 }
530}
531
532void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
533 const std::string& sync = op->args[0].as<StringImmNode>()->value;
534 if (sync == "warp") {
535 // DO nothing.
536 } else if (sync == "shared" || sync == "shared.dyn") {
537 this->PrintIndent();
538 this->stream << "__syncthreads();\n";
539 } else if (sync == "global") {
540 if (!need_global_barrier_) {
541 need_global_barrier_ = true;
542 this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_
543 << ";\n";
544 }
545 // global synchronizer
546 std::string is_load = PrintExpr(op->args[1]);
547 std::string num_blocks = PrintExpr(op->args[2]);
548 this->PrintIndent();
549 // In theory only threadfence is needed
550 // but we observed problems with only threadfence
551 this->stream << "__threadfence_system();\n";
552 this->PrintIndent();
553 this->stream << "if (" << is_load << ") {\n";
554 int wb = this->BeginScope();
555 this->PrintIndent();
556 this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
557 this->PrintIndent();
558 std::string ptr = name_supply_->FreshName("pf");
559 this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n";
560 this->PrintIndent();
561 this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
562 this->PrintIndent();
563 this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n";
564 this->EndScope(wb);
565 this->PrintIndent();
566 this->stream << "}\n";
567 this->PrintIndent();
568 this->stream << "__syncthreads();\n";
569 }
570}
571
572void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
573 ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
574 "all global arrays as input instead";
575 if (scope == "shared") {
576 os << "__shared__ ";
577 } else if (scope == "shared.dyn") {
578 os << "extern __shared__ ";
579 }
580}
581
582std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType target) {
583 if (from == target) return value;
584 std::ostringstream os;
585 os << "((";
586 this->PrintType(target, os);
587 os << ")";
588 if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) {
589 os << "(";
590 if (target.is_uint()) {
591 os << "u";
592 }
593 os << "int)";
594 }
595 os << value << ")";
596 return os.str();
597}
598
599void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
600 DataType from_ty = op->value.dtype();
601 DataType target_ty = op->dtype;
602 ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
603
604 // Emit simple C-style type conversion.
605 if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
606
607 // We could emit make_float4 like calls, but the emitted code looks
608 // too compact to read. Emit this as vectorized unary ops.
609 std::string sret = name_supply_->FreshName("_");
610 this->PrintIndent();
611 this->PrintType(target_ty, stream);
612 stream << ' ' << sret << ";\n";
613 {
614 std::string src = SSAGetID(PrintExpr(op->value), from_ty);
615 for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
616 std::ostringstream val;
617 val << "(";
618 PrintType(target_ty.element_of(), val);
619 val << ")(";
620 PrintVecElemLoad(src, from_ty, i, val);
621 val << ")";
622 PrintVecElemStore(sret, target_ty, i, val.str());
623 }
624 }
625 os << sret;
626}
627
628void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
629 bool skip_first_arg, std::ostream& os) { // NOLINT(*)
630 DataType ret_dtype = GetRuntimeDataType(ret_type);
631 if (ret_dtype.is_vector()) {
632 //
633 // Emit an unsupported vector call
634 //
635 // v = intrin_f((float4*)A[0], (float4*)B[0])
636 //
637 // as
638 //
639 // float4 __ret;
640 // {
641 // float4 __arg0 = ((float4*)A)[0];
642 // float4 __arg1 = ((float4*)B)[0];
643 // __ret.x = intrin_f(__arg0.x, __arg1.x);
644 // __ret.y = intrin_f(__arg0.y, __arg1.y);
645 // __ret.z = intrin_f(__arg0.z, __arg1.z);
646 // __ret.w = intrin_f(__arg0.w, __arg1.w);
647 // }
648 // v = __ret;
649 //
650 // Declare the result vector.
651 std::string sret = name_supply_->FreshName("_");
652 this->PrintIndent();
653 this->PrintType(ret_dtype, stream);
654 stream << ' ' << sret << ";\n";
655 {
656 // Load arguments.
657 std::vector<std::string> sargs;
658 size_t arg_begin = static_cast<size_t>(skip_first_arg);
659 for (size_t i = arg_begin; i < args.size(); ++i) {
660 std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
661 sargs.push_back(std::move(val));
662 }
663
664 // Emit a scalar call for each lane.
665 for (int i = 0; i < ret_dtype.lanes(); ++i) {
666 std::ostringstream scall;
667 scall << global_symbol << "(";
668 for (size_t j = 0; j < sargs.size(); ++j) {
669 if (j > 0) scall << ", ";
670 PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
671 }
672 scall << ")";
673 PrintVecElemStore(sret, ret_dtype, i, scall.str());
674 }
675 }
676 os << sret;
677 } else {
678 CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
679 }
680}
681
682void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
683 if (auto* ptr_op = op->op.as<OpNode>()) {
684 Op call_op = GetRef<Op>(ptr_op);
685 // This is only for backward compatibility with __shfl_{up/down}.
686 // A macro will be used to replace *_sync calls to legacy ones.
687 if (op_need_warp_shuffle_.get(call_op, false)) {
688 enable_warp_shuffle_ = true;
689 }
690 }
691
692 if (op->op.same_as(builtin::tvm_fill_fragment())) {
693 need_mma_h_ = true;
694 ICHECK_EQ(op->args.size(), 6U);
695 os << "nvcuda::wmma::fill_fragment(";
696 this->PrintExpr(op->args[0], os);
697 os << "[";
698 this->PrintExpr(op->args[4], os);
699 os << "], ";
700 this->PrintExpr(op->args[5], os);
701 os << ")";
702 } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
703 need_mma_h_ = true;
704 ICHECK_EQ(op->args.size(), 8U);
705 os << "nvcuda::wmma::load_matrix_sync(";
706 this->PrintExpr(op->args[0], os);
707 os << "[";
708 this->PrintExpr(op->args[4], os);
709 os << "], ";
710 this->PrintExpr(op->args[5], os);
711 os << ", ";
712 this->PrintExpr(op->args[6], os);
713 os << ")";
714 } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
715 need_mma_h_ = true;
716 ICHECK_EQ(op->args.size(), 8U);
717 os << "nvcuda::wmma::store_matrix_sync(";
718 this->PrintExpr(op->args[5], os);
719 os << ", ";
720 this->PrintExpr(op->args[0], os);
721 os << "[";
722 this->PrintExpr(op->args[4], os);
723 os << "], ";
724 this->PrintExpr(op->args[6], os);
725 if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
726 os << ", nvcuda::wmma::mem_" << str->value;
727 } else {
728 LOG(FATAL) << "Invalid parameters";
729 }
730 os << ")";
731 } else if (op->op.same_as(builtin::tvm_mma_sync())) {
732 need_mma_h_ = true;
733 ICHECK_EQ(op->args.size(), 8U);
734 os << "nvcuda::wmma::mma_sync(";
735 for (int i = 0; i < 4; ++i) {
736 this->PrintExpr(op->args[i * 2], os);
737 os << "[";
738 this->PrintExpr(op->args[i * 2 + 1], os);
739 os << "]" << ((i < 3) ? ", " : ")");
740 }
741 } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
742 need_mma_h_ = true;
743 ICHECK_EQ(op->args.size(), 8U);
744 os << "nvcuda::wmma::bmma_sync(";
745 for (int i = 0; i < 4; ++i) {
746 this->PrintExpr(op->args[i * 2], os);
747 os << "[";
748 this->PrintExpr(op->args[i * 2 + 1], os);
749 os << "]" << ((i < 3) ? ", " : ")");
750 }
751 } else if (op->op.same_as(builtin::ptx_mma())) {
752 // arg 0: shape: mXnXkX
753 // arg 1: A layout: row/col
754 // arg 2: B layout: row/col
755 // arg 3: A precision: fp16, fp64, ...
756 // arg 4: B precision: fp16, fp64, ...
757 // arg 5: C precision: fp32, fp64, ...
758 // arg 6: A multiplicand
759 // arg 7: A multiplicand index
760 // arg 8: B multiplicand
761 // arg 9: B multiplicand index
762 // arg 10: C accumulator
763 // arg 11: C accumulator index
764 // arg 12: saturate
765 // arg 13: (optional) 1-bit operator (xor or and)
766 ICHECK(op->args.size() == 13U || op->args.size() == 14U);
767 std::string shape = Downcast<StringImm>(op->args[0])->value;
768 std::string A_layout = Downcast<StringImm>(op->args[1])->value;
769 std::string B_layout = Downcast<StringImm>(op->args[2])->value;
770 std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
771 std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
772 std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
773 std::string a_ref = this->PrintExpr(op->args[6]);
774 std::string a_bias = this->PrintExpr(op->args[7]);
775 std::string b_ref = this->PrintExpr(op->args[8]);
776 std::string b_bias = this->PrintExpr(op->args[9]);
777 std::string c_ref = this->PrintExpr(op->args[10]);
778 std::string c_bias = this->PrintExpr(op->args[11]);
779 bool saturate = Downcast<Bool>(op->args[12])->value;
780 std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
781 std::string asm_code =
782 PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref,
783 b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
784
785 this->stream << asm_code;
786 } else if (op->op.same_as(builtin::ptx_mma_sp())) {
787 // arg 0: shape: mXnXkX
788 // arg 1: A layout: row/col
789 // arg 2: B layout: row/col
790 // arg 3: A precision: fp16, fp32, ...
791 // arg 4: B precision: fp16, fp32, ...
792 // arg 5: C precision: fp16, fp32, ...
793 // arg 6: A multiplicand pointer
794 // arg 7: A multiplicand index
795 // arg 8: B multiplicand pointer
796 // arg 9: B multiplicand index
797 // arg 10: C accumulator pointer
798 // arg 11: C accumulator index
799 // arg 12: metadata
800 // arg 13: metadata index
801 // arg 14: sparse_selector
802 // arg 15: saturate
803 ICHECK_EQ(op->args.size(), 16U);
804 std::string shape = Downcast<StringImm>(op->args[0])->value;
805 std::string A_layout = Downcast<StringImm>(op->args[1])->value;
806 std::string B_layout = Downcast<StringImm>(op->args[2])->value;
807 std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
808 std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
809 std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
810 std::string a_ref = this->PrintExpr(op->args[6]);
811 std::string a_offset = this->PrintExpr(op->args[7]);
812 std::string b_ref = this->PrintExpr(op->args[8]);
813 std::string b_offset = this->PrintExpr(op->args[9]);
814 std::string c_ref = this->PrintExpr(op->args[10]);
815 std::string c_offset = this->PrintExpr(op->args[11]);
816 std::string metadata = this->PrintExpr(op->args[12]);
817 std::string metadata_offset = this->PrintExpr(op->args[13]);
818 std::string sparse_selector = this->PrintExpr(op->args[14]);
819 bool saturate = Downcast<Bool>(op->args[15])->value;
820 std::string asm_code = PrintMMAAssembly(
821 shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
822 c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
823 this->stream << asm_code;
824 } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
825 // arg 0: whether the matrix is loaded in column major format or not.
826 // arg 1: number of matrices to load.
827 // arg 2: The data type in the matrix, .b16 is the only accepted data type.
828 // arg 3: pointer to local buffer.
829 // arg 4: The offset of the element to store in the local buffer.
830 // arg 5: pointer to the shared memory buffer to load.
831 // arg 6: The offset of the start element of the row to load in shared memory.
832 ICHECK_EQ(op->args.size(), 7U);
833 bool trans = Downcast<Bool>(op->args[0])->value;
834 int num = Downcast<Integer>(op->args[1])->value;
835 std::string type = Downcast<StringImm>(op->args[2])->value;
836 std::string local_ptr = this->PrintExpr(op->args[3]);
837 std::string local_elem_offset = this->PrintExpr(op->args[4]);
838 std::string smem_ptr = this->PrintExpr(op->args[5]);
839 if (trans && op->dtype.bits() == 8) {
840 // Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an
841 // int8 matrix.
842 std::string smem_stride = this->PrintExpr(op->args[6]);
843 ICHECK(num == 4);
844 os << "for (int i = 0; i < 16; ++i) {\n";
845 os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
846 << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
847 "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n";
848 os << "}\n";
849 } else {
850 std::string smem_elem_offset = this->PrintExpr(op->args[6]);
851 this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
852 smem_ptr, smem_elem_offset);
853 }
854 } else if (op->op.same_as(builtin::mma_store())) {
855 int m = Downcast<Integer>(op->args[0])->value;
856 int n = Downcast<Integer>(op->args[1])->value;
857 std::string dst = this->PrintExpr(op->args[2]);
858 std::string src = this->PrintExpr(op->args[3]);
859 std::string src_offset = this->PrintExpr(op->args[4]);
860 PrimExpr stride = op->args[5];
861
862 ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now";
863
864 // Each thread in a warp holds a certain number of elements of an MMA output.
865 // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements
866 // in its registers. So conceptually, a warp memory is organized as a 32x8 block.
867 // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below.
868
869 // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map
870 // to determine the output location for each 8 element.
871
872 const auto* index_map_func =
873 runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout");
874 ICHECK(index_map_func);
875
876 auto inverse_index_map =
877 IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)});
878 auto indices_16x16 = inverse_index_map->final_indices;
879
880 // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine.
881 // FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them
882 // to the plain ones here.
883 class LowerFloorDivMod : public ExprMutator {
884 public:
885 PrimExpr VisitExpr_(const FloorDivNode* op) {
886 return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
887 }
888 PrimExpr VisitExpr_(const FloorModNode* op) {
889 return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
890 }
891 };
892
893 auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
894
895 var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
896 var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
897
898 os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
899 os << dst << "[" + this->PrintExpr(dst_ind) + "]"
900 << " = " << src << "[" << src_offset << " + local_id];\n";
901 os << "}\n";
902
903 } else if (op->op.same_as(builtin::mma_fill())) {
904 std::string num_elem = this->PrintExpr(op->args[0]);
905 std::string dst = this->PrintExpr(op->args[1]);
906 std::string dst_offset = this->PrintExpr(op->args[2]);
907
908 os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
909 os << dst << "[" << dst_offset << " + i] = 0.0;";
910 os << "}\n";
911 } else if (op->op.same_as(builtin::ptx_cp_async())) {
912 std::string dst = this->PrintExpr(op->args[0]);
913 std::string dst_offset = this->PrintExpr(op->args[1]);
914 std::string src = this->PrintExpr(op->args[2]);
915 std::string src_offset = this->PrintExpr(op->args[3]);
916 std::string size = this->PrintExpr(op->args[4]);
917 this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
918 } else if (op->op.same_as(builtin::ptx_commit_group())) {
919 this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
920 } else if (op->op.same_as(builtin::ptx_wait_group())) {
921 std::string N = this->PrintExpr(op->args[0]);
922 this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n";
923 } else {
924 CodeGenC::VisitExpr_(op, os);
925 }
926}
927
928void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
929 if (op->attr_key == tir::attr::fragment_shape) {
930 const VarNode* buffer = op->node.as<VarNode>();
931 const StringImmNode* shape_str = op->value.as<StringImmNode>();
932 fragment_shapes[buffer] = shape_str->value;
933 } else if (op->attr_key == tir::attr::fragment_layout) {
934 const VarNode* buffer = op->node.as<VarNode>();
935 const StringImmNode* layout_str = op->value.as<StringImmNode>();
936 fragment_layouts[buffer] = layout_str->value;
937 } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
938 const IntImmNode* queue_id = op->value.as<IntImmNode>();
939 ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
940 this->VisitStmt(op->body);
941 auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
942 this->VisitExpr(commit_group, this->stream);
943 return;
944 } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
945 auto wait_attrs = GetAsyncWaitAttributes(op);
946 auto queue_id = wait_attrs.first.as<IntImmNode>();
947 ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
948 auto wait_cnt = wait_attrs.second;
949 auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
950 this->VisitExpr(wait_group, this->stream);
951 auto inner = op->body.as<AttrStmtNode>();
952 ICHECK(inner);
953 this->VisitStmt(inner->body);
954 return;
955 }
956 CodeGenC::VisitStmt_(op);
957}
958
959void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
960 ICHECK(!is_zero(op->condition));
961 std::string vid = AllocVarID(op->buffer_var.get());
962
963 this->PrintIndent();
964 std::string scope = GetPtrStorageScope(op->buffer_var);
965 const VarNode* buffer = op->buffer_var.as<VarNode>();
966 if (scope.find("wmma.") == 0) {
967 if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
968 ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
969 op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
970 op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
971 op->dtype == DataType::BFloat(16))
972 << "Matrix_a and matrix_b only support half or char or unsigned char "
973 << "or uint4 or int4 or int1 type for now";
974 } else {
975 ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
976 op->dtype == DataType::Int(32))
977 << "Accumulator only support half, float and int type for now";
978 }
979 PrintWmmaScope(scope, op->dtype, buffer, stream);
980 } else {
981 PrintStorageScope(scope, stream);
982 PrintType(op->dtype, stream);
983 }
984
985 if (scope == "shared.dyn") {
986 stream << ' ' << vid << "[];\n";
987 } else {
988 size_t constant_size = op->ConstantAllocationSize();
989 ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
990
991 if (scope.find("wmma.") == 0) {
992 constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
993 }
994 if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
995 op->dtype == DataType::Int(1)) &&
996 scope == "shared") {
997 constant_size = constant_size / (32 / op->dtype.bits());
998 }
999 stream << ' ' << vid << '[' << constant_size << "];\n";
1000 }
1001
1002 RegisterHandleType(op->buffer_var.get(), op->dtype);
1003 this->PrintStmt(op->body);
1004}
1005
1006void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
1007 if (is_const_int(op->value)) return;
1008 const CallNode* call = op->value.as<CallNode>();
1009 if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
1010 PrintIndent();
1011 stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
1012 PrintIndent();
1013 stream << "if (threadIdx.x == 0) {\n";
1014 PrintIndent();
1015 stream << " " << vid_global_barrier_expect_ << " = 0;\n";
1016 PrintIndent();
1017 stream << "}\n";
1018 } else {
1019 CodeGenC::VisitStmt_(op);
1020 }
1021}
1022
1023void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
1024 CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
1025 os << "(make_";
1026 PrintType(op->dtype, os);
1027 os << "(";
1028 for (int i = 0; i < op->lanes; i++) {
1029 os << "(" << PrintExpr(op->base) << ")"
1030 << "+(" << PrintExpr(op->stride) << "*" << i << ")";
1031 if (i != op->lanes - 1) os << ", ";
1032 }
1033 os << "))";
1034}
1035
1036void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
1037 if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) {
1038 // make_int8x4
1039 const int64_t* p = as_const_int(op->value);
1040 ICHECK(p);
1041 int64_t v = *p & 0xFF;
1042 v = (v << 24) | (v << 16) | (v << 8) | v;
1043 if (op->dtype.is_uint()) {
1044 os << "(uint)" << v;
1045 } else {
1046 os << "(int)" << v;
1047 }
1048 return;
1049 }
1050
1051 if (op->dtype.is_float16()) {
1052 std::string v = PrintExpr(op->value);
1053 os << "make_";
1054 PrintType(op->dtype, os);
1055 os << '(';
1056 for (int i = 0; i < op->lanes / 2; ++i) {
1057 if (i != 0) os << ", ";
1058 os << "__pack_half2(" << v << ", " << v << ")";
1059 }
1060 os << ')';
1061 return;
1062 }
1063
1064 if (op->dtype.is_bfloat16()) {
1065 std::string v = PrintExpr(op->value);
1066 os << "make_";
1067 PrintType(op->dtype, os);
1068 os << '(';
1069 for (int i = 0; i < op->lanes / 2; ++i) {
1070 if (i != 0) os << ", ";
1071 os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
1072 }
1073 os << ')';
1074 return;
1075 }
1076
1077 if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
1078 bool fail = false;
1079 const int64_t* p = as_const_int(op->value);
1080 ICHECK(p);
1081 int64_t v = *p & 0xF;
1082
1083 if (op->lanes == 4) {
1084 v = (v << 12) | (v << 8) | (v << 4) | v;
1085 if (op->dtype.is_uint()) {
1086 os << "(uint16_t)" << v;
1087 } else {
1088 os << "(int16_t)" << v;
1089 }
1090 } else {
1091 v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
1092 if (op->lanes == 8) {
1093 if (op->dtype.is_uint()) {
1094 os << "(uint)" << v;
1095 } else {
1096 os << "(int)" << v;
1097 }
1098 } else if (op->lanes == 16 || op->lanes == 32) {
1099 os << "make_";
1100 PrintType(op->dtype, os);
1101 os << '(';
1102 for (int i = 0; i < op->lanes / 8; ++i) {
1103 if (i != 0) os << ", ";
1104 if (op->dtype.is_uint()) {
1105 os << "(uint)" << v;
1106 } else {
1107 os << "(int)" << v;
1108 }
1109 }
1110 os << ')';
1111 } else {
1112 fail = true;
1113 }
1114 }
1115
1116 if (!fail) {
1117 return;
1118 }
1119 }
1120
1121 std::string v = PrintExpr(op->value);
1122 os << "make_";
1123 PrintType(op->dtype, os);
1124 os << '(';
1125 for (int i = 0; i < op->lanes; ++i) {
1126 if (i != 0) os << ", ";
1127 os << v;
1128 }
1129 os << ')';
1130}
1131
1132void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
1133 std::vector<std::string> to_shuffle(op->vectors.size());
1134 for (int i = 0, e = op->vectors.size(); i < e; ++i) {
1135 ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
1136 to_shuffle[i] = PrintExpr(op->vectors[i]);
1137 }
1138 os << "make_";
1139 PrintType(op->dtype, os);
1140 os << '(';
1141 for (int i = 0, e = op->indices.size(); i < e; ++i) {
1142 const int64_t* val = as_const_int(op->indices[i]);
1143 ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
1144 if (i != 0) os << ", ";
1145 os << to_shuffle[*val];
1146 }
1147 os << ')';
1148}
1149
1150void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
1151 // Non-vector cases.
1152 if (!op->dtype.is_vector()) {
1153 CodeGenC::VisitExpr_(op, os);
1154 return;
1155 }
1156
1157 // Codegen vector condition case by serializing the select op.
1158 ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype &&
1159 op->dtype.lanes() == op->condition.dtype().lanes());
1160
1161 std::string r_var = name_supply_->FreshName("_");
1162 this->PrintIndent();
1163 this->PrintType(op->dtype, stream);
1164 stream << ' ' << r_var << ";\n";
1165 {
1166 std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
1167 std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
1168 std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
1169
1170 // The condition is stored as an ushort vector.
1171 int lanes = op->dtype.lanes();
1172 DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);
1173
1174 for (int i = 0; i < lanes; ++i) {
1175 std::ostringstream item;
1176 item << "(bool(";
1177 PrintVecElemLoad(c_var, memory_ty, i, item);
1178 item << ")?";
1179 PrintVecElemLoad(t_var, op->dtype, i, item);
1180 item << ':';
1181 PrintVecElemLoad(f_var, op->dtype, i, item);
1182 item << ')';
1183 PrintVecElemStore(r_var, op->dtype, i, item.str());
1184 }
1185 }
1186 os << r_var;
1187}
1188
1189inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
1190 // Type code is kBFloat
1191 if (op->dtype.is_bfloat16()) {
1192 os << "__float2bfloat16_rn";
1193 os << '(' << std::scientific << op->value << 'f' << ')';
1194 return;
1195 }
1196 // Type code is kFloat
1197 switch (op->dtype.bits()) {
1198 case 64:
1199 case 32: {
1200 std::ostringstream temp;
1201 if (std::isinf(op->value)) {
1202 if (op->value < 0) {
1203 temp << "-";
1204 }
1205 temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
1206 p->need_math_constants_h_ = true;
1207 } else if (std::isnan(op->value)) {
1208 temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
1209 p->need_math_constants_h_ = true;
1210 } else {
1211 temp << std::scientific << op->value;
1212 if (op->dtype.bits() == 32) temp << 'f';
1213 }
1214 p->MarkConst(temp.str());
1215 os << temp.str();
1216 break;
1217 }
1218 case 16: {
1219 os << "__float2half_rn" << '(';
1220 FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
1221 PrintConst(const_f32.get(), os, p);
1222 os << ')';
1223 break;
1224 }
1225 default:
1226 LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
1227 }
1228}
1229
1230void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
1231 PrintConst(op, os, this);
1232}
1233
1234void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
1235 std::ostream& os) {
1236 std::stringstream type;
1237 PrintType(t, type);
1238 std::string shape_str = fragment_shapes.at(variable);
1239 if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
1240 type.str(std::string());
1241 if (t.is_int()) {
1242 if (t.bits() == 4) {
1243 type << "nvcuda::wmma::experimental::precision::s4";
1244 } else if (t.bits() == 1) {
1245 type << "nvcuda::wmma::experimental::precision::b1";
1246 } else {
1247 LOG(FATAL) << "Unhandled interger type for wmma fragment!";
1248 }
1249 } else if (t.is_uint()) {
1250 if (t.bits() == 4) {
1251 type << "nvcuda::wmma::experimental::precision::u4";
1252 } else {
1253 LOG(FATAL) << "Unhandled interger type for wmma fragment!";
1254 }
1255 }
1256 }
1257 if (scope == "wmma.matrix_a") {
1258 need_mma_h_ = true;
1259 std::string layout_str = fragment_layouts[variable];
1260 ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
1261 os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
1262 << ", nvcuda::wmma::" << layout_str << ">";
1263 } else if (scope == "wmma.matrix_b") {
1264 need_mma_h_ = true;
1265 std::string layout_str = fragment_layouts[variable];
1266 ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
1267 os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
1268 << ", nvcuda::wmma::" << layout_str << ">";
1269 } else if (scope == "wmma.accumulator") {
1270 need_mma_h_ = true;
1271 os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str()
1272 << ">";
1273 }
1274}
1275
1276int stoi(const std::string& str) {
1277 try {
1278 return std::stoi(str);
1279 } catch (std::invalid_argument& e) {
1280 LOG(FATAL) << "Cannot convert \"" << str << "\" to int";
1281 throw;
1282 }
1283}
1284
1285int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable,
1286 int32_t size) {
1287 std::string shape_str = fragment_shapes.at(variable);
1288 size_t m, n, k;
1289 size_t last_pos = 0, pos = 0;
1290 pos = shape_str.find(", ", last_pos);
1291 m = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos));
1292 last_pos = pos + 2;
1293 pos = shape_str.find(", ", last_pos);
1294 n = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos));
1295 last_pos = pos + 2;
1296 k = tvm::codegen::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
1297 if (scope == "wmma.matrix_a") {
1298 return size / m / k;
1299 } else if (scope == "wmma.matrix_b") {
1300 return size / n / k;
1301 } else if (scope == "wmma.accumulator") {
1302 return size / m / n;
1303 }
1304 return 0;
1305}
1306
1307void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
1308 std::ostream& os) {
1309 // Cast away volatile qualifier for fp16 types. That is, only loads and
1310 // stores are volatile. The loaded objects are not marked as volatile.
1311 //
1312 if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) {
1313 os << "(";
1314 PrintType(op->dtype, os);
1315 os << ")(" << value << ")";
1316 } else {
1317 os << value;
1318 }
1319}
1320
1321void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
1322 std::ostream& os) {
1323 ICHECK_GT(t.lanes(), 1);
1324 if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
1325 if (!(t.lanes() == 2 || t.lanes() == 3)) {
1326 if (i != 0) {
1327 os << "|";
1328 }
1329 os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
1330 return;
1331 }
1332 }
1333
1334 if (t.is_float16()) {
1335 if (i == 0) {
1336 os << "make_";
1337 PrintType(t, os);
1338 os << '(';
1339 }
1340 if (i % 2 == 0) {
1341 os << "__pack_half2(" << value;
1342 } else {
1343 os << "," << value << ")";
1344 if (i != t.lanes() - 1) {
1345 os << ",";
1346 } else {
1347 os << ")";
1348 }
1349 }
1350 return;
1351 }
1352
1353 if (t.is_bfloat16()) {
1354 if (i == 0) {
1355 os << "make_";
1356 PrintType(t, os);
1357 os << '(';
1358 }
1359 if (i % 2 == 0) {
1360 os << "__pack_bfloat162(" << value;
1361 } else {
1362 os << "," << value << ")";
1363 if (i != t.lanes() - 1) {
1364 os << ",";
1365 } else {
1366 os << ")";
1367 }
1368 }
1369 return;
1370 }
1371
1372 if (i == 0) {
1373 os << "make_";
1374 PrintType(t, os);
1375 os << "(";
1376 }
1377 os << value;
1378 if (i != t.lanes() - 1) {
1379 os << ",";
1380 } else {
1381 os << ")";
1382 }
1383 return;
1384}
1385
1386} // namespace codegen
1387} // namespace tvm
1388