1// Program, context for Taichi program execution
2
3#include "program.h"
4
5#include "taichi/ir/statements.h"
6#include "taichi/program/extension.h"
7#include "taichi/codegen/cpu/codegen_cpu.h"
8#include "taichi/struct/struct.h"
9#include "taichi/runtime/wasm/aot_module_builder_impl.h"
10#include "taichi/runtime/program_impls/opengl/opengl_program.h"
11#include "taichi/runtime/program_impls/metal/metal_program.h"
12#include "taichi/codegen/cc/cc_program.h"
13#include "taichi/platform/cuda/detect_cuda.h"
14#include "taichi/system/unified_allocator.h"
15#include "taichi/system/timeline.h"
16#include "taichi/ir/snode.h"
17#include "taichi/ir/frontend_ir.h"
18#include "taichi/program/snode_expr_utils.h"
19#include "taichi/math/arithmetic.h"
20
21#ifdef TI_WITH_LLVM
22#include "taichi/runtime/program_impls/llvm/llvm_program.h"
23#include "taichi/codegen/llvm/struct_llvm.h"
24#endif
25
26#if defined(TI_WITH_CC)
27#include "taichi/codegen/cc/cc_program.h"
28#endif
29#ifdef TI_WITH_VULKAN
30#include "taichi/runtime/program_impls/vulkan/vulkan_program.h"
31#include "taichi/rhi/vulkan/vulkan_loader.h"
32#endif
33#ifdef TI_WITH_OPENGL
34#include "taichi/runtime/program_impls/opengl/opengl_program.h"
35#include "taichi/rhi/opengl/opengl_api.h"
36#endif
37#ifdef TI_WITH_DX11
38#include "taichi/runtime/program_impls/dx/dx_program.h"
39#include "taichi/rhi/dx/dx_api.h"
40#endif
41#ifdef TI_WITH_DX12
42#include "taichi/runtime/program_impls/dx12/dx12_program.h"
43#include "taichi/rhi/dx12/dx12_api.h"
44#endif
45#ifdef TI_WITH_METAL
46#include "taichi/runtime/program_impls/metal/metal_program.h"
47#include "taichi/rhi/metal/metal_api.h"
48#endif // TI_WITH_METAL
49
50#if defined(_M_X64) || defined(__x86_64)
51// For _MM_SET_FLUSH_ZERO_MODE
52#include <xmmintrin.h>
53#endif // defined(_M_X64) || defined(__x86_64)
54
55namespace taichi::lang {
56std::atomic<int> Program::num_instances_;
57
58Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) {
59 TI_TRACE("Program initializing...");
60
61 // For performance considerations and correctness of QuantFloatType
62 // operations, we force floating-point operations to flush to zero on all
63 // backends (including CPUs).
64#if defined(_M_X64) || defined(__x86_64)
65 _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
66#endif // defined(_M_X64) || defined(__x86_64)
67#if defined(__arm64__) || defined(__aarch64__)
68 // Enforce flush to zero on arm64 CPUs
69 // https://developer.arm.com/documentation/100403/0201/register-descriptions/advanced-simd-and-floating-point-registers/aarch64-register-descriptions/fpcr--floating-point-control-register?lang=en
70 std::uint64_t fpcr;
71 __asm__ __volatile__("");
72 __asm__ __volatile__("MRS %0, FPCR" : "=r"(fpcr));
73 __asm__ __volatile__("");
74 __asm__ __volatile__("MSR FPCR, %0"
75 :
76 : "ri"(fpcr | (1 << 24))); // Bit 24 is FZ
77 __asm__ __volatile__("");
78#endif // defined(__arm64__) || defined(__aarch64__)
79 auto &config = compile_config_;
80 config = default_compile_config;
81 config.arch = desired_arch;
82 config.fit();
83
84 profiler = make_profiler(config.arch, config.kernel_profiler);
85 if (arch_uses_llvm(config.arch)) {
86#ifdef TI_WITH_LLVM
87 if (config.arch != Arch::dx12) {
88 program_impl_ = std::make_unique<LlvmProgramImpl>(config, profiler.get());
89 } else {
90 // NOTE: use Dx12ProgramImpl to avoid using LlvmRuntimeExecutor for dx12.
91#ifdef TI_WITH_DX12
92 TI_ASSERT(directx12::is_dx12_api_available());
93 program_impl_ = std::make_unique<Dx12ProgramImpl>(config);
94#else
95 TI_ERROR("This taichi is not compiled with DX12");
96#endif
97 }
98#else
99 TI_ERROR("This taichi is not compiled with LLVM");
100#endif
101 } else if (config.arch == Arch::metal) {
102#ifdef TI_WITH_METAL
103 TI_ASSERT(metal::is_metal_api_available());
104 program_impl_ = std::make_unique<MetalProgramImpl>(config);
105#else
106 TI_ERROR("This taichi is not compiled with Metal")
107#endif
108 } else if (config.arch == Arch::vulkan) {
109#ifdef TI_WITH_VULKAN
110 TI_ASSERT(vulkan::is_vulkan_api_available());
111 program_impl_ = std::make_unique<VulkanProgramImpl>(config);
112#else
113 TI_ERROR("This taichi is not compiled with Vulkan")
114#endif
115 } else if (config.arch == Arch::dx11) {
116#ifdef TI_WITH_DX11
117 TI_ASSERT(directx11::is_dx_api_available());
118 program_impl_ = std::make_unique<Dx11ProgramImpl>(config);
119#else
120 TI_ERROR("This taichi is not compiled with DX11");
121#endif
122 } else if (config.arch == Arch::opengl) {
123#ifdef TI_WITH_OPENGL
124 TI_ASSERT(opengl::initialize_opengl(false));
125 program_impl_ = std::make_unique<OpenglProgramImpl>(config);
126#else
127 TI_ERROR("This taichi is not compiled with OpenGL");
128#endif
129 } else if (config.arch == Arch::gles) {
130#ifdef TI_WITH_OPENGL
131 TI_ASSERT(opengl::initialize_opengl(true));
132 program_impl_ = std::make_unique<OpenglProgramImpl>(config);
133#else
134 TI_ERROR("This taichi is not compiled with OpenGL");
135#endif
136 } else if (config.arch == Arch::cc) {
137#ifdef TI_WITH_CC
138 program_impl_ = std::make_unique<CCProgramImpl>(config);
139#else
140 TI_ERROR("No C backend detected.");
141#endif
142 } else {
143 TI_NOT_IMPLEMENTED
144 }
145
146 // program_impl_ should be set in the if-else branch above
147 TI_ASSERT(program_impl_);
148
149 Device *compute_device = nullptr;
150 compute_device = program_impl_->get_compute_device();
151 // Must have handled all the arch fallback logic by this point.
152 memory_pool_ = std::make_unique<MemoryPool>(config.arch, compute_device);
153 TI_ASSERT_INFO(num_instances_ == 0, "Only one instance at a time");
154 total_compilation_time_ = 0;
155 num_instances_ += 1;
156 SNode::counter = 0;
157
158 result_buffer = nullptr;
159 finalized_ = false;
160
161 if (!is_extension_supported(config.arch, Extension::assertion)) {
162 if (config.check_out_of_bound) {
163 TI_WARN("Out-of-bound access checking is not supported on arch={}",
164 arch_name(config.arch));
165 config.check_out_of_bound = false;
166 }
167 }
168
169 Timelines::get_instance().set_enabled(config.timeline);
170
171 TI_TRACE("Program ({}) arch={} initialized.", fmt::ptr(this),
172 arch_name(config.arch));
173}
174
175TypeFactory &Program::get_type_factory() {
176 TI_WARN(
177 "Program::get_type_factory() will be deprecated, Please use "
178 "TypeFactory::get_instance()");
179 return TypeFactory::get_instance();
180}
181
182Function *Program::create_function(const FunctionKey &func_key) {
183 TI_TRACE("Creating function {}...", func_key.get_full_name());
184 functions_.emplace_back(std::make_unique<Function>(this, func_key));
185 TI_ASSERT(function_map_.count(func_key) == 0);
186 function_map_[func_key] = functions_.back().get();
187 return functions_.back().get();
188}
189
190FunctionType Program::compile(const CompileConfig &compile_config,
191 Kernel &kernel) {
192 auto start_t = Time::get_time();
193 TI_AUTO_PROF;
194 auto ret = program_impl_->compile(compile_config, &kernel);
195 TI_ASSERT(ret);
196 total_compilation_time_ += Time::get_time() - start_t;
197 return ret;
198}
199
200void Program::materialize_runtime() {
201 program_impl_->materialize_runtime(memory_pool_.get(), profiler.get(),
202 &result_buffer);
203}
204
205void Program::destroy_snode_tree(SNodeTree *snode_tree) {
206 TI_ASSERT(arch_uses_llvm(compile_config().arch) ||
207 compile_config().arch == Arch::vulkan ||
208 compile_config().arch == Arch::dx11 ||
209 compile_config().arch == Arch::dx12);
210 program_impl_->destroy_snode_tree(snode_tree);
211 free_snode_tree_ids_.push(snode_tree->id());
212}
213
214SNodeTree *Program::add_snode_tree(std::unique_ptr<SNode> root,
215 bool compile_only) {
216 const int id = allocate_snode_tree_id();
217 auto tree = std::make_unique<SNodeTree>(id, std::move(root));
218 tree->root()->set_snode_tree_id(id);
219 if (compile_only) {
220 program_impl_->compile_snode_tree_types(tree.get());
221 } else {
222 program_impl_->materialize_snode_tree(tree.get(), result_buffer);
223 }
224 if (id < snode_trees_.size()) {
225 snode_trees_[id] = std::move(tree);
226 } else {
227 TI_ASSERT(id == snode_trees_.size());
228 snode_trees_.push_back(std::move(tree));
229 }
230 return snode_trees_[id].get();
231}
232
233SNode *Program::get_snode_root(int tree_id) {
234 return snode_trees_[tree_id]->root();
235}
236
237void Program::check_runtime_error() {
238 program_impl_->check_runtime_error(result_buffer);
239}
240
241void Program::synchronize() {
242 program_impl_->synchronize();
243}
244
245StreamSemaphore Program::flush() {
246 return program_impl_->flush();
247}
248
249int Program::get_snode_tree_size() {
250 return snode_trees_.size();
251}
252
253Kernel &Program::get_snode_reader(SNode *snode) {
254 TI_ASSERT(snode->type == SNodeType::place);
255 auto kernel_name = fmt::format("snode_reader_{}", snode->id);
256 auto &ker = kernel([snode, this](Kernel *kernel) {
257 ExprGroup indices;
258 for (int i = 0; i < snode->num_active_indices; i++) {
259 auto argload_expr = Expr::make<ArgLoadExpression>(i, PrimitiveType::i32);
260 argload_expr->type_check(&this->compile_config());
261 indices.push_back(std::move(argload_expr));
262 }
263 ASTBuilder &builder = kernel->context->builder();
264 auto ret = Stmt::make<FrontendReturnStmt>(ExprGroup(
265 builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices)));
266 builder.insert(std::move(ret));
267 });
268 ker.name = kernel_name;
269 ker.is_accessor = true;
270 for (int i = 0; i < snode->num_active_indices; i++)
271 ker.insert_scalar_param(PrimitiveType::i32);
272 ker.insert_ret(snode->dt);
273 ker.finalize_rets();
274 return ker;
275}
276
277Kernel &Program::get_snode_writer(SNode *snode) {
278 TI_ASSERT(snode->type == SNodeType::place);
279 auto kernel_name = fmt::format("snode_writer_{}", snode->id);
280 auto &ker = kernel([snode, this](Kernel *kernel) {
281 ExprGroup indices;
282 for (int i = 0; i < snode->num_active_indices; i++) {
283 auto argload_expr = Expr::make<ArgLoadExpression>(i, PrimitiveType::i32);
284 argload_expr->type_check(&this->compile_config());
285 indices.push_back(std::move(argload_expr));
286 }
287 ASTBuilder &builder = kernel->context->builder();
288 auto expr =
289 builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices);
290 builder.insert_assignment(
291 expr,
292 Expr::make<ArgLoadExpression>(snode->num_active_indices,
293 snode->dt->get_compute_type()),
294 expr->tb);
295 });
296 ker.name = kernel_name;
297 ker.is_accessor = true;
298 for (int i = 0; i < snode->num_active_indices; i++)
299 ker.insert_scalar_param(PrimitiveType::i32);
300 ker.insert_scalar_param(snode->dt);
301 return ker;
302}
303
304uint64 Program::fetch_result_uint64(int i) {
305 return program_impl_->fetch_result_uint64(i, result_buffer);
306}
307
308void Program::finalize() {
309 if (finalized_) {
310 return;
311 }
312 synchronize();
313 TI_TRACE("Program finalizing...");
314
315 synchronize();
316 memory_pool_->terminate();
317 if (arch_uses_llvm(compile_config().arch)) {
318 program_impl_->finalize();
319 }
320
321 Stmt::reset_counter();
322
323 finalized_ = true;
324 num_instances_ -= 1;
325 program_impl_->dump_cache_data_to_disk();
326 compile_config_ = default_compile_config;
327 TI_TRACE("Program ({}) finalized_.", fmt::ptr(this));
328}
329
330int Program::default_block_dim(const CompileConfig &config) {
331 if (arch_is_cpu(config.arch)) {
332 return config.default_cpu_block_dim;
333 } else {
334 return config.default_gpu_block_dim;
335 }
336}
337
338void Program::print_memory_profiler_info() {
339 program_impl_->print_memory_profiler_info(snode_trees_, result_buffer);
340}
341
342std::size_t Program::get_snode_num_dynamically_allocated(SNode *snode) {
343 return program_impl_->get_snode_num_dynamically_allocated(snode,
344 result_buffer);
345}
346
347Ndarray *Program::create_ndarray(const DataType type,
348 const std::vector<int> &shape,
349 ExternalArrayLayout layout,
350 bool zero_fill) {
351 auto arr = std::make_unique<Ndarray>(this, type, shape, layout);
352 if (zero_fill) {
353 Arch arch = compile_config().arch;
354 if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu) {
355 fill_ndarray_fast_u32(arr.get(), /*data=*/0);
356 } else if (arch != Arch::dx12) {
357 // Device api support for dx12 backend are not complete yet
358 Stream *stream =
359 program_impl_->get_compute_device()->get_compute_stream();
360 auto [cmdlist, res] = stream->new_command_list_unique();
361 TI_ASSERT(res == RhiResult::success);
362 cmdlist->buffer_fill(arr->ndarray_alloc_.get_ptr(0),
363 arr->get_element_size() * arr->get_nelement(),
364 /*data=*/0);
365 stream->submit_synced(cmdlist.get());
366 }
367 }
368 auto arr_ptr = arr.get();
369 ndarrays_.insert({arr_ptr, std::move(arr)});
370 return arr_ptr;
371}
372
373void Program::delete_ndarray(Ndarray *ndarray) {
374 // [Note] Ndarray memory deallocation
375 // Ndarray's memory allocation is managed by Taichi and Python can control
376 // this via Taichi indirectly. For example, when an ndarray is GC-ed in
377 // Python, it signals Taichi to free its memory allocation. But Taichi will
378 // make sure **no pending kernels to be executed needs the ndarray** before it
379 // actually frees the memory. When `ti.reset()` is called, all ndarrays
380 // allocated in this program should be gone and no longer valid in Python.
381 // This isn't the best implementation, ndarrays should be managed by taichi
382 // runtime instead of this giant program and it should be freed when:
383 // - Python GC signals taichi that it's no longer useful
384 // - All kernels using it are executed.
385 if (ndarrays_.count(ndarray) &&
386 !program_impl_->used_in_kernel(ndarray->ndarray_alloc_.alloc_id)) {
387 ndarrays_.erase(ndarray);
388 }
389}
390
391Texture *Program::create_texture(const DataType type,
392 int num_channels,
393 const std::vector<int> &shape) {
394 BufferFormat buffer_format = type_channels2buffer_format(type, num_channels);
395 if (shape.size() == 1) {
396 textures_.push_back(
397 std::make_unique<Texture>(this, buffer_format, shape[0], 1, 1));
398 } else if (shape.size() == 2) {
399 textures_.push_back(
400 std::make_unique<Texture>(this, buffer_format, shape[0], shape[1], 1));
401 } else if (shape.size() == 3) {
402 textures_.push_back(std::make_unique<Texture>(this, buffer_format, shape[0],
403 shape[1], shape[2]));
404 } else {
405 TI_ERROR("Texture shape invalid");
406 }
407 return textures_.back().get();
408}
409
410intptr_t Program::get_ndarray_data_ptr_as_int(const Ndarray *ndarray) {
411 uint64_t *data_ptr{nullptr};
412 if (arch_is_cpu(compile_config().arch) ||
413 compile_config().arch == Arch::cuda ||
414 compile_config().arch == Arch::amdgpu) {
415 // For the LLVM backends, device allocation is a physical pointer.
416 data_ptr =
417 program_impl_->get_ndarray_alloc_info_ptr(ndarray->ndarray_alloc_);
418 }
419
420 return reinterpret_cast<intptr_t>(data_ptr);
421}
422
423void Program::fill_ndarray_fast_u32(Ndarray *ndarray, uint32_t val) {
424 // This is a temporary solution to bypass device api.
425 // Should be moved to CommandList once available in CUDA.
426 program_impl_->fill_ndarray(
427 ndarray->ndarray_alloc_,
428 ndarray->get_nelement() * ndarray->get_element_size() / sizeof(uint32_t),
429 val);
430}
431
432Program::~Program() {
433 finalize();
434}
435
436DeviceCapabilityConfig translate_devcaps(const std::vector<std::string> &caps) {
437 // Each device capability assignment is named like this:
438 // - `spirv_version=1.3`
439 // - `spirv_has_int8`
440 DeviceCapabilityConfig cfg{};
441 for (const std::string &cap : caps) {
442 std::string_view key;
443 std::string_view value;
444 size_t ieq = cap.find('=');
445 if (ieq == std::string::npos) {
446 key = cap;
447 } else {
448 key = std::string_view(cap.c_str(), ieq);
449 value = std::string_view(cap.c_str() + ieq + 1);
450 }
451 DeviceCapability devcap = str2devcap(key);
452 switch (devcap) {
453 case DeviceCapability::spirv_version: {
454 if (value == "1.3") {
455 cfg.set(devcap, 0x10300);
456 } else if (value == "1.4") {
457 cfg.set(devcap, 0x10400);
458 } else if (value == "1.5") {
459 cfg.set(devcap, 0x10500);
460 } else {
461 TI_ERROR(
462 "'{}' is not a valid value of device capability `spirv_version`",
463 value);
464 }
465 break;
466 }
467 default:
468 cfg.set(devcap, 1);
469 break;
470 }
471 }
472
473 // Assign default caps (that always present).
474 if (!cfg.contains(DeviceCapability::spirv_version)) {
475 cfg.set(DeviceCapability::spirv_version, 0x10300);
476 }
477 return cfg;
478}
479
480std::unique_ptr<AotModuleBuilder> Program::make_aot_module_builder(
481 Arch arch,
482 const std::vector<std::string> &caps) {
483 DeviceCapabilityConfig cfg = translate_devcaps(caps);
484 // FIXME: This couples the runtime backend with the target AOT backend. E.g.
485 // If we want to build a Metal AOT module, we have to be on the macOS
486 // platform. Consider decoupling this part
487 if (arch == Arch::wasm) {
488 // TODO(PGZXB): Dispatch to the LlvmProgramImpl.
489#ifdef TI_WITH_LLVM
490 auto *llvm_prog = dynamic_cast<LlvmProgramImpl *>(program_impl_.get());
491 TI_ASSERT(llvm_prog != nullptr);
492 return std::make_unique<wasm::AotModuleBuilderImpl>(
493 compile_config(), *llvm_prog->get_llvm_context());
494#else
495 TI_NOT_IMPLEMENTED
496#endif
497 }
498 if (arch_uses_llvm(compile_config().arch) ||
499 compile_config().arch == Arch::metal ||
500 compile_config().arch == Arch::vulkan ||
501 compile_config().arch == Arch::opengl ||
502 compile_config().arch == Arch::gles ||
503 compile_config().arch == Arch::dx12) {
504 return program_impl_->make_aot_module_builder(cfg);
505 }
506 return nullptr;
507}
508
509int Program::allocate_snode_tree_id() {
510 if (free_snode_tree_ids_.empty()) {
511 return snode_trees_.size();
512 } else {
513 int id = free_snode_tree_ids_.top();
514 free_snode_tree_ids_.pop();
515 return id;
516 }
517}
518
519void Program::prepare_runtime_context(RuntimeContext *ctx) {
520 ctx->result_buffer = result_buffer;
521 program_impl_->prepare_runtime_context(ctx);
522}
523
524void Program::enqueue_compute_op_lambda(
525 std::function<void(Device *device, CommandList *cmdlist)> op,
526 const std::vector<ComputeOpImageRef> &image_refs) {
527 program_impl_->enqueue_compute_op_lambda(op, image_refs);
528}
529
530} // namespace taichi::lang
531