1#pragma once
2
3#include "taichi/aot/module_builder.h"
4#include "taichi/ir/statements.h"
5#include "taichi/system/memory_pool.h"
6#include "taichi/common/logging.h"
7#include "taichi/struct/snode_tree.h"
8#include "taichi/program/snode_expr_utils.h"
9#include "taichi/program/kernel_profiler.h"
10#include "taichi/rhi/device.h"
11#include "taichi/aot/graph_data.h"
12
13namespace taichi::lang {
14
15// Represents an image resource reference for a compute/render Op
16struct ComputeOpImageRef {
17 DeviceAllocation image;
18 // The requested initial layout of the image, when Op is invoked
19 ImageLayout initial_layout;
20 // The final layout the image will be in once Op finishes
21 ImageLayout final_layout;
22};
23
24struct RuntimeContext;
25
26class ProgramImpl {
27 public:
28 // TODO: Make it safer, we exposed it for now as it's directly accessed
29 // outside.
30 CompileConfig *config;
31
32 public:
33 explicit ProgramImpl(CompileConfig &config);
34
35 /**
36 * Codegen to specific backend
37 */
38 virtual FunctionType compile(const CompileConfig &compile_config,
39 Kernel *kernel) = 0;
40
41 /**
42 * Allocate runtime buffer, e.g result_buffer or backend specific runtime
43 * buffer, e.g. preallocated_device_buffer on CUDA.
44 */
45 virtual void materialize_runtime(MemoryPool *memory_pool,
46 KernelProfilerBase *profiler,
47 uint64 **result_buffer_ptr) = 0;
48
49 /**
50 * JIT compiles @param tree to backend-specific data types.
51 */
52 virtual void compile_snode_tree_types(SNodeTree *tree);
53
54 /**
55 * Compiles the @param tree types and allocates runtime buffer for it.
56 */
57 virtual void materialize_snode_tree(SNodeTree *tree,
58 uint64 *result_buffer_ptr) = 0;
59
60 virtual void destroy_snode_tree(SNodeTree *snode_tree) = 0;
61
62 virtual std::size_t get_snode_num_dynamically_allocated(
63 SNode *snode,
64 uint64 *result_buffer) = 0;
65
66 /**
67 * Perform a backend synchronization.
68 */
69 virtual void synchronize() = 0;
70
71 virtual StreamSemaphore flush() {
72 synchronize();
73 return nullptr;
74 }
75
76 /**
77 * Make a AotModulerBuilder, currently only supported by metal and wasm.
78 */
79 virtual std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
80 const DeviceCapabilityConfig &caps) = 0;
81
82 /**
83 * Dump Offline-cache data to disk
84 */
85 virtual void dump_cache_data_to_disk() {
86 }
87
88 virtual Device *get_compute_device() {
89 return nullptr;
90 }
91
92 virtual Device *get_graphics_device() {
93 return nullptr;
94 }
95
96 virtual size_t get_field_in_tree_offset(int tree_id, const SNode *child) {
97 return 0;
98 }
99
100 virtual DevicePtr get_snode_tree_device_ptr(int tree_id) {
101 return kDeviceNullPtr;
102 }
103
104 virtual DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size,
105 uint64 *result_buffer) {
106 return kDeviceNullAllocation;
107 }
108
109 virtual bool used_in_kernel(DeviceAllocationId) {
110 return false;
111 }
112
113 virtual DeviceAllocation allocate_texture(const ImageParams &params) {
114 return kDeviceNullAllocation;
115 }
116
117 virtual ~ProgramImpl() {
118 }
119
120 // TODO: Move to Runtime Object
121 virtual uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc) {
122 TI_ERROR(
123 "get_ndarray_alloc_info_ptr() not implemented on the current backend");
124 return nullptr;
125 }
126
127 // TODO: Move to Runtime Object
128 virtual void fill_ndarray(const DeviceAllocation &alloc,
129 std::size_t size,
130 uint32_t data) {
131 TI_ERROR("fill_ndarray() not implemented on the current backend");
132 }
133
134 // TODO: Move to Runtime Object
135 virtual void prepare_runtime_context(RuntimeContext *ctx) {
136 }
137
138 virtual void enqueue_compute_op_lambda(
139 std::function<void(Device *device, CommandList *cmdlist)> op,
140 const std::vector<ComputeOpImageRef> &image_refs) {
141 TI_NOT_IMPLEMENTED;
142 }
143
144 virtual void print_memory_profiler_info(
145 std::vector<std::unique_ptr<SNodeTree>> &snode_trees_,
146 uint64 *result_buffer) {
147 TI_ERROR(
148 "print_memory_profiler_info() not implemented on the current backend");
149 }
150
151 virtual void check_runtime_error(uint64 *result_buffer) {
152 TI_ERROR("check_runtime_error() not implemented on the current backend");
153 }
154
155 virtual void finalize() {
156 }
157
158 virtual uint64 fetch_result_uint64(int i, uint64 *result_buffer) {
159 return result_buffer[i];
160 }
161
162 virtual TypedConstant fetch_result(char *result_buffer,
163 int offset,
164 const Type *dt) {
165 TI_NOT_IMPLEMENTED;
166 }
167
168 virtual std::string get_kernel_return_data_layout() {
169 return "";
170 };
171
172 virtual std::string get_kernel_argument_data_layout() {
173 return "";
174 };
175
176 virtual const StructType *get_struct_type_with_data_layout(
177 const StructType *old_ty,
178 const std::string &layout) {
179 return old_ty;
180 }
181
182 private:
183};
184
185} // namespace taichi::lang
186