1#pragma once
2#include "taichi/util/lang_util.h"
3
4#include <vector>
5#include <chrono>
6
7#include "taichi/rhi/device.h"
8#include "taichi/codegen/spirv/snode_struct_compiler.h"
9#include "taichi/codegen/spirv/kernel_utils.h"
10#include "taichi/codegen/spirv/spirv_codegen.h"
11#include "taichi/program/compile_config.h"
12#include "taichi/struct/snode_tree.h"
13#include "taichi/program/snode_expr_utils.h"
14#include "taichi/program/program_impl.h"
15
16namespace taichi::lang {
17namespace gfx {
18
19using namespace taichi::lang::spirv;
20
21using BufferType = TaskAttributes::BufferType;
22using BufferInfo = TaskAttributes::BufferInfo;
23using BufferBind = TaskAttributes::BufferBind;
24using BufferInfoHasher = TaskAttributes::BufferInfoHasher;
25
26using high_res_clock = std::chrono::high_resolution_clock;
27
28// TODO: In the future this isn't necessarily a pointer, since DeviceAllocation
29// is already a pretty cheap handle>
30using InputBuffersMap =
31 std::unordered_map<BufferInfo, DeviceAllocation *, BufferInfoHasher>;
32
33class SNodeTreeManager;
34
35class CompiledTaichiKernel {
36 public:
37 struct Params {
38 const TaichiKernelAttributes *ti_kernel_attribs{nullptr};
39 std::vector<std::vector<uint32_t>> spirv_bins;
40 std::size_t num_snode_trees{0};
41
42 Device *device{nullptr};
43 std::vector<DeviceAllocation *> root_buffers;
44 DeviceAllocation *global_tmps_buffer{nullptr};
45 DeviceAllocation *listgen_buffer{nullptr};
46
47 PipelineCache *backend_cache{nullptr};
48 };
49
50 explicit CompiledTaichiKernel(const Params &ti_params);
51
52 const TaichiKernelAttributes &ti_kernel_attribs() const;
53
54 size_t num_pipelines() const;
55
56 size_t get_args_buffer_size() const;
57 size_t get_ret_buffer_size() const;
58
59 Pipeline *get_pipeline(int i);
60
61 DeviceAllocation *get_buffer_bind(const BufferInfo &bind) {
62 return input_buffers_[bind];
63 }
64
65 private:
66 TaichiKernelAttributes ti_kernel_attribs_;
67 std::vector<TaskAttributes> tasks_attribs_;
68
69 [[maybe_unused]] Device *device_;
70
71 InputBuffersMap input_buffers_;
72
73 size_t args_buffer_size_{0};
74 size_t ret_buffer_size_{0};
75 std::vector<std::unique_ptr<Pipeline>> pipelines_;
76};
77
78class TI_DLL_EXPORT GfxRuntime {
79 public:
80 struct Params {
81 uint64_t *host_result_buffer{nullptr};
82 Device *device{nullptr};
83 KernelProfilerBase *profiler{nullptr};
84 };
85
86 explicit GfxRuntime(const Params &params);
87 // To make Pimpl + std::unique_ptr work
88 ~GfxRuntime();
89
90 class KernelHandle {
91 private:
92 friend class GfxRuntime;
93 int id_ = -1;
94 };
95
96 struct RegisterParams {
97 TaichiKernelAttributes kernel_attribs;
98 std::vector<std::vector<uint32_t>> task_spirv_source_codes;
99 std::size_t num_snode_trees{0};
100 };
101
102 KernelHandle register_taichi_kernel(RegisterParams params);
103
104 void launch_kernel(KernelHandle handle, RuntimeContext *host_ctx);
105
106 void buffer_copy(DevicePtr dst, DevicePtr src, size_t size);
107 void copy_image(DeviceAllocation dst,
108 DeviceAllocation src,
109 const ImageCopyParams &params);
110
111 DeviceAllocation create_image(const ImageParams &params);
112 void track_image(DeviceAllocation image, ImageLayout layout);
113 void untrack_image(DeviceAllocation image);
114 void transition_image(DeviceAllocation image, ImageLayout layout);
115
116 void synchronize();
117
118 StreamSemaphore flush();
119
120 Device *get_ti_device() const;
121
122 void add_root_buffer(size_t root_buffer_size);
123
124 DeviceAllocation *get_root_buffer(int id) const;
125
126 size_t get_root_buffer_size(int id) const;
127
128 void enqueue_compute_op_lambda(
129 std::function<void(Device *device, CommandList *cmdlist)> op,
130 const std::vector<ComputeOpImageRef> &image_refs);
131
132 bool used_in_kernel(DeviceAllocationId id) {
133 return ndarrays_in_use_.count(id) > 0;
134 }
135
136 private:
137 friend class taichi::lang::gfx::SNodeTreeManager;
138
139 void ensure_current_cmdlist();
140 void submit_current_cmdlist_if_timeout();
141
142 void init_nonroot_buffers();
143
144 Device *device_{nullptr};
145 uint64_t *const host_result_buffer_;
146 KernelProfilerBase *profiler_;
147
148 std::unique_ptr<PipelineCache> backend_cache_{nullptr};
149
150 std::vector<std::unique_ptr<DeviceAllocationGuard>> root_buffers_;
151 std::unique_ptr<DeviceAllocationGuard> global_tmps_buffer_;
152 // FIXME: Support proper multiple lists
153 std::unique_ptr<DeviceAllocationGuard> listgen_buffer_;
154
155 std::vector<std::unique_ptr<DeviceAllocationGuard>> ctx_buffers_;
156
157 std::unique_ptr<CommandList> current_cmdlist_{nullptr};
158 high_res_clock::time_point current_cmdlist_pending_since_;
159
160 std::vector<std::unique_ptr<CompiledTaichiKernel>> ti_kernels_;
161
162 std::unordered_map<DeviceAllocation *, size_t> root_buffers_size_map_;
163 std::unordered_map<DeviceAllocationId, ImageLayout> last_image_layouts_;
164 // [Note] Why do we need to track ndarrays that are in use?
165 // Since we separate cmdlist is async, taichi needs a way to know whether
166 // ndarrays are still used by pending kernels to be executed. So we use
167 // ndarray_in_use_ to track this so that we can free memory allocated for
168 // ndarray whenever it's safe to do so.
169 std::unordered_set<DeviceAllocationId> ndarrays_in_use_;
170};
171
172GfxRuntime::RegisterParams run_codegen(
173 Kernel *kernel,
174 Arch arch,
175 const DeviceCapabilityConfig &caps,
176 const std::vector<CompiledSNodeStructs> &compiled_structs,
177 const CompileConfig &compile_config);
178
179} // namespace gfx
180} // namespace taichi::lang
181