1 | #pragma once |
2 | #include "taichi/util/lang_util.h" |
3 | |
4 | #include <vector> |
5 | #include <chrono> |
6 | |
7 | #include "taichi/rhi/device.h" |
8 | #include "taichi/codegen/spirv/snode_struct_compiler.h" |
9 | #include "taichi/codegen/spirv/kernel_utils.h" |
10 | #include "taichi/codegen/spirv/spirv_codegen.h" |
11 | #include "taichi/program/compile_config.h" |
12 | #include "taichi/struct/snode_tree.h" |
13 | #include "taichi/program/snode_expr_utils.h" |
14 | #include "taichi/program/program_impl.h" |
15 | |
16 | namespace taichi::lang { |
17 | namespace gfx { |
18 | |
19 | using namespace taichi::lang::spirv; |
20 | |
21 | using BufferType = TaskAttributes::BufferType; |
22 | using BufferInfo = TaskAttributes::BufferInfo; |
23 | using BufferBind = TaskAttributes::BufferBind; |
24 | using BufferInfoHasher = TaskAttributes::BufferInfoHasher; |
25 | |
26 | using high_res_clock = std::chrono::high_resolution_clock; |
27 | |
28 | // TODO: In the future this isn't necessarily a pointer, since DeviceAllocation |
29 | // is already a pretty cheap handle> |
30 | using InputBuffersMap = |
31 | std::unordered_map<BufferInfo, DeviceAllocation *, BufferInfoHasher>; |
32 | |
33 | class SNodeTreeManager; |
34 | |
35 | class CompiledTaichiKernel { |
36 | public: |
37 | struct Params { |
38 | const TaichiKernelAttributes *ti_kernel_attribs{nullptr}; |
39 | std::vector<std::vector<uint32_t>> spirv_bins; |
40 | std::size_t num_snode_trees{0}; |
41 | |
42 | Device *device{nullptr}; |
43 | std::vector<DeviceAllocation *> root_buffers; |
44 | DeviceAllocation *global_tmps_buffer{nullptr}; |
45 | DeviceAllocation *listgen_buffer{nullptr}; |
46 | |
47 | PipelineCache *backend_cache{nullptr}; |
48 | }; |
49 | |
50 | explicit CompiledTaichiKernel(const Params &ti_params); |
51 | |
52 | const TaichiKernelAttributes &ti_kernel_attribs() const; |
53 | |
54 | size_t num_pipelines() const; |
55 | |
56 | size_t get_args_buffer_size() const; |
57 | size_t get_ret_buffer_size() const; |
58 | |
59 | Pipeline *get_pipeline(int i); |
60 | |
61 | DeviceAllocation *get_buffer_bind(const BufferInfo &bind) { |
62 | return input_buffers_[bind]; |
63 | } |
64 | |
65 | private: |
66 | TaichiKernelAttributes ti_kernel_attribs_; |
67 | std::vector<TaskAttributes> tasks_attribs_; |
68 | |
69 | [[maybe_unused]] Device *device_; |
70 | |
71 | InputBuffersMap input_buffers_; |
72 | |
73 | size_t args_buffer_size_{0}; |
74 | size_t ret_buffer_size_{0}; |
75 | std::vector<std::unique_ptr<Pipeline>> pipelines_; |
76 | }; |
77 | |
78 | class TI_DLL_EXPORT GfxRuntime { |
79 | public: |
80 | struct Params { |
81 | uint64_t *host_result_buffer{nullptr}; |
82 | Device *device{nullptr}; |
83 | KernelProfilerBase *profiler{nullptr}; |
84 | }; |
85 | |
86 | explicit GfxRuntime(const Params ¶ms); |
87 | // To make Pimpl + std::unique_ptr work |
88 | ~GfxRuntime(); |
89 | |
90 | class KernelHandle { |
91 | private: |
92 | friend class GfxRuntime; |
93 | int id_ = -1; |
94 | }; |
95 | |
96 | struct RegisterParams { |
97 | TaichiKernelAttributes kernel_attribs; |
98 | std::vector<std::vector<uint32_t>> task_spirv_source_codes; |
99 | std::size_t num_snode_trees{0}; |
100 | }; |
101 | |
102 | KernelHandle register_taichi_kernel(RegisterParams params); |
103 | |
104 | void launch_kernel(KernelHandle handle, RuntimeContext *host_ctx); |
105 | |
106 | void buffer_copy(DevicePtr dst, DevicePtr src, size_t size); |
107 | void copy_image(DeviceAllocation dst, |
108 | DeviceAllocation src, |
109 | const ImageCopyParams ¶ms); |
110 | |
111 | DeviceAllocation create_image(const ImageParams ¶ms); |
112 | void track_image(DeviceAllocation image, ImageLayout layout); |
113 | void untrack_image(DeviceAllocation image); |
114 | void transition_image(DeviceAllocation image, ImageLayout layout); |
115 | |
116 | void synchronize(); |
117 | |
118 | StreamSemaphore flush(); |
119 | |
120 | Device *get_ti_device() const; |
121 | |
122 | void add_root_buffer(size_t root_buffer_size); |
123 | |
124 | DeviceAllocation *get_root_buffer(int id) const; |
125 | |
126 | size_t get_root_buffer_size(int id) const; |
127 | |
128 | void enqueue_compute_op_lambda( |
129 | std::function<void(Device *device, CommandList *cmdlist)> op, |
130 | const std::vector<ComputeOpImageRef> &image_refs); |
131 | |
132 | bool used_in_kernel(DeviceAllocationId id) { |
133 | return ndarrays_in_use_.count(id) > 0; |
134 | } |
135 | |
136 | private: |
137 | friend class taichi::lang::gfx::SNodeTreeManager; |
138 | |
139 | void ensure_current_cmdlist(); |
140 | void submit_current_cmdlist_if_timeout(); |
141 | |
142 | void init_nonroot_buffers(); |
143 | |
144 | Device *device_{nullptr}; |
145 | uint64_t *const host_result_buffer_; |
146 | KernelProfilerBase *profiler_; |
147 | |
148 | std::unique_ptr<PipelineCache> backend_cache_{nullptr}; |
149 | |
150 | std::vector<std::unique_ptr<DeviceAllocationGuard>> root_buffers_; |
151 | std::unique_ptr<DeviceAllocationGuard> global_tmps_buffer_; |
152 | // FIXME: Support proper multiple lists |
153 | std::unique_ptr<DeviceAllocationGuard> listgen_buffer_; |
154 | |
155 | std::vector<std::unique_ptr<DeviceAllocationGuard>> ctx_buffers_; |
156 | |
157 | std::unique_ptr<CommandList> current_cmdlist_{nullptr}; |
158 | high_res_clock::time_point current_cmdlist_pending_since_; |
159 | |
160 | std::vector<std::unique_ptr<CompiledTaichiKernel>> ti_kernels_; |
161 | |
162 | std::unordered_map<DeviceAllocation *, size_t> root_buffers_size_map_; |
163 | std::unordered_map<DeviceAllocationId, ImageLayout> last_image_layouts_; |
164 | // [Note] Why do we need to track ndarrays that are in use? |
165 | // Since we separate cmdlist is async, taichi needs a way to know whether |
166 | // ndarrays are still used by pending kernels to be executed. So we use |
167 | // ndarray_in_use_ to track this so that we can free memory allocated for |
168 | // ndarray whenever it's safe to do so. |
169 | std::unordered_set<DeviceAllocationId> ndarrays_in_use_; |
170 | }; |
171 | |
172 | GfxRuntime::RegisterParams run_codegen( |
173 | Kernel *kernel, |
174 | Arch arch, |
175 | const DeviceCapabilityConfig &caps, |
176 | const std::vector<CompiledSNodeStructs> &compiled_structs, |
177 | const CompileConfig &compile_config); |
178 | |
179 | } // namespace gfx |
180 | } // namespace taichi::lang |
181 | |