1#include "taichi/program/kernel.h"
2
3#include "taichi/rhi/cuda/cuda_driver.h"
4#include "taichi/codegen/codegen.h"
5#include "taichi/common/logging.h"
6#include "taichi/common/task.h"
7#include "taichi/ir/statements.h"
8#include "taichi/program/program.h"
9#include "taichi/util/action_recorder.h"
10
11#ifdef TI_WITH_LLVM
12#include "taichi/runtime/program_impls/llvm/llvm_program.h"
13#endif
14
15namespace taichi::lang {
16
17class Function;
18
19Kernel::Kernel(Program &program,
20 const std::function<void()> &func,
21 const std::string &primal_name,
22 AutodiffMode autodiff_mode) {
23 this->init(program, func, primal_name, autodiff_mode);
24}
25
26Kernel::Kernel(Program &program,
27 const std::function<void(Kernel *)> &func,
28 const std::string &primal_name,
29 AutodiffMode autodiff_mode) {
30 // due to #6362, we cannot write [func, this] { return func(this); }
31 this->init(
32 program, [&] { return func(this); }, primal_name, autodiff_mode);
33}
34
35Kernel::Kernel(Program &program,
36 std::unique_ptr<IRNode> &&ir,
37 const std::string &primal_name,
38 AutodiffMode autodiff_mode)
39 : autodiff_mode(autodiff_mode), lowered_(false) {
40 this->ir = std::move(ir);
41 this->program = &program;
42 is_accessor = false;
43 is_evaluator = false;
44 compiled_ = nullptr;
45 ir_is_ast_ = false; // CHI IR
46
47 if (autodiff_mode == AutodiffMode::kNone) {
48 name = primal_name;
49 } else if (autodiff_mode == AutodiffMode::kForward) {
50 name = primal_name + "_forward_grad";
51 } else if (autodiff_mode == AutodiffMode::kReverse) {
52 name = primal_name + "_reverse_grad";
53 }
54}
55
56void Kernel::compile(const CompileConfig &compile_config) {
57 compiled_ = program->compile(compile_config, *this);
58}
59
60void Kernel::operator()(const CompileConfig &compile_config,
61 LaunchContextBuilder &ctx_builder) {
62 if (!compiled_) {
63 compile(compile_config);
64 }
65
66 compiled_(ctx_builder.get_context());
67
68 const auto arch = compile_config.arch;
69 if (compile_config.debug &&
70 (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu)) {
71 program->check_runtime_error();
72 }
73}
74
75Kernel::LaunchContextBuilder Kernel::make_launch_context() {
76 return LaunchContextBuilder(this);
77}
78
79Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel,
80 RuntimeContext *ctx)
81 : kernel_(kernel), owned_ctx_(nullptr), ctx_(ctx) {
82}
83
84Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel)
85 : kernel_(kernel),
86 owned_ctx_(std::make_unique<RuntimeContext>()),
87 ctx_(owned_ctx_.get()) {
88}
89
90void Kernel::LaunchContextBuilder::set_arg_float(int arg_id, float64 d) {
91 TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array,
92 "Assigning scalar value to external (numpy) array argument is "
93 "not allowed.");
94
95 ActionRecorder::get_instance().record(
96 "set_kernel_arg_float64",
97 {ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
98 ActionArg("val", d)});
99
100 auto dt = kernel_->parameter_list[arg_id].get_dtype();
101 if (dt->is_primitive(PrimitiveTypeID::f32)) {
102 ctx_->set_arg(arg_id, (float32)d);
103 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
104 ctx_->set_arg(arg_id, (float64)d);
105 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
106 ctx_->set_arg(arg_id, (int32)d);
107 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
108 ctx_->set_arg(arg_id, (int64)d);
109 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
110 ctx_->set_arg(arg_id, (int8)d);
111 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
112 ctx_->set_arg(arg_id, (int16)d);
113 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
114 ctx_->set_arg(arg_id, (uint8)d);
115 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
116 ctx_->set_arg(arg_id, (uint16)d);
117 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
118 ctx_->set_arg(arg_id, (uint32)d);
119 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
120 ctx_->set_arg(arg_id, (uint64)d);
121 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
122 // use f32 to interact with python
123 ctx_->set_arg(arg_id, (float32)d);
124 } else {
125 TI_NOT_IMPLEMENTED
126 }
127}
128
129void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {
130 TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array,
131 "Assigning scalar value to external (numpy) array argument is "
132 "not allowed.");
133
134 ActionRecorder::get_instance().record(
135 "set_kernel_arg_integer",
136 {ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
137 ActionArg("val", d)});
138
139 auto dt = kernel_->parameter_list[arg_id].get_dtype();
140 if (dt->is_primitive(PrimitiveTypeID::i32)) {
141 ctx_->set_arg(arg_id, (int32)d);
142 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
143 ctx_->set_arg(arg_id, (int64)d);
144 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
145 ctx_->set_arg(arg_id, (int8)d);
146 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
147 ctx_->set_arg(arg_id, (int16)d);
148 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
149 ctx_->set_arg(arg_id, (uint8)d);
150 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
151 ctx_->set_arg(arg_id, (uint16)d);
152 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
153 ctx_->set_arg(arg_id, (uint32)d);
154 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
155 ctx_->set_arg(arg_id, (uint64)d);
156 } else {
157 TI_INFO(dt->to_string());
158 TI_NOT_IMPLEMENTED
159 }
160}
161
162void Kernel::LaunchContextBuilder::set_arg_uint(int arg_id, uint64 d) {
163 set_arg_int(arg_id, d);
164}
165
166void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {
167 ctx_->extra_args[i][j] = d;
168}
169
170void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape(
171 int arg_id,
172 uintptr_t ptr,
173 uint64 size,
174 const std::vector<int64> &shape) {
175 TI_ASSERT_INFO(
176 kernel_->parameter_list[arg_id].is_array,
177 "Assigning external (numpy) array to scalar argument is not allowed.");
178
179 ActionRecorder::get_instance().record(
180 "set_kernel_arg_ext_ptr",
181 {ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
182 ActionArg("address", fmt::format("0x{:x}", ptr)),
183 ActionArg("array_size_in_bytes", (int64)size)});
184
185 TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices,
186 "External array cannot have > {max_num_indices} indices");
187 ctx_->set_arg_external_array(arg_id, ptr, size, shape);
188}
189
190void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id,
191 const Ndarray &arr) {
192 intptr_t ptr = arr.get_device_allocation_ptr_as_int();
193 TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices,
194 "External array cannot have > {max_num_indices} indices");
195 ctx_->set_arg_ndarray(arg_id, ptr, arr.shape);
196}
197
198void Kernel::LaunchContextBuilder::set_arg_ndarray_with_grad(
199 int arg_id,
200 const Ndarray &arr,
201 const Ndarray &arr_grad) {
202 intptr_t ptr = arr.get_device_allocation_ptr_as_int();
203 intptr_t ptr_grad = arr_grad.get_device_allocation_ptr_as_int();
204 TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices,
205 "External array cannot have > {max_num_indices} indices");
206 ctx_->set_arg_ndarray(arg_id, ptr, arr.shape, true, ptr_grad);
207}
208
209void Kernel::LaunchContextBuilder::set_arg_texture(int arg_id,
210 const Texture &tex) {
211 intptr_t ptr = tex.get_device_allocation_ptr_as_int();
212 ctx_->set_arg_texture(arg_id, ptr);
213}
214
215void Kernel::LaunchContextBuilder::set_arg_rw_texture(int arg_id,
216 const Texture &tex) {
217 intptr_t ptr = tex.get_device_allocation_ptr_as_int();
218 ctx_->set_arg_rw_texture(arg_id, ptr, tex.get_size());
219}
220
221void Kernel::LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) {
222 TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array,
223 "Assigning scalar value to external (numpy) array argument is "
224 "not allowed.");
225
226 if (!kernel_->is_evaluator) {
227 ActionRecorder::get_instance().record(
228 "set_arg_raw",
229 {ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
230 ActionArg("val", (int64)d)});
231 }
232 ctx_->set_arg<uint64>(arg_id, d);
233}
234
235RuntimeContext &Kernel::LaunchContextBuilder::get_context() {
236 kernel_->program->prepare_runtime_context(ctx_);
237 return *ctx_;
238}
239
240template <typename T>
241T Kernel::fetch_ret(DataType dt, int i) {
242 if (dt->is_primitive(PrimitiveTypeID::f32)) {
243 return (T)program->fetch_result<float32>(i);
244 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
245 return (T)program->fetch_result<float64>(i);
246 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
247 return (T)program->fetch_result<int32>(i);
248 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
249 return (T)program->fetch_result<int64>(i);
250 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
251 return (T)program->fetch_result<int8>(i);
252 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
253 return (T)program->fetch_result<int16>(i);
254 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
255 return (T)program->fetch_result<uint8>(i);
256 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
257 return (T)program->fetch_result<uint16>(i);
258 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
259 return (T)program->fetch_result<uint32>(i);
260 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
261 return (T)program->fetch_result<uint64>(i);
262 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
263 // use f32 to interact with python
264 return (T)program->fetch_result<float32>(i);
265 } else {
266 TI_NOT_IMPLEMENTED
267 }
268}
269
270float64 Kernel::get_ret_float(int i) {
271 auto dt = rets[i].dt->get_compute_type();
272 return fetch_ret<float64>(dt, i);
273}
274
275int64 Kernel::get_ret_int(int i) {
276 auto dt = rets[i].dt->get_compute_type();
277 return fetch_ret<int64>(dt, i);
278}
279
280uint64 Kernel::get_ret_uint(int i) {
281 auto dt = rets[i].dt->get_compute_type();
282 return fetch_ret<uint64>(dt, i);
283}
284
285std::vector<int64> Kernel::get_ret_int_tensor(int i) {
286 DataType dt = rets[i].dt->as<TensorType>()->get_element_type();
287 int size = rets[i].dt->as<TensorType>()->get_num_elements();
288 std::vector<int64> res;
289 for (int j = 0; j < size; j++) {
290 res.emplace_back(fetch_ret<int64>(dt, j));
291 }
292 return res;
293}
294
295std::vector<uint64> Kernel::get_ret_uint_tensor(int i) {
296 DataType dt = rets[i].dt->as<TensorType>()->get_element_type();
297 int size = rets[i].dt->as<TensorType>()->get_num_elements();
298 std::vector<uint64> res;
299 for (int j = 0; j < size; j++) {
300 res.emplace_back(fetch_ret<uint64>(dt, j));
301 }
302 return res;
303}
304
305std::vector<float64> Kernel::get_ret_float_tensor(int i) {
306 DataType dt = rets[i].dt->as<TensorType>()->get_element_type();
307 int size = rets[i].dt->as<TensorType>()->get_num_elements();
308 std::vector<float64> res;
309 for (int j = 0; j < size; j++) {
310 res.emplace_back(fetch_ret<float64>(dt, j));
311 }
312 return res;
313}
314
315std::string Kernel::get_name() const {
316 return name;
317}
318
319void Kernel::init(Program &program,
320 const std::function<void()> &func,
321 const std::string &primal_name,
322 AutodiffMode autodiff_mode) {
323 this->autodiff_mode = autodiff_mode;
324 this->lowered_ = false;
325 this->program = &program;
326
327 is_accessor = false;
328 is_evaluator = false;
329 compiled_ = nullptr;
330 context = std::make_unique<FrontendContext>(program.compile_config().arch);
331 ir = context->get_root();
332 ir_is_ast_ = true;
333
334 if (autodiff_mode == AutodiffMode::kNone) {
335 name = primal_name;
336 } else if (autodiff_mode == AutodiffMode::kCheckAutodiffValid) {
337 name = primal_name + "_validate_grad";
338 } else if (autodiff_mode == AutodiffMode::kForward) {
339 name = primal_name + "_forward_grad";
340 } else if (autodiff_mode == AutodiffMode::kReverse) {
341 name = primal_name + "_reverse_grad";
342 }
343
344 func();
345}
346
347TypedConstant Kernel::fetch_ret(const std::vector<int> &index) {
348 const Type *dt = ret_type->get_element_type(index);
349 int offset = ret_type->get_element_offset(index);
350 return program->fetch_result(offset, dt);
351}
352
353float64 Kernel::get_struct_ret_float(const std::vector<int> &index) {
354 return fetch_ret(index).val_float();
355}
356
357int64 Kernel::get_struct_ret_int(const std::vector<int> &index) {
358 return fetch_ret(index).val_int();
359}
360
361uint64 Kernel::get_struct_ret_uint(const std::vector<int> &index) {
362 return fetch_ret(index).val_uint();
363}
364
365} // namespace taichi::lang
366