1 | #include "taichi/aot/graph_data.h" |
2 | #include "taichi/program/program.h" |
3 | #include "taichi/program/ndarray.h" |
4 | #include "taichi/program/texture.h" |
5 | #include "taichi/program/kernel.h" |
6 | |
7 | #include <numeric> |
8 | |
9 | namespace taichi::lang { |
10 | namespace aot { |
11 | |
12 | void CompiledGraph::run( |
13 | const std::unordered_map<std::string, IValue> &args) const { |
14 | for (const auto &dispatch : dispatches) { |
15 | RuntimeContext ctx = ctx_; |
16 | |
17 | TI_ASSERT(dispatch.ti_kernel || dispatch.compiled_kernel); |
18 | |
19 | // Populate args metadata into RuntimeContext |
20 | const auto &symbolic_args_ = dispatch.symbolic_args; |
21 | for (int i = 0; i < symbolic_args_.size(); ++i) { |
22 | auto &symbolic_arg = symbolic_args_[i]; |
23 | auto found = args.find(symbolic_arg.name); |
24 | TI_ERROR_IF(found == args.end(), "Missing runtime value for {}" , |
25 | symbolic_arg.name); |
26 | const aot::IValue &ival = found->second; |
27 | if (symbolic_arg.tag == aot::ArgKind::kNdarray) { |
28 | TI_ASSERT(ival.tag == aot::ArgKind::kNdarray); |
29 | Ndarray *arr = reinterpret_cast<Ndarray *>(ival.val); |
30 | |
31 | TI_ERROR_IF(arr->get_element_shape() != symbolic_arg.element_shape, |
32 | "Mismatched shape information for argument {}" , |
33 | symbolic_arg.name); |
34 | TI_ERROR_IF(arr->shape.size() != symbolic_arg.field_dim, |
35 | "Dispatch node is compiled for argument {} with " |
36 | "field_dim={} but got an ndarray with field_dim={}" , |
37 | symbolic_arg.name, symbolic_arg.field_dim, |
38 | arr->shape.size()); |
39 | |
40 | // CGraph uses aot::Arg as symbolic argument, which represents |
41 | // TensorType via combination of element_shape and PrimitiveTypeID |
42 | // Therefore we only check for element_type for now. |
43 | // |
44 | // TODO(zhanlue): Replace all "element_shape + PrimitiveType" use cases |
45 | // with direct use of "TensorType", |
46 | // In the end, "element_shape" should only appear inside |
47 | // TensorType and nowhere else. |
48 | // |
49 | // This refactor includes aot::Arg, kernel::Arg, |
50 | // MetalDataType, and more... |
51 | DataType symbolic_arg_primitive_dtype = symbolic_arg.dtype(); |
52 | if (symbolic_arg.dtype()->is<TensorType>()) { |
53 | symbolic_arg_primitive_dtype = |
54 | symbolic_arg.dtype()->cast<TensorType>()->get_element_type(); |
55 | } |
56 | |
57 | DataType arr_primitive_dtype = arr->dtype; |
58 | if (arr->dtype->is<TensorType>()) { |
59 | arr_primitive_dtype = |
60 | arr->dtype->cast<TensorType>()->get_element_type(); |
61 | } |
62 | |
63 | TI_ERROR_IF(arr_primitive_dtype != symbolic_arg_primitive_dtype, |
64 | "Dispatch node is compiled for argument {} with " |
65 | "dtype={} but got an ndarray with dtype={}" , |
66 | symbolic_arg.name, symbolic_arg_primitive_dtype.to_string(), |
67 | arr_primitive_dtype.to_string()); |
68 | ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(), |
69 | arr->shape); |
70 | } else if (symbolic_arg.tag == aot::ArgKind::kScalar || |
71 | symbolic_arg.tag == aot::ArgKind::kMatrix) { |
72 | TI_ASSERT(ival.tag == aot::ArgKind::kScalar); |
73 | // Matrix args are flattened so they're same as scalars. |
74 | ctx.set_arg(i, ival.val); |
75 | } else if (symbolic_arg.tag == aot::ArgKind::kTexture) { |
76 | TI_ASSERT(ival.tag == aot::ArgKind::kTexture); |
77 | Texture *tex = reinterpret_cast<Texture *>(ival.val); |
78 | ctx.set_arg_texture(i, tex->get_device_allocation_ptr_as_int()); |
79 | } else if (symbolic_arg.tag == aot::ArgKind::kRWTexture) { |
80 | TI_ASSERT(ival.tag == aot::ArgKind::kTexture); |
81 | Texture *tex = reinterpret_cast<Texture *>(ival.val); |
82 | ctx.set_arg_rw_texture(i, tex->get_device_allocation_ptr_as_int(), |
83 | tex->get_size()); |
84 | } else { |
85 | TI_ERROR("Error in compiled graph: unknown tag {}" , ival.tag); |
86 | } |
87 | } |
88 | |
89 | if (dispatch.compiled_kernel) { |
90 | // Run cgraph loaded from AOT module |
91 | dispatch.compiled_kernel->launch(&ctx); |
92 | } else { |
93 | // JIT & Run |
94 | TI_ASSERT(dispatch.ti_kernel); |
95 | lang::Kernel::LaunchContextBuilder launch_ctx(dispatch.ti_kernel, &ctx); |
96 | auto *ker = dispatch.ti_kernel; |
97 | ker->operator()(ker->program->compile_config(), launch_ctx); |
98 | } |
99 | } |
100 | } |
101 | } // namespace aot |
102 | } // namespace taichi::lang |
103 | |