1 | #include "taichi/program/kernel.h" |
2 | |
3 | #include "taichi/rhi/cuda/cuda_driver.h" |
4 | #include "taichi/codegen/codegen.h" |
5 | #include "taichi/common/logging.h" |
6 | #include "taichi/common/task.h" |
7 | #include "taichi/ir/statements.h" |
8 | #include "taichi/program/program.h" |
9 | #include "taichi/util/action_recorder.h" |
10 | |
11 | #ifdef TI_WITH_LLVM |
12 | #include "taichi/runtime/program_impls/llvm/llvm_program.h" |
13 | #endif |
14 | |
15 | namespace taichi::lang { |
16 | |
17 | class Function; |
18 | |
19 | Kernel::Kernel(Program &program, |
20 | const std::function<void()> &func, |
21 | const std::string &primal_name, |
22 | AutodiffMode autodiff_mode) { |
23 | this->init(program, func, primal_name, autodiff_mode); |
24 | } |
25 | |
26 | Kernel::Kernel(Program &program, |
27 | const std::function<void(Kernel *)> &func, |
28 | const std::string &primal_name, |
29 | AutodiffMode autodiff_mode) { |
30 | // due to #6362, we cannot write [func, this] { return func(this); } |
31 | this->init( |
32 | program, [&] { return func(this); }, primal_name, autodiff_mode); |
33 | } |
34 | |
35 | Kernel::Kernel(Program &program, |
36 | std::unique_ptr<IRNode> &&ir, |
37 | const std::string &primal_name, |
38 | AutodiffMode autodiff_mode) |
39 | : autodiff_mode(autodiff_mode), lowered_(false) { |
40 | this->ir = std::move(ir); |
41 | this->program = &program; |
42 | is_accessor = false; |
43 | is_evaluator = false; |
44 | compiled_ = nullptr; |
45 | ir_is_ast_ = false; // CHI IR |
46 | |
47 | if (autodiff_mode == AutodiffMode::kNone) { |
48 | name = primal_name; |
49 | } else if (autodiff_mode == AutodiffMode::kForward) { |
50 | name = primal_name + "_forward_grad" ; |
51 | } else if (autodiff_mode == AutodiffMode::kReverse) { |
52 | name = primal_name + "_reverse_grad" ; |
53 | } |
54 | } |
55 | |
56 | void Kernel::compile(const CompileConfig &compile_config) { |
57 | compiled_ = program->compile(compile_config, *this); |
58 | } |
59 | |
60 | void Kernel::operator()(const CompileConfig &compile_config, |
61 | LaunchContextBuilder &ctx_builder) { |
62 | if (!compiled_) { |
63 | compile(compile_config); |
64 | } |
65 | |
66 | compiled_(ctx_builder.get_context()); |
67 | |
68 | const auto arch = compile_config.arch; |
69 | if (compile_config.debug && |
70 | (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu)) { |
71 | program->check_runtime_error(); |
72 | } |
73 | } |
74 | |
75 | Kernel::LaunchContextBuilder Kernel::make_launch_context() { |
76 | return LaunchContextBuilder(this); |
77 | } |
78 | |
79 | Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel, |
80 | RuntimeContext *ctx) |
81 | : kernel_(kernel), owned_ctx_(nullptr), ctx_(ctx) { |
82 | } |
83 | |
84 | Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel) |
85 | : kernel_(kernel), |
86 | owned_ctx_(std::make_unique<RuntimeContext>()), |
87 | ctx_(owned_ctx_.get()) { |
88 | } |
89 | |
90 | void Kernel::LaunchContextBuilder::set_arg_float(int arg_id, float64 d) { |
91 | TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array, |
92 | "Assigning scalar value to external (numpy) array argument is " |
93 | "not allowed." ); |
94 | |
95 | ActionRecorder::get_instance().record( |
96 | "set_kernel_arg_float64" , |
97 | {ActionArg("kernel_name" , kernel_->name), ActionArg("arg_id" , arg_id), |
98 | ActionArg("val" , d)}); |
99 | |
100 | auto dt = kernel_->parameter_list[arg_id].get_dtype(); |
101 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
102 | ctx_->set_arg(arg_id, (float32)d); |
103 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
104 | ctx_->set_arg(arg_id, (float64)d); |
105 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
106 | ctx_->set_arg(arg_id, (int32)d); |
107 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
108 | ctx_->set_arg(arg_id, (int64)d); |
109 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
110 | ctx_->set_arg(arg_id, (int8)d); |
111 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
112 | ctx_->set_arg(arg_id, (int16)d); |
113 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
114 | ctx_->set_arg(arg_id, (uint8)d); |
115 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
116 | ctx_->set_arg(arg_id, (uint16)d); |
117 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
118 | ctx_->set_arg(arg_id, (uint32)d); |
119 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
120 | ctx_->set_arg(arg_id, (uint64)d); |
121 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
122 | // use f32 to interact with python |
123 | ctx_->set_arg(arg_id, (float32)d); |
124 | } else { |
125 | TI_NOT_IMPLEMENTED |
126 | } |
127 | } |
128 | |
129 | void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) { |
130 | TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array, |
131 | "Assigning scalar value to external (numpy) array argument is " |
132 | "not allowed." ); |
133 | |
134 | ActionRecorder::get_instance().record( |
135 | "set_kernel_arg_integer" , |
136 | {ActionArg("kernel_name" , kernel_->name), ActionArg("arg_id" , arg_id), |
137 | ActionArg("val" , d)}); |
138 | |
139 | auto dt = kernel_->parameter_list[arg_id].get_dtype(); |
140 | if (dt->is_primitive(PrimitiveTypeID::i32)) { |
141 | ctx_->set_arg(arg_id, (int32)d); |
142 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
143 | ctx_->set_arg(arg_id, (int64)d); |
144 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
145 | ctx_->set_arg(arg_id, (int8)d); |
146 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
147 | ctx_->set_arg(arg_id, (int16)d); |
148 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
149 | ctx_->set_arg(arg_id, (uint8)d); |
150 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
151 | ctx_->set_arg(arg_id, (uint16)d); |
152 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
153 | ctx_->set_arg(arg_id, (uint32)d); |
154 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
155 | ctx_->set_arg(arg_id, (uint64)d); |
156 | } else { |
157 | TI_INFO(dt->to_string()); |
158 | TI_NOT_IMPLEMENTED |
159 | } |
160 | } |
161 | |
162 | void Kernel::LaunchContextBuilder::set_arg_uint(int arg_id, uint64 d) { |
163 | set_arg_int(arg_id, d); |
164 | } |
165 | |
166 | void Kernel::LaunchContextBuilder::(int i, int j, int32 d) { |
167 | ctx_->extra_args[i][j] = d; |
168 | } |
169 | |
170 | void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape( |
171 | int arg_id, |
172 | uintptr_t ptr, |
173 | uint64 size, |
174 | const std::vector<int64> &shape) { |
175 | TI_ASSERT_INFO( |
176 | kernel_->parameter_list[arg_id].is_array, |
177 | "Assigning external (numpy) array to scalar argument is not allowed." ); |
178 | |
179 | ActionRecorder::get_instance().record( |
180 | "set_kernel_arg_ext_ptr" , |
181 | {ActionArg("kernel_name" , kernel_->name), ActionArg("arg_id" , arg_id), |
182 | ActionArg("address" , fmt::format("0x{:x}" , ptr)), |
183 | ActionArg("array_size_in_bytes" , (int64)size)}); |
184 | |
185 | TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices, |
186 | "External array cannot have > {max_num_indices} indices" ); |
187 | ctx_->set_arg_external_array(arg_id, ptr, size, shape); |
188 | } |
189 | |
190 | void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id, |
191 | const Ndarray &arr) { |
192 | intptr_t ptr = arr.get_device_allocation_ptr_as_int(); |
193 | TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices, |
194 | "External array cannot have > {max_num_indices} indices" ); |
195 | ctx_->set_arg_ndarray(arg_id, ptr, arr.shape); |
196 | } |
197 | |
198 | void Kernel::LaunchContextBuilder::set_arg_ndarray_with_grad( |
199 | int arg_id, |
200 | const Ndarray &arr, |
201 | const Ndarray &arr_grad) { |
202 | intptr_t ptr = arr.get_device_allocation_ptr_as_int(); |
203 | intptr_t ptr_grad = arr_grad.get_device_allocation_ptr_as_int(); |
204 | TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices, |
205 | "External array cannot have > {max_num_indices} indices" ); |
206 | ctx_->set_arg_ndarray(arg_id, ptr, arr.shape, true, ptr_grad); |
207 | } |
208 | |
209 | void Kernel::LaunchContextBuilder::set_arg_texture(int arg_id, |
210 | const Texture &tex) { |
211 | intptr_t ptr = tex.get_device_allocation_ptr_as_int(); |
212 | ctx_->set_arg_texture(arg_id, ptr); |
213 | } |
214 | |
215 | void Kernel::LaunchContextBuilder::set_arg_rw_texture(int arg_id, |
216 | const Texture &tex) { |
217 | intptr_t ptr = tex.get_device_allocation_ptr_as_int(); |
218 | ctx_->set_arg_rw_texture(arg_id, ptr, tex.get_size()); |
219 | } |
220 | |
221 | void Kernel::LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) { |
222 | TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array, |
223 | "Assigning scalar value to external (numpy) array argument is " |
224 | "not allowed." ); |
225 | |
226 | if (!kernel_->is_evaluator) { |
227 | ActionRecorder::get_instance().record( |
228 | "set_arg_raw" , |
229 | {ActionArg("kernel_name" , kernel_->name), ActionArg("arg_id" , arg_id), |
230 | ActionArg("val" , (int64)d)}); |
231 | } |
232 | ctx_->set_arg<uint64>(arg_id, d); |
233 | } |
234 | |
235 | RuntimeContext &Kernel::LaunchContextBuilder::get_context() { |
236 | kernel_->program->prepare_runtime_context(ctx_); |
237 | return *ctx_; |
238 | } |
239 | |
240 | template <typename T> |
241 | T Kernel::fetch_ret(DataType dt, int i) { |
242 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
243 | return (T)program->fetch_result<float32>(i); |
244 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
245 | return (T)program->fetch_result<float64>(i); |
246 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
247 | return (T)program->fetch_result<int32>(i); |
248 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
249 | return (T)program->fetch_result<int64>(i); |
250 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
251 | return (T)program->fetch_result<int8>(i); |
252 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
253 | return (T)program->fetch_result<int16>(i); |
254 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
255 | return (T)program->fetch_result<uint8>(i); |
256 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
257 | return (T)program->fetch_result<uint16>(i); |
258 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
259 | return (T)program->fetch_result<uint32>(i); |
260 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
261 | return (T)program->fetch_result<uint64>(i); |
262 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
263 | // use f32 to interact with python |
264 | return (T)program->fetch_result<float32>(i); |
265 | } else { |
266 | TI_NOT_IMPLEMENTED |
267 | } |
268 | } |
269 | |
270 | float64 Kernel::get_ret_float(int i) { |
271 | auto dt = rets[i].dt->get_compute_type(); |
272 | return fetch_ret<float64>(dt, i); |
273 | } |
274 | |
275 | int64 Kernel::get_ret_int(int i) { |
276 | auto dt = rets[i].dt->get_compute_type(); |
277 | return fetch_ret<int64>(dt, i); |
278 | } |
279 | |
280 | uint64 Kernel::get_ret_uint(int i) { |
281 | auto dt = rets[i].dt->get_compute_type(); |
282 | return fetch_ret<uint64>(dt, i); |
283 | } |
284 | |
285 | std::vector<int64> Kernel::get_ret_int_tensor(int i) { |
286 | DataType dt = rets[i].dt->as<TensorType>()->get_element_type(); |
287 | int size = rets[i].dt->as<TensorType>()->get_num_elements(); |
288 | std::vector<int64> res; |
289 | for (int j = 0; j < size; j++) { |
290 | res.emplace_back(fetch_ret<int64>(dt, j)); |
291 | } |
292 | return res; |
293 | } |
294 | |
295 | std::vector<uint64> Kernel::get_ret_uint_tensor(int i) { |
296 | DataType dt = rets[i].dt->as<TensorType>()->get_element_type(); |
297 | int size = rets[i].dt->as<TensorType>()->get_num_elements(); |
298 | std::vector<uint64> res; |
299 | for (int j = 0; j < size; j++) { |
300 | res.emplace_back(fetch_ret<uint64>(dt, j)); |
301 | } |
302 | return res; |
303 | } |
304 | |
305 | std::vector<float64> Kernel::get_ret_float_tensor(int i) { |
306 | DataType dt = rets[i].dt->as<TensorType>()->get_element_type(); |
307 | int size = rets[i].dt->as<TensorType>()->get_num_elements(); |
308 | std::vector<float64> res; |
309 | for (int j = 0; j < size; j++) { |
310 | res.emplace_back(fetch_ret<float64>(dt, j)); |
311 | } |
312 | return res; |
313 | } |
314 | |
315 | std::string Kernel::get_name() const { |
316 | return name; |
317 | } |
318 | |
319 | void Kernel::init(Program &program, |
320 | const std::function<void()> &func, |
321 | const std::string &primal_name, |
322 | AutodiffMode autodiff_mode) { |
323 | this->autodiff_mode = autodiff_mode; |
324 | this->lowered_ = false; |
325 | this->program = &program; |
326 | |
327 | is_accessor = false; |
328 | is_evaluator = false; |
329 | compiled_ = nullptr; |
330 | context = std::make_unique<FrontendContext>(program.compile_config().arch); |
331 | ir = context->get_root(); |
332 | ir_is_ast_ = true; |
333 | |
334 | if (autodiff_mode == AutodiffMode::kNone) { |
335 | name = primal_name; |
336 | } else if (autodiff_mode == AutodiffMode::kCheckAutodiffValid) { |
337 | name = primal_name + "_validate_grad" ; |
338 | } else if (autodiff_mode == AutodiffMode::kForward) { |
339 | name = primal_name + "_forward_grad" ; |
340 | } else if (autodiff_mode == AutodiffMode::kReverse) { |
341 | name = primal_name + "_reverse_grad" ; |
342 | } |
343 | |
344 | func(); |
345 | } |
346 | |
347 | TypedConstant Kernel::fetch_ret(const std::vector<int> &index) { |
348 | const Type *dt = ret_type->get_element_type(index); |
349 | int offset = ret_type->get_element_offset(index); |
350 | return program->fetch_result(offset, dt); |
351 | } |
352 | |
353 | float64 Kernel::get_struct_ret_float(const std::vector<int> &index) { |
354 | return fetch_ret(index).val_float(); |
355 | } |
356 | |
357 | int64 Kernel::get_struct_ret_int(const std::vector<int> &index) { |
358 | return fetch_ret(index).val_int(); |
359 | } |
360 | |
361 | uint64 Kernel::get_struct_ret_uint(const std::vector<int> &index) { |
362 | return fetch_ret(index).val_uint(); |
363 | } |
364 | |
365 | } // namespace taichi::lang |
366 | |