1 | #pragma once |
2 | |
3 | #include <atomic> |
4 | |
5 | #include "taichi/inc/constants.h" |
6 | #include "taichi/ir/expr.h" |
7 | #include "taichi/ir/snode_types.h" |
8 | #include "taichi/ir/type.h" |
9 | #include "taichi/program/snode_expr_utils.h" |
10 | |
11 | namespace taichi::lang { |
12 | class Program; |
13 | class SNodeRwAccessorsBank; |
14 | |
15 | /** |
16 | * Dimension (or axis) of a tensor. |
17 | * |
18 | * For example, in the frontend we have ti.ij, which is translated to |
19 | * {Axis{0}, Axis{1}}. |
20 | */ |
21 | class Axis { |
22 | public: |
23 | int value; |
24 | Axis() { |
25 | value = 0; |
26 | } |
27 | explicit Axis(int value) : value(value) { |
28 | TI_ERROR_UNLESS(0 <= value && value < taichi_max_num_indices, |
29 | "Too many dimensions. The maximum dimensionality is {}" , |
30 | taichi_max_num_indices); |
31 | } |
32 | }; |
33 | |
34 | /** |
35 | * SNode shape metadata at a specific Axis. |
36 | */ |
37 | struct { |
38 | /** |
39 | * Number of elements from root at this index. |
40 | */ |
41 | int {1}; |
42 | /** |
43 | * Shape at this index. |
44 | */ |
45 | int {1}; |
46 | /** |
47 | * Accumulated shape from the last activated index to the first one. |
48 | */ |
49 | int {1}; |
50 | /** |
51 | * Whether this index (axis) is activated. |
52 | */ |
53 | bool {false}; |
54 | }; |
55 | |
56 | /** |
57 | * Structural nodes |
58 | */ |
59 | class SNode { |
60 | public: |
61 | // This class decouples SNode from the frontend expression. |
62 | class GradInfoProvider { |
63 | public: |
64 | virtual ~GradInfoProvider() = default; |
65 | virtual bool is_primal() const = 0; |
66 | virtual SNodeGradType get_snode_grad_type() const = 0; |
67 | virtual SNode *adjoint_snode() const = 0; |
68 | virtual SNode *dual_snode() const = 0; |
69 | virtual SNode *adjoint_checkbit_snode() const = 0; |
70 | |
71 | template <typename T> |
72 | T *cast() { |
73 | return static_cast<T *>(this); |
74 | } |
75 | }; |
76 | std::vector<std::unique_ptr<SNode>> ch; |
77 | |
78 | AxisExtractor [taichi_max_num_indices]; |
79 | std::vector<int> index_offsets; |
80 | int num_active_indices{0}; |
81 | int physical_index_position[taichi_max_num_indices]{}; |
82 | // physical indices are (ti.i, ti.j, ti.k, ti.l, ...) |
83 | // physical_index_position[i] = |
84 | // which physical index does the i-th virtual index (the one exposed to |
85 | // programmers) refer to? i.e. in a[i, j, k], "i", "j", and "k" are virtual |
86 | // indices. |
87 | |
88 | static std::atomic<int> counter; |
89 | int id{0}; |
90 | int depth{0}; |
91 | |
92 | std::string name; |
93 | // Product of the |shape| of all the activated axes identified by |
94 | // |extractors|. |
95 | // See https://docs.taichi-lang.org/docs/internal for terms |
96 | // like cell and container. |
97 | int64 num_cells_per_container{1}; |
98 | int chunk_size{0}; |
99 | std::size_t cell_size_bytes{0}; |
100 | std::size_t offset_bytes_in_parent_cell{0}; |
101 | DataType dt; |
102 | bool has_ambient{false}; |
103 | TypedConstant ambient_val; |
104 | // Note: parent will not be set until structural nodes are compiled! |
105 | SNode *parent{nullptr}; |
106 | std::unique_ptr<GradInfoProvider> grad_info{nullptr}; |
107 | |
108 | // Quant |
109 | PrimitiveType *physical_type{nullptr}; // for bit_struct and quant_array only |
110 | int id_in_bit_struct{-1}; // for children of bit_struct only |
111 | bool is_bit_level{false}; // true if inside bit_struct or quant_array |
112 | |
113 | // Whether the path from root to |this| contains only `dense` SNodes. |
114 | bool is_path_all_dense{true}; |
115 | |
116 | explicit SNode(SNodeFieldMap *snode_to_fields = nullptr, |
117 | SNodeRwAccessorsBank *snode_rw_accessors_bank = nullptr); |
118 | |
119 | SNode(int depth, |
120 | SNodeType t, |
121 | SNodeFieldMap *snode_to_fields = nullptr, |
122 | SNodeRwAccessorsBank *snode_rw_accessors_bank = nullptr); |
123 | |
124 | SNode(const SNode &); |
125 | |
126 | ~SNode() = default; |
127 | |
128 | std::string node_type_name; |
129 | SNodeType type; |
130 | bool _morton{false}; |
131 | |
132 | std::string get_node_type_name() const; |
133 | |
134 | std::string get_node_type_name_hinted() const; |
135 | |
136 | SNode &insert_children(SNodeType t); |
137 | |
138 | SNode &create_node(std::vector<Axis> axes, |
139 | std::vector<int> sizes, |
140 | SNodeType type, |
141 | const std::string &tb); |
142 | |
143 | // SNodes maintains how flattened index bits are taken from indices |
144 | SNode &dense(const std::vector<Axis> &axes, |
145 | const std::vector<int> &sizes, |
146 | const std::string &tb) { |
147 | return create_node(axes, sizes, SNodeType::dense, tb); |
148 | } |
149 | |
150 | SNode &dense(const std::vector<Axis> &axes, |
151 | int sizes, |
152 | const std::string &tb) { |
153 | return create_node(axes, std::vector<int>{sizes}, SNodeType::dense, tb); |
154 | } |
155 | |
156 | SNode &dense(const Axis &axis, int size, const std::string &tb) { |
157 | return SNode::dense(std::vector<Axis>{axis}, size, tb); |
158 | } |
159 | |
160 | SNode &pointer(const std::vector<Axis> &axes, |
161 | const std::vector<int> &sizes, |
162 | const std::string &tb) { |
163 | return create_node(axes, sizes, SNodeType::pointer, tb); |
164 | } |
165 | |
166 | SNode &pointer(const std::vector<Axis> &axes, |
167 | int sizes, |
168 | const std::string &tb) { |
169 | return create_node(axes, std::vector<int>{sizes}, SNodeType::pointer, tb); |
170 | } |
171 | |
172 | SNode &pointer(const Axis &axis, int size, const std::string &tb) { |
173 | return SNode::pointer(std::vector<Axis>{axis}, size, tb); |
174 | } |
175 | |
176 | SNode &bitmasked(const std::vector<Axis> &axes, |
177 | const std::vector<int> &sizes, |
178 | const std::string &tb) { |
179 | return create_node(axes, sizes, SNodeType::bitmasked, tb); |
180 | } |
181 | |
182 | SNode &bitmasked(const std::vector<Axis> &axes, |
183 | int sizes, |
184 | const std::string &tb) { |
185 | return create_node(axes, std::vector<int>{sizes}, SNodeType::bitmasked, tb); |
186 | } |
187 | |
188 | SNode &bitmasked(const Axis &axis, int size, const std::string &tb) { |
189 | return SNode::bitmasked(std::vector<Axis>{axis}, size, tb); |
190 | } |
191 | |
192 | SNode &hash(const std::vector<Axis> &axes, |
193 | const std::vector<int> &sizes, |
194 | const std::string &tb) { |
195 | return create_node(axes, sizes, SNodeType::hash, tb); |
196 | } |
197 | |
198 | SNode &hash(const std::vector<Axis> &axes, int sizes, const std::string &tb) { |
199 | return create_node(axes, std::vector<int>{sizes}, SNodeType::hash, tb); |
200 | } |
201 | |
202 | SNode &hash(const Axis &axis, int size, const std::string &tb) { |
203 | return hash(std::vector<Axis>{axis}, size, tb); |
204 | } |
205 | |
206 | std::string type_name() { |
207 | return snode_type_name(type); |
208 | } |
209 | |
210 | SNode &bit_struct(BitStructType *bit_struct_type, const std::string &tb); |
211 | |
212 | SNode &quant_array(const std::vector<Axis> &axes, |
213 | const std::vector<int> &sizes, |
214 | int bits, |
215 | const std::string &tb); |
216 | |
217 | void print(); |
218 | |
219 | void set_index_offsets(std::vector<int> index_offsets); |
220 | |
221 | SNode &dynamic(const Axis &expr, |
222 | int n, |
223 | int chunk_size, |
224 | const std::string &tb); |
225 | |
226 | SNode &morton(bool val = true) { |
227 | _morton = val; |
228 | return *this; |
229 | } |
230 | |
231 | int child_id(SNode *c) { |
232 | for (int i = 0; i < (int)ch.size(); i++) { |
233 | if (ch[i].get() == c) { |
234 | return i; |
235 | } |
236 | } |
237 | return -1; |
238 | } |
239 | |
240 | bool has_null() const { |
241 | return type == SNodeType::pointer || type == SNodeType::hash; |
242 | } |
243 | |
244 | bool has_allocator() const { |
245 | return type == SNodeType::pointer || type == SNodeType::hash || |
246 | type == SNodeType::root; |
247 | } |
248 | |
249 | bool need_activation() const; |
250 | |
251 | bool is_primal() const; |
252 | |
253 | SNodeGradType get_snode_grad_type() const; |
254 | |
255 | bool is_place() const; |
256 | |
257 | bool is_scalar() const; |
258 | |
259 | bool has_adjoint() const; |
260 | |
261 | SNode *get_adjoint() const; |
262 | |
263 | bool has_adjoint_checkbit() const; |
264 | |
265 | SNode *get_adjoint_checkbit() const; |
266 | |
267 | bool has_dual() const; |
268 | |
269 | SNode *get_dual() const; |
270 | |
271 | SNode *get_least_sparse_ancestor() const; |
272 | |
273 | std::string get_name() const { |
274 | return node_type_name; |
275 | } |
276 | |
277 | std::string element_listgen_func_name() const { |
278 | return get_name() + "_element_listgen" ; |
279 | } |
280 | |
281 | std::string get_ch_from_parent_func_name() const { |
282 | TI_ASSERT(parent != nullptr); |
283 | return fmt::format("get_ch_{}_to_{}" , parent->get_name(), get_name()); |
284 | } |
285 | |
286 | std::string refine_coordinates_func_name() const { |
287 | TI_ASSERT(type != SNodeType::place); |
288 | return fmt::format("{}_refine_coordinates" , get_name()); |
289 | } |
290 | |
291 | int64 max_num_elements() const { |
292 | return num_cells_per_container; |
293 | } |
294 | |
295 | int64 get_total_num_elements_towards_root() const { |
296 | int64 total_num_elemts = 1; |
297 | for (auto *s = this; s != nullptr; s = s->parent) |
298 | total_num_elemts *= (int)s->max_num_elements(); |
299 | return total_num_elemts; |
300 | } |
301 | |
302 | int shape_along_axis(int i) const; |
303 | |
304 | void place(Expr &expr, const std::vector<int> &offset, int id_in_bit_struct) { |
305 | place_child(&expr, offset, id_in_bit_struct, this, snode_to_fields_); |
306 | } |
307 | |
308 | void lazy_grad(); |
309 | |
310 | void lazy_dual(); |
311 | |
312 | void allocate_adjoint_checkbit(); |
313 | |
314 | int64 read_int(const std::vector<int> &i); |
315 | uint64 read_uint(const std::vector<int> &i); |
316 | float64 read_float(const std::vector<int> &i); |
317 | void write_int(const std::vector<int> &i, int64 val); |
318 | void write_uint(const std::vector<int> &i, uint64 val); |
319 | void write_float(const std::vector<int> &i, float64 val); |
320 | |
321 | Expr get_expr() const; |
322 | |
323 | uint64 fetch_reader_result(); // TODO: refactor |
324 | |
325 | // SNodeTree part |
326 | |
327 | void set_snode_tree_id(int id); |
328 | |
329 | int get_snode_tree_id() const; |
330 | |
331 | const SNode *get_root() const; |
332 | |
333 | static void reset_counter() { |
334 | counter = 0; |
335 | } |
336 | |
337 | private: |
338 | int snode_tree_id_{0}; |
339 | SNodeFieldMap *snode_to_fields_{nullptr}; |
340 | SNodeRwAccessorsBank *snode_rw_accessors_bank_{nullptr}; |
341 | }; |
342 | |
343 | } // namespace taichi::lang |
344 | |