1 | #include "taichi/runtime/gfx/runtime.h" |
2 | #include "taichi/program/program.h" |
3 | #include "taichi/common/filesystem.hpp" |
4 | |
5 | #include <chrono> |
6 | #include <array> |
7 | #include <iostream> |
8 | #include <memory> |
9 | #include <optional> |
10 | #include <unordered_map> |
11 | #include <unordered_set> |
12 | #include <vector> |
13 | |
14 | #include "fp16.h" |
15 | |
16 | #define TI_RUNTIME_HOST |
17 | #include "taichi/program/context.h" |
18 | #undef TI_RUNTIME_HOST |
19 | |
20 | namespace taichi::lang { |
21 | namespace gfx { |
22 | |
23 | namespace { |
24 | |
25 | class HostDeviceContextBlitter { |
26 | public: |
27 | HostDeviceContextBlitter(const KernelContextAttributes *ctx_attribs, |
28 | RuntimeContext *host_ctx, |
29 | Device *device, |
30 | uint64_t *host_result_buffer, |
31 | DeviceAllocation *device_args_buffer, |
32 | DeviceAllocation *device_ret_buffer) |
33 | : ctx_attribs_(ctx_attribs), |
34 | host_ctx_(host_ctx), |
35 | host_result_buffer_(host_result_buffer), |
36 | device_args_buffer_(device_args_buffer), |
37 | device_ret_buffer_(device_ret_buffer), |
38 | device_(device) { |
39 | } |
40 | |
41 | void host_to_device( |
42 | const std::unordered_map<int, DeviceAllocation> &ext_arrays, |
43 | const std::unordered_map<int, size_t> &ext_arr_size) { |
44 | if (!ctx_attribs_->has_args()) { |
45 | return; |
46 | } |
47 | |
48 | void *device_base{nullptr}; |
49 | TI_ASSERT(device_->map(*device_args_buffer_, &device_base) == |
50 | RhiResult::success); |
51 | |
52 | #define TO_DEVICE(short_type, type) \ |
53 | if (arg.dtype == PrimitiveTypeID::short_type) { \ |
54 | auto d = host_ctx_->get_arg<type>(i); \ |
55 | reinterpret_cast<type *>(device_ptr)[0] = d; \ |
56 | break; \ |
57 | } |
58 | |
59 | for (int i = 0; i < ctx_attribs_->args().size(); ++i) { |
60 | const auto &arg = ctx_attribs_->args()[i]; |
61 | void *device_ptr = (uint8_t *)device_base + arg.offset_in_mem; |
62 | do { |
63 | if (arg.is_array) { |
64 | if (host_ctx_->device_allocation_type[i] == |
65 | RuntimeContext::DevAllocType::kNone && |
66 | ext_arr_size.at(i)) { |
67 | // Only need to blit ext arrs (host array) |
68 | uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i)); |
69 | if (access & uint32_t(irpass::ExternalPtrAccess::READ)) { |
70 | DeviceAllocation buffer = ext_arrays.at(i); |
71 | void *device_arr_ptr{nullptr}; |
72 | TI_ASSERT(device_->map(buffer, &device_arr_ptr) == |
73 | RhiResult::success); |
74 | const void *host_ptr = host_ctx_->get_arg<void *>(i); |
75 | std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i)); |
76 | device_->unmap(buffer); |
77 | } |
78 | } |
79 | // Substitute in the device address. |
80 | |
81 | // (penguinliong) We don't check the availability of physical pointer |
82 | // here. It should be done before you need this class. |
83 | if ((host_ctx_->device_allocation_type[i] == |
84 | RuntimeContext::DevAllocType::kNone || |
85 | host_ctx_->device_allocation_type[i] == |
86 | RuntimeContext::DevAllocType::kNdarray)) { |
87 | uint64_t addr = |
88 | device_->get_memory_physical_pointer(ext_arrays.at(i)); |
89 | reinterpret_cast<uint64 *>(device_ptr)[0] = addr; |
90 | } |
91 | // We should not process the rest |
92 | break; |
93 | } |
94 | // (penguinliong) Same. The availability of short/long int types depends |
95 | // on the kernels and compute graphs and the check should already be |
96 | // done during module loads. |
97 | TO_DEVICE(i8, int8) |
98 | TO_DEVICE(u8, uint8) |
99 | TO_DEVICE(i16, int16) |
100 | TO_DEVICE(u16, uint16) |
101 | TO_DEVICE(i32, int32) |
102 | TO_DEVICE(u32, uint32) |
103 | TO_DEVICE(f32, float32) |
104 | TO_DEVICE(i64, int64) |
105 | TO_DEVICE(u64, uint64) |
106 | TO_DEVICE(f64, float64) |
107 | if (arg.dtype == PrimitiveTypeID::f16) { |
108 | auto d = fp16_ieee_from_fp32_value(host_ctx_->get_arg<float>(i)); |
109 | reinterpret_cast<uint16 *>(device_ptr)[0] = d; |
110 | break; |
111 | } |
112 | TI_ERROR("Device does not support arg type={}" , |
113 | PrimitiveType::get(arg.dtype).to_string()); |
114 | } while (false); |
115 | } |
116 | |
117 | void *device_ptr = |
118 | (uint8_t *)device_base + ctx_attribs_->extra_args_mem_offset(); |
119 | std::memcpy(device_ptr, host_ctx_->extra_args, |
120 | ctx_attribs_->extra_args_bytes()); |
121 | |
122 | device_->unmap(*device_args_buffer_); |
123 | #undef TO_DEVICE |
124 | } |
125 | |
126 | bool device_to_host( |
127 | CommandList *cmdlist, |
128 | const std::unordered_map<int, DeviceAllocation> &ext_arrays, |
129 | const std::unordered_map<int, size_t> &ext_arr_size) { |
130 | if (ctx_attribs_->empty()) { |
131 | return false; |
132 | } |
133 | |
134 | bool require_sync = ctx_attribs_->rets().size() > 0; |
135 | std::vector<DevicePtr> readback_dev_ptrs; |
136 | std::vector<void *> readback_host_ptrs; |
137 | std::vector<size_t> readback_sizes; |
138 | |
139 | for (int i = 0; i < ctx_attribs_->args().size(); ++i) { |
140 | const auto &arg = ctx_attribs_->args()[i]; |
141 | if (arg.is_array && |
142 | host_ctx_->device_allocation_type[i] == |
143 | RuntimeContext::DevAllocType::kNone && |
144 | ext_arr_size.at(i)) { |
145 | uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i)); |
146 | if (access & uint32_t(irpass::ExternalPtrAccess::WRITE)) { |
147 | // Only need to blit ext arrs (host array) |
148 | readback_dev_ptrs.push_back(ext_arrays.at(i).get_ptr(0)); |
149 | readback_host_ptrs.push_back(host_ctx_->get_arg<void *>(i)); |
150 | readback_sizes.push_back(ext_arr_size.at(i)); |
151 | require_sync = true; |
152 | } |
153 | } |
154 | } |
155 | |
156 | if (require_sync) { |
157 | if (readback_sizes.size()) { |
158 | StreamSemaphore command_complete_sema = |
159 | device_->get_compute_stream()->submit(cmdlist); |
160 | |
161 | device_->wait_idle(); |
162 | |
163 | // In this case `readback_data` syncs |
164 | TI_ASSERT(device_->readback_data( |
165 | readback_dev_ptrs.data(), readback_host_ptrs.data(), |
166 | readback_sizes.data(), int(readback_sizes.size()), |
167 | {command_complete_sema}) == RhiResult::success); |
168 | } else { |
169 | device_->get_compute_stream()->submit_synced(cmdlist); |
170 | } |
171 | |
172 | if (!ctx_attribs_->has_rets()) { |
173 | return true; |
174 | } |
175 | } else { |
176 | return false; |
177 | } |
178 | |
179 | void *device_base{nullptr}; |
180 | TI_ASSERT(device_->map(*device_ret_buffer_, &device_base) == |
181 | RhiResult::success); |
182 | |
183 | #define TO_HOST(short_type, type, offset) \ |
184 | if (dt->is_primitive(PrimitiveTypeID::short_type)) { \ |
185 | const type d = *(reinterpret_cast<type *>(device_ptr) + offset); \ |
186 | host_result_buffer_[offset] = \ |
187 | taichi_union_cast_with_different_sizes<uint64>(d); \ |
188 | continue; \ |
189 | } |
190 | |
191 | for (int i = 0; i < ctx_attribs_->rets().size(); ++i) { |
192 | // Note that we are copying the i-th return value on Metal to the i-th |
193 | // *arg* on the host context. |
194 | const auto &ret = ctx_attribs_->rets()[i]; |
195 | void *device_ptr = (uint8_t *)device_base + ret.offset_in_mem; |
196 | const auto dt = PrimitiveType::get(ret.dtype); |
197 | const auto num = ret.stride / data_type_size(dt); |
198 | for (int j = 0; j < num; ++j) { |
199 | // (penguinliong) Again, it's the module loader's responsibility to |
200 | // check the data type availability. |
201 | TO_HOST(i8, int8, j) |
202 | TO_HOST(u8, uint8, j) |
203 | TO_HOST(i16, int16, j) |
204 | TO_HOST(u16, uint16, j) |
205 | TO_HOST(i32, int32, j) |
206 | TO_HOST(u32, uint32, j) |
207 | TO_HOST(f32, float32, j) |
208 | TO_HOST(i64, int64, j) |
209 | TO_HOST(u64, uint64, j) |
210 | TO_HOST(f64, float64, j) |
211 | if (dt->is_primitive(PrimitiveTypeID::f16)) { |
212 | const float d = fp16_ieee_to_fp32_value( |
213 | *reinterpret_cast<uint16 *>(device_ptr) + j); |
214 | host_result_buffer_[j] = |
215 | taichi_union_cast_with_different_sizes<uint64>(d); |
216 | continue; |
217 | } |
218 | TI_ERROR("Device does not support return value type={}" , |
219 | data_type_name(PrimitiveType::get(ret.dtype))); |
220 | } |
221 | } |
222 | #undef TO_HOST |
223 | |
224 | device_->unmap(*device_ret_buffer_); |
225 | |
226 | return true; |
227 | } |
228 | |
229 | static std::unique_ptr<HostDeviceContextBlitter> maybe_make( |
230 | const KernelContextAttributes *ctx_attribs, |
231 | RuntimeContext *host_ctx, |
232 | Device *device, |
233 | uint64_t *host_result_buffer, |
234 | DeviceAllocation *device_args_buffer, |
235 | DeviceAllocation *device_ret_buffer) { |
236 | if (ctx_attribs->empty()) { |
237 | return nullptr; |
238 | } |
239 | return std::make_unique<HostDeviceContextBlitter>( |
240 | ctx_attribs, host_ctx, device, host_result_buffer, device_args_buffer, |
241 | device_ret_buffer); |
242 | } |
243 | |
244 | private: |
245 | const KernelContextAttributes *const ctx_attribs_; |
246 | RuntimeContext *const host_ctx_; |
247 | uint64_t *const host_result_buffer_; |
248 | DeviceAllocation *const device_args_buffer_; |
249 | DeviceAllocation *const device_ret_buffer_; |
250 | Device *const device_; |
251 | }; |
252 | |
253 | } // namespace |
254 | |
255 | constexpr size_t kGtmpBufferSize = 1024 * 1024; |
256 | constexpr size_t kListGenBufferSize = 32 << 20; |
257 | |
258 | // Info for launching a compiled Taichi kernel, which consists of a series of |
259 | // Unified Device API pipelines. |
260 | |
261 | CompiledTaichiKernel::CompiledTaichiKernel(const Params &ti_params) |
262 | : ti_kernel_attribs_(*ti_params.ti_kernel_attribs), |
263 | device_(ti_params.device) { |
264 | input_buffers_[BufferType::GlobalTmps] = ti_params.global_tmps_buffer; |
265 | input_buffers_[BufferType::ListGen] = ti_params.listgen_buffer; |
266 | |
267 | // Compiled_structs can be empty if loading a kernel from an AOT module as |
268 | // the SNode are not re-compiled/structured. In this case, we assume a |
269 | // single root buffer size configured from the AOT module. |
270 | for (int root = 0; root < ti_params.num_snode_trees; ++root) { |
271 | BufferInfo buffer = {BufferType::Root, root}; |
272 | input_buffers_[buffer] = ti_params.root_buffers[root]; |
273 | } |
274 | |
275 | const auto arg_sz = ti_kernel_attribs_.ctx_attribs.args_bytes(); |
276 | const auto ret_sz = ti_kernel_attribs_.ctx_attribs.rets_bytes(); |
277 | |
278 | args_buffer_size_ = arg_sz; |
279 | ret_buffer_size_ = ret_sz; |
280 | |
281 | if (arg_sz) { |
282 | args_buffer_size_ += ti_kernel_attribs_.ctx_attribs.extra_args_bytes(); |
283 | } |
284 | |
285 | const auto &task_attribs = ti_kernel_attribs_.tasks_attribs; |
286 | const auto &spirv_bins = ti_params.spirv_bins; |
287 | TI_ASSERT(task_attribs.size() == spirv_bins.size()); |
288 | |
289 | for (int i = 0; i < task_attribs.size(); ++i) { |
290 | PipelineSourceDesc source_desc{PipelineSourceType::spirv_binary, |
291 | (void *)spirv_bins[i].data(), |
292 | spirv_bins[i].size() * sizeof(uint32_t)}; |
293 | auto [vp, res] = ti_params.device->create_pipeline_unique( |
294 | source_desc, task_attribs[i].name, ti_params.backend_cache); |
295 | pipelines_.push_back(std::move(vp)); |
296 | } |
297 | } |
298 | |
299 | const TaichiKernelAttributes &CompiledTaichiKernel::ti_kernel_attribs() const { |
300 | return ti_kernel_attribs_; |
301 | } |
302 | |
303 | size_t CompiledTaichiKernel::num_pipelines() const { |
304 | return pipelines_.size(); |
305 | } |
306 | |
307 | size_t CompiledTaichiKernel::get_args_buffer_size() const { |
308 | return args_buffer_size_; |
309 | } |
310 | |
311 | size_t CompiledTaichiKernel::get_ret_buffer_size() const { |
312 | return ret_buffer_size_; |
313 | } |
314 | |
315 | Pipeline *CompiledTaichiKernel::get_pipeline(int i) { |
316 | return pipelines_[i].get(); |
317 | } |
318 | |
319 | GfxRuntime::GfxRuntime(const Params ¶ms) |
320 | : device_(params.device), |
321 | host_result_buffer_(params.host_result_buffer), |
322 | profiler_(params.profiler) { |
323 | TI_ASSERT(host_result_buffer_ != nullptr); |
324 | current_cmdlist_pending_since_ = high_res_clock::now(); |
325 | init_nonroot_buffers(); |
326 | |
327 | // Read pipeline cache from disk if available. |
328 | std::filesystem::path cache_path(get_repo_dir()); |
329 | cache_path /= "rhi_cache.bin" ; |
330 | std::vector<char> cache_data; |
331 | if (std::filesystem::exists(cache_path)) { |
332 | TI_TRACE("Loading pipeline cache from {}" , cache_path.generic_string()); |
333 | std::ifstream cache_file(cache_path, std::ios::binary); |
334 | cache_data.assign(std::istreambuf_iterator<char>(cache_file), |
335 | std::istreambuf_iterator<char>()); |
336 | } else { |
337 | TI_TRACE("Pipeline cache not found at {}" , cache_path.generic_string()); |
338 | } |
339 | auto [cache, res] = device_->create_pipeline_cache_unique(cache_data.size(), |
340 | cache_data.data()); |
341 | if (res == RhiResult::success) { |
342 | backend_cache_ = std::move(cache); |
343 | } |
344 | } |
345 | |
346 | GfxRuntime::~GfxRuntime() { |
347 | synchronize(); |
348 | |
349 | // Write pipeline cache back to disk. |
350 | if (backend_cache_) { |
351 | uint8_t *cache_data = (uint8_t *)backend_cache_->data(); |
352 | size_t cache_size = backend_cache_->size(); |
353 | if (cache_data) { |
354 | std::filesystem::path cache_path = |
355 | std::filesystem::path(get_repo_dir()) / "rhi_cache.bin" ; |
356 | std::ofstream cache_file(cache_path, std::ios::binary | std::ios::trunc); |
357 | std::ostreambuf_iterator<char> output_iterator(cache_file); |
358 | std::copy(cache_data, cache_data + cache_size, output_iterator); |
359 | } |
360 | backend_cache_.reset(); |
361 | } |
362 | |
363 | { |
364 | decltype(ti_kernels_) tmp; |
365 | tmp.swap(ti_kernels_); |
366 | } |
367 | global_tmps_buffer_.reset(); |
368 | listgen_buffer_.reset(); |
369 | } |
370 | |
371 | GfxRuntime::KernelHandle GfxRuntime::register_taichi_kernel( |
372 | GfxRuntime::RegisterParams reg_params) { |
373 | CompiledTaichiKernel::Params params; |
374 | params.ti_kernel_attribs = &(reg_params.kernel_attribs); |
375 | params.num_snode_trees = reg_params.num_snode_trees; |
376 | params.device = device_; |
377 | params.root_buffers = {}; |
378 | for (int root = 0; root < root_buffers_.size(); ++root) { |
379 | params.root_buffers.push_back(root_buffers_[root].get()); |
380 | } |
381 | params.global_tmps_buffer = global_tmps_buffer_.get(); |
382 | params.listgen_buffer = listgen_buffer_.get(); |
383 | params.backend_cache = backend_cache_.get(); |
384 | |
385 | for (int i = 0; i < reg_params.task_spirv_source_codes.size(); ++i) { |
386 | const auto &spirv_src = reg_params.task_spirv_source_codes[i]; |
387 | |
388 | // If we can reach here, we have succeeded. Otherwise |
389 | // std::optional::value() would have killed us. |
390 | params.spirv_bins.push_back(std::move(spirv_src)); |
391 | } |
392 | KernelHandle res; |
393 | res.id_ = ti_kernels_.size(); |
394 | ti_kernels_.push_back(std::make_unique<CompiledTaichiKernel>(params)); |
395 | return res; |
396 | } |
397 | |
398 | void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { |
399 | auto *ti_kernel = ti_kernels_[handle.id_].get(); |
400 | |
401 | #if defined(__APPLE__) |
402 | if (profiler_) { |
403 | const int apple_max_query_pool_count = 32; |
404 | int task_count = ti_kernel->ti_kernel_attribs().tasks_attribs.size(); |
405 | if (task_count > apple_max_query_pool_count) { |
406 | TI_WARN( |
407 | "Cannot concurrently profile more than 32 tasks in a single Taichi " |
408 | "kernel. Profiling aborted." ); |
409 | profiler_ = nullptr; |
410 | } else if (device_->profiler_get_sampler_count() + task_count > |
411 | apple_max_query_pool_count) { |
412 | flush(); |
413 | device_->profiler_sync(); |
414 | } |
415 | } |
416 | #endif |
417 | |
418 | std::unique_ptr<DeviceAllocationGuard> args_buffer{nullptr}, |
419 | ret_buffer{nullptr}; |
420 | |
421 | if (ti_kernel->get_args_buffer_size()) { |
422 | args_buffer = device_->allocate_memory_unique( |
423 | {ti_kernel->get_args_buffer_size(), |
424 | /*host_write=*/true, /*host_read=*/false, |
425 | /*export_sharing=*/false, AllocUsage::Uniform}); |
426 | } |
427 | |
428 | if (ti_kernel->get_ret_buffer_size()) { |
429 | ret_buffer = device_->allocate_memory_unique( |
430 | {ti_kernel->get_ret_buffer_size(), |
431 | /*host_write=*/false, /*host_read=*/true, |
432 | /*export_sharing=*/false, AllocUsage::Storage}); |
433 | } |
434 | |
435 | // Create context blitter |
436 | auto ctx_blitter = HostDeviceContextBlitter::maybe_make( |
437 | &ti_kernel->ti_kernel_attribs().ctx_attribs, host_ctx, device_, |
438 | host_result_buffer_, args_buffer.get(), ret_buffer.get()); |
439 | |
440 | // `any_arrays` contain both external arrays and NDArrays |
441 | std::unordered_map<int, DeviceAllocation> any_arrays; |
442 | // `ext_array_size` only holds the size of external arrays (host arrays) |
443 | // As buffer size information is only needed when it needs to be allocated |
444 | // and transferred by the host |
445 | std::unordered_map<int, size_t> ext_array_size; |
446 | std::unordered_map<int, DeviceAllocation> textures; |
447 | |
448 | // Prepare context buffers & arrays |
449 | if (ctx_blitter) { |
450 | TI_ASSERT(ti_kernel->get_args_buffer_size() || |
451 | ti_kernel->get_ret_buffer_size()); |
452 | |
453 | int i = 0; |
454 | const auto &args = ti_kernel->ti_kernel_attribs().ctx_attribs.args(); |
455 | for (auto &arg : args) { |
456 | if (arg.is_array) { |
457 | if (host_ctx->device_allocation_type[i] != |
458 | RuntimeContext::DevAllocType::kNone) { |
459 | DeviceAllocation devalloc = kDeviceNullAllocation; |
460 | |
461 | // NDArray / Texture |
462 | if (host_ctx->args[i]) { |
463 | devalloc = *(DeviceAllocation *)(host_ctx->args[i]); |
464 | } |
465 | |
466 | if (host_ctx->device_allocation_type[i] == |
467 | RuntimeContext::DevAllocType::kNdarray) { |
468 | any_arrays[i] = devalloc; |
469 | ndarrays_in_use_.insert(devalloc.alloc_id); |
470 | } else if (host_ctx->device_allocation_type[i] == |
471 | RuntimeContext::DevAllocType::kTexture) { |
472 | textures[i] = devalloc; |
473 | } else if (host_ctx->device_allocation_type[i] == |
474 | RuntimeContext::DevAllocType::kRWTexture) { |
475 | textures[i] = devalloc; |
476 | } else { |
477 | TI_NOT_IMPLEMENTED; |
478 | } |
479 | } else { |
480 | ext_array_size[i] = host_ctx->array_runtime_sizes[i]; |
481 | uint32_t access = uint32_t( |
482 | ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access.at(i)); |
483 | |
484 | // Alloc ext arr |
485 | size_t alloc_size = std::max(size_t(32), ext_array_size.at(i)); |
486 | bool host_write = access & uint32_t(irpass::ExternalPtrAccess::READ); |
487 | auto allocated = device_->allocate_memory_unique( |
488 | {alloc_size, host_write, false, /*export_sharing=*/false, |
489 | AllocUsage::Storage}); |
490 | any_arrays[i] = *allocated.get(); |
491 | ctx_buffers_.push_back(std::move(allocated)); |
492 | } |
493 | } |
494 | i++; |
495 | } |
496 | |
497 | ctx_blitter->host_to_device(any_arrays, ext_array_size); |
498 | } |
499 | |
500 | ensure_current_cmdlist(); |
501 | |
502 | // Record commands |
503 | const auto &task_attribs = ti_kernel->ti_kernel_attribs().tasks_attribs; |
504 | |
505 | for (int i = 0; i < task_attribs.size(); ++i) { |
506 | const auto &attribs = task_attribs[i]; |
507 | auto vp = ti_kernel->get_pipeline(i); |
508 | const int group_x = (attribs.advisory_total_num_threads + |
509 | attribs.advisory_num_threads_per_group - 1) / |
510 | attribs.advisory_num_threads_per_group; |
511 | std::unique_ptr<ShaderResourceSet> bindings = |
512 | device_->create_resource_set_unique(); |
513 | for (auto &bind : attribs.buffer_binds) { |
514 | // We might have to bind a invalid buffer (this is fine as long as |
515 | // shader don't do anything with it) |
516 | if (bind.buffer.type == BufferType::ExtArr) { |
517 | bindings->rw_buffer(bind.binding, any_arrays.at(bind.buffer.root_id)); |
518 | } else if (bind.buffer.type == BufferType::Args) { |
519 | bindings->buffer(bind.binding, |
520 | args_buffer ? *args_buffer : kDeviceNullAllocation); |
521 | } else if (bind.buffer.type == BufferType::Rets) { |
522 | bindings->rw_buffer(bind.binding, |
523 | ret_buffer ? *ret_buffer : kDeviceNullAllocation); |
524 | } else { |
525 | DeviceAllocation *alloc = ti_kernel->get_buffer_bind(bind.buffer); |
526 | bindings->rw_buffer(bind.binding, |
527 | alloc ? *alloc : kDeviceNullAllocation); |
528 | } |
529 | } |
530 | |
531 | for (auto &bind : attribs.texture_binds) { |
532 | DeviceAllocation texture = textures.at(bind.arg_id); |
533 | if (bind.is_storage) { |
534 | transition_image(texture, ImageLayout::shader_read_write); |
535 | bindings->rw_image(bind.binding, texture, 0); |
536 | } else { |
537 | transition_image(texture, ImageLayout::shader_read); |
538 | bindings->image(bind.binding, texture, {}); |
539 | } |
540 | } |
541 | |
542 | if (attribs.task_type == OffloadedTaskType::listgen) { |
543 | for (auto &bind : attribs.buffer_binds) { |
544 | if (bind.buffer.type == BufferType::ListGen) { |
545 | // FIXME: properlly support multiple list |
546 | current_cmdlist_->buffer_fill( |
547 | ti_kernel->get_buffer_bind(bind.buffer)->get_ptr(0), |
548 | kBufferSizeEntireSize, |
549 | /*data=*/0); |
550 | current_cmdlist_->buffer_barrier( |
551 | *ti_kernel->get_buffer_bind(bind.buffer)); |
552 | } |
553 | } |
554 | } |
555 | |
556 | current_cmdlist_->bind_pipeline(vp); |
557 | RhiResult status = current_cmdlist_->bind_shader_resources(bindings.get()); |
558 | TI_ERROR_IF(status != RhiResult::success, |
559 | "Resource binding error : RhiResult({})" , status); |
560 | |
561 | if (profiler_) { |
562 | current_cmdlist_->begin_profiler_scope(attribs.name); |
563 | } |
564 | |
565 | status = current_cmdlist_->dispatch(group_x); |
566 | |
567 | if (profiler_) { |
568 | current_cmdlist_->end_profiler_scope(); |
569 | } |
570 | |
571 | TI_ERROR_IF(status != RhiResult::success, "Dispatch error : RhiResult({})" , |
572 | status); |
573 | current_cmdlist_->memory_barrier(); |
574 | } |
575 | |
576 | // Keep context buffers used in this dispatch |
577 | if (ti_kernel->get_args_buffer_size()) { |
578 | ctx_buffers_.push_back(std::move(args_buffer)); |
579 | } |
580 | if (ti_kernel->get_ret_buffer_size()) { |
581 | ctx_buffers_.push_back(std::move(ret_buffer)); |
582 | } |
583 | |
584 | // If we need to host sync, sync and remove in-flight references |
585 | if (ctx_blitter) { |
586 | if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_arrays, |
587 | ext_array_size)) { |
588 | current_cmdlist_ = nullptr; |
589 | ctx_buffers_.clear(); |
590 | } |
591 | } |
592 | |
593 | submit_current_cmdlist_if_timeout(); |
594 | } |
595 | |
596 | void GfxRuntime::buffer_copy(DevicePtr dst, DevicePtr src, size_t size) { |
597 | ensure_current_cmdlist(); |
598 | current_cmdlist_->buffer_copy(dst, src, size); |
599 | } |
600 | |
601 | void GfxRuntime::copy_image(DeviceAllocation dst, |
602 | DeviceAllocation src, |
603 | const ImageCopyParams ¶ms) { |
604 | ensure_current_cmdlist(); |
605 | transition_image(dst, ImageLayout::transfer_dst); |
606 | transition_image(src, ImageLayout::transfer_src); |
607 | current_cmdlist_->copy_image(dst, src, ImageLayout::transfer_dst, |
608 | ImageLayout::transfer_src, params); |
609 | } |
610 | |
611 | DeviceAllocation GfxRuntime::create_image(const ImageParams ¶ms) { |
612 | GraphicsDevice *gfx_device = dynamic_cast<GraphicsDevice *>(device_); |
613 | TI_ERROR_IF(gfx_device == nullptr, |
614 | "Image can only be created on a graphics device" ); |
615 | DeviceAllocation image = gfx_device->create_image(params); |
616 | track_image(image, ImageLayout::undefined); |
617 | last_image_layouts_.at(image.alloc_id) = params.initial_layout; |
618 | return image; |
619 | } |
620 | |
621 | void GfxRuntime::track_image(DeviceAllocation image, ImageLayout layout) { |
622 | last_image_layouts_[image.alloc_id] = layout; |
623 | } |
624 | void GfxRuntime::untrack_image(DeviceAllocation image) { |
625 | last_image_layouts_.erase(image.alloc_id); |
626 | } |
627 | void GfxRuntime::transition_image(DeviceAllocation image, ImageLayout layout) { |
628 | ImageLayout &last_layout = last_image_layouts_.at(image.alloc_id); |
629 | ensure_current_cmdlist(); |
630 | current_cmdlist_->image_transition(image, last_layout, layout); |
631 | last_layout = layout; |
632 | } |
633 | |
634 | void GfxRuntime::synchronize() { |
635 | flush(); |
636 | device_->wait_idle(); |
637 | // Profiler support |
638 | if (profiler_) { |
639 | device_->profiler_sync(); |
640 | auto sampled_records = device_->profiler_flush_sampled_time(); |
641 | for (auto &record : sampled_records) { |
642 | profiler_->insert_record(record.first, record.second); |
643 | } |
644 | } |
645 | ctx_buffers_.clear(); |
646 | ndarrays_in_use_.clear(); |
647 | fflush(stdout); |
648 | } |
649 | |
650 | StreamSemaphore GfxRuntime::flush() { |
651 | StreamSemaphore sema; |
652 | if (current_cmdlist_) { |
653 | sema = device_->get_compute_stream()->submit(current_cmdlist_.get()); |
654 | current_cmdlist_ = nullptr; |
655 | ctx_buffers_.clear(); |
656 | } else { |
657 | auto [cmdlist, res] = |
658 | device_->get_compute_stream()->new_command_list_unique(); |
659 | TI_ASSERT(res == RhiResult::success); |
660 | cmdlist->memory_barrier(); |
661 | sema = device_->get_compute_stream()->submit(cmdlist.get()); |
662 | } |
663 | return sema; |
664 | } |
665 | |
666 | Device *GfxRuntime::get_ti_device() const { |
667 | return device_; |
668 | } |
669 | |
670 | void GfxRuntime::ensure_current_cmdlist() { |
671 | // Create new command list if current one is nullptr |
672 | if (!current_cmdlist_) { |
673 | current_cmdlist_pending_since_ = high_res_clock::now(); |
674 | auto [cmdlist, res] = |
675 | device_->get_compute_stream()->new_command_list_unique(); |
676 | TI_ASSERT(res == RhiResult::success); |
677 | current_cmdlist_ = std::move(cmdlist); |
678 | } |
679 | } |
680 | |
681 | void GfxRuntime::submit_current_cmdlist_if_timeout() { |
682 | // If we have accumulated some work but does not require sync |
683 | // and if the accumulated cmdlist has been pending for some time |
684 | // launch the cmdlist to start processing. |
685 | if (current_cmdlist_) { |
686 | constexpr uint64_t max_pending_time = 2000; // 2000us = 2ms |
687 | auto duration = high_res_clock::now() - current_cmdlist_pending_since_; |
688 | if (std::chrono::duration_cast<std::chrono::microseconds>(duration) |
689 | .count() > max_pending_time) { |
690 | flush(); |
691 | } |
692 | } |
693 | } |
694 | |
695 | void GfxRuntime::init_nonroot_buffers() { |
696 | global_tmps_buffer_ = device_->allocate_memory_unique( |
697 | {kGtmpBufferSize, |
698 | /*host_write=*/false, /*host_read=*/false, |
699 | /*export_sharing=*/false, AllocUsage::Storage}); |
700 | |
701 | listgen_buffer_ = device_->allocate_memory_unique( |
702 | {kListGenBufferSize, |
703 | /*host_write=*/false, /*host_read=*/false, |
704 | /*export_sharing=*/false, AllocUsage::Storage}); |
705 | |
706 | // Need to zero fill the buffers, otherwise there could be NaN. |
707 | Stream *stream = device_->get_compute_stream(); |
708 | auto [cmdlist, res] = |
709 | device_->get_compute_stream()->new_command_list_unique(); |
710 | TI_ASSERT(res == RhiResult::success); |
711 | |
712 | cmdlist->buffer_fill(global_tmps_buffer_->get_ptr(0), kBufferSizeEntireSize, |
713 | /*data=*/0); |
714 | cmdlist->buffer_fill(listgen_buffer_->get_ptr(0), kBufferSizeEntireSize, |
715 | /*data=*/0); |
716 | stream->submit_synced(cmdlist.get()); |
717 | } |
718 | |
719 | void GfxRuntime::add_root_buffer(size_t root_buffer_size) { |
720 | if (root_buffer_size == 0) { |
721 | root_buffer_size = 4; // there might be empty roots |
722 | } |
723 | std::unique_ptr<DeviceAllocationGuard> new_buffer = |
724 | device_->allocate_memory_unique( |
725 | {root_buffer_size, |
726 | /*host_write=*/false, /*host_read=*/false, |
727 | /*export_sharing=*/false, AllocUsage::Storage}); |
728 | Stream *stream = device_->get_compute_stream(); |
729 | auto [cmdlist, res] = |
730 | device_->get_compute_stream()->new_command_list_unique(); |
731 | TI_ASSERT(res == RhiResult::success); |
732 | cmdlist->buffer_fill(new_buffer->get_ptr(0), kBufferSizeEntireSize, |
733 | /*data=*/0); |
734 | stream->submit_synced(cmdlist.get()); |
735 | root_buffers_.push_back(std::move(new_buffer)); |
736 | // cache the root buffer size |
737 | root_buffers_size_map_[root_buffers_.back().get()] = root_buffer_size; |
738 | } |
739 | |
740 | DeviceAllocation *GfxRuntime::get_root_buffer(int id) const { |
741 | if (id >= root_buffers_.size()) { |
742 | TI_ERROR("root buffer id {} not found" , id); |
743 | } |
744 | return root_buffers_[id].get(); |
745 | } |
746 | |
747 | size_t GfxRuntime::get_root_buffer_size(int id) const { |
748 | auto it = root_buffers_size_map_.find(root_buffers_[id].get()); |
749 | if (id >= root_buffers_.size() || it == root_buffers_size_map_.end()) { |
750 | TI_ERROR("root buffer id {} not found" , id); |
751 | } |
752 | return it->second; |
753 | } |
754 | |
755 | void GfxRuntime::enqueue_compute_op_lambda( |
756 | std::function<void(Device *device, CommandList *cmdlist)> op, |
757 | const std::vector<ComputeOpImageRef> &image_refs) { |
758 | for (const auto &ref : image_refs) { |
759 | TI_ASSERT(last_image_layouts_.find(ref.image.alloc_id) != |
760 | last_image_layouts_.end()); |
761 | transition_image(ref.image, ref.initial_layout); |
762 | } |
763 | |
764 | ensure_current_cmdlist(); |
765 | op(device_, current_cmdlist_.get()); |
766 | |
767 | for (const auto &ref : image_refs) { |
768 | last_image_layouts_[ref.image.alloc_id] = ref.final_layout; |
769 | } |
770 | } |
771 | |
772 | GfxRuntime::RegisterParams run_codegen( |
773 | Kernel *kernel, |
774 | Arch arch, |
775 | const DeviceCapabilityConfig &caps, |
776 | const std::vector<CompiledSNodeStructs> &compiled_structs, |
777 | const CompileConfig &compile_config) { |
778 | const auto id = Program::get_kernel_id(); |
779 | const auto taichi_kernel_name(fmt::format("{}_k{:04d}_vk" , kernel->name, id)); |
780 | TI_TRACE("VK codegen for Taichi kernel={}" , taichi_kernel_name); |
781 | spirv::KernelCodegen::Params params; |
782 | params.ti_kernel_name = taichi_kernel_name; |
783 | params.kernel = kernel; |
784 | params.compiled_structs = compiled_structs; |
785 | params.arch = arch; |
786 | params.caps = caps; |
787 | params.enable_spv_opt = compile_config.external_optimization_level > 0; |
788 | spirv::KernelCodegen codegen(params); |
789 | GfxRuntime::RegisterParams res; |
790 | codegen.run(res.kernel_attribs, res.task_spirv_source_codes); |
791 | res.num_snode_trees = compiled_structs.size(); |
792 | return res; |
793 | } |
794 | |
795 | } // namespace gfx |
796 | } // namespace taichi::lang |
797 | |