1 | #pragma once |
2 | #include <vector> |
3 | #include <string> |
4 | #include <unordered_map> |
5 | #include "taichi/ir/type.h" |
6 | #include "taichi/aot/module_data.h" |
7 | #define TI_RUNTIME_HOST |
8 | #include "taichi/program/context.h" |
9 | #undef TI_RUNTIME_HOST |
10 | |
11 | template <typename T, typename G> |
12 | T taichi_union_cast_with_different_sizes(G g); |
13 | |
14 | namespace taichi::lang { |
15 | class AotModuleBuilder; |
16 | class Ndarray; |
17 | class Texture; |
18 | class Kernel; |
19 | |
20 | namespace aot { |
21 | // Currently only scalar, matrix and ndarray are supported. |
22 | enum class ArgKind { |
23 | kScalar, |
24 | kMatrix, |
25 | kNdarray, |
26 | kTexture, |
27 | kRWTexture, |
28 | kUnknown |
29 | }; |
30 | |
31 | /** |
32 | * Symbolic argument used in building `Dispatch` nodes in the `Graph`. |
33 | */ |
34 | struct Arg { |
35 | ArgKind tag; |
36 | std::string name; |
37 | // Ndarray: element_dtype = dtype + element_shape |
38 | // Texture: element_shape carries [width, height, depth] info |
39 | // dtype_id carries channel_format info |
40 | PrimitiveTypeID dtype_id; |
41 | size_t field_dim; |
42 | std::vector<int> element_shape; |
43 | |
44 | // For texture |
45 | size_t num_channels; // TODO: maybe rename field_dim and merge? |
46 | |
47 | // For serialization & deserialization |
48 | explicit Arg() |
49 | : tag(ArgKind::kUnknown), |
50 | name("" ), |
51 | dtype_id(PrimitiveTypeID::unknown), |
52 | field_dim(0), |
53 | element_shape({}) { |
54 | } |
55 | |
56 | explicit Arg(ArgKind tag, |
57 | const std::string &name, |
58 | |
59 | PrimitiveTypeID dtype_id, |
60 | size_t field_dim, |
61 | const std::vector<int> &element_shape) |
62 | : tag(tag), |
63 | name(name), |
64 | dtype_id(dtype_id), |
65 | field_dim(field_dim), |
66 | element_shape(element_shape) { |
67 | } |
68 | |
69 | // Python/C++ interface that's user facing. |
70 | explicit Arg(ArgKind tag, |
71 | const std::string &name, |
72 | const DataType &dtype, |
73 | size_t dim = 0, |
74 | const std::vector<int> &element_shape = {}) |
75 | : tag(tag), name(name), element_shape(element_shape) { |
76 | if (tag == ArgKind::kTexture || tag == ArgKind::kRWTexture) { |
77 | num_channels = dim; |
78 | } else { |
79 | field_dim = dim; |
80 | } |
81 | dtype_id = dtype->as<PrimitiveType>()->type; |
82 | } |
83 | |
84 | DataType dtype() const { |
85 | return PrimitiveType::get(dtype_id); |
86 | } |
87 | |
88 | bool operator==(const Arg &other) const { |
89 | return tag == other.tag && name == other.name && |
90 | field_dim == other.field_dim && dtype_id == other.dtype_id && |
91 | element_shape == other.element_shape; |
92 | } |
93 | |
94 | bool operator!=(const Arg &other) const { |
95 | return !(*this == other); |
96 | } |
97 | |
98 | TI_IO_DEF(name, dtype_id, field_dim, tag, element_shape, num_channels); |
99 | }; |
100 | |
101 | /** |
102 | * Runtime value used in graph execution. |
103 | */ |
104 | struct TI_DLL_EXPORT IValue { |
105 | public: |
106 | uint64 val; |
107 | ArgKind tag; |
108 | |
109 | static IValue create(const Ndarray &ndarray) { |
110 | return IValue(reinterpret_cast<intptr_t>(&ndarray), ArgKind::kNdarray); |
111 | } |
112 | |
113 | static IValue create(const Texture &tex) { |
114 | return IValue(reinterpret_cast<intptr_t>(&tex), ArgKind::kTexture); |
115 | } |
116 | |
117 | template <typename T, |
118 | typename = std::enable_if_t<!std::is_same<T, Ndarray>::value, void>> |
119 | static IValue create(T v) { |
120 | return IValue(taichi_union_cast_with_different_sizes<uint64>(v), |
121 | ArgKind::kScalar); |
122 | } |
123 | |
124 | private: |
125 | IValue(uint64 val, ArgKind tag) : val(val), tag(tag) { |
126 | } |
127 | }; |
128 | |
129 | class TI_DLL_EXPORT Kernel { |
130 | public: |
131 | // Rule of 5 to make MSVC happy |
132 | Kernel() = default; |
133 | virtual ~Kernel() = default; |
134 | Kernel(const Kernel &) = delete; |
135 | Kernel &operator=(const Kernel &) = delete; |
136 | Kernel(Kernel &&) = default; |
137 | Kernel &operator=(Kernel &&) = default; |
138 | |
139 | /** |
140 | * @brief Launches the kernel to the device |
141 | * |
142 | * This does not manage the device to host synchronization. |
143 | * |
144 | * @param ctx Host context |
145 | */ |
146 | virtual void launch(RuntimeContext *ctx) = 0; |
147 | }; |
148 | |
149 | struct CompiledDispatch { |
150 | std::string kernel_name; |
151 | std::vector<Arg> symbolic_args; |
152 | Kernel *compiled_kernel{nullptr}; |
153 | taichi::lang::Kernel *ti_kernel{nullptr}; |
154 | |
155 | TI_IO_DEF(kernel_name, symbolic_args); |
156 | }; |
157 | |
158 | struct TI_DLL_EXPORT CompiledGraph { |
159 | std::vector<CompiledDispatch> dispatches; |
160 | std::unordered_map<std::string, aot::Arg> args; |
161 | RuntimeContext ctx_; |
162 | |
163 | void run(const std::unordered_map<std::string, IValue> &args) const; |
164 | |
165 | TI_IO_DEF(dispatches); |
166 | }; |
167 | |
168 | } // namespace aot |
169 | } // namespace taichi::lang |
170 | |