1#include "taichi/runtime/gfx/runtime.h"
2#include "taichi/program/program.h"
3#include "taichi/common/filesystem.hpp"
4
5#include <chrono>
6#include <array>
7#include <iostream>
8#include <memory>
9#include <optional>
10#include <unordered_map>
11#include <unordered_set>
12#include <vector>
13
14#include "fp16.h"
15
16#define TI_RUNTIME_HOST
17#include "taichi/program/context.h"
18#undef TI_RUNTIME_HOST
19
20namespace taichi::lang {
21namespace gfx {
22
23namespace {
24
25class HostDeviceContextBlitter {
26 public:
27 HostDeviceContextBlitter(const KernelContextAttributes *ctx_attribs,
28 RuntimeContext *host_ctx,
29 Device *device,
30 uint64_t *host_result_buffer,
31 DeviceAllocation *device_args_buffer,
32 DeviceAllocation *device_ret_buffer)
33 : ctx_attribs_(ctx_attribs),
34 host_ctx_(host_ctx),
35 host_result_buffer_(host_result_buffer),
36 device_args_buffer_(device_args_buffer),
37 device_ret_buffer_(device_ret_buffer),
38 device_(device) {
39 }
40
41 void host_to_device(
42 const std::unordered_map<int, DeviceAllocation> &ext_arrays,
43 const std::unordered_map<int, size_t> &ext_arr_size) {
44 if (!ctx_attribs_->has_args()) {
45 return;
46 }
47
48 void *device_base{nullptr};
49 TI_ASSERT(device_->map(*device_args_buffer_, &device_base) ==
50 RhiResult::success);
51
52#define TO_DEVICE(short_type, type) \
53 if (arg.dtype == PrimitiveTypeID::short_type) { \
54 auto d = host_ctx_->get_arg<type>(i); \
55 reinterpret_cast<type *>(device_ptr)[0] = d; \
56 break; \
57 }
58
59 for (int i = 0; i < ctx_attribs_->args().size(); ++i) {
60 const auto &arg = ctx_attribs_->args()[i];
61 void *device_ptr = (uint8_t *)device_base + arg.offset_in_mem;
62 do {
63 if (arg.is_array) {
64 if (host_ctx_->device_allocation_type[i] ==
65 RuntimeContext::DevAllocType::kNone &&
66 ext_arr_size.at(i)) {
67 // Only need to blit ext arrs (host array)
68 uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i));
69 if (access & uint32_t(irpass::ExternalPtrAccess::READ)) {
70 DeviceAllocation buffer = ext_arrays.at(i);
71 void *device_arr_ptr{nullptr};
72 TI_ASSERT(device_->map(buffer, &device_arr_ptr) ==
73 RhiResult::success);
74 const void *host_ptr = host_ctx_->get_arg<void *>(i);
75 std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i));
76 device_->unmap(buffer);
77 }
78 }
79 // Substitute in the device address.
80
81 // (penguinliong) We don't check the availability of physical pointer
82 // here. It should be done before you need this class.
83 if ((host_ctx_->device_allocation_type[i] ==
84 RuntimeContext::DevAllocType::kNone ||
85 host_ctx_->device_allocation_type[i] ==
86 RuntimeContext::DevAllocType::kNdarray)) {
87 uint64_t addr =
88 device_->get_memory_physical_pointer(ext_arrays.at(i));
89 reinterpret_cast<uint64 *>(device_ptr)[0] = addr;
90 }
91 // We should not process the rest
92 break;
93 }
94 // (penguinliong) Same. The availability of short/long int types depends
95 // on the kernels and compute graphs and the check should already be
96 // done during module loads.
97 TO_DEVICE(i8, int8)
98 TO_DEVICE(u8, uint8)
99 TO_DEVICE(i16, int16)
100 TO_DEVICE(u16, uint16)
101 TO_DEVICE(i32, int32)
102 TO_DEVICE(u32, uint32)
103 TO_DEVICE(f32, float32)
104 TO_DEVICE(i64, int64)
105 TO_DEVICE(u64, uint64)
106 TO_DEVICE(f64, float64)
107 if (arg.dtype == PrimitiveTypeID::f16) {
108 auto d = fp16_ieee_from_fp32_value(host_ctx_->get_arg<float>(i));
109 reinterpret_cast<uint16 *>(device_ptr)[0] = d;
110 break;
111 }
112 TI_ERROR("Device does not support arg type={}",
113 PrimitiveType::get(arg.dtype).to_string());
114 } while (false);
115 }
116
117 void *device_ptr =
118 (uint8_t *)device_base + ctx_attribs_->extra_args_mem_offset();
119 std::memcpy(device_ptr, host_ctx_->extra_args,
120 ctx_attribs_->extra_args_bytes());
121
122 device_->unmap(*device_args_buffer_);
123#undef TO_DEVICE
124 }
125
126 bool device_to_host(
127 CommandList *cmdlist,
128 const std::unordered_map<int, DeviceAllocation> &ext_arrays,
129 const std::unordered_map<int, size_t> &ext_arr_size) {
130 if (ctx_attribs_->empty()) {
131 return false;
132 }
133
134 bool require_sync = ctx_attribs_->rets().size() > 0;
135 std::vector<DevicePtr> readback_dev_ptrs;
136 std::vector<void *> readback_host_ptrs;
137 std::vector<size_t> readback_sizes;
138
139 for (int i = 0; i < ctx_attribs_->args().size(); ++i) {
140 const auto &arg = ctx_attribs_->args()[i];
141 if (arg.is_array &&
142 host_ctx_->device_allocation_type[i] ==
143 RuntimeContext::DevAllocType::kNone &&
144 ext_arr_size.at(i)) {
145 uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i));
146 if (access & uint32_t(irpass::ExternalPtrAccess::WRITE)) {
147 // Only need to blit ext arrs (host array)
148 readback_dev_ptrs.push_back(ext_arrays.at(i).get_ptr(0));
149 readback_host_ptrs.push_back(host_ctx_->get_arg<void *>(i));
150 readback_sizes.push_back(ext_arr_size.at(i));
151 require_sync = true;
152 }
153 }
154 }
155
156 if (require_sync) {
157 if (readback_sizes.size()) {
158 StreamSemaphore command_complete_sema =
159 device_->get_compute_stream()->submit(cmdlist);
160
161 device_->wait_idle();
162
163 // In this case `readback_data` syncs
164 TI_ASSERT(device_->readback_data(
165 readback_dev_ptrs.data(), readback_host_ptrs.data(),
166 readback_sizes.data(), int(readback_sizes.size()),
167 {command_complete_sema}) == RhiResult::success);
168 } else {
169 device_->get_compute_stream()->submit_synced(cmdlist);
170 }
171
172 if (!ctx_attribs_->has_rets()) {
173 return true;
174 }
175 } else {
176 return false;
177 }
178
179 void *device_base{nullptr};
180 TI_ASSERT(device_->map(*device_ret_buffer_, &device_base) ==
181 RhiResult::success);
182
183#define TO_HOST(short_type, type, offset) \
184 if (dt->is_primitive(PrimitiveTypeID::short_type)) { \
185 const type d = *(reinterpret_cast<type *>(device_ptr) + offset); \
186 host_result_buffer_[offset] = \
187 taichi_union_cast_with_different_sizes<uint64>(d); \
188 continue; \
189 }
190
191 for (int i = 0; i < ctx_attribs_->rets().size(); ++i) {
192 // Note that we are copying the i-th return value on Metal to the i-th
193 // *arg* on the host context.
194 const auto &ret = ctx_attribs_->rets()[i];
195 void *device_ptr = (uint8_t *)device_base + ret.offset_in_mem;
196 const auto dt = PrimitiveType::get(ret.dtype);
197 const auto num = ret.stride / data_type_size(dt);
198 for (int j = 0; j < num; ++j) {
199 // (penguinliong) Again, it's the module loader's responsibility to
200 // check the data type availability.
201 TO_HOST(i8, int8, j)
202 TO_HOST(u8, uint8, j)
203 TO_HOST(i16, int16, j)
204 TO_HOST(u16, uint16, j)
205 TO_HOST(i32, int32, j)
206 TO_HOST(u32, uint32, j)
207 TO_HOST(f32, float32, j)
208 TO_HOST(i64, int64, j)
209 TO_HOST(u64, uint64, j)
210 TO_HOST(f64, float64, j)
211 if (dt->is_primitive(PrimitiveTypeID::f16)) {
212 const float d = fp16_ieee_to_fp32_value(
213 *reinterpret_cast<uint16 *>(device_ptr) + j);
214 host_result_buffer_[j] =
215 taichi_union_cast_with_different_sizes<uint64>(d);
216 continue;
217 }
218 TI_ERROR("Device does not support return value type={}",
219 data_type_name(PrimitiveType::get(ret.dtype)));
220 }
221 }
222#undef TO_HOST
223
224 device_->unmap(*device_ret_buffer_);
225
226 return true;
227 }
228
229 static std::unique_ptr<HostDeviceContextBlitter> maybe_make(
230 const KernelContextAttributes *ctx_attribs,
231 RuntimeContext *host_ctx,
232 Device *device,
233 uint64_t *host_result_buffer,
234 DeviceAllocation *device_args_buffer,
235 DeviceAllocation *device_ret_buffer) {
236 if (ctx_attribs->empty()) {
237 return nullptr;
238 }
239 return std::make_unique<HostDeviceContextBlitter>(
240 ctx_attribs, host_ctx, device, host_result_buffer, device_args_buffer,
241 device_ret_buffer);
242 }
243
244 private:
245 const KernelContextAttributes *const ctx_attribs_;
246 RuntimeContext *const host_ctx_;
247 uint64_t *const host_result_buffer_;
248 DeviceAllocation *const device_args_buffer_;
249 DeviceAllocation *const device_ret_buffer_;
250 Device *const device_;
251};
252
253} // namespace
254
255constexpr size_t kGtmpBufferSize = 1024 * 1024;
256constexpr size_t kListGenBufferSize = 32 << 20;
257
258// Info for launching a compiled Taichi kernel, which consists of a series of
259// Unified Device API pipelines.
260
261CompiledTaichiKernel::CompiledTaichiKernel(const Params &ti_params)
262 : ti_kernel_attribs_(*ti_params.ti_kernel_attribs),
263 device_(ti_params.device) {
264 input_buffers_[BufferType::GlobalTmps] = ti_params.global_tmps_buffer;
265 input_buffers_[BufferType::ListGen] = ti_params.listgen_buffer;
266
267 // Compiled_structs can be empty if loading a kernel from an AOT module as
268 // the SNode are not re-compiled/structured. In this case, we assume a
269 // single root buffer size configured from the AOT module.
270 for (int root = 0; root < ti_params.num_snode_trees; ++root) {
271 BufferInfo buffer = {BufferType::Root, root};
272 input_buffers_[buffer] = ti_params.root_buffers[root];
273 }
274
275 const auto arg_sz = ti_kernel_attribs_.ctx_attribs.args_bytes();
276 const auto ret_sz = ti_kernel_attribs_.ctx_attribs.rets_bytes();
277
278 args_buffer_size_ = arg_sz;
279 ret_buffer_size_ = ret_sz;
280
281 if (arg_sz) {
282 args_buffer_size_ += ti_kernel_attribs_.ctx_attribs.extra_args_bytes();
283 }
284
285 const auto &task_attribs = ti_kernel_attribs_.tasks_attribs;
286 const auto &spirv_bins = ti_params.spirv_bins;
287 TI_ASSERT(task_attribs.size() == spirv_bins.size());
288
289 for (int i = 0; i < task_attribs.size(); ++i) {
290 PipelineSourceDesc source_desc{PipelineSourceType::spirv_binary,
291 (void *)spirv_bins[i].data(),
292 spirv_bins[i].size() * sizeof(uint32_t)};
293 auto [vp, res] = ti_params.device->create_pipeline_unique(
294 source_desc, task_attribs[i].name, ti_params.backend_cache);
295 pipelines_.push_back(std::move(vp));
296 }
297}
298
299const TaichiKernelAttributes &CompiledTaichiKernel::ti_kernel_attribs() const {
300 return ti_kernel_attribs_;
301}
302
303size_t CompiledTaichiKernel::num_pipelines() const {
304 return pipelines_.size();
305}
306
307size_t CompiledTaichiKernel::get_args_buffer_size() const {
308 return args_buffer_size_;
309}
310
311size_t CompiledTaichiKernel::get_ret_buffer_size() const {
312 return ret_buffer_size_;
313}
314
315Pipeline *CompiledTaichiKernel::get_pipeline(int i) {
316 return pipelines_[i].get();
317}
318
319GfxRuntime::GfxRuntime(const Params &params)
320 : device_(params.device),
321 host_result_buffer_(params.host_result_buffer),
322 profiler_(params.profiler) {
323 TI_ASSERT(host_result_buffer_ != nullptr);
324 current_cmdlist_pending_since_ = high_res_clock::now();
325 init_nonroot_buffers();
326
327 // Read pipeline cache from disk if available.
328 std::filesystem::path cache_path(get_repo_dir());
329 cache_path /= "rhi_cache.bin";
330 std::vector<char> cache_data;
331 if (std::filesystem::exists(cache_path)) {
332 TI_TRACE("Loading pipeline cache from {}", cache_path.generic_string());
333 std::ifstream cache_file(cache_path, std::ios::binary);
334 cache_data.assign(std::istreambuf_iterator<char>(cache_file),
335 std::istreambuf_iterator<char>());
336 } else {
337 TI_TRACE("Pipeline cache not found at {}", cache_path.generic_string());
338 }
339 auto [cache, res] = device_->create_pipeline_cache_unique(cache_data.size(),
340 cache_data.data());
341 if (res == RhiResult::success) {
342 backend_cache_ = std::move(cache);
343 }
344}
345
346GfxRuntime::~GfxRuntime() {
347 synchronize();
348
349 // Write pipeline cache back to disk.
350 if (backend_cache_) {
351 uint8_t *cache_data = (uint8_t *)backend_cache_->data();
352 size_t cache_size = backend_cache_->size();
353 if (cache_data) {
354 std::filesystem::path cache_path =
355 std::filesystem::path(get_repo_dir()) / "rhi_cache.bin";
356 std::ofstream cache_file(cache_path, std::ios::binary | std::ios::trunc);
357 std::ostreambuf_iterator<char> output_iterator(cache_file);
358 std::copy(cache_data, cache_data + cache_size, output_iterator);
359 }
360 backend_cache_.reset();
361 }
362
363 {
364 decltype(ti_kernels_) tmp;
365 tmp.swap(ti_kernels_);
366 }
367 global_tmps_buffer_.reset();
368 listgen_buffer_.reset();
369}
370
371GfxRuntime::KernelHandle GfxRuntime::register_taichi_kernel(
372 GfxRuntime::RegisterParams reg_params) {
373 CompiledTaichiKernel::Params params;
374 params.ti_kernel_attribs = &(reg_params.kernel_attribs);
375 params.num_snode_trees = reg_params.num_snode_trees;
376 params.device = device_;
377 params.root_buffers = {};
378 for (int root = 0; root < root_buffers_.size(); ++root) {
379 params.root_buffers.push_back(root_buffers_[root].get());
380 }
381 params.global_tmps_buffer = global_tmps_buffer_.get();
382 params.listgen_buffer = listgen_buffer_.get();
383 params.backend_cache = backend_cache_.get();
384
385 for (int i = 0; i < reg_params.task_spirv_source_codes.size(); ++i) {
386 const auto &spirv_src = reg_params.task_spirv_source_codes[i];
387
388 // If we can reach here, we have succeeded. Otherwise
389 // std::optional::value() would have killed us.
390 params.spirv_bins.push_back(std::move(spirv_src));
391 }
392 KernelHandle res;
393 res.id_ = ti_kernels_.size();
394 ti_kernels_.push_back(std::make_unique<CompiledTaichiKernel>(params));
395 return res;
396}
397
398void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) {
399 auto *ti_kernel = ti_kernels_[handle.id_].get();
400
401#if defined(__APPLE__)
402 if (profiler_) {
403 const int apple_max_query_pool_count = 32;
404 int task_count = ti_kernel->ti_kernel_attribs().tasks_attribs.size();
405 if (task_count > apple_max_query_pool_count) {
406 TI_WARN(
407 "Cannot concurrently profile more than 32 tasks in a single Taichi "
408 "kernel. Profiling aborted.");
409 profiler_ = nullptr;
410 } else if (device_->profiler_get_sampler_count() + task_count >
411 apple_max_query_pool_count) {
412 flush();
413 device_->profiler_sync();
414 }
415 }
416#endif
417
418 std::unique_ptr<DeviceAllocationGuard> args_buffer{nullptr},
419 ret_buffer{nullptr};
420
421 if (ti_kernel->get_args_buffer_size()) {
422 args_buffer = device_->allocate_memory_unique(
423 {ti_kernel->get_args_buffer_size(),
424 /*host_write=*/true, /*host_read=*/false,
425 /*export_sharing=*/false, AllocUsage::Uniform});
426 }
427
428 if (ti_kernel->get_ret_buffer_size()) {
429 ret_buffer = device_->allocate_memory_unique(
430 {ti_kernel->get_ret_buffer_size(),
431 /*host_write=*/false, /*host_read=*/true,
432 /*export_sharing=*/false, AllocUsage::Storage});
433 }
434
435 // Create context blitter
436 auto ctx_blitter = HostDeviceContextBlitter::maybe_make(
437 &ti_kernel->ti_kernel_attribs().ctx_attribs, host_ctx, device_,
438 host_result_buffer_, args_buffer.get(), ret_buffer.get());
439
440 // `any_arrays` contain both external arrays and NDArrays
441 std::unordered_map<int, DeviceAllocation> any_arrays;
442 // `ext_array_size` only holds the size of external arrays (host arrays)
443 // As buffer size information is only needed when it needs to be allocated
444 // and transferred by the host
445 std::unordered_map<int, size_t> ext_array_size;
446 std::unordered_map<int, DeviceAllocation> textures;
447
448 // Prepare context buffers & arrays
449 if (ctx_blitter) {
450 TI_ASSERT(ti_kernel->get_args_buffer_size() ||
451 ti_kernel->get_ret_buffer_size());
452
453 int i = 0;
454 const auto &args = ti_kernel->ti_kernel_attribs().ctx_attribs.args();
455 for (auto &arg : args) {
456 if (arg.is_array) {
457 if (host_ctx->device_allocation_type[i] !=
458 RuntimeContext::DevAllocType::kNone) {
459 DeviceAllocation devalloc = kDeviceNullAllocation;
460
461 // NDArray / Texture
462 if (host_ctx->args[i]) {
463 devalloc = *(DeviceAllocation *)(host_ctx->args[i]);
464 }
465
466 if (host_ctx->device_allocation_type[i] ==
467 RuntimeContext::DevAllocType::kNdarray) {
468 any_arrays[i] = devalloc;
469 ndarrays_in_use_.insert(devalloc.alloc_id);
470 } else if (host_ctx->device_allocation_type[i] ==
471 RuntimeContext::DevAllocType::kTexture) {
472 textures[i] = devalloc;
473 } else if (host_ctx->device_allocation_type[i] ==
474 RuntimeContext::DevAllocType::kRWTexture) {
475 textures[i] = devalloc;
476 } else {
477 TI_NOT_IMPLEMENTED;
478 }
479 } else {
480 ext_array_size[i] = host_ctx->array_runtime_sizes[i];
481 uint32_t access = uint32_t(
482 ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access.at(i));
483
484 // Alloc ext arr
485 size_t alloc_size = std::max(size_t(32), ext_array_size.at(i));
486 bool host_write = access & uint32_t(irpass::ExternalPtrAccess::READ);
487 auto allocated = device_->allocate_memory_unique(
488 {alloc_size, host_write, false, /*export_sharing=*/false,
489 AllocUsage::Storage});
490 any_arrays[i] = *allocated.get();
491 ctx_buffers_.push_back(std::move(allocated));
492 }
493 }
494 i++;
495 }
496
497 ctx_blitter->host_to_device(any_arrays, ext_array_size);
498 }
499
500 ensure_current_cmdlist();
501
502 // Record commands
503 const auto &task_attribs = ti_kernel->ti_kernel_attribs().tasks_attribs;
504
505 for (int i = 0; i < task_attribs.size(); ++i) {
506 const auto &attribs = task_attribs[i];
507 auto vp = ti_kernel->get_pipeline(i);
508 const int group_x = (attribs.advisory_total_num_threads +
509 attribs.advisory_num_threads_per_group - 1) /
510 attribs.advisory_num_threads_per_group;
511 std::unique_ptr<ShaderResourceSet> bindings =
512 device_->create_resource_set_unique();
513 for (auto &bind : attribs.buffer_binds) {
514 // We might have to bind a invalid buffer (this is fine as long as
515 // shader don't do anything with it)
516 if (bind.buffer.type == BufferType::ExtArr) {
517 bindings->rw_buffer(bind.binding, any_arrays.at(bind.buffer.root_id));
518 } else if (bind.buffer.type == BufferType::Args) {
519 bindings->buffer(bind.binding,
520 args_buffer ? *args_buffer : kDeviceNullAllocation);
521 } else if (bind.buffer.type == BufferType::Rets) {
522 bindings->rw_buffer(bind.binding,
523 ret_buffer ? *ret_buffer : kDeviceNullAllocation);
524 } else {
525 DeviceAllocation *alloc = ti_kernel->get_buffer_bind(bind.buffer);
526 bindings->rw_buffer(bind.binding,
527 alloc ? *alloc : kDeviceNullAllocation);
528 }
529 }
530
531 for (auto &bind : attribs.texture_binds) {
532 DeviceAllocation texture = textures.at(bind.arg_id);
533 if (bind.is_storage) {
534 transition_image(texture, ImageLayout::shader_read_write);
535 bindings->rw_image(bind.binding, texture, 0);
536 } else {
537 transition_image(texture, ImageLayout::shader_read);
538 bindings->image(bind.binding, texture, {});
539 }
540 }
541
542 if (attribs.task_type == OffloadedTaskType::listgen) {
543 for (auto &bind : attribs.buffer_binds) {
544 if (bind.buffer.type == BufferType::ListGen) {
545 // FIXME: properlly support multiple list
546 current_cmdlist_->buffer_fill(
547 ti_kernel->get_buffer_bind(bind.buffer)->get_ptr(0),
548 kBufferSizeEntireSize,
549 /*data=*/0);
550 current_cmdlist_->buffer_barrier(
551 *ti_kernel->get_buffer_bind(bind.buffer));
552 }
553 }
554 }
555
556 current_cmdlist_->bind_pipeline(vp);
557 RhiResult status = current_cmdlist_->bind_shader_resources(bindings.get());
558 TI_ERROR_IF(status != RhiResult::success,
559 "Resource binding error : RhiResult({})", status);
560
561 if (profiler_) {
562 current_cmdlist_->begin_profiler_scope(attribs.name);
563 }
564
565 status = current_cmdlist_->dispatch(group_x);
566
567 if (profiler_) {
568 current_cmdlist_->end_profiler_scope();
569 }
570
571 TI_ERROR_IF(status != RhiResult::success, "Dispatch error : RhiResult({})",
572 status);
573 current_cmdlist_->memory_barrier();
574 }
575
576 // Keep context buffers used in this dispatch
577 if (ti_kernel->get_args_buffer_size()) {
578 ctx_buffers_.push_back(std::move(args_buffer));
579 }
580 if (ti_kernel->get_ret_buffer_size()) {
581 ctx_buffers_.push_back(std::move(ret_buffer));
582 }
583
584 // If we need to host sync, sync and remove in-flight references
585 if (ctx_blitter) {
586 if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_arrays,
587 ext_array_size)) {
588 current_cmdlist_ = nullptr;
589 ctx_buffers_.clear();
590 }
591 }
592
593 submit_current_cmdlist_if_timeout();
594}
595
596void GfxRuntime::buffer_copy(DevicePtr dst, DevicePtr src, size_t size) {
597 ensure_current_cmdlist();
598 current_cmdlist_->buffer_copy(dst, src, size);
599}
600
601void GfxRuntime::copy_image(DeviceAllocation dst,
602 DeviceAllocation src,
603 const ImageCopyParams &params) {
604 ensure_current_cmdlist();
605 transition_image(dst, ImageLayout::transfer_dst);
606 transition_image(src, ImageLayout::transfer_src);
607 current_cmdlist_->copy_image(dst, src, ImageLayout::transfer_dst,
608 ImageLayout::transfer_src, params);
609}
610
611DeviceAllocation GfxRuntime::create_image(const ImageParams &params) {
612 GraphicsDevice *gfx_device = dynamic_cast<GraphicsDevice *>(device_);
613 TI_ERROR_IF(gfx_device == nullptr,
614 "Image can only be created on a graphics device");
615 DeviceAllocation image = gfx_device->create_image(params);
616 track_image(image, ImageLayout::undefined);
617 last_image_layouts_.at(image.alloc_id) = params.initial_layout;
618 return image;
619}
620
621void GfxRuntime::track_image(DeviceAllocation image, ImageLayout layout) {
622 last_image_layouts_[image.alloc_id] = layout;
623}
624void GfxRuntime::untrack_image(DeviceAllocation image) {
625 last_image_layouts_.erase(image.alloc_id);
626}
627void GfxRuntime::transition_image(DeviceAllocation image, ImageLayout layout) {
628 ImageLayout &last_layout = last_image_layouts_.at(image.alloc_id);
629 ensure_current_cmdlist();
630 current_cmdlist_->image_transition(image, last_layout, layout);
631 last_layout = layout;
632}
633
634void GfxRuntime::synchronize() {
635 flush();
636 device_->wait_idle();
637 // Profiler support
638 if (profiler_) {
639 device_->profiler_sync();
640 auto sampled_records = device_->profiler_flush_sampled_time();
641 for (auto &record : sampled_records) {
642 profiler_->insert_record(record.first, record.second);
643 }
644 }
645 ctx_buffers_.clear();
646 ndarrays_in_use_.clear();
647 fflush(stdout);
648}
649
650StreamSemaphore GfxRuntime::flush() {
651 StreamSemaphore sema;
652 if (current_cmdlist_) {
653 sema = device_->get_compute_stream()->submit(current_cmdlist_.get());
654 current_cmdlist_ = nullptr;
655 ctx_buffers_.clear();
656 } else {
657 auto [cmdlist, res] =
658 device_->get_compute_stream()->new_command_list_unique();
659 TI_ASSERT(res == RhiResult::success);
660 cmdlist->memory_barrier();
661 sema = device_->get_compute_stream()->submit(cmdlist.get());
662 }
663 return sema;
664}
665
666Device *GfxRuntime::get_ti_device() const {
667 return device_;
668}
669
670void GfxRuntime::ensure_current_cmdlist() {
671 // Create new command list if current one is nullptr
672 if (!current_cmdlist_) {
673 current_cmdlist_pending_since_ = high_res_clock::now();
674 auto [cmdlist, res] =
675 device_->get_compute_stream()->new_command_list_unique();
676 TI_ASSERT(res == RhiResult::success);
677 current_cmdlist_ = std::move(cmdlist);
678 }
679}
680
681void GfxRuntime::submit_current_cmdlist_if_timeout() {
682 // If we have accumulated some work but does not require sync
683 // and if the accumulated cmdlist has been pending for some time
684 // launch the cmdlist to start processing.
685 if (current_cmdlist_) {
686 constexpr uint64_t max_pending_time = 2000; // 2000us = 2ms
687 auto duration = high_res_clock::now() - current_cmdlist_pending_since_;
688 if (std::chrono::duration_cast<std::chrono::microseconds>(duration)
689 .count() > max_pending_time) {
690 flush();
691 }
692 }
693}
694
695void GfxRuntime::init_nonroot_buffers() {
696 global_tmps_buffer_ = device_->allocate_memory_unique(
697 {kGtmpBufferSize,
698 /*host_write=*/false, /*host_read=*/false,
699 /*export_sharing=*/false, AllocUsage::Storage});
700
701 listgen_buffer_ = device_->allocate_memory_unique(
702 {kListGenBufferSize,
703 /*host_write=*/false, /*host_read=*/false,
704 /*export_sharing=*/false, AllocUsage::Storage});
705
706 // Need to zero fill the buffers, otherwise there could be NaN.
707 Stream *stream = device_->get_compute_stream();
708 auto [cmdlist, res] =
709 device_->get_compute_stream()->new_command_list_unique();
710 TI_ASSERT(res == RhiResult::success);
711
712 cmdlist->buffer_fill(global_tmps_buffer_->get_ptr(0), kBufferSizeEntireSize,
713 /*data=*/0);
714 cmdlist->buffer_fill(listgen_buffer_->get_ptr(0), kBufferSizeEntireSize,
715 /*data=*/0);
716 stream->submit_synced(cmdlist.get());
717}
718
719void GfxRuntime::add_root_buffer(size_t root_buffer_size) {
720 if (root_buffer_size == 0) {
721 root_buffer_size = 4; // there might be empty roots
722 }
723 std::unique_ptr<DeviceAllocationGuard> new_buffer =
724 device_->allocate_memory_unique(
725 {root_buffer_size,
726 /*host_write=*/false, /*host_read=*/false,
727 /*export_sharing=*/false, AllocUsage::Storage});
728 Stream *stream = device_->get_compute_stream();
729 auto [cmdlist, res] =
730 device_->get_compute_stream()->new_command_list_unique();
731 TI_ASSERT(res == RhiResult::success);
732 cmdlist->buffer_fill(new_buffer->get_ptr(0), kBufferSizeEntireSize,
733 /*data=*/0);
734 stream->submit_synced(cmdlist.get());
735 root_buffers_.push_back(std::move(new_buffer));
736 // cache the root buffer size
737 root_buffers_size_map_[root_buffers_.back().get()] = root_buffer_size;
738}
739
740DeviceAllocation *GfxRuntime::get_root_buffer(int id) const {
741 if (id >= root_buffers_.size()) {
742 TI_ERROR("root buffer id {} not found", id);
743 }
744 return root_buffers_[id].get();
745}
746
747size_t GfxRuntime::get_root_buffer_size(int id) const {
748 auto it = root_buffers_size_map_.find(root_buffers_[id].get());
749 if (id >= root_buffers_.size() || it == root_buffers_size_map_.end()) {
750 TI_ERROR("root buffer id {} not found", id);
751 }
752 return it->second;
753}
754
755void GfxRuntime::enqueue_compute_op_lambda(
756 std::function<void(Device *device, CommandList *cmdlist)> op,
757 const std::vector<ComputeOpImageRef> &image_refs) {
758 for (const auto &ref : image_refs) {
759 TI_ASSERT(last_image_layouts_.find(ref.image.alloc_id) !=
760 last_image_layouts_.end());
761 transition_image(ref.image, ref.initial_layout);
762 }
763
764 ensure_current_cmdlist();
765 op(device_, current_cmdlist_.get());
766
767 for (const auto &ref : image_refs) {
768 last_image_layouts_[ref.image.alloc_id] = ref.final_layout;
769 }
770}
771
772GfxRuntime::RegisterParams run_codegen(
773 Kernel *kernel,
774 Arch arch,
775 const DeviceCapabilityConfig &caps,
776 const std::vector<CompiledSNodeStructs> &compiled_structs,
777 const CompileConfig &compile_config) {
778 const auto id = Program::get_kernel_id();
779 const auto taichi_kernel_name(fmt::format("{}_k{:04d}_vk", kernel->name, id));
780 TI_TRACE("VK codegen for Taichi kernel={}", taichi_kernel_name);
781 spirv::KernelCodegen::Params params;
782 params.ti_kernel_name = taichi_kernel_name;
783 params.kernel = kernel;
784 params.compiled_structs = compiled_structs;
785 params.arch = arch;
786 params.caps = caps;
787 params.enable_spv_opt = compile_config.external_optimization_level > 0;
788 spirv::KernelCodegen codegen(params);
789 GfxRuntime::RegisterParams res;
790 codegen.run(res.kernel_attribs, res.task_spirv_source_codes);
791 res.num_snode_trees = compiled_structs.size();
792 return res;
793}
794
795} // namespace gfx
796} // namespace taichi::lang
797