1#include "taichi/codegen/cuda/codegen_cuda.h"
2
3#include <vector>
4#include <set>
5#include <functional>
6
7#include "taichi/common/core.h"
8#include "taichi/util/io.h"
9#include "taichi/ir/ir.h"
10#include "taichi/ir/statements.h"
11#include "taichi/program/program.h"
12#include "taichi/util/lang_util.h"
13#include "taichi/rhi/cuda/cuda_driver.h"
14#include "taichi/rhi/cuda/cuda_context.h"
15#include "taichi/runtime/program_impls/llvm/llvm_program.h"
16#include "taichi/util/action_recorder.h"
17#include "taichi/analysis/offline_cache_util.h"
18#include "taichi/ir/analysis.h"
19#include "taichi/ir/transforms.h"
20#include "taichi/codegen/codegen_utils.h"
21
22namespace taichi::lang {
23
24using namespace llvm;
25
26// NVVM IR Spec:
27// https://docs.nvidia.com/cuda/archive/10.0/pdf/NVVM_IR_Specification.pdf
28
29class TaskCodeGenCUDA : public TaskCodeGenLLVM {
30 public:
31 using IRVisitor::visit;
32
33 explicit TaskCodeGenCUDA(const CompileConfig &config,
34 TaichiLLVMContext &tlctx,
35 Kernel *kernel,
36 IRNode *ir = nullptr)
37 : TaskCodeGenLLVM(config, tlctx, kernel, ir) {
38 }
39
40 llvm::Value *create_print(std::string tag,
41 DataType dt,
42 llvm::Value *value) override {
43 std::string format = data_type_format(dt);
44 if (value->getType() == llvm::Type::getFloatTy(*llvm_context)) {
45 value =
46 builder->CreateFPExt(value, llvm::Type::getDoubleTy(*llvm_context));
47 }
48 return create_print("[cuda codegen debug] " + tag + " " + format + "\n",
49 {value->getType()}, {value});
50 }
51
52 llvm::Value *create_print(const std::string &format,
53 const std::vector<llvm::Type *> &types,
54 const std::vector<llvm::Value *> &values) {
55 auto stype = llvm::StructType::get(*llvm_context, types, false);
56 auto value_arr = builder->CreateAlloca(stype);
57 for (int i = 0; i < values.size(); i++) {
58 auto value_ptr = builder->CreateGEP(
59 stype, value_arr, {tlctx->get_constant(0), tlctx->get_constant(i)});
60 builder->CreateStore(values[i], value_ptr);
61 }
62 return LLVMModuleBuilder::call(
63 builder.get(), "vprintf",
64 builder->CreateGlobalStringPtr(format, "format_string"),
65 builder->CreateBitCast(value_arr,
66 llvm::Type::getInt8PtrTy(*llvm_context)));
67 }
68
69 std::tuple<llvm::Value *, llvm::Type *> create_value_and_type(
70 llvm::Value *value,
71 DataType dt) {
72 auto value_type = tlctx->get_data_type(dt);
73 if (dt->is_primitive(PrimitiveTypeID::f32) ||
74 dt->is_primitive(PrimitiveTypeID::f16)) {
75 value_type = tlctx->get_data_type(PrimitiveType::f64);
76 value = builder->CreateFPExt(value, value_type);
77 }
78 if (dt->is_primitive(PrimitiveTypeID::i8)) {
79 value_type = tlctx->get_data_type(PrimitiveType::i16);
80 value = builder->CreateSExt(value, value_type);
81 }
82 if (dt->is_primitive(PrimitiveTypeID::u8)) {
83 value_type = tlctx->get_data_type(PrimitiveType::u16);
84 value = builder->CreateZExt(value, value_type);
85 }
86 return std::make_tuple(value, value_type);
87 }
88
89 void visit(PrintStmt *stmt) override {
90 TI_ASSERT_INFO(stmt->contents.size() < 32,
91 "CUDA `print()` doesn't support more than 32 entries");
92
93 std::vector<llvm::Type *> types;
94 std::vector<llvm::Value *> values;
95
96 std::string formats;
97 size_t num_contents = 0;
98 for (auto const &content : stmt->contents) {
99 if (std::holds_alternative<Stmt *>(content)) {
100 auto arg_stmt = std::get<Stmt *>(content);
101
102 formats += data_type_format(arg_stmt->ret_type);
103
104 auto value = llvm_val[arg_stmt];
105 auto value_type = value->getType();
106 if (arg_stmt->ret_type->is<TensorType>()) {
107 auto dtype = arg_stmt->ret_type->cast<TensorType>();
108 num_contents += dtype->get_num_elements();
109 auto elem_type = dtype->get_element_type();
110 for (int i = 0; i < dtype->get_num_elements(); ++i) {
111 llvm::Value *elem_value;
112 if (codegen_vector_type(compile_config)) {
113 TI_ASSERT(llvm::dyn_cast<llvm::VectorType>(value_type));
114 elem_value = builder->CreateExtractElement(value, i);
115 } else {
116 TI_ASSERT(llvm::dyn_cast<llvm::ArrayType>(value_type));
117 elem_value = builder->CreateExtractValue(value, i);
118 }
119 auto [casted_value, elem_value_type] =
120 create_value_and_type(elem_value, elem_type);
121 types.push_back(elem_value_type);
122 values.push_back(casted_value);
123 }
124 } else {
125 num_contents++;
126 auto [val, dtype] = create_value_and_type(value, arg_stmt->ret_type);
127 types.push_back(dtype);
128 values.push_back(val);
129 }
130 } else {
131 num_contents += 1;
132 auto arg_str = std::get<std::string>(content);
133
134 auto value = builder->CreateGlobalStringPtr(arg_str, "content_string");
135 auto char_type =
136 llvm::Type::getInt8Ty(*tlctx->get_this_thread_context());
137 auto value_type = llvm::PointerType::get(char_type, 0);
138
139 types.push_back(value_type);
140 values.push_back(value);
141 formats += "%s";
142 }
143 TI_ASSERT_INFO(num_contents < 32,
144 "CUDA `print()` doesn't support more than 32 entries");
145 }
146
147 llvm_val[stmt] = create_print(formats, types, values);
148 }
149
150 void emit_extra_unary(UnaryOpStmt *stmt) override {
151 // functions from libdevice
152 auto input = llvm_val[stmt->operand];
153 auto input_taichi_type = stmt->operand->ret_type;
154 if (input_taichi_type->is_primitive(PrimitiveTypeID::f16)) {
155 // Promote to f32 since we don't have f16 support for extra unary ops in
156 // libdevice.
157 input =
158 builder->CreateFPExt(input, llvm::Type::getFloatTy(*llvm_context));
159 input_taichi_type = PrimitiveType::f32;
160 }
161
162 auto op = stmt->op_type;
163
164#define UNARY_STD(x) \
165 else if (op == UnaryOpType::x) { \
166 if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \
167 llvm_val[stmt] = call("__nv_" #x "f", input); \
168 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \
169 llvm_val[stmt] = call("__nv_" #x, input); \
170 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \
171 llvm_val[stmt] = call(#x, input); \
172 } else { \
173 TI_NOT_IMPLEMENTED \
174 } \
175 }
176 if (op == UnaryOpType::abs) {
177 if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) {
178 llvm_val[stmt] = call("__nv_fabsf", input);
179 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) {
180 llvm_val[stmt] = call("__nv_fabs", input);
181 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) {
182 llvm_val[stmt] = call("__nv_abs", input);
183 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) {
184 llvm_val[stmt] = call("__nv_llabs", input);
185 } else {
186 TI_NOT_IMPLEMENTED
187 }
188 } else if (op == UnaryOpType::sqrt) {
189 if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) {
190 llvm_val[stmt] = call("__nv_sqrtf", input);
191 } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) {
192 llvm_val[stmt] = call("__nv_sqrt", input);
193 } else {
194 TI_NOT_IMPLEMENTED
195 }
196 } else if (op == UnaryOpType::logic_not) {
197 if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) {
198 llvm_val[stmt] = call("logic_not_i32", input);
199 } else {
200 TI_NOT_IMPLEMENTED
201 }
202 }
203 UNARY_STD(exp)
204 UNARY_STD(log)
205 UNARY_STD(tan)
206 UNARY_STD(tanh)
207 UNARY_STD(sgn)
208 UNARY_STD(acos)
209 UNARY_STD(asin)
210 UNARY_STD(cos)
211 UNARY_STD(sin)
212 else {
213 TI_P(unary_op_type_name(op));
214 TI_NOT_IMPLEMENTED
215 }
216#undef UNARY_STD
217 if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
218 // Convert back to f16.
219 llvm_val[stmt] = builder->CreateFPTrunc(
220 llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
221 }
222 }
223
224 // Not all reduction statements can be optimized.
225 // If the operation cannot be optimized, this function returns nullptr.
226 llvm::Value *optimized_reduction(AtomicOpStmt *stmt) override {
227 if (!stmt->is_reduction) {
228 return nullptr;
229 }
230 TI_ASSERT(stmt->val->ret_type->is<PrimitiveType>());
231 PrimitiveTypeID prim_type =
232 stmt->val->ret_type->cast<PrimitiveType>()->type;
233
234 std::unordered_map<PrimitiveTypeID,
235 std::unordered_map<AtomicOpType, std::string>>
236 fast_reductions;
237
238 fast_reductions[PrimitiveTypeID::i32][AtomicOpType::add] = "reduce_add_i32";
239 fast_reductions[PrimitiveTypeID::f32][AtomicOpType::add] = "reduce_add_f32";
240 fast_reductions[PrimitiveTypeID::i32][AtomicOpType::min] = "reduce_min_i32";
241 fast_reductions[PrimitiveTypeID::f32][AtomicOpType::min] = "reduce_min_f32";
242 fast_reductions[PrimitiveTypeID::i32][AtomicOpType::max] = "reduce_max_i32";
243 fast_reductions[PrimitiveTypeID::f32][AtomicOpType::max] = "reduce_max_f32";
244
245 fast_reductions[PrimitiveTypeID::i32][AtomicOpType::bit_and] =
246 "reduce_and_i32";
247 fast_reductions[PrimitiveTypeID::i32][AtomicOpType::bit_or] =
248 "reduce_or_i32";
249 fast_reductions[PrimitiveTypeID::i32][AtomicOpType::bit_xor] =
250 "reduce_xor_i32";
251
252 AtomicOpType op = stmt->op_type;
253 if (fast_reductions.find(prim_type) == fast_reductions.end()) {
254 return nullptr;
255 }
256 TI_ASSERT(fast_reductions.at(prim_type).find(op) !=
257 fast_reductions.at(prim_type).end());
258 return call(fast_reductions.at(prim_type).at(op), llvm_val[stmt->dest],
259 llvm_val[stmt->val]);
260 }
261
262 void visit(RangeForStmt *for_stmt) override {
263 create_naive_range_for(for_stmt);
264 }
265
266 void create_offload_range_for(OffloadedStmt *stmt) override {
267 auto tls_prologue = create_xlogue(stmt->tls_prologue);
268
269 llvm::Function *body;
270 {
271 auto guard = get_function_creation_guard(
272 {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
273 get_tls_buffer_type(), tlctx->get_data_type<int>()});
274
275 auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
276 loop_vars_llvm[stmt].push_back(loop_var);
277 builder->CreateStore(get_arg(2), loop_var);
278 stmt->body->accept(this);
279
280 body = guard.body;
281 }
282
283 auto epilogue = create_xlogue(stmt->tls_epilogue);
284
285 auto [begin, end] = get_range_for_bounds(stmt);
286 call("gpu_parallel_range_for", get_arg(0), begin, end, tls_prologue, body,
287 epilogue, tlctx->get_constant(stmt->tls_size));
288 }
289
290 void create_offload_mesh_for(OffloadedStmt *stmt) override {
291 auto tls_prologue = create_mesh_xlogue(stmt->tls_prologue);
292
293 llvm::Function *body;
294 {
295 auto guard = get_function_creation_guard(
296 {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
297 get_tls_buffer_type(), tlctx->get_data_type<int>()});
298
299 for (int i = 0; i < stmt->mesh_prologue->size(); i++) {
300 auto &s = stmt->mesh_prologue->statements[i];
301 s->accept(this);
302 }
303
304 if (stmt->bls_prologue) {
305 stmt->bls_prologue->accept(this);
306 call("block_barrier"); // "__syncthreads()"
307 }
308
309 auto loop_test_bb =
310 llvm::BasicBlock::Create(*llvm_context, "loop_test", func);
311 auto loop_body_bb =
312 llvm::BasicBlock::Create(*llvm_context, "loop_body", func);
313 auto func_exit =
314 llvm::BasicBlock::Create(*llvm_context, "func_exit", func);
315 auto i32_ty = llvm::Type::getInt32Ty(*llvm_context);
316 auto loop_index = create_entry_block_alloca(i32_ty);
317 llvm::Value *thread_idx =
318 builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {});
319 llvm::Value *block_dim = builder->CreateIntrinsic(
320 Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {});
321 builder->CreateStore(thread_idx, loop_index);
322 builder->CreateBr(loop_test_bb);
323
324 {
325 builder->SetInsertPoint(loop_test_bb);
326 auto cond = builder->CreateICmp(
327 llvm::CmpInst::Predicate::ICMP_SLT,
328 builder->CreateLoad(i32_ty, loop_index),
329 llvm_val[stmt->owned_num_local.find(stmt->major_from_type)
330 ->second]);
331 builder->CreateCondBr(cond, loop_body_bb, func_exit);
332 }
333
334 {
335 builder->SetInsertPoint(loop_body_bb);
336 loop_vars_llvm[stmt].push_back(loop_index);
337 for (int i = 0; i < stmt->body->size(); i++) {
338 auto &s = stmt->body->statements[i];
339 s->accept(this);
340 }
341 builder->CreateStore(
342 builder->CreateAdd(builder->CreateLoad(i32_ty, loop_index),
343 block_dim),
344 loop_index);
345 builder->CreateBr(loop_test_bb);
346 builder->SetInsertPoint(func_exit);
347 }
348
349 if (stmt->bls_epilogue) {
350 call("block_barrier"); // "__syncthreads()"
351 stmt->bls_epilogue->accept(this);
352 }
353
354 body = guard.body;
355 }
356
357 auto tls_epilogue = create_mesh_xlogue(stmt->tls_epilogue);
358
359 call("gpu_parallel_mesh_for", get_arg(0),
360 tlctx->get_constant(stmt->mesh->num_patches), tls_prologue, body,
361 tls_epilogue, tlctx->get_constant(stmt->tls_size));
362 }
363
364 void emit_cuda_gc(OffloadedStmt *stmt) {
365 auto snode_id = tlctx->get_constant(stmt->snode->id);
366 {
367 init_offloaded_task_function(stmt, "gather_list");
368 call("gc_parallel_0", get_context(), snode_id);
369 finalize_offloaded_task_function();
370 current_task->grid_dim = compile_config.saturating_grid_dim;
371 current_task->block_dim = 64;
372 offloaded_tasks.push_back(*current_task);
373 current_task = nullptr;
374 }
375 {
376 init_offloaded_task_function(stmt, "reinit_lists");
377 call("gc_parallel_1", get_context(), snode_id);
378 finalize_offloaded_task_function();
379 current_task->grid_dim = 1;
380 current_task->block_dim = 1;
381 offloaded_tasks.push_back(*current_task);
382 current_task = nullptr;
383 }
384 {
385 init_offloaded_task_function(stmt, "zero_fill");
386 call("gc_parallel_2", get_context(), snode_id);
387 finalize_offloaded_task_function();
388 current_task->grid_dim = compile_config.saturating_grid_dim;
389 current_task->block_dim = 64;
390 offloaded_tasks.push_back(*current_task);
391 current_task = nullptr;
392 }
393 }
394
395 void emit_cuda_gc_rc(OffloadedStmt *stmt) {
396 {
397 init_offloaded_task_function(stmt, "gather_list");
398 call("gc_rc_parallel_0", get_context());
399 finalize_offloaded_task_function();
400 current_task->grid_dim = compile_config.saturating_grid_dim;
401 current_task->block_dim = 64;
402 offloaded_tasks.push_back(*current_task);
403 current_task = nullptr;
404 }
405 {
406 init_offloaded_task_function(stmt, "reinit_lists");
407 call("gc_rc_parallel_1", get_context());
408 finalize_offloaded_task_function();
409 current_task->grid_dim = 1;
410 current_task->block_dim = 1;
411 offloaded_tasks.push_back(*current_task);
412 current_task = nullptr;
413 }
414 {
415 init_offloaded_task_function(stmt, "zero_fill");
416 call("gc_rc_parallel_2", get_context());
417 finalize_offloaded_task_function();
418 current_task->grid_dim = compile_config.saturating_grid_dim;
419 current_task->block_dim = 64;
420 offloaded_tasks.push_back(*current_task);
421 current_task = nullptr;
422 }
423 }
424
425 bool kernel_argument_by_val() const override {
426 return true; // on CUDA, pass the argument by value
427 }
428
429 llvm::Value *create_intrinsic_load(llvm::Value *ptr,
430 llvm::Type *ty) override {
431 // Issue an "__ldg" instruction to cache data in the read-only data cache.
432 auto intrin = ty->isFloatingPointTy() ? llvm::Intrinsic::nvvm_ldg_global_f
433 : llvm::Intrinsic::nvvm_ldg_global_i;
434 return builder->CreateIntrinsic(
435 intrin, {ty, llvm::PointerType::get(ty, 0)},
436 {ptr, tlctx->get_constant(ty->getScalarSizeInBits())});
437 }
438
439 void visit(GlobalLoadStmt *stmt) override {
440 if (auto get_ch = stmt->src->cast<GetChStmt>()) {
441 bool should_cache_as_read_only = current_offload->mem_access_opt.has_flag(
442 get_ch->output_snode, SNodeAccessFlag::read_only);
443 create_global_load(stmt, should_cache_as_read_only);
444 } else {
445 create_global_load(stmt, false);
446 }
447 }
448
449 void create_bls_buffer(OffloadedStmt *stmt) {
450 auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context),
451 stmt->bls_size);
452 bls_buffer = new GlobalVariable(
453 *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr,
454 "bls_buffer", nullptr, llvm::GlobalVariable::NotThreadLocal,
455 3 /*addrspace=shared*/);
456 bls_buffer->setAlignment(llvm::MaybeAlign(8));
457 }
458
459 void visit(OffloadedStmt *stmt) override {
460 if (stmt->bls_size > 0)
461 create_bls_buffer(stmt);
462#if defined(TI_WITH_CUDA)
463 TI_ASSERT(current_offload == nullptr);
464 current_offload = stmt;
465 using Type = OffloadedStmt::TaskType;
466 if (stmt->task_type == Type::gc) {
467 // gc has 3 kernels, so we treat it specially
468 emit_cuda_gc(stmt);
469 } else if (stmt->task_type == Type::gc_rc) {
470 emit_cuda_gc_rc(stmt);
471 } else {
472 init_offloaded_task_function(stmt);
473 if (stmt->task_type == Type::serial) {
474 stmt->body->accept(this);
475 } else if (stmt->task_type == Type::range_for) {
476 create_offload_range_for(stmt);
477 } else if (stmt->task_type == Type::struct_for) {
478 create_offload_struct_for(stmt);
479 } else if (stmt->task_type == Type::mesh_for) {
480 create_offload_mesh_for(stmt);
481 } else if (stmt->task_type == Type::listgen) {
482 emit_list_gen(stmt);
483 } else {
484 TI_NOT_IMPLEMENTED
485 }
486 finalize_offloaded_task_function();
487 current_task->grid_dim = stmt->grid_dim;
488 if (stmt->task_type == Type::range_for) {
489 if (stmt->const_begin && stmt->const_end) {
490 int num_threads = stmt->end_value - stmt->begin_value;
491 int grid_dim = ((num_threads % stmt->block_dim) == 0)
492 ? (num_threads / stmt->block_dim)
493 : (num_threads / stmt->block_dim) + 1;
494 grid_dim = std::max(grid_dim, 1);
495 current_task->grid_dim = std::min(stmt->grid_dim, grid_dim);
496 }
497 }
498 if (stmt->task_type == Type::listgen) {
499 int query_max_block_per_sm;
500 CUDADriver::get_instance().device_get_attribute(
501 &query_max_block_per_sm,
502 CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR, nullptr);
503 int num_SMs;
504 CUDADriver::get_instance().device_get_attribute(
505 &num_SMs, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, nullptr);
506 current_task->grid_dim = num_SMs * query_max_block_per_sm;
507 }
508 current_task->block_dim = stmt->block_dim;
509 TI_ASSERT(current_task->grid_dim != 0);
510 TI_ASSERT(current_task->block_dim != 0);
511 offloaded_tasks.push_back(*current_task);
512 current_task = nullptr;
513 }
514 current_offload = nullptr;
515#else
516 TI_NOT_IMPLEMENTED
517#endif
518 }
519
520 void visit(ExternalFuncCallStmt *stmt) override {
521 if (stmt->type == ExternalFuncCallStmt::BITCODE) {
522 TaskCodeGenLLVM::visit_call_bitcode(stmt);
523 } else {
524 TI_NOT_IMPLEMENTED
525 }
526 }
527
528 void visit(ExternalTensorShapeAlongAxisStmt *stmt) override {
529 const auto arg_id = stmt->arg_id;
530 const auto axis = stmt->axis;
531 llvm_val[stmt] =
532 call("RuntimeContext_get_extra_args", get_context(),
533 tlctx->get_constant(arg_id), tlctx->get_constant(axis));
534 }
535
536 void visit(BinaryOpStmt *stmt) override {
537 auto op = stmt->op_type;
538 if (op != BinaryOpType::atan2 && op != BinaryOpType::pow) {
539 return TaskCodeGenLLVM::visit(stmt);
540 }
541
542 auto ret_type = stmt->ret_type;
543
544 llvm::Value *lhs = llvm_val[stmt->lhs];
545 llvm::Value *rhs = llvm_val[stmt->rhs];
546
547 // This branch contains atan2 and pow which use runtime.cpp function for
548 // **real** type. We don't have f16 support there so promoting to f32 is
549 // necessary.
550 if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::f16)) {
551 lhs = builder->CreateFPExt(lhs, llvm::Type::getFloatTy(*llvm_context));
552 }
553 if (stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::f16)) {
554 rhs = builder->CreateFPExt(rhs, llvm::Type::getFloatTy(*llvm_context));
555 }
556 if (ret_type->is_primitive(PrimitiveTypeID::f16)) {
557 ret_type = PrimitiveType::f32;
558 }
559
560 if (op == BinaryOpType::atan2) {
561 if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
562 llvm_val[stmt] = call("__nv_atan2f", lhs, rhs);
563 } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
564 llvm_val[stmt] = call("__nv_atan2", lhs, rhs);
565 } else {
566 TI_P(data_type_name(ret_type));
567 TI_NOT_IMPLEMENTED
568 }
569 } else {
570 // Note that ret_type here cannot be integral because pow with an
571 // integral exponent has been demoted in the demote_operations pass
572 if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
573 llvm_val[stmt] = call("__nv_powf", lhs, rhs);
574 } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
575 llvm_val[stmt] = call("__nv_pow", lhs, rhs);
576 } else {
577 TI_P(data_type_name(ret_type));
578 TI_NOT_IMPLEMENTED
579 }
580 }
581
582 // Convert back to f16 if applicable.
583 if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
584 llvm_val[stmt] = builder->CreateFPTrunc(
585 llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
586 }
587 }
588
589 private:
590 std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
591 auto thread_idx =
592 builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {});
593 auto block_dim =
594 builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {});
595 return std::make_tuple(thread_idx, block_dim);
596 }
597};
598
599LLVMCompiledTask KernelCodeGenCUDA::compile_task(
600 const CompileConfig &config,
601 std::unique_ptr<llvm::Module> &&module,
602 OffloadedStmt *stmt) {
603 TaskCodeGenCUDA gen(config, get_taichi_llvm_context(), kernel, stmt);
604 return gen.run_compilation();
605}
606
607FunctionType KernelCodeGenCUDA::compile_to_function() {
608 TI_AUTO_PROF
609 CUDAModuleToFunctionConverter converter{
610 &get_taichi_llvm_context(),
611 get_llvm_program(prog)->get_runtime_executor()};
612 return converter.convert(this->kernel, compile_kernel_to_module());
613}
614
615FunctionType CUDAModuleToFunctionConverter::convert(
616 const std::string &kernel_name,
617 const std::vector<LlvmLaunchArgInfo> &args,
618 LLVMCompiledKernel data) const {
619 auto &mod = data.module;
620 auto &tasks = data.tasks;
621#ifdef TI_WITH_CUDA
622 auto jit = tlctx_->jit.get();
623 auto cuda_module =
624 jit->add_module(std::move(mod), executor_->get_config().gpu_max_reg);
625
626 return [cuda_module, kernel_name, args, offloaded_tasks = tasks,
627 executor = this->executor_](RuntimeContext &context) {
628 CUDAContext::get_instance().make_current();
629 std::vector<void *> arg_buffers(args.size(), nullptr);
630 std::vector<void *> device_buffers(args.size(), nullptr);
631 std::vector<DeviceAllocation> temporary_devallocs(args.size());
632
633 bool transferred = false;
634 for (int i = 0; i < (int)args.size(); i++) {
635 if (args[i].is_array) {
636 const auto arr_sz = context.array_runtime_sizes[i];
637 if (arr_sz == 0) {
638 continue;
639 }
640 arg_buffers[i] = context.get_arg<void *>(i);
641 if (context.device_allocation_type[i] ==
642 RuntimeContext::DevAllocType::kNone) {
643 // Note: both numpy and PyTorch support arrays/tensors with zeros
644 // in shapes, e.g., shape=(0) or shape=(100, 0, 200). This makes
645 // `arr_sz` zero.
646 unsigned int attr_val = 0;
647 uint32_t ret_code = CUDADriver::get_instance().mem_get_attribute.call(
648 &attr_val, CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
649 (void *)arg_buffers[i]);
650
651 if (ret_code != CUDA_SUCCESS || attr_val != CU_MEMORYTYPE_DEVICE) {
652 // Copy to device buffer if arg is on host
653 // - ret_code != CUDA_SUCCESS:
654 // arg_buffers[i] is not on device
655 // - attr_val != CU_MEMORYTYPE_DEVICE:
656 // Cuda driver is aware of arg_buffers[i] but it might be on
657 // host.
658 // See CUDA driver API `cuPointerGetAttribute` for more details.
659 transferred = true;
660
661 auto result_buffer = context.result_buffer;
662 DeviceAllocation devalloc =
663 executor->allocate_memory_ndarray(arr_sz, result_buffer);
664 device_buffers[i] = executor->get_ndarray_alloc_info_ptr(devalloc);
665 temporary_devallocs[i] = devalloc;
666
667 CUDADriver::get_instance().memcpy_host_to_device(
668 (void *)device_buffers[i], arg_buffers[i], arr_sz);
669 } else {
670 device_buffers[i] = arg_buffers[i];
671 }
672 // device_buffers[i] saves a raw ptr on CUDA device.
673 context.set_arg(i, (uint64)device_buffers[i]);
674
675 } else if (arr_sz > 0) {
676 // arg_buffers[i] is a DeviceAllocation*
677 // TODO: Unwraps DeviceAllocation* can be done at TaskCodeGenLLVM
678 // since it's shared by cpu and cuda.
679 DeviceAllocation *ptr =
680 static_cast<DeviceAllocation *>(arg_buffers[i]);
681 device_buffers[i] = executor->get_ndarray_alloc_info_ptr(*ptr);
682 // We compare arg_buffers[i] and device_buffers[i] later to check
683 // if transfer happened.
684 // TODO: this logic can be improved but I'll leave it to a followup
685 // PR.
686 arg_buffers[i] = device_buffers[i];
687
688 // device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i]
689 context.set_arg(i, (uint64)device_buffers[i]);
690 }
691 }
692 }
693 if (transferred) {
694 CUDADriver::get_instance().stream_synchronize(nullptr);
695 }
696 CUDADriver::get_instance().context_set_limit(
697 CU_LIMIT_STACK_SIZE, executor->get_config().cuda_stack_limit);
698
699 for (auto task : offloaded_tasks) {
700 TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
701 task.block_dim);
702 cuda_module->launch(task.name, task.grid_dim, task.block_dim, 0,
703 {&context}, {});
704 }
705
706 // copy data back to host
707 if (transferred) {
708 CUDADriver::get_instance().stream_synchronize(nullptr);
709 for (int i = 0; i < (int)args.size(); i++) {
710 if (device_buffers[i] != arg_buffers[i]) {
711 CUDADriver::get_instance().memcpy_device_to_host(
712 arg_buffers[i], (void *)device_buffers[i],
713 context.array_runtime_sizes[i]);
714 executor->deallocate_memory_ndarray(temporary_devallocs[i]);
715 }
716 }
717 }
718 };
719#else
720 TI_ERROR("No CUDA");
721 return nullptr;
722#endif // TI_WITH_CUDA
723}
724
725} // namespace taichi::lang
726