1#pragma once
2#include <vector>
3#include <string>
4#include <unordered_map>
5#include "taichi/ir/type.h"
6#include "taichi/aot/module_data.h"
7#define TI_RUNTIME_HOST
8#include "taichi/program/context.h"
9#undef TI_RUNTIME_HOST
10
11template <typename T, typename G>
12T taichi_union_cast_with_different_sizes(G g);
13
14namespace taichi::lang {
15class AotModuleBuilder;
16class Ndarray;
17class Texture;
18class Kernel;
19
20namespace aot {
21// Currently only scalar, matrix and ndarray are supported.
22enum class ArgKind {
23 kScalar,
24 kMatrix,
25 kNdarray,
26 kTexture,
27 kRWTexture,
28 kUnknown
29};
30
31/**
32 * Symbolic argument used in building `Dispatch` nodes in the `Graph`.
33 */
34struct Arg {
35 ArgKind tag;
36 std::string name;
37 // Ndarray: element_dtype = dtype + element_shape
38 // Texture: element_shape carries [width, height, depth] info
39 // dtype_id carries channel_format info
40 PrimitiveTypeID dtype_id;
41 size_t field_dim;
42 std::vector<int> element_shape;
43
44 // For texture
45 size_t num_channels; // TODO: maybe rename field_dim and merge?
46
47 // For serialization & deserialization
48 explicit Arg()
49 : tag(ArgKind::kUnknown),
50 name(""),
51 dtype_id(PrimitiveTypeID::unknown),
52 field_dim(0),
53 element_shape({}) {
54 }
55
56 explicit Arg(ArgKind tag,
57 const std::string &name,
58
59 PrimitiveTypeID dtype_id,
60 size_t field_dim,
61 const std::vector<int> &element_shape)
62 : tag(tag),
63 name(name),
64 dtype_id(dtype_id),
65 field_dim(field_dim),
66 element_shape(element_shape) {
67 }
68
69 // Python/C++ interface that's user facing.
70 explicit Arg(ArgKind tag,
71 const std::string &name,
72 const DataType &dtype,
73 size_t dim = 0,
74 const std::vector<int> &element_shape = {})
75 : tag(tag), name(name), element_shape(element_shape) {
76 if (tag == ArgKind::kTexture || tag == ArgKind::kRWTexture) {
77 num_channels = dim;
78 } else {
79 field_dim = dim;
80 }
81 dtype_id = dtype->as<PrimitiveType>()->type;
82 }
83
84 DataType dtype() const {
85 return PrimitiveType::get(dtype_id);
86 }
87
88 bool operator==(const Arg &other) const {
89 return tag == other.tag && name == other.name &&
90 field_dim == other.field_dim && dtype_id == other.dtype_id &&
91 element_shape == other.element_shape;
92 }
93
94 bool operator!=(const Arg &other) const {
95 return !(*this == other);
96 }
97
98 TI_IO_DEF(name, dtype_id, field_dim, tag, element_shape, num_channels);
99};
100
101/**
102 * Runtime value used in graph execution.
103 */
104struct TI_DLL_EXPORT IValue {
105 public:
106 uint64 val;
107 ArgKind tag;
108
109 static IValue create(const Ndarray &ndarray) {
110 return IValue(reinterpret_cast<intptr_t>(&ndarray), ArgKind::kNdarray);
111 }
112
113 static IValue create(const Texture &tex) {
114 return IValue(reinterpret_cast<intptr_t>(&tex), ArgKind::kTexture);
115 }
116
117 template <typename T,
118 typename = std::enable_if_t<!std::is_same<T, Ndarray>::value, void>>
119 static IValue create(T v) {
120 return IValue(taichi_union_cast_with_different_sizes<uint64>(v),
121 ArgKind::kScalar);
122 }
123
124 private:
125 IValue(uint64 val, ArgKind tag) : val(val), tag(tag) {
126 }
127};
128
129class TI_DLL_EXPORT Kernel {
130 public:
131 // Rule of 5 to make MSVC happy
132 Kernel() = default;
133 virtual ~Kernel() = default;
134 Kernel(const Kernel &) = delete;
135 Kernel &operator=(const Kernel &) = delete;
136 Kernel(Kernel &&) = default;
137 Kernel &operator=(Kernel &&) = default;
138
139 /**
140 * @brief Launches the kernel to the device
141 *
142 * This does not manage the device to host synchronization.
143 *
144 * @param ctx Host context
145 */
146 virtual void launch(RuntimeContext *ctx) = 0;
147};
148
149struct CompiledDispatch {
150 std::string kernel_name;
151 std::vector<Arg> symbolic_args;
152 Kernel *compiled_kernel{nullptr};
153 taichi::lang::Kernel *ti_kernel{nullptr};
154
155 TI_IO_DEF(kernel_name, symbolic_args);
156};
157
158struct TI_DLL_EXPORT CompiledGraph {
159 std::vector<CompiledDispatch> dispatches;
160 std::unordered_map<std::string, aot::Arg> args;
161 RuntimeContext ctx_;
162
163 void run(const std::unordered_map<std::string, IValue> &args) const;
164
165 TI_IO_DEF(dispatches);
166};
167
168} // namespace aot
169} // namespace taichi::lang
170