1 | #include "taichi/codegen/cuda/codegen_cuda.h" |
2 | |
3 | #include <vector> |
4 | #include <set> |
5 | #include <functional> |
6 | |
7 | #include "taichi/common/core.h" |
8 | #include "taichi/util/io.h" |
9 | #include "taichi/ir/ir.h" |
10 | #include "taichi/ir/statements.h" |
11 | #include "taichi/program/program.h" |
12 | #include "taichi/util/lang_util.h" |
13 | #include "taichi/rhi/cuda/cuda_driver.h" |
14 | #include "taichi/rhi/cuda/cuda_context.h" |
15 | #include "taichi/runtime/program_impls/llvm/llvm_program.h" |
16 | #include "taichi/util/action_recorder.h" |
17 | #include "taichi/analysis/offline_cache_util.h" |
18 | #include "taichi/ir/analysis.h" |
19 | #include "taichi/ir/transforms.h" |
20 | #include "taichi/codegen/codegen_utils.h" |
21 | |
22 | namespace taichi::lang { |
23 | |
24 | using namespace llvm; |
25 | |
26 | // NVVM IR Spec: |
27 | // https://docs.nvidia.com/cuda/archive/10.0/pdf/NVVM_IR_Specification.pdf |
28 | |
29 | class TaskCodeGenCUDA : public TaskCodeGenLLVM { |
30 | public: |
31 | using IRVisitor::visit; |
32 | |
33 | explicit TaskCodeGenCUDA(const CompileConfig &config, |
34 | TaichiLLVMContext &tlctx, |
35 | Kernel *kernel, |
36 | IRNode *ir = nullptr) |
37 | : TaskCodeGenLLVM(config, tlctx, kernel, ir) { |
38 | } |
39 | |
40 | llvm::Value *create_print(std::string tag, |
41 | DataType dt, |
42 | llvm::Value *value) override { |
43 | std::string format = data_type_format(dt); |
44 | if (value->getType() == llvm::Type::getFloatTy(*llvm_context)) { |
45 | value = |
46 | builder->CreateFPExt(value, llvm::Type::getDoubleTy(*llvm_context)); |
47 | } |
48 | return create_print("[cuda codegen debug] " + tag + " " + format + "\n" , |
49 | {value->getType()}, {value}); |
50 | } |
51 | |
52 | llvm::Value *create_print(const std::string &format, |
53 | const std::vector<llvm::Type *> &types, |
54 | const std::vector<llvm::Value *> &values) { |
55 | auto stype = llvm::StructType::get(*llvm_context, types, false); |
56 | auto value_arr = builder->CreateAlloca(stype); |
57 | for (int i = 0; i < values.size(); i++) { |
58 | auto value_ptr = builder->CreateGEP( |
59 | stype, value_arr, {tlctx->get_constant(0), tlctx->get_constant(i)}); |
60 | builder->CreateStore(values[i], value_ptr); |
61 | } |
62 | return LLVMModuleBuilder::call( |
63 | builder.get(), "vprintf" , |
64 | builder->CreateGlobalStringPtr(format, "format_string" ), |
65 | builder->CreateBitCast(value_arr, |
66 | llvm::Type::getInt8PtrTy(*llvm_context))); |
67 | } |
68 | |
69 | std::tuple<llvm::Value *, llvm::Type *> create_value_and_type( |
70 | llvm::Value *value, |
71 | DataType dt) { |
72 | auto value_type = tlctx->get_data_type(dt); |
73 | if (dt->is_primitive(PrimitiveTypeID::f32) || |
74 | dt->is_primitive(PrimitiveTypeID::f16)) { |
75 | value_type = tlctx->get_data_type(PrimitiveType::f64); |
76 | value = builder->CreateFPExt(value, value_type); |
77 | } |
78 | if (dt->is_primitive(PrimitiveTypeID::i8)) { |
79 | value_type = tlctx->get_data_type(PrimitiveType::i16); |
80 | value = builder->CreateSExt(value, value_type); |
81 | } |
82 | if (dt->is_primitive(PrimitiveTypeID::u8)) { |
83 | value_type = tlctx->get_data_type(PrimitiveType::u16); |
84 | value = builder->CreateZExt(value, value_type); |
85 | } |
86 | return std::make_tuple(value, value_type); |
87 | } |
88 | |
89 | void visit(PrintStmt *stmt) override { |
90 | TI_ASSERT_INFO(stmt->contents.size() < 32, |
91 | "CUDA `print()` doesn't support more than 32 entries" ); |
92 | |
93 | std::vector<llvm::Type *> types; |
94 | std::vector<llvm::Value *> values; |
95 | |
96 | std::string formats; |
97 | size_t num_contents = 0; |
98 | for (auto const &content : stmt->contents) { |
99 | if (std::holds_alternative<Stmt *>(content)) { |
100 | auto arg_stmt = std::get<Stmt *>(content); |
101 | |
102 | formats += data_type_format(arg_stmt->ret_type); |
103 | |
104 | auto value = llvm_val[arg_stmt]; |
105 | auto value_type = value->getType(); |
106 | if (arg_stmt->ret_type->is<TensorType>()) { |
107 | auto dtype = arg_stmt->ret_type->cast<TensorType>(); |
108 | num_contents += dtype->get_num_elements(); |
109 | auto elem_type = dtype->get_element_type(); |
110 | for (int i = 0; i < dtype->get_num_elements(); ++i) { |
111 | llvm::Value *elem_value; |
112 | if (codegen_vector_type(compile_config)) { |
113 | TI_ASSERT(llvm::dyn_cast<llvm::VectorType>(value_type)); |
114 | elem_value = builder->CreateExtractElement(value, i); |
115 | } else { |
116 | TI_ASSERT(llvm::dyn_cast<llvm::ArrayType>(value_type)); |
117 | elem_value = builder->CreateExtractValue(value, i); |
118 | } |
119 | auto [casted_value, elem_value_type] = |
120 | create_value_and_type(elem_value, elem_type); |
121 | types.push_back(elem_value_type); |
122 | values.push_back(casted_value); |
123 | } |
124 | } else { |
125 | num_contents++; |
126 | auto [val, dtype] = create_value_and_type(value, arg_stmt->ret_type); |
127 | types.push_back(dtype); |
128 | values.push_back(val); |
129 | } |
130 | } else { |
131 | num_contents += 1; |
132 | auto arg_str = std::get<std::string>(content); |
133 | |
134 | auto value = builder->CreateGlobalStringPtr(arg_str, "content_string" ); |
135 | auto char_type = |
136 | llvm::Type::getInt8Ty(*tlctx->get_this_thread_context()); |
137 | auto value_type = llvm::PointerType::get(char_type, 0); |
138 | |
139 | types.push_back(value_type); |
140 | values.push_back(value); |
141 | formats += "%s" ; |
142 | } |
143 | TI_ASSERT_INFO(num_contents < 32, |
144 | "CUDA `print()` doesn't support more than 32 entries" ); |
145 | } |
146 | |
147 | llvm_val[stmt] = create_print(formats, types, values); |
148 | } |
149 | |
150 | void (UnaryOpStmt *stmt) override { |
151 | // functions from libdevice |
152 | auto input = llvm_val[stmt->operand]; |
153 | auto input_taichi_type = stmt->operand->ret_type; |
154 | if (input_taichi_type->is_primitive(PrimitiveTypeID::f16)) { |
155 | // Promote to f32 since we don't have f16 support for extra unary ops in |
156 | // libdevice. |
157 | input = |
158 | builder->CreateFPExt(input, llvm::Type::getFloatTy(*llvm_context)); |
159 | input_taichi_type = PrimitiveType::f32; |
160 | } |
161 | |
162 | auto op = stmt->op_type; |
163 | |
164 | #define UNARY_STD(x) \ |
165 | else if (op == UnaryOpType::x) { \ |
166 | if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \ |
167 | llvm_val[stmt] = call("__nv_" #x "f", input); \ |
168 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \ |
169 | llvm_val[stmt] = call("__nv_" #x, input); \ |
170 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \ |
171 | llvm_val[stmt] = call(#x, input); \ |
172 | } else { \ |
173 | TI_NOT_IMPLEMENTED \ |
174 | } \ |
175 | } |
176 | if (op == UnaryOpType::abs) { |
177 | if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { |
178 | llvm_val[stmt] = call("__nv_fabsf" , input); |
179 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { |
180 | llvm_val[stmt] = call("__nv_fabs" , input); |
181 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { |
182 | llvm_val[stmt] = call("__nv_abs" , input); |
183 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) { |
184 | llvm_val[stmt] = call("__nv_llabs" , input); |
185 | } else { |
186 | TI_NOT_IMPLEMENTED |
187 | } |
188 | } else if (op == UnaryOpType::sqrt) { |
189 | if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { |
190 | llvm_val[stmt] = call("__nv_sqrtf" , input); |
191 | } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { |
192 | llvm_val[stmt] = call("__nv_sqrt" , input); |
193 | } else { |
194 | TI_NOT_IMPLEMENTED |
195 | } |
196 | } else if (op == UnaryOpType::logic_not) { |
197 | if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { |
198 | llvm_val[stmt] = call("logic_not_i32" , input); |
199 | } else { |
200 | TI_NOT_IMPLEMENTED |
201 | } |
202 | } |
203 | UNARY_STD(exp) |
204 | UNARY_STD(log) |
205 | UNARY_STD(tan) |
206 | UNARY_STD(tanh) |
207 | UNARY_STD(sgn) |
208 | UNARY_STD(acos) |
209 | UNARY_STD(asin) |
210 | UNARY_STD(cos) |
211 | UNARY_STD(sin) |
212 | else { |
213 | TI_P(unary_op_type_name(op)); |
214 | TI_NOT_IMPLEMENTED |
215 | } |
216 | #undef UNARY_STD |
217 | if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
218 | // Convert back to f16. |
219 | llvm_val[stmt] = builder->CreateFPTrunc( |
220 | llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); |
221 | } |
222 | } |
223 | |
224 | // Not all reduction statements can be optimized. |
225 | // If the operation cannot be optimized, this function returns nullptr. |
226 | llvm::Value *optimized_reduction(AtomicOpStmt *stmt) override { |
227 | if (!stmt->is_reduction) { |
228 | return nullptr; |
229 | } |
230 | TI_ASSERT(stmt->val->ret_type->is<PrimitiveType>()); |
231 | PrimitiveTypeID prim_type = |
232 | stmt->val->ret_type->cast<PrimitiveType>()->type; |
233 | |
234 | std::unordered_map<PrimitiveTypeID, |
235 | std::unordered_map<AtomicOpType, std::string>> |
236 | fast_reductions; |
237 | |
238 | fast_reductions[PrimitiveTypeID::i32][AtomicOpType::add] = "reduce_add_i32" ; |
239 | fast_reductions[PrimitiveTypeID::f32][AtomicOpType::add] = "reduce_add_f32" ; |
240 | fast_reductions[PrimitiveTypeID::i32][AtomicOpType::min] = "reduce_min_i32" ; |
241 | fast_reductions[PrimitiveTypeID::f32][AtomicOpType::min] = "reduce_min_f32" ; |
242 | fast_reductions[PrimitiveTypeID::i32][AtomicOpType::max] = "reduce_max_i32" ; |
243 | fast_reductions[PrimitiveTypeID::f32][AtomicOpType::max] = "reduce_max_f32" ; |
244 | |
245 | fast_reductions[PrimitiveTypeID::i32][AtomicOpType::bit_and] = |
246 | "reduce_and_i32" ; |
247 | fast_reductions[PrimitiveTypeID::i32][AtomicOpType::bit_or] = |
248 | "reduce_or_i32" ; |
249 | fast_reductions[PrimitiveTypeID::i32][AtomicOpType::bit_xor] = |
250 | "reduce_xor_i32" ; |
251 | |
252 | AtomicOpType op = stmt->op_type; |
253 | if (fast_reductions.find(prim_type) == fast_reductions.end()) { |
254 | return nullptr; |
255 | } |
256 | TI_ASSERT(fast_reductions.at(prim_type).find(op) != |
257 | fast_reductions.at(prim_type).end()); |
258 | return call(fast_reductions.at(prim_type).at(op), llvm_val[stmt->dest], |
259 | llvm_val[stmt->val]); |
260 | } |
261 | |
262 | void visit(RangeForStmt *for_stmt) override { |
263 | create_naive_range_for(for_stmt); |
264 | } |
265 | |
266 | void create_offload_range_for(OffloadedStmt *stmt) override { |
267 | auto tls_prologue = create_xlogue(stmt->tls_prologue); |
268 | |
269 | llvm::Function *body; |
270 | { |
271 | auto guard = get_function_creation_guard( |
272 | {llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0), |
273 | get_tls_buffer_type(), tlctx->get_data_type<int>()}); |
274 | |
275 | auto loop_var = create_entry_block_alloca(PrimitiveType::i32); |
276 | loop_vars_llvm[stmt].push_back(loop_var); |
277 | builder->CreateStore(get_arg(2), loop_var); |
278 | stmt->body->accept(this); |
279 | |
280 | body = guard.body; |
281 | } |
282 | |
283 | auto epilogue = create_xlogue(stmt->tls_epilogue); |
284 | |
285 | auto [begin, end] = get_range_for_bounds(stmt); |
286 | call("gpu_parallel_range_for" , get_arg(0), begin, end, tls_prologue, body, |
287 | epilogue, tlctx->get_constant(stmt->tls_size)); |
288 | } |
289 | |
290 | void create_offload_mesh_for(OffloadedStmt *stmt) override { |
291 | auto tls_prologue = create_mesh_xlogue(stmt->tls_prologue); |
292 | |
293 | llvm::Function *body; |
294 | { |
295 | auto guard = get_function_creation_guard( |
296 | {llvm::PointerType::get(get_runtime_type("RuntimeContext" ), 0), |
297 | get_tls_buffer_type(), tlctx->get_data_type<int>()}); |
298 | |
299 | for (int i = 0; i < stmt->mesh_prologue->size(); i++) { |
300 | auto &s = stmt->mesh_prologue->statements[i]; |
301 | s->accept(this); |
302 | } |
303 | |
304 | if (stmt->bls_prologue) { |
305 | stmt->bls_prologue->accept(this); |
306 | call("block_barrier" ); // "__syncthreads()" |
307 | } |
308 | |
309 | auto loop_test_bb = |
310 | llvm::BasicBlock::Create(*llvm_context, "loop_test" , func); |
311 | auto loop_body_bb = |
312 | llvm::BasicBlock::Create(*llvm_context, "loop_body" , func); |
313 | auto func_exit = |
314 | llvm::BasicBlock::Create(*llvm_context, "func_exit" , func); |
315 | auto i32_ty = llvm::Type::getInt32Ty(*llvm_context); |
316 | auto loop_index = create_entry_block_alloca(i32_ty); |
317 | llvm::Value *thread_idx = |
318 | builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}); |
319 | llvm::Value *block_dim = builder->CreateIntrinsic( |
320 | Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}); |
321 | builder->CreateStore(thread_idx, loop_index); |
322 | builder->CreateBr(loop_test_bb); |
323 | |
324 | { |
325 | builder->SetInsertPoint(loop_test_bb); |
326 | auto cond = builder->CreateICmp( |
327 | llvm::CmpInst::Predicate::ICMP_SLT, |
328 | builder->CreateLoad(i32_ty, loop_index), |
329 | llvm_val[stmt->owned_num_local.find(stmt->major_from_type) |
330 | ->second]); |
331 | builder->CreateCondBr(cond, loop_body_bb, func_exit); |
332 | } |
333 | |
334 | { |
335 | builder->SetInsertPoint(loop_body_bb); |
336 | loop_vars_llvm[stmt].push_back(loop_index); |
337 | for (int i = 0; i < stmt->body->size(); i++) { |
338 | auto &s = stmt->body->statements[i]; |
339 | s->accept(this); |
340 | } |
341 | builder->CreateStore( |
342 | builder->CreateAdd(builder->CreateLoad(i32_ty, loop_index), |
343 | block_dim), |
344 | loop_index); |
345 | builder->CreateBr(loop_test_bb); |
346 | builder->SetInsertPoint(func_exit); |
347 | } |
348 | |
349 | if (stmt->bls_epilogue) { |
350 | call("block_barrier" ); // "__syncthreads()" |
351 | stmt->bls_epilogue->accept(this); |
352 | } |
353 | |
354 | body = guard.body; |
355 | } |
356 | |
357 | auto tls_epilogue = create_mesh_xlogue(stmt->tls_epilogue); |
358 | |
359 | call("gpu_parallel_mesh_for" , get_arg(0), |
360 | tlctx->get_constant(stmt->mesh->num_patches), tls_prologue, body, |
361 | tls_epilogue, tlctx->get_constant(stmt->tls_size)); |
362 | } |
363 | |
364 | void emit_cuda_gc(OffloadedStmt *stmt) { |
365 | auto snode_id = tlctx->get_constant(stmt->snode->id); |
366 | { |
367 | init_offloaded_task_function(stmt, "gather_list" ); |
368 | call("gc_parallel_0" , get_context(), snode_id); |
369 | finalize_offloaded_task_function(); |
370 | current_task->grid_dim = compile_config.saturating_grid_dim; |
371 | current_task->block_dim = 64; |
372 | offloaded_tasks.push_back(*current_task); |
373 | current_task = nullptr; |
374 | } |
375 | { |
376 | init_offloaded_task_function(stmt, "reinit_lists" ); |
377 | call("gc_parallel_1" , get_context(), snode_id); |
378 | finalize_offloaded_task_function(); |
379 | current_task->grid_dim = 1; |
380 | current_task->block_dim = 1; |
381 | offloaded_tasks.push_back(*current_task); |
382 | current_task = nullptr; |
383 | } |
384 | { |
385 | init_offloaded_task_function(stmt, "zero_fill" ); |
386 | call("gc_parallel_2" , get_context(), snode_id); |
387 | finalize_offloaded_task_function(); |
388 | current_task->grid_dim = compile_config.saturating_grid_dim; |
389 | current_task->block_dim = 64; |
390 | offloaded_tasks.push_back(*current_task); |
391 | current_task = nullptr; |
392 | } |
393 | } |
394 | |
395 | void emit_cuda_gc_rc(OffloadedStmt *stmt) { |
396 | { |
397 | init_offloaded_task_function(stmt, "gather_list" ); |
398 | call("gc_rc_parallel_0" , get_context()); |
399 | finalize_offloaded_task_function(); |
400 | current_task->grid_dim = compile_config.saturating_grid_dim; |
401 | current_task->block_dim = 64; |
402 | offloaded_tasks.push_back(*current_task); |
403 | current_task = nullptr; |
404 | } |
405 | { |
406 | init_offloaded_task_function(stmt, "reinit_lists" ); |
407 | call("gc_rc_parallel_1" , get_context()); |
408 | finalize_offloaded_task_function(); |
409 | current_task->grid_dim = 1; |
410 | current_task->block_dim = 1; |
411 | offloaded_tasks.push_back(*current_task); |
412 | current_task = nullptr; |
413 | } |
414 | { |
415 | init_offloaded_task_function(stmt, "zero_fill" ); |
416 | call("gc_rc_parallel_2" , get_context()); |
417 | finalize_offloaded_task_function(); |
418 | current_task->grid_dim = compile_config.saturating_grid_dim; |
419 | current_task->block_dim = 64; |
420 | offloaded_tasks.push_back(*current_task); |
421 | current_task = nullptr; |
422 | } |
423 | } |
424 | |
425 | bool kernel_argument_by_val() const override { |
426 | return true; // on CUDA, pass the argument by value |
427 | } |
428 | |
429 | llvm::Value *create_intrinsic_load(llvm::Value *ptr, |
430 | llvm::Type *ty) override { |
431 | // Issue an "__ldg" instruction to cache data in the read-only data cache. |
432 | auto intrin = ty->isFloatingPointTy() ? llvm::Intrinsic::nvvm_ldg_global_f |
433 | : llvm::Intrinsic::nvvm_ldg_global_i; |
434 | return builder->CreateIntrinsic( |
435 | intrin, {ty, llvm::PointerType::get(ty, 0)}, |
436 | {ptr, tlctx->get_constant(ty->getScalarSizeInBits())}); |
437 | } |
438 | |
439 | void visit(GlobalLoadStmt *stmt) override { |
440 | if (auto get_ch = stmt->src->cast<GetChStmt>()) { |
441 | bool should_cache_as_read_only = current_offload->mem_access_opt.has_flag( |
442 | get_ch->output_snode, SNodeAccessFlag::read_only); |
443 | create_global_load(stmt, should_cache_as_read_only); |
444 | } else { |
445 | create_global_load(stmt, false); |
446 | } |
447 | } |
448 | |
449 | void create_bls_buffer(OffloadedStmt *stmt) { |
450 | auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context), |
451 | stmt->bls_size); |
452 | bls_buffer = new GlobalVariable( |
453 | *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, |
454 | "bls_buffer" , nullptr, llvm::GlobalVariable::NotThreadLocal, |
455 | 3 /*addrspace=shared*/); |
456 | bls_buffer->setAlignment(llvm::MaybeAlign(8)); |
457 | } |
458 | |
459 | void visit(OffloadedStmt *stmt) override { |
460 | if (stmt->bls_size > 0) |
461 | create_bls_buffer(stmt); |
462 | #if defined(TI_WITH_CUDA) |
463 | TI_ASSERT(current_offload == nullptr); |
464 | current_offload = stmt; |
465 | using Type = OffloadedStmt::TaskType; |
466 | if (stmt->task_type == Type::gc) { |
467 | // gc has 3 kernels, so we treat it specially |
468 | emit_cuda_gc(stmt); |
469 | } else if (stmt->task_type == Type::gc_rc) { |
470 | emit_cuda_gc_rc(stmt); |
471 | } else { |
472 | init_offloaded_task_function(stmt); |
473 | if (stmt->task_type == Type::serial) { |
474 | stmt->body->accept(this); |
475 | } else if (stmt->task_type == Type::range_for) { |
476 | create_offload_range_for(stmt); |
477 | } else if (stmt->task_type == Type::struct_for) { |
478 | create_offload_struct_for(stmt); |
479 | } else if (stmt->task_type == Type::mesh_for) { |
480 | create_offload_mesh_for(stmt); |
481 | } else if (stmt->task_type == Type::listgen) { |
482 | emit_list_gen(stmt); |
483 | } else { |
484 | TI_NOT_IMPLEMENTED |
485 | } |
486 | finalize_offloaded_task_function(); |
487 | current_task->grid_dim = stmt->grid_dim; |
488 | if (stmt->task_type == Type::range_for) { |
489 | if (stmt->const_begin && stmt->const_end) { |
490 | int num_threads = stmt->end_value - stmt->begin_value; |
491 | int grid_dim = ((num_threads % stmt->block_dim) == 0) |
492 | ? (num_threads / stmt->block_dim) |
493 | : (num_threads / stmt->block_dim) + 1; |
494 | grid_dim = std::max(grid_dim, 1); |
495 | current_task->grid_dim = std::min(stmt->grid_dim, grid_dim); |
496 | } |
497 | } |
498 | if (stmt->task_type == Type::listgen) { |
499 | int query_max_block_per_sm; |
500 | CUDADriver::get_instance().device_get_attribute( |
501 | &query_max_block_per_sm, |
502 | CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR, nullptr); |
503 | int num_SMs; |
504 | CUDADriver::get_instance().device_get_attribute( |
505 | &num_SMs, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, nullptr); |
506 | current_task->grid_dim = num_SMs * query_max_block_per_sm; |
507 | } |
508 | current_task->block_dim = stmt->block_dim; |
509 | TI_ASSERT(current_task->grid_dim != 0); |
510 | TI_ASSERT(current_task->block_dim != 0); |
511 | offloaded_tasks.push_back(*current_task); |
512 | current_task = nullptr; |
513 | } |
514 | current_offload = nullptr; |
515 | #else |
516 | TI_NOT_IMPLEMENTED |
517 | #endif |
518 | } |
519 | |
520 | void visit(ExternalFuncCallStmt *stmt) override { |
521 | if (stmt->type == ExternalFuncCallStmt::BITCODE) { |
522 | TaskCodeGenLLVM::visit_call_bitcode(stmt); |
523 | } else { |
524 | TI_NOT_IMPLEMENTED |
525 | } |
526 | } |
527 | |
528 | void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { |
529 | const auto arg_id = stmt->arg_id; |
530 | const auto axis = stmt->axis; |
531 | llvm_val[stmt] = |
532 | call("RuntimeContext_get_extra_args" , get_context(), |
533 | tlctx->get_constant(arg_id), tlctx->get_constant(axis)); |
534 | } |
535 | |
536 | void visit(BinaryOpStmt *stmt) override { |
537 | auto op = stmt->op_type; |
538 | if (op != BinaryOpType::atan2 && op != BinaryOpType::pow) { |
539 | return TaskCodeGenLLVM::visit(stmt); |
540 | } |
541 | |
542 | auto ret_type = stmt->ret_type; |
543 | |
544 | llvm::Value *lhs = llvm_val[stmt->lhs]; |
545 | llvm::Value *rhs = llvm_val[stmt->rhs]; |
546 | |
547 | // This branch contains atan2 and pow which use runtime.cpp function for |
548 | // **real** type. We don't have f16 support there so promoting to f32 is |
549 | // necessary. |
550 | if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
551 | lhs = builder->CreateFPExt(lhs, llvm::Type::getFloatTy(*llvm_context)); |
552 | } |
553 | if (stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
554 | rhs = builder->CreateFPExt(rhs, llvm::Type::getFloatTy(*llvm_context)); |
555 | } |
556 | if (ret_type->is_primitive(PrimitiveTypeID::f16)) { |
557 | ret_type = PrimitiveType::f32; |
558 | } |
559 | |
560 | if (op == BinaryOpType::atan2) { |
561 | if (ret_type->is_primitive(PrimitiveTypeID::f32)) { |
562 | llvm_val[stmt] = call("__nv_atan2f" , lhs, rhs); |
563 | } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { |
564 | llvm_val[stmt] = call("__nv_atan2" , lhs, rhs); |
565 | } else { |
566 | TI_P(data_type_name(ret_type)); |
567 | TI_NOT_IMPLEMENTED |
568 | } |
569 | } else { |
570 | // Note that ret_type here cannot be integral because pow with an |
571 | // integral exponent has been demoted in the demote_operations pass |
572 | if (ret_type->is_primitive(PrimitiveTypeID::f32)) { |
573 | llvm_val[stmt] = call("__nv_powf" , lhs, rhs); |
574 | } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { |
575 | llvm_val[stmt] = call("__nv_pow" , lhs, rhs); |
576 | } else { |
577 | TI_P(data_type_name(ret_type)); |
578 | TI_NOT_IMPLEMENTED |
579 | } |
580 | } |
581 | |
582 | // Convert back to f16 if applicable. |
583 | if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { |
584 | llvm_val[stmt] = builder->CreateFPTrunc( |
585 | llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); |
586 | } |
587 | } |
588 | |
589 | private: |
590 | std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override { |
591 | auto thread_idx = |
592 | builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}); |
593 | auto block_dim = |
594 | builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}); |
595 | return std::make_tuple(thread_idx, block_dim); |
596 | } |
597 | }; |
598 | |
599 | LLVMCompiledTask KernelCodeGenCUDA::compile_task( |
600 | const CompileConfig &config, |
601 | std::unique_ptr<llvm::Module> &&module, |
602 | OffloadedStmt *stmt) { |
603 | TaskCodeGenCUDA gen(config, get_taichi_llvm_context(), kernel, stmt); |
604 | return gen.run_compilation(); |
605 | } |
606 | |
607 | FunctionType KernelCodeGenCUDA::compile_to_function() { |
608 | TI_AUTO_PROF |
609 | CUDAModuleToFunctionConverter converter{ |
610 | &get_taichi_llvm_context(), |
611 | get_llvm_program(prog)->get_runtime_executor()}; |
612 | return converter.convert(this->kernel, compile_kernel_to_module()); |
613 | } |
614 | |
615 | FunctionType CUDAModuleToFunctionConverter::convert( |
616 | const std::string &kernel_name, |
617 | const std::vector<LlvmLaunchArgInfo> &args, |
618 | LLVMCompiledKernel data) const { |
619 | auto &mod = data.module; |
620 | auto &tasks = data.tasks; |
621 | #ifdef TI_WITH_CUDA |
622 | auto jit = tlctx_->jit.get(); |
623 | auto cuda_module = |
624 | jit->add_module(std::move(mod), executor_->get_config().gpu_max_reg); |
625 | |
626 | return [cuda_module, kernel_name, args, offloaded_tasks = tasks, |
627 | executor = this->executor_](RuntimeContext &context) { |
628 | CUDAContext::get_instance().make_current(); |
629 | std::vector<void *> arg_buffers(args.size(), nullptr); |
630 | std::vector<void *> device_buffers(args.size(), nullptr); |
631 | std::vector<DeviceAllocation> temporary_devallocs(args.size()); |
632 | |
633 | bool transferred = false; |
634 | for (int i = 0; i < (int)args.size(); i++) { |
635 | if (args[i].is_array) { |
636 | const auto arr_sz = context.array_runtime_sizes[i]; |
637 | if (arr_sz == 0) { |
638 | continue; |
639 | } |
640 | arg_buffers[i] = context.get_arg<void *>(i); |
641 | if (context.device_allocation_type[i] == |
642 | RuntimeContext::DevAllocType::kNone) { |
643 | // Note: both numpy and PyTorch support arrays/tensors with zeros |
644 | // in shapes, e.g., shape=(0) or shape=(100, 0, 200). This makes |
645 | // `arr_sz` zero. |
646 | unsigned int attr_val = 0; |
647 | uint32_t ret_code = CUDADriver::get_instance().mem_get_attribute.call( |
648 | &attr_val, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, |
649 | (void *)arg_buffers[i]); |
650 | |
651 | if (ret_code != CUDA_SUCCESS || attr_val != CU_MEMORYTYPE_DEVICE) { |
652 | // Copy to device buffer if arg is on host |
653 | // - ret_code != CUDA_SUCCESS: |
654 | // arg_buffers[i] is not on device |
655 | // - attr_val != CU_MEMORYTYPE_DEVICE: |
656 | // Cuda driver is aware of arg_buffers[i] but it might be on |
657 | // host. |
658 | // See CUDA driver API `cuPointerGetAttribute` for more details. |
659 | transferred = true; |
660 | |
661 | auto result_buffer = context.result_buffer; |
662 | DeviceAllocation devalloc = |
663 | executor->allocate_memory_ndarray(arr_sz, result_buffer); |
664 | device_buffers[i] = executor->get_ndarray_alloc_info_ptr(devalloc); |
665 | temporary_devallocs[i] = devalloc; |
666 | |
667 | CUDADriver::get_instance().memcpy_host_to_device( |
668 | (void *)device_buffers[i], arg_buffers[i], arr_sz); |
669 | } else { |
670 | device_buffers[i] = arg_buffers[i]; |
671 | } |
672 | // device_buffers[i] saves a raw ptr on CUDA device. |
673 | context.set_arg(i, (uint64)device_buffers[i]); |
674 | |
675 | } else if (arr_sz > 0) { |
676 | // arg_buffers[i] is a DeviceAllocation* |
677 | // TODO: Unwraps DeviceAllocation* can be done at TaskCodeGenLLVM |
678 | // since it's shared by cpu and cuda. |
679 | DeviceAllocation *ptr = |
680 | static_cast<DeviceAllocation *>(arg_buffers[i]); |
681 | device_buffers[i] = executor->get_ndarray_alloc_info_ptr(*ptr); |
682 | // We compare arg_buffers[i] and device_buffers[i] later to check |
683 | // if transfer happened. |
684 | // TODO: this logic can be improved but I'll leave it to a followup |
685 | // PR. |
686 | arg_buffers[i] = device_buffers[i]; |
687 | |
688 | // device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i] |
689 | context.set_arg(i, (uint64)device_buffers[i]); |
690 | } |
691 | } |
692 | } |
693 | if (transferred) { |
694 | CUDADriver::get_instance().stream_synchronize(nullptr); |
695 | } |
696 | CUDADriver::get_instance().context_set_limit( |
697 | CU_LIMIT_STACK_SIZE, executor->get_config().cuda_stack_limit); |
698 | |
699 | for (auto task : offloaded_tasks) { |
700 | TI_TRACE("Launching kernel {}<<<{}, {}>>>" , task.name, task.grid_dim, |
701 | task.block_dim); |
702 | cuda_module->launch(task.name, task.grid_dim, task.block_dim, 0, |
703 | {&context}, {}); |
704 | } |
705 | |
706 | // copy data back to host |
707 | if (transferred) { |
708 | CUDADriver::get_instance().stream_synchronize(nullptr); |
709 | for (int i = 0; i < (int)args.size(); i++) { |
710 | if (device_buffers[i] != arg_buffers[i]) { |
711 | CUDADriver::get_instance().memcpy_device_to_host( |
712 | arg_buffers[i], (void *)device_buffers[i], |
713 | context.array_runtime_sizes[i]); |
714 | executor->deallocate_memory_ndarray(temporary_devallocs[i]); |
715 | } |
716 | } |
717 | } |
718 | }; |
719 | #else |
720 | TI_ERROR("No CUDA" ); |
721 | return nullptr; |
722 | #endif // TI_WITH_CUDA |
723 | } |
724 | |
725 | } // namespace taichi::lang |
726 | |