1#pragma once
2
3#include <atomic>
4
5#include "taichi/inc/constants.h"
6#include "taichi/ir/expr.h"
7#include "taichi/ir/snode_types.h"
8#include "taichi/ir/type.h"
9#include "taichi/program/snode_expr_utils.h"
10
11namespace taichi::lang {
12class Program;
13class SNodeRwAccessorsBank;
14
15/**
16 * Dimension (or axis) of a tensor.
17 *
18 * For example, in the frontend we have ti.ij, which is translated to
19 * {Axis{0}, Axis{1}}.
20 */
21class Axis {
22 public:
23 int value;
24 Axis() {
25 value = 0;
26 }
27 explicit Axis(int value) : value(value) {
28 TI_ERROR_UNLESS(0 <= value && value < taichi_max_num_indices,
29 "Too many dimensions. The maximum dimensionality is {}",
30 taichi_max_num_indices);
31 }
32};
33
34/**
35 * SNode shape metadata at a specific Axis.
36 */
37struct AxisExtractor {
38 /**
39 * Number of elements from root at this index.
40 */
41 int num_elements_from_root{1};
42 /**
43 * Shape at this index.
44 */
45 int shape{1};
46 /**
47 * Accumulated shape from the last activated index to the first one.
48 */
49 int acc_shape{1};
50 /**
51 * Whether this index (axis) is activated.
52 */
53 bool active{false};
54};
55
56/**
57 * Structural nodes
58 */
59class SNode {
60 public:
61 // This class decouples SNode from the frontend expression.
62 class GradInfoProvider {
63 public:
64 virtual ~GradInfoProvider() = default;
65 virtual bool is_primal() const = 0;
66 virtual SNodeGradType get_snode_grad_type() const = 0;
67 virtual SNode *adjoint_snode() const = 0;
68 virtual SNode *dual_snode() const = 0;
69 virtual SNode *adjoint_checkbit_snode() const = 0;
70
71 template <typename T>
72 T *cast() {
73 return static_cast<T *>(this);
74 }
75 };
76 std::vector<std::unique_ptr<SNode>> ch;
77
78 AxisExtractor extractors[taichi_max_num_indices];
79 std::vector<int> index_offsets;
80 int num_active_indices{0};
81 int physical_index_position[taichi_max_num_indices]{};
82 // physical indices are (ti.i, ti.j, ti.k, ti.l, ...)
83 // physical_index_position[i] =
84 // which physical index does the i-th virtual index (the one exposed to
85 // programmers) refer to? i.e. in a[i, j, k], "i", "j", and "k" are virtual
86 // indices.
87
88 static std::atomic<int> counter;
89 int id{0};
90 int depth{0};
91
92 std::string name;
93 // Product of the |shape| of all the activated axes identified by
94 // |extractors|.
95 // See https://docs.taichi-lang.org/docs/internal for terms
96 // like cell and container.
97 int64 num_cells_per_container{1};
98 int chunk_size{0};
99 std::size_t cell_size_bytes{0};
100 std::size_t offset_bytes_in_parent_cell{0};
101 DataType dt;
102 bool has_ambient{false};
103 TypedConstant ambient_val;
104 // Note: parent will not be set until structural nodes are compiled!
105 SNode *parent{nullptr};
106 std::unique_ptr<GradInfoProvider> grad_info{nullptr};
107
108 // Quant
109 PrimitiveType *physical_type{nullptr}; // for bit_struct and quant_array only
110 int id_in_bit_struct{-1}; // for children of bit_struct only
111 bool is_bit_level{false}; // true if inside bit_struct or quant_array
112
113 // Whether the path from root to |this| contains only `dense` SNodes.
114 bool is_path_all_dense{true};
115
116 explicit SNode(SNodeFieldMap *snode_to_fields = nullptr,
117 SNodeRwAccessorsBank *snode_rw_accessors_bank = nullptr);
118
119 SNode(int depth,
120 SNodeType t,
121 SNodeFieldMap *snode_to_fields = nullptr,
122 SNodeRwAccessorsBank *snode_rw_accessors_bank = nullptr);
123
124 SNode(const SNode &);
125
126 ~SNode() = default;
127
128 std::string node_type_name;
129 SNodeType type;
130 bool _morton{false};
131
132 std::string get_node_type_name() const;
133
134 std::string get_node_type_name_hinted() const;
135
136 SNode &insert_children(SNodeType t);
137
138 SNode &create_node(std::vector<Axis> axes,
139 std::vector<int> sizes,
140 SNodeType type,
141 const std::string &tb);
142
143 // SNodes maintains how flattened index bits are taken from indices
144 SNode &dense(const std::vector<Axis> &axes,
145 const std::vector<int> &sizes,
146 const std::string &tb) {
147 return create_node(axes, sizes, SNodeType::dense, tb);
148 }
149
150 SNode &dense(const std::vector<Axis> &axes,
151 int sizes,
152 const std::string &tb) {
153 return create_node(axes, std::vector<int>{sizes}, SNodeType::dense, tb);
154 }
155
156 SNode &dense(const Axis &axis, int size, const std::string &tb) {
157 return SNode::dense(std::vector<Axis>{axis}, size, tb);
158 }
159
160 SNode &pointer(const std::vector<Axis> &axes,
161 const std::vector<int> &sizes,
162 const std::string &tb) {
163 return create_node(axes, sizes, SNodeType::pointer, tb);
164 }
165
166 SNode &pointer(const std::vector<Axis> &axes,
167 int sizes,
168 const std::string &tb) {
169 return create_node(axes, std::vector<int>{sizes}, SNodeType::pointer, tb);
170 }
171
172 SNode &pointer(const Axis &axis, int size, const std::string &tb) {
173 return SNode::pointer(std::vector<Axis>{axis}, size, tb);
174 }
175
176 SNode &bitmasked(const std::vector<Axis> &axes,
177 const std::vector<int> &sizes,
178 const std::string &tb) {
179 return create_node(axes, sizes, SNodeType::bitmasked, tb);
180 }
181
182 SNode &bitmasked(const std::vector<Axis> &axes,
183 int sizes,
184 const std::string &tb) {
185 return create_node(axes, std::vector<int>{sizes}, SNodeType::bitmasked, tb);
186 }
187
188 SNode &bitmasked(const Axis &axis, int size, const std::string &tb) {
189 return SNode::bitmasked(std::vector<Axis>{axis}, size, tb);
190 }
191
192 SNode &hash(const std::vector<Axis> &axes,
193 const std::vector<int> &sizes,
194 const std::string &tb) {
195 return create_node(axes, sizes, SNodeType::hash, tb);
196 }
197
198 SNode &hash(const std::vector<Axis> &axes, int sizes, const std::string &tb) {
199 return create_node(axes, std::vector<int>{sizes}, SNodeType::hash, tb);
200 }
201
202 SNode &hash(const Axis &axis, int size, const std::string &tb) {
203 return hash(std::vector<Axis>{axis}, size, tb);
204 }
205
206 std::string type_name() {
207 return snode_type_name(type);
208 }
209
210 SNode &bit_struct(BitStructType *bit_struct_type, const std::string &tb);
211
212 SNode &quant_array(const std::vector<Axis> &axes,
213 const std::vector<int> &sizes,
214 int bits,
215 const std::string &tb);
216
217 void print();
218
219 void set_index_offsets(std::vector<int> index_offsets);
220
221 SNode &dynamic(const Axis &expr,
222 int n,
223 int chunk_size,
224 const std::string &tb);
225
226 SNode &morton(bool val = true) {
227 _morton = val;
228 return *this;
229 }
230
231 int child_id(SNode *c) {
232 for (int i = 0; i < (int)ch.size(); i++) {
233 if (ch[i].get() == c) {
234 return i;
235 }
236 }
237 return -1;
238 }
239
240 bool has_null() const {
241 return type == SNodeType::pointer || type == SNodeType::hash;
242 }
243
244 bool has_allocator() const {
245 return type == SNodeType::pointer || type == SNodeType::hash ||
246 type == SNodeType::root;
247 }
248
249 bool need_activation() const;
250
251 bool is_primal() const;
252
253 SNodeGradType get_snode_grad_type() const;
254
255 bool is_place() const;
256
257 bool is_scalar() const;
258
259 bool has_adjoint() const;
260
261 SNode *get_adjoint() const;
262
263 bool has_adjoint_checkbit() const;
264
265 SNode *get_adjoint_checkbit() const;
266
267 bool has_dual() const;
268
269 SNode *get_dual() const;
270
271 SNode *get_least_sparse_ancestor() const;
272
273 std::string get_name() const {
274 return node_type_name;
275 }
276
277 std::string element_listgen_func_name() const {
278 return get_name() + "_element_listgen";
279 }
280
281 std::string get_ch_from_parent_func_name() const {
282 TI_ASSERT(parent != nullptr);
283 return fmt::format("get_ch_{}_to_{}", parent->get_name(), get_name());
284 }
285
286 std::string refine_coordinates_func_name() const {
287 TI_ASSERT(type != SNodeType::place);
288 return fmt::format("{}_refine_coordinates", get_name());
289 }
290
291 int64 max_num_elements() const {
292 return num_cells_per_container;
293 }
294
295 int64 get_total_num_elements_towards_root() const {
296 int64 total_num_elemts = 1;
297 for (auto *s = this; s != nullptr; s = s->parent)
298 total_num_elemts *= (int)s->max_num_elements();
299 return total_num_elemts;
300 }
301
302 int shape_along_axis(int i) const;
303
304 void place(Expr &expr, const std::vector<int> &offset, int id_in_bit_struct) {
305 place_child(&expr, offset, id_in_bit_struct, this, snode_to_fields_);
306 }
307
308 void lazy_grad();
309
310 void lazy_dual();
311
312 void allocate_adjoint_checkbit();
313
314 int64 read_int(const std::vector<int> &i);
315 uint64 read_uint(const std::vector<int> &i);
316 float64 read_float(const std::vector<int> &i);
317 void write_int(const std::vector<int> &i, int64 val);
318 void write_uint(const std::vector<int> &i, uint64 val);
319 void write_float(const std::vector<int> &i, float64 val);
320
321 Expr get_expr() const;
322
323 uint64 fetch_reader_result(); // TODO: refactor
324
325 // SNodeTree part
326
327 void set_snode_tree_id(int id);
328
329 int get_snode_tree_id() const;
330
331 const SNode *get_root() const;
332
333 static void reset_counter() {
334 counter = 0;
335 }
336
337 private:
338 int snode_tree_id_{0};
339 SNodeFieldMap *snode_to_fields_{nullptr};
340 SNodeRwAccessorsBank *snode_rw_accessors_bank_{nullptr};
341};
342
343} // namespace taichi::lang
344