1 | // Program, context for Taichi program execution |
2 | |
3 | #include "program.h" |
4 | |
5 | #include "taichi/ir/statements.h" |
6 | #include "taichi/program/extension.h" |
7 | #include "taichi/codegen/cpu/codegen_cpu.h" |
8 | #include "taichi/struct/struct.h" |
9 | #include "taichi/runtime/wasm/aot_module_builder_impl.h" |
10 | #include "taichi/runtime/program_impls/opengl/opengl_program.h" |
11 | #include "taichi/runtime/program_impls/metal/metal_program.h" |
12 | #include "taichi/codegen/cc/cc_program.h" |
13 | #include "taichi/platform/cuda/detect_cuda.h" |
14 | #include "taichi/system/unified_allocator.h" |
15 | #include "taichi/system/timeline.h" |
16 | #include "taichi/ir/snode.h" |
17 | #include "taichi/ir/frontend_ir.h" |
18 | #include "taichi/program/snode_expr_utils.h" |
19 | #include "taichi/math/arithmetic.h" |
20 | |
21 | #ifdef TI_WITH_LLVM |
22 | #include "taichi/runtime/program_impls/llvm/llvm_program.h" |
23 | #include "taichi/codegen/llvm/struct_llvm.h" |
24 | #endif |
25 | |
26 | #if defined(TI_WITH_CC) |
27 | #include "taichi/codegen/cc/cc_program.h" |
28 | #endif |
29 | #ifdef TI_WITH_VULKAN |
30 | #include "taichi/runtime/program_impls/vulkan/vulkan_program.h" |
31 | #include "taichi/rhi/vulkan/vulkan_loader.h" |
32 | #endif |
33 | #ifdef TI_WITH_OPENGL |
34 | #include "taichi/runtime/program_impls/opengl/opengl_program.h" |
35 | #include "taichi/rhi/opengl/opengl_api.h" |
36 | #endif |
37 | #ifdef TI_WITH_DX11 |
38 | #include "taichi/runtime/program_impls/dx/dx_program.h" |
39 | #include "taichi/rhi/dx/dx_api.h" |
40 | #endif |
41 | #ifdef TI_WITH_DX12 |
42 | #include "taichi/runtime/program_impls/dx12/dx12_program.h" |
43 | #include "taichi/rhi/dx12/dx12_api.h" |
44 | #endif |
45 | #ifdef TI_WITH_METAL |
46 | #include "taichi/runtime/program_impls/metal/metal_program.h" |
47 | #include "taichi/rhi/metal/metal_api.h" |
48 | #endif // TI_WITH_METAL |
49 | |
50 | #if defined(_M_X64) || defined(__x86_64) |
51 | // For _MM_SET_FLUSH_ZERO_MODE |
52 | #include <xmmintrin.h> |
53 | #endif // defined(_M_X64) || defined(__x86_64) |
54 | |
55 | namespace taichi::lang { |
56 | std::atomic<int> Program::num_instances_; |
57 | |
58 | Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { |
59 | TI_TRACE("Program initializing..." ); |
60 | |
61 | // For performance considerations and correctness of QuantFloatType |
62 | // operations, we force floating-point operations to flush to zero on all |
63 | // backends (including CPUs). |
64 | #if defined(_M_X64) || defined(__x86_64) |
65 | _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); |
66 | #endif // defined(_M_X64) || defined(__x86_64) |
67 | #if defined(__arm64__) || defined(__aarch64__) |
68 | // Enforce flush to zero on arm64 CPUs |
69 | // https://developer.arm.com/documentation/100403/0201/register-descriptions/advanced-simd-and-floating-point-registers/aarch64-register-descriptions/fpcr--floating-point-control-register?lang=en |
70 | std::uint64_t fpcr; |
71 | __asm__ __volatile__("" ); |
72 | __asm__ __volatile__("MRS %0, FPCR" : "=r" (fpcr)); |
73 | __asm__ __volatile__("" ); |
74 | __asm__ __volatile__("MSR FPCR, %0" |
75 | : |
76 | : "ri" (fpcr | (1 << 24))); // Bit 24 is FZ |
77 | __asm__ __volatile__("" ); |
78 | #endif // defined(__arm64__) || defined(__aarch64__) |
79 | auto &config = compile_config_; |
80 | config = default_compile_config; |
81 | config.arch = desired_arch; |
82 | config.fit(); |
83 | |
84 | profiler = make_profiler(config.arch, config.kernel_profiler); |
85 | if (arch_uses_llvm(config.arch)) { |
86 | #ifdef TI_WITH_LLVM |
87 | if (config.arch != Arch::dx12) { |
88 | program_impl_ = std::make_unique<LlvmProgramImpl>(config, profiler.get()); |
89 | } else { |
90 | // NOTE: use Dx12ProgramImpl to avoid using LlvmRuntimeExecutor for dx12. |
91 | #ifdef TI_WITH_DX12 |
92 | TI_ASSERT(directx12::is_dx12_api_available()); |
93 | program_impl_ = std::make_unique<Dx12ProgramImpl>(config); |
94 | #else |
95 | TI_ERROR("This taichi is not compiled with DX12" ); |
96 | #endif |
97 | } |
98 | #else |
99 | TI_ERROR("This taichi is not compiled with LLVM" ); |
100 | #endif |
101 | } else if (config.arch == Arch::metal) { |
102 | #ifdef TI_WITH_METAL |
103 | TI_ASSERT(metal::is_metal_api_available()); |
104 | program_impl_ = std::make_unique<MetalProgramImpl>(config); |
105 | #else |
106 | TI_ERROR("This taichi is not compiled with Metal" ) |
107 | #endif |
108 | } else if (config.arch == Arch::vulkan) { |
109 | #ifdef TI_WITH_VULKAN |
110 | TI_ASSERT(vulkan::is_vulkan_api_available()); |
111 | program_impl_ = std::make_unique<VulkanProgramImpl>(config); |
112 | #else |
113 | TI_ERROR("This taichi is not compiled with Vulkan" ) |
114 | #endif |
115 | } else if (config.arch == Arch::dx11) { |
116 | #ifdef TI_WITH_DX11 |
117 | TI_ASSERT(directx11::is_dx_api_available()); |
118 | program_impl_ = std::make_unique<Dx11ProgramImpl>(config); |
119 | #else |
120 | TI_ERROR("This taichi is not compiled with DX11" ); |
121 | #endif |
122 | } else if (config.arch == Arch::opengl) { |
123 | #ifdef TI_WITH_OPENGL |
124 | TI_ASSERT(opengl::initialize_opengl(false)); |
125 | program_impl_ = std::make_unique<OpenglProgramImpl>(config); |
126 | #else |
127 | TI_ERROR("This taichi is not compiled with OpenGL" ); |
128 | #endif |
129 | } else if (config.arch == Arch::gles) { |
130 | #ifdef TI_WITH_OPENGL |
131 | TI_ASSERT(opengl::initialize_opengl(true)); |
132 | program_impl_ = std::make_unique<OpenglProgramImpl>(config); |
133 | #else |
134 | TI_ERROR("This taichi is not compiled with OpenGL" ); |
135 | #endif |
136 | } else if (config.arch == Arch::cc) { |
137 | #ifdef TI_WITH_CC |
138 | program_impl_ = std::make_unique<CCProgramImpl>(config); |
139 | #else |
140 | TI_ERROR("No C backend detected." ); |
141 | #endif |
142 | } else { |
143 | TI_NOT_IMPLEMENTED |
144 | } |
145 | |
146 | // program_impl_ should be set in the if-else branch above |
147 | TI_ASSERT(program_impl_); |
148 | |
149 | Device *compute_device = nullptr; |
150 | compute_device = program_impl_->get_compute_device(); |
151 | // Must have handled all the arch fallback logic by this point. |
152 | memory_pool_ = std::make_unique<MemoryPool>(config.arch, compute_device); |
153 | TI_ASSERT_INFO(num_instances_ == 0, "Only one instance at a time" ); |
154 | total_compilation_time_ = 0; |
155 | num_instances_ += 1; |
156 | SNode::counter = 0; |
157 | |
158 | result_buffer = nullptr; |
159 | finalized_ = false; |
160 | |
161 | if (!is_extension_supported(config.arch, Extension::assertion)) { |
162 | if (config.check_out_of_bound) { |
163 | TI_WARN("Out-of-bound access checking is not supported on arch={}" , |
164 | arch_name(config.arch)); |
165 | config.check_out_of_bound = false; |
166 | } |
167 | } |
168 | |
169 | Timelines::get_instance().set_enabled(config.timeline); |
170 | |
171 | TI_TRACE("Program ({}) arch={} initialized." , fmt::ptr(this), |
172 | arch_name(config.arch)); |
173 | } |
174 | |
175 | TypeFactory &Program::get_type_factory() { |
176 | TI_WARN( |
177 | "Program::get_type_factory() will be deprecated, Please use " |
178 | "TypeFactory::get_instance()" ); |
179 | return TypeFactory::get_instance(); |
180 | } |
181 | |
182 | Function *Program::create_function(const FunctionKey &func_key) { |
183 | TI_TRACE("Creating function {}..." , func_key.get_full_name()); |
184 | functions_.emplace_back(std::make_unique<Function>(this, func_key)); |
185 | TI_ASSERT(function_map_.count(func_key) == 0); |
186 | function_map_[func_key] = functions_.back().get(); |
187 | return functions_.back().get(); |
188 | } |
189 | |
190 | FunctionType Program::compile(const CompileConfig &compile_config, |
191 | Kernel &kernel) { |
192 | auto start_t = Time::get_time(); |
193 | TI_AUTO_PROF; |
194 | auto ret = program_impl_->compile(compile_config, &kernel); |
195 | TI_ASSERT(ret); |
196 | total_compilation_time_ += Time::get_time() - start_t; |
197 | return ret; |
198 | } |
199 | |
200 | void Program::materialize_runtime() { |
201 | program_impl_->materialize_runtime(memory_pool_.get(), profiler.get(), |
202 | &result_buffer); |
203 | } |
204 | |
205 | void Program::destroy_snode_tree(SNodeTree *snode_tree) { |
206 | TI_ASSERT(arch_uses_llvm(compile_config().arch) || |
207 | compile_config().arch == Arch::vulkan || |
208 | compile_config().arch == Arch::dx11 || |
209 | compile_config().arch == Arch::dx12); |
210 | program_impl_->destroy_snode_tree(snode_tree); |
211 | free_snode_tree_ids_.push(snode_tree->id()); |
212 | } |
213 | |
214 | SNodeTree *Program::add_snode_tree(std::unique_ptr<SNode> root, |
215 | bool compile_only) { |
216 | const int id = allocate_snode_tree_id(); |
217 | auto tree = std::make_unique<SNodeTree>(id, std::move(root)); |
218 | tree->root()->set_snode_tree_id(id); |
219 | if (compile_only) { |
220 | program_impl_->compile_snode_tree_types(tree.get()); |
221 | } else { |
222 | program_impl_->materialize_snode_tree(tree.get(), result_buffer); |
223 | } |
224 | if (id < snode_trees_.size()) { |
225 | snode_trees_[id] = std::move(tree); |
226 | } else { |
227 | TI_ASSERT(id == snode_trees_.size()); |
228 | snode_trees_.push_back(std::move(tree)); |
229 | } |
230 | return snode_trees_[id].get(); |
231 | } |
232 | |
233 | SNode *Program::get_snode_root(int tree_id) { |
234 | return snode_trees_[tree_id]->root(); |
235 | } |
236 | |
237 | void Program::check_runtime_error() { |
238 | program_impl_->check_runtime_error(result_buffer); |
239 | } |
240 | |
241 | void Program::synchronize() { |
242 | program_impl_->synchronize(); |
243 | } |
244 | |
245 | StreamSemaphore Program::flush() { |
246 | return program_impl_->flush(); |
247 | } |
248 | |
249 | int Program::get_snode_tree_size() { |
250 | return snode_trees_.size(); |
251 | } |
252 | |
253 | Kernel &Program::get_snode_reader(SNode *snode) { |
254 | TI_ASSERT(snode->type == SNodeType::place); |
255 | auto kernel_name = fmt::format("snode_reader_{}" , snode->id); |
256 | auto &ker = kernel([snode, this](Kernel *kernel) { |
257 | ExprGroup indices; |
258 | for (int i = 0; i < snode->num_active_indices; i++) { |
259 | auto argload_expr = Expr::make<ArgLoadExpression>(i, PrimitiveType::i32); |
260 | argload_expr->type_check(&this->compile_config()); |
261 | indices.push_back(std::move(argload_expr)); |
262 | } |
263 | ASTBuilder &builder = kernel->context->builder(); |
264 | auto ret = Stmt::make<FrontendReturnStmt>(ExprGroup( |
265 | builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices))); |
266 | builder.insert(std::move(ret)); |
267 | }); |
268 | ker.name = kernel_name; |
269 | ker.is_accessor = true; |
270 | for (int i = 0; i < snode->num_active_indices; i++) |
271 | ker.insert_scalar_param(PrimitiveType::i32); |
272 | ker.insert_ret(snode->dt); |
273 | ker.finalize_rets(); |
274 | return ker; |
275 | } |
276 | |
277 | Kernel &Program::get_snode_writer(SNode *snode) { |
278 | TI_ASSERT(snode->type == SNodeType::place); |
279 | auto kernel_name = fmt::format("snode_writer_{}" , snode->id); |
280 | auto &ker = kernel([snode, this](Kernel *kernel) { |
281 | ExprGroup indices; |
282 | for (int i = 0; i < snode->num_active_indices; i++) { |
283 | auto argload_expr = Expr::make<ArgLoadExpression>(i, PrimitiveType::i32); |
284 | argload_expr->type_check(&this->compile_config()); |
285 | indices.push_back(std::move(argload_expr)); |
286 | } |
287 | ASTBuilder &builder = kernel->context->builder(); |
288 | auto expr = |
289 | builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices); |
290 | builder.insert_assignment( |
291 | expr, |
292 | Expr::make<ArgLoadExpression>(snode->num_active_indices, |
293 | snode->dt->get_compute_type()), |
294 | expr->tb); |
295 | }); |
296 | ker.name = kernel_name; |
297 | ker.is_accessor = true; |
298 | for (int i = 0; i < snode->num_active_indices; i++) |
299 | ker.insert_scalar_param(PrimitiveType::i32); |
300 | ker.insert_scalar_param(snode->dt); |
301 | return ker; |
302 | } |
303 | |
304 | uint64 Program::fetch_result_uint64(int i) { |
305 | return program_impl_->fetch_result_uint64(i, result_buffer); |
306 | } |
307 | |
308 | void Program::finalize() { |
309 | if (finalized_) { |
310 | return; |
311 | } |
312 | synchronize(); |
313 | TI_TRACE("Program finalizing..." ); |
314 | |
315 | synchronize(); |
316 | memory_pool_->terminate(); |
317 | if (arch_uses_llvm(compile_config().arch)) { |
318 | program_impl_->finalize(); |
319 | } |
320 | |
321 | Stmt::reset_counter(); |
322 | |
323 | finalized_ = true; |
324 | num_instances_ -= 1; |
325 | program_impl_->dump_cache_data_to_disk(); |
326 | compile_config_ = default_compile_config; |
327 | TI_TRACE("Program ({}) finalized_." , fmt::ptr(this)); |
328 | } |
329 | |
330 | int Program::default_block_dim(const CompileConfig &config) { |
331 | if (arch_is_cpu(config.arch)) { |
332 | return config.default_cpu_block_dim; |
333 | } else { |
334 | return config.default_gpu_block_dim; |
335 | } |
336 | } |
337 | |
338 | void Program::print_memory_profiler_info() { |
339 | program_impl_->print_memory_profiler_info(snode_trees_, result_buffer); |
340 | } |
341 | |
342 | std::size_t Program::get_snode_num_dynamically_allocated(SNode *snode) { |
343 | return program_impl_->get_snode_num_dynamically_allocated(snode, |
344 | result_buffer); |
345 | } |
346 | |
347 | Ndarray *Program::create_ndarray(const DataType type, |
348 | const std::vector<int> &shape, |
349 | ExternalArrayLayout layout, |
350 | bool zero_fill) { |
351 | auto arr = std::make_unique<Ndarray>(this, type, shape, layout); |
352 | if (zero_fill) { |
353 | Arch arch = compile_config().arch; |
354 | if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu) { |
355 | fill_ndarray_fast_u32(arr.get(), /*data=*/0); |
356 | } else if (arch != Arch::dx12) { |
357 | // Device api support for dx12 backend are not complete yet |
358 | Stream *stream = |
359 | program_impl_->get_compute_device()->get_compute_stream(); |
360 | auto [cmdlist, res] = stream->new_command_list_unique(); |
361 | TI_ASSERT(res == RhiResult::success); |
362 | cmdlist->buffer_fill(arr->ndarray_alloc_.get_ptr(0), |
363 | arr->get_element_size() * arr->get_nelement(), |
364 | /*data=*/0); |
365 | stream->submit_synced(cmdlist.get()); |
366 | } |
367 | } |
368 | auto arr_ptr = arr.get(); |
369 | ndarrays_.insert({arr_ptr, std::move(arr)}); |
370 | return arr_ptr; |
371 | } |
372 | |
373 | void Program::delete_ndarray(Ndarray *ndarray) { |
374 | // [Note] Ndarray memory deallocation |
375 | // Ndarray's memory allocation is managed by Taichi and Python can control |
376 | // this via Taichi indirectly. For example, when an ndarray is GC-ed in |
377 | // Python, it signals Taichi to free its memory allocation. But Taichi will |
378 | // make sure **no pending kernels to be executed needs the ndarray** before it |
379 | // actually frees the memory. When `ti.reset()` is called, all ndarrays |
380 | // allocated in this program should be gone and no longer valid in Python. |
381 | // This isn't the best implementation, ndarrays should be managed by taichi |
382 | // runtime instead of this giant program and it should be freed when: |
383 | // - Python GC signals taichi that it's no longer useful |
384 | // - All kernels using it are executed. |
385 | if (ndarrays_.count(ndarray) && |
386 | !program_impl_->used_in_kernel(ndarray->ndarray_alloc_.alloc_id)) { |
387 | ndarrays_.erase(ndarray); |
388 | } |
389 | } |
390 | |
391 | Texture *Program::create_texture(const DataType type, |
392 | int num_channels, |
393 | const std::vector<int> &shape) { |
394 | BufferFormat buffer_format = type_channels2buffer_format(type, num_channels); |
395 | if (shape.size() == 1) { |
396 | textures_.push_back( |
397 | std::make_unique<Texture>(this, buffer_format, shape[0], 1, 1)); |
398 | } else if (shape.size() == 2) { |
399 | textures_.push_back( |
400 | std::make_unique<Texture>(this, buffer_format, shape[0], shape[1], 1)); |
401 | } else if (shape.size() == 3) { |
402 | textures_.push_back(std::make_unique<Texture>(this, buffer_format, shape[0], |
403 | shape[1], shape[2])); |
404 | } else { |
405 | TI_ERROR("Texture shape invalid" ); |
406 | } |
407 | return textures_.back().get(); |
408 | } |
409 | |
410 | intptr_t Program::get_ndarray_data_ptr_as_int(const Ndarray *ndarray) { |
411 | uint64_t *data_ptr{nullptr}; |
412 | if (arch_is_cpu(compile_config().arch) || |
413 | compile_config().arch == Arch::cuda || |
414 | compile_config().arch == Arch::amdgpu) { |
415 | // For the LLVM backends, device allocation is a physical pointer. |
416 | data_ptr = |
417 | program_impl_->get_ndarray_alloc_info_ptr(ndarray->ndarray_alloc_); |
418 | } |
419 | |
420 | return reinterpret_cast<intptr_t>(data_ptr); |
421 | } |
422 | |
423 | void Program::fill_ndarray_fast_u32(Ndarray *ndarray, uint32_t val) { |
424 | // This is a temporary solution to bypass device api. |
425 | // Should be moved to CommandList once available in CUDA. |
426 | program_impl_->fill_ndarray( |
427 | ndarray->ndarray_alloc_, |
428 | ndarray->get_nelement() * ndarray->get_element_size() / sizeof(uint32_t), |
429 | val); |
430 | } |
431 | |
432 | Program::~Program() { |
433 | finalize(); |
434 | } |
435 | |
436 | DeviceCapabilityConfig translate_devcaps(const std::vector<std::string> &caps) { |
437 | // Each device capability assignment is named like this: |
438 | // - `spirv_version=1.3` |
439 | // - `spirv_has_int8` |
440 | DeviceCapabilityConfig cfg{}; |
441 | for (const std::string &cap : caps) { |
442 | std::string_view key; |
443 | std::string_view value; |
444 | size_t ieq = cap.find('='); |
445 | if (ieq == std::string::npos) { |
446 | key = cap; |
447 | } else { |
448 | key = std::string_view(cap.c_str(), ieq); |
449 | value = std::string_view(cap.c_str() + ieq + 1); |
450 | } |
451 | DeviceCapability devcap = str2devcap(key); |
452 | switch (devcap) { |
453 | case DeviceCapability::spirv_version: { |
454 | if (value == "1.3" ) { |
455 | cfg.set(devcap, 0x10300); |
456 | } else if (value == "1.4" ) { |
457 | cfg.set(devcap, 0x10400); |
458 | } else if (value == "1.5" ) { |
459 | cfg.set(devcap, 0x10500); |
460 | } else { |
461 | TI_ERROR( |
462 | "'{}' is not a valid value of device capability `spirv_version`" , |
463 | value); |
464 | } |
465 | break; |
466 | } |
467 | default: |
468 | cfg.set(devcap, 1); |
469 | break; |
470 | } |
471 | } |
472 | |
473 | // Assign default caps (that always present). |
474 | if (!cfg.contains(DeviceCapability::spirv_version)) { |
475 | cfg.set(DeviceCapability::spirv_version, 0x10300); |
476 | } |
477 | return cfg; |
478 | } |
479 | |
480 | std::unique_ptr<AotModuleBuilder> Program::make_aot_module_builder( |
481 | Arch arch, |
482 | const std::vector<std::string> &caps) { |
483 | DeviceCapabilityConfig cfg = translate_devcaps(caps); |
484 | // FIXME: This couples the runtime backend with the target AOT backend. E.g. |
485 | // If we want to build a Metal AOT module, we have to be on the macOS |
486 | // platform. Consider decoupling this part |
487 | if (arch == Arch::wasm) { |
488 | // TODO(PGZXB): Dispatch to the LlvmProgramImpl. |
489 | #ifdef TI_WITH_LLVM |
490 | auto *llvm_prog = dynamic_cast<LlvmProgramImpl *>(program_impl_.get()); |
491 | TI_ASSERT(llvm_prog != nullptr); |
492 | return std::make_unique<wasm::AotModuleBuilderImpl>( |
493 | compile_config(), *llvm_prog->get_llvm_context()); |
494 | #else |
495 | TI_NOT_IMPLEMENTED |
496 | #endif |
497 | } |
498 | if (arch_uses_llvm(compile_config().arch) || |
499 | compile_config().arch == Arch::metal || |
500 | compile_config().arch == Arch::vulkan || |
501 | compile_config().arch == Arch::opengl || |
502 | compile_config().arch == Arch::gles || |
503 | compile_config().arch == Arch::dx12) { |
504 | return program_impl_->make_aot_module_builder(cfg); |
505 | } |
506 | return nullptr; |
507 | } |
508 | |
509 | int Program::allocate_snode_tree_id() { |
510 | if (free_snode_tree_ids_.empty()) { |
511 | return snode_trees_.size(); |
512 | } else { |
513 | int id = free_snode_tree_ids_.top(); |
514 | free_snode_tree_ids_.pop(); |
515 | return id; |
516 | } |
517 | } |
518 | |
519 | void Program::prepare_runtime_context(RuntimeContext *ctx) { |
520 | ctx->result_buffer = result_buffer; |
521 | program_impl_->prepare_runtime_context(ctx); |
522 | } |
523 | |
524 | void Program::enqueue_compute_op_lambda( |
525 | std::function<void(Device *device, CommandList *cmdlist)> op, |
526 | const std::vector<ComputeOpImageRef> &image_refs) { |
527 | program_impl_->enqueue_compute_op_lambda(op, image_refs); |
528 | } |
529 | |
530 | } // namespace taichi::lang |
531 | |