1// Program - Taichi program execution context
2
3#pragma once
4
5#include <functional>
6#include <optional>
7#include <atomic>
8#include <stack>
9#include <shared_mutex>
10
11#define TI_RUNTIME_HOST
12#include "taichi/aot/module_builder.h"
13#include "taichi/ir/frontend_ir.h"
14#include "taichi/ir/ir.h"
15#include "taichi/ir/type_factory.h"
16#include "taichi/ir/snode.h"
17#include "taichi/util/lang_util.h"
18#include "taichi/program/program_impl.h"
19#include "taichi/program/callable.h"
20#include "taichi/program/function.h"
21#include "taichi/program/kernel.h"
22#include "taichi/program/kernel_profiler.h"
23#include "taichi/program/snode_expr_utils.h"
24#include "taichi/program/snode_rw_accessors_bank.h"
25#include "taichi/program/context.h"
26#include "taichi/struct/snode_tree.h"
27#include "taichi/system/memory_pool.h"
28#include "taichi/system/threading.h"
29#include "taichi/system/unified_allocator.h"
30#include "taichi/program/sparse_matrix.h"
31#include "taichi/ir/mesh.h"
32
33namespace taichi::lang {
34
35struct JITEvaluatorId {
36 std::thread::id thread_id;
37 // Note that on certain backends (e.g. CUDA), functions created in one
38 // thread cannot be used in another. Hence the thread_id member.
39 int op;
40 DataType ret, lhs, rhs;
41 std::string tb;
42 bool is_binary;
43
44 UnaryOpType unary_op() const {
45 TI_ASSERT(!is_binary);
46 return (UnaryOpType)op;
47 }
48
49 BinaryOpType binary_op() const {
50 TI_ASSERT(is_binary);
51 return (BinaryOpType)op;
52 }
53
54 bool operator==(const JITEvaluatorId &o) const {
55 return thread_id == o.thread_id && op == o.op && ret == o.ret &&
56 lhs == o.lhs && rhs == o.rhs && is_binary == o.is_binary &&
57 tb == o.tb;
58 }
59};
60
61} // namespace taichi::lang
62
63namespace std {
64template <>
65struct hash<taichi::lang::JITEvaluatorId> {
66 std::size_t operator()(
67 taichi::lang::JITEvaluatorId const &id) const noexcept {
68 return ((std::size_t)id.op | (id.ret.hash() << 8) | (id.lhs.hash() << 16) |
69 (id.rhs.hash() << 24) | ((std::size_t)id.is_binary << 31)) ^
70 (std::hash<std::thread::id>{}(id.thread_id) << 32);
71 }
72};
73} // namespace std
74
75namespace taichi::lang {
76
77class StructCompiler;
78
79/**
80 * Note [Backend-specific ProgramImpl]
81 * We're working in progress to keep Program class minimal and move all backend
82 * specific logic to their corresponding backend ProgramImpls.
83
84 * If you are thinking about exposing/adding attributes/methods to Program
85 class,
86 * please first think about if it's general for all backends:
87 * - If so, please consider adding it to ProgramImpl class first.
88 * - Otherwise please add it to a backend-specific ProgramImpl, e.g.
89 * LlvmProgramImpl, MetalProgramImpl..
90 */
91
92class TI_DLL_EXPORT Program {
93 public:
94 using Kernel = taichi::lang::Kernel;
95
96 uint64 *result_buffer{nullptr}; // Note result_buffer is used by all backends
97
98 std::vector<std::unique_ptr<Kernel>> kernels;
99
100 std::unique_ptr<KernelProfilerBase> profiler{nullptr};
101
102 std::unordered_map<JITEvaluatorId, std::unique_ptr<Kernel>>
103 jit_evaluator_cache;
104 std::mutex jit_evaluator_cache_mut;
105
106 // Note: for now we let all Programs share a single TypeFactory for smooth
107 // migration. In the future each program should have its own copy.
108 static TypeFactory &get_type_factory();
109
110 Program() : Program(default_compile_config.arch) {
111 }
112
113 explicit Program(Arch arch);
114
115 ~Program();
116
117 const CompileConfig &compile_config() const {
118 return compile_config_;
119 }
120
121 struct KernelProfilerQueryResult {
122 int counter{0};
123 double min{0.0};
124 double max{0.0};
125 double avg{0.0};
126 };
127
128 KernelProfilerQueryResult query_kernel_profile_info(const std::string &name) {
129 KernelProfilerQueryResult query_result;
130 profiler->query(name, query_result.counter, query_result.min,
131 query_result.max, query_result.avg);
132 return query_result;
133 }
134
135 void clear_kernel_profile_info() {
136 profiler->clear();
137 }
138
139 void profiler_start(const std::string &name) {
140 profiler->start(name);
141 }
142
143 void profiler_stop() {
144 profiler->stop();
145 }
146
147 KernelProfilerBase *get_profiler() {
148 return profiler.get();
149 }
150
151 void synchronize();
152
153 StreamSemaphore flush();
154
155 /**
156 * Materializes the runtime.
157 */
158 void materialize_runtime();
159
160 int get_snode_tree_size();
161
162 Kernel &kernel(const std::function<void(Kernel *)> &body,
163 const std::string &name = "",
164 AutodiffMode autodiff_mode = AutodiffMode::kNone) {
165 // Expr::set_allow_store(true);
166 auto func = std::make_unique<Kernel>(*this, body, name, autodiff_mode);
167 // Expr::set_allow_store(false);
168 kernels.emplace_back(std::move(func));
169 return *kernels.back();
170 }
171
172 Function *create_function(const FunctionKey &func_key);
173
174 // TODO: This function is doing two things: 1) compiling CHI IR, and 2)
175 // offloading them to each backend. We should probably separate the logic?
176 FunctionType compile(const CompileConfig &compile_config, Kernel &kernel);
177
178 void check_runtime_error();
179
180 Kernel &get_snode_reader(SNode *snode);
181
182 Kernel &get_snode_writer(SNode *snode);
183
184 uint64 fetch_result_uint64(int i);
185
186 TypedConstant fetch_result(int offset, const Type *dt) {
187 return program_impl_->fetch_result((char *)result_buffer, offset, dt);
188 }
189
190 template <typename T>
191 T fetch_result(int i) {
192 return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64(i));
193 }
194
195 Arch get_host_arch() {
196 return host_arch();
197 }
198
199 float64 get_total_compilation_time() {
200 return total_compilation_time_;
201 }
202
203 void finalize();
204
205 static int get_kernel_id() {
206 static int id = 0;
207 TI_ASSERT(id < 100000);
208 return id++;
209 }
210
211 static int default_block_dim(const CompileConfig &config);
212
213 // Note this method is specific to LlvmProgramImpl, but we keep it here since
214 // it's exposed to python.
215 void print_memory_profiler_info();
216
217 // Returns zero if the SNode is statically allocated
218 std::size_t get_snode_num_dynamically_allocated(SNode *snode);
219
220 inline SNodeFieldMap *get_snode_to_fields() {
221 return &snode_to_fields_;
222 }
223
224 inline SNodeRwAccessorsBank &get_snode_rw_accessors_bank() {
225 return snode_rw_accessors_bank_;
226 }
227
228 /**
229 * Destroys a new SNode tree.
230 *
231 * @param snode_tree The pointer to SNode tree.
232 */
233 void destroy_snode_tree(SNodeTree *snode_tree);
234
235 /**
236 * Adds a new SNode tree.
237 *
238 * @param root The root of the new SNode tree.
239 * @param compile_only Only generates the compiled type
240 * @return The pointer to SNode tree.
241 *
242 * FIXME: compile_only is mostly a hack to make AOT & cross-compilation work.
243 * E.g. users who would like to AOT to a specific target backend can do so,
244 * even if their platform doesn't support that backend. Unfortunately, the
245 * current implementation would leave the backend in a mostly broken state. We
246 * need a cleaner design to support both AOT and JIT modes.
247 */
248 SNodeTree *add_snode_tree(std::unique_ptr<SNode> root, bool compile_only);
249
250 /**
251 * Allocates a SNode tree id for a new SNode tree
252 *
253 * @return The SNode tree id allocated
254 *
255 * Returns and consumes a free SNode tree id if there is any,
256 * Otherwise returns the size of `snode_trees_`
257 */
258 int allocate_snode_tree_id();
259
260 /**
261 * Gets the root of a SNode tree.
262 *
263 * @param tree_id Index of the SNode tree
264 * @return Root of the tree
265 */
266 SNode *get_snode_root(int tree_id);
267
268 std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
269 Arch arch,
270 const std::vector<std::string> &caps);
271
272 size_t get_field_in_tree_offset(int tree_id, const SNode *child) {
273 return program_impl_->get_field_in_tree_offset(tree_id, child);
274 }
275
276 DevicePtr get_snode_tree_device_ptr(int tree_id) {
277 return program_impl_->get_snode_tree_device_ptr(tree_id);
278 }
279
280 Device *get_compute_device() {
281 return program_impl_->get_compute_device();
282 }
283
284 Device *get_graphics_device() {
285 return program_impl_->get_graphics_device();
286 }
287
288 // TODO: do we still need result_buffer?
289 DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size,
290 uint64 *result_buffer) {
291 return program_impl_->allocate_memory_ndarray(alloc_size, result_buffer);
292 }
293 DeviceAllocation allocate_texture(const ImageParams &params) {
294 return program_impl_->allocate_texture(params);
295 }
296
297 Ndarray *create_ndarray(
298 const DataType type,
299 const std::vector<int> &shape,
300 ExternalArrayLayout layout = ExternalArrayLayout::kNull,
301 bool zero_fill = false);
302
303 std::string get_kernel_return_data_layout() {
304 return program_impl_->get_kernel_return_data_layout();
305 };
306
307 std::string get_kernel_argument_data_layout() {
308 return program_impl_->get_kernel_argument_data_layout();
309 };
310
311 const StructType *get_struct_type_with_data_layout(
312 const StructType *old_ty,
313 const std::string &layout) {
314 return program_impl_->get_struct_type_with_data_layout(old_ty, layout);
315 }
316
317 void delete_ndarray(Ndarray *ndarray);
318
319 Texture *create_texture(const DataType type,
320 int num_channels,
321 const std::vector<int> &shape);
322
323 intptr_t get_ndarray_data_ptr_as_int(const Ndarray *ndarray);
324
325 void fill_ndarray_fast_u32(Ndarray *ndarray, uint32_t val);
326
327 Identifier get_next_global_id(const std::string &name = "") {
328 return Identifier(global_id_counter_++, name);
329 }
330
331 void prepare_runtime_context(RuntimeContext *ctx);
332
333 /** Enqueue a custom compute op to the current program execution flow.
334 *
335 * @params op The lambda that is invoked to construct the custom compute Op
336 * @params image_refs The image resource references used in this compute Op
337 */
338 void enqueue_compute_op_lambda(
339 std::function<void(Device *device, CommandList *cmdlist)> op,
340 const std::vector<ComputeOpImageRef> &image_refs);
341
342 /**
343 * TODO(zhanlue): Remove this interface
344 *
345 * Gets the underlying ProgramImpl object
346 *
347 * This interface is essentially a hack to temporarily accommodate
348 * historical design issues with LLVM backend
349 *
350 * Please limit its use to LLVM backend only
351 */
352 ProgramImpl *get_program_impl() {
353 TI_ASSERT(arch_uses_llvm(compile_config().arch));
354 return program_impl_.get();
355 }
356
357 // TODO(zhanlue): Move these members and corresponding interfaces to
358 // ProgramImpl Ideally, Program should serve as a pure interface class and all
359 // the implementations should fall inside ProgramImpl
360 //
361 // Once we migrated these implementations to ProgramImpl, lower-level objects
362 // could store ProgramImpl rather than Program.
363
364 private:
365 CompileConfig compile_config_;
366
367 uint64 ndarray_writer_counter_{0};
368 uint64 ndarray_reader_counter_{0};
369 int global_id_counter_{0};
370
371 // SNode information that requires using Program.
372 SNodeFieldMap snode_to_fields_;
373 SNodeRwAccessorsBank snode_rw_accessors_bank_;
374
375 std::vector<std::unique_ptr<SNodeTree>> snode_trees_;
376 std::stack<int> free_snode_tree_ids_;
377
378 std::vector<std::unique_ptr<Function>> functions_;
379 std::unordered_map<FunctionKey, Function *> function_map_;
380
381 std::unique_ptr<ProgramImpl> program_impl_;
382 float64 total_compilation_time_{0.0};
383 static std::atomic<int> num_instances_;
384 bool finalized_{false};
385
386 std::unique_ptr<MemoryPool> memory_pool_{nullptr};
387 // TODO: Move ndarrays_ and textures_ to be managed by runtime
388 std::unordered_map<void *, std::unique_ptr<Ndarray>> ndarrays_;
389 std::vector<std::unique_ptr<Texture>> textures_;
390};
391
392} // namespace taichi::lang
393