1#include "taichi/aot/graph_data.h"
2#include "taichi/program/program.h"
3#include "taichi/program/ndarray.h"
4#include "taichi/program/texture.h"
5#include "taichi/program/kernel.h"
6
7#include <numeric>
8
9namespace taichi::lang {
10namespace aot {
11
12void CompiledGraph::run(
13 const std::unordered_map<std::string, IValue> &args) const {
14 for (const auto &dispatch : dispatches) {
15 RuntimeContext ctx = ctx_;
16
17 TI_ASSERT(dispatch.ti_kernel || dispatch.compiled_kernel);
18
19 // Populate args metadata into RuntimeContext
20 const auto &symbolic_args_ = dispatch.symbolic_args;
21 for (int i = 0; i < symbolic_args_.size(); ++i) {
22 auto &symbolic_arg = symbolic_args_[i];
23 auto found = args.find(symbolic_arg.name);
24 TI_ERROR_IF(found == args.end(), "Missing runtime value for {}",
25 symbolic_arg.name);
26 const aot::IValue &ival = found->second;
27 if (symbolic_arg.tag == aot::ArgKind::kNdarray) {
28 TI_ASSERT(ival.tag == aot::ArgKind::kNdarray);
29 Ndarray *arr = reinterpret_cast<Ndarray *>(ival.val);
30
31 TI_ERROR_IF(arr->get_element_shape() != symbolic_arg.element_shape,
32 "Mismatched shape information for argument {}",
33 symbolic_arg.name);
34 TI_ERROR_IF(arr->shape.size() != symbolic_arg.field_dim,
35 "Dispatch node is compiled for argument {} with "
36 "field_dim={} but got an ndarray with field_dim={}",
37 symbolic_arg.name, symbolic_arg.field_dim,
38 arr->shape.size());
39
40 // CGraph uses aot::Arg as symbolic argument, which represents
41 // TensorType via combination of element_shape and PrimitiveTypeID
42 // Therefore we only check for element_type for now.
43 //
44 // TODO(zhanlue): Replace all "element_shape + PrimitiveType" use cases
45 // with direct use of "TensorType",
46 // In the end, "element_shape" should only appear inside
47 // TensorType and nowhere else.
48 //
49 // This refactor includes aot::Arg, kernel::Arg,
50 // MetalDataType, and more...
51 DataType symbolic_arg_primitive_dtype = symbolic_arg.dtype();
52 if (symbolic_arg.dtype()->is<TensorType>()) {
53 symbolic_arg_primitive_dtype =
54 symbolic_arg.dtype()->cast<TensorType>()->get_element_type();
55 }
56
57 DataType arr_primitive_dtype = arr->dtype;
58 if (arr->dtype->is<TensorType>()) {
59 arr_primitive_dtype =
60 arr->dtype->cast<TensorType>()->get_element_type();
61 }
62
63 TI_ERROR_IF(arr_primitive_dtype != symbolic_arg_primitive_dtype,
64 "Dispatch node is compiled for argument {} with "
65 "dtype={} but got an ndarray with dtype={}",
66 symbolic_arg.name, symbolic_arg_primitive_dtype.to_string(),
67 arr_primitive_dtype.to_string());
68 ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(),
69 arr->shape);
70 } else if (symbolic_arg.tag == aot::ArgKind::kScalar ||
71 symbolic_arg.tag == aot::ArgKind::kMatrix) {
72 TI_ASSERT(ival.tag == aot::ArgKind::kScalar);
73 // Matrix args are flattened so they're same as scalars.
74 ctx.set_arg(i, ival.val);
75 } else if (symbolic_arg.tag == aot::ArgKind::kTexture) {
76 TI_ASSERT(ival.tag == aot::ArgKind::kTexture);
77 Texture *tex = reinterpret_cast<Texture *>(ival.val);
78 ctx.set_arg_texture(i, tex->get_device_allocation_ptr_as_int());
79 } else if (symbolic_arg.tag == aot::ArgKind::kRWTexture) {
80 TI_ASSERT(ival.tag == aot::ArgKind::kTexture);
81 Texture *tex = reinterpret_cast<Texture *>(ival.val);
82 ctx.set_arg_rw_texture(i, tex->get_device_allocation_ptr_as_int(),
83 tex->get_size());
84 } else {
85 TI_ERROR("Error in compiled graph: unknown tag {}", ival.tag);
86 }
87 }
88
89 if (dispatch.compiled_kernel) {
90 // Run cgraph loaded from AOT module
91 dispatch.compiled_kernel->launch(&ctx);
92 } else {
93 // JIT & Run
94 TI_ASSERT(dispatch.ti_kernel);
95 lang::Kernel::LaunchContextBuilder launch_ctx(dispatch.ti_kernel, &ctx);
96 auto *ker = dispatch.ti_kernel;
97 ker->operator()(ker->program->compile_config(), launch_ctx);
98 }
99 }
100}
101} // namespace aot
102} // namespace taichi::lang
103