1#include "taichi/ir/snode.h"
2
3#include <limits>
4
5#include "taichi/ir/ir.h"
6#include "taichi/ir/statements.h"
7#include "taichi/program/program.h"
8#include "taichi/program/snode_rw_accessors_bank.h"
9
10namespace taichi::lang {
11
12std::atomic<int> SNode::counter{0};
13
14SNode &SNode::insert_children(SNodeType t) {
15 TI_ASSERT(t != SNodeType::root);
16
17 auto new_ch = std::make_unique<SNode>(depth + 1, t, snode_to_fields_,
18 snode_rw_accessors_bank_);
19 new_ch->parent = this;
20 new_ch->is_path_all_dense = (is_path_all_dense && !new_ch->need_activation());
21 for (int i = 0; i < taichi_max_num_indices; i++) {
22 new_ch->extractors[i].num_elements_from_root *=
23 extractors[i].num_elements_from_root;
24 }
25 std::memcpy(new_ch->physical_index_position, physical_index_position,
26 sizeof(physical_index_position));
27 new_ch->num_active_indices = num_active_indices;
28 if (type == SNodeType::bit_struct || type == SNodeType::quant_array) {
29 new_ch->is_bit_level = true;
30 } else {
31 new_ch->is_bit_level = is_bit_level;
32 }
33 ch.push_back(std::move(new_ch));
34 return *ch.back();
35}
36
37SNode &SNode::create_node(std::vector<Axis> axes,
38 std::vector<int> sizes,
39 SNodeType type,
40 const std::string &tb) {
41 TI_ASSERT(axes.size() == sizes.size() || sizes.size() == 1);
42 if (sizes.size() == 1) {
43 sizes = std::vector<int>(axes.size(), sizes[0]);
44 }
45
46 if (type == SNodeType::hash)
47 TI_ASSERT_INFO(depth == 0,
48 "hashed node must be child of root due to initialization "
49 "memset limitation.");
50
51 auto &new_node = insert_children(type);
52 for (int i = 0; i < (int)axes.size(); i++) {
53 if (sizes[i] <= 0) {
54 throw TaichiRuntimeError(
55 "Every dimension of a Taichi field should be positive");
56 }
57 int ind = axes[i].value;
58 auto end = new_node.physical_index_position + new_node.num_active_indices;
59 bool is_first_division =
60 std::find(new_node.physical_index_position, end, ind) == end;
61 if (is_first_division) {
62 new_node.physical_index_position[new_node.num_active_indices++] = ind;
63 } else {
64 TI_WARN_IF(
65 !bit::is_power_of_two(sizes[i]),
66 "Shape {} is detected on non-first division of axis {}:\n{} For "
67 "best performance, we recommend that you set it to a power of two.",
68 sizes[i], char('i' + ind), tb);
69 }
70 new_node.extractors[ind].active = true;
71 new_node.extractors[ind].num_elements_from_root *= sizes[i];
72 new_node.extractors[ind].shape = sizes[i];
73 }
74 std::sort(new_node.physical_index_position,
75 new_node.physical_index_position + new_node.num_active_indices);
76 // infer extractors
77 int64 acc_shape = 1;
78 for (int i = taichi_max_num_indices - 1; i >= 0; i--) {
79 // casting to int32 in extractors.
80 new_node.extractors[i].acc_shape = static_cast<int>(acc_shape);
81 acc_shape *= new_node.extractors[i].shape;
82 }
83 if (acc_shape > std::numeric_limits<int>::max()) {
84 TI_WARN(
85 "SNode index might be out of int32 boundary but int64 indexing is not "
86 "supported yet. Struct fors might not work either.");
87 }
88 new_node.num_cells_per_container = acc_shape;
89
90 if (new_node.type == SNodeType::dynamic) {
91 int active_extractor_counder = 0;
92 for (int i = 0; i < taichi_max_num_indices; i++) {
93 if (new_node.extractors[i].active) {
94 active_extractor_counder += 1;
95 SNode *p = new_node.parent;
96 while (p) {
97 TI_ASSERT_INFO(
98 !p->extractors[i].active,
99 "Dynamic SNode must have a standalone dimensionality.");
100 p = p->parent;
101 }
102 }
103 }
104 TI_ASSERT_INFO(active_extractor_counder == 1,
105 "Dynamic SNode can have only one index extractor.");
106 }
107 return new_node;
108}
109
110SNode &SNode::dynamic(const Axis &expr,
111 int n,
112 int chunk_size,
113 const std::string &tb) {
114 auto &snode = create_node({expr}, {n}, SNodeType::dynamic, tb);
115 snode.chunk_size = chunk_size;
116 return snode;
117}
118
119SNode &SNode::bit_struct(BitStructType *bit_struct_type,
120 const std::string &tb) {
121 auto &snode = create_node({}, {}, SNodeType::bit_struct, tb);
122 snode.dt = bit_struct_type;
123 snode.physical_type = bit_struct_type->get_physical_type();
124 return snode;
125}
126
127SNode &SNode::quant_array(const std::vector<Axis> &axes,
128 const std::vector<int> &sizes,
129 int bits,
130 const std::string &tb) {
131 auto &snode = create_node(axes, sizes, SNodeType::quant_array, tb);
132 snode.physical_type =
133 TypeFactory::get_instance().get_primitive_int_type(bits, false);
134 return snode;
135}
136
137bool SNode::is_place() const {
138 return type == SNodeType::place;
139}
140
141bool SNode::is_scalar() const {
142 return is_place() && (num_active_indices == 0);
143}
144
145SNode *SNode::get_least_sparse_ancestor() const {
146 if (is_path_all_dense) {
147 return nullptr;
148 }
149 auto *result = const_cast<SNode *>(this);
150 while (!result->need_activation()) {
151 result = result->parent;
152 TI_ASSERT(result);
153 }
154 return result;
155}
156
157int SNode::shape_along_axis(int i) const {
158 const auto &extractor = extractors[physical_index_position[i]];
159 return extractor.num_elements_from_root;
160}
161
162int64 SNode::read_int(const std::vector<int> &i) {
163 return snode_rw_accessors_bank_->get(this).read_int(i);
164}
165
166uint64 SNode::read_uint(const std::vector<int> &i) {
167 return snode_rw_accessors_bank_->get(this).read_uint(i);
168}
169
170float64 SNode::read_float(const std::vector<int> &i) {
171 return snode_rw_accessors_bank_->get(this).read_float(i);
172}
173
174void SNode::write_int(const std::vector<int> &i, int64 val) {
175 snode_rw_accessors_bank_->get(this).write_int(i, val);
176}
177
178void SNode::write_uint(const std::vector<int> &i, uint64 val) {
179 snode_rw_accessors_bank_->get(this).write_uint(i, val);
180}
181
182void SNode::write_float(const std::vector<int> &i, float64 val) {
183 snode_rw_accessors_bank_->get(this).write_float(i, val);
184}
185
186Expr SNode::get_expr() const {
187 return Expr(snode_to_fields_->at(this));
188}
189
190SNode::SNode(SNodeFieldMap *snode_to_fields,
191 SNodeRwAccessorsBank *snode_rw_accessors_bank)
192 : SNode(0, SNodeType::undefined, snode_to_fields, snode_rw_accessors_bank) {
193}
194
195SNode::SNode(int depth,
196 SNodeType t,
197 SNodeFieldMap *snode_to_fields,
198 SNodeRwAccessorsBank *snode_rw_accessors_bank)
199 : depth(depth),
200 type(t),
201 snode_to_fields_(snode_to_fields),
202 snode_rw_accessors_bank_(snode_rw_accessors_bank) {
203 id = counter++;
204 node_type_name = get_node_type_name();
205 num_active_indices = 0;
206 std::memset(physical_index_position, -1, sizeof(physical_index_position));
207 parent = nullptr;
208 has_ambient = false;
209 dt = PrimitiveType::gen;
210 _morton = false;
211}
212
213SNode::SNode(const SNode &) {
214 TI_NOT_IMPLEMENTED; // Copying an SNode is forbidden. However we need the
215 // definition here to make pybind11 happy.
216}
217
218std::string SNode::get_node_type_name() const {
219 return fmt::format("S{}", id);
220}
221
222std::string SNode::get_node_type_name_hinted() const {
223 std::string suffix;
224 if (type == SNodeType::place || type == SNodeType::bit_struct)
225 suffix = fmt::format("<{}>", dt->to_string());
226 if (is_bit_level)
227 suffix += "<bit>";
228 return fmt::format("S{}{}{}", id, snode_type_name(type), suffix);
229}
230
231void SNode::print() {
232 for (int i = 0; i < depth; i++) {
233 fmt::print(" ");
234 }
235 fmt::print("{}", get_node_type_name_hinted());
236 fmt::print("\n");
237 for (auto &c : ch) {
238 c->print();
239 }
240}
241
242void SNode::set_index_offsets(std::vector<int> index_offsets_) {
243 TI_ASSERT(this->index_offsets.empty());
244 TI_ASSERT(!index_offsets_.empty());
245 TI_ASSERT(type == SNodeType::place);
246 TI_ASSERT(index_offsets_.size() == this->num_active_indices);
247 this->index_offsets = index_offsets_;
248}
249
250// TODO: rename to is_sparse?
251bool SNode::need_activation() const {
252 return type == SNodeType::pointer || type == SNodeType::hash ||
253 type == SNodeType::bitmasked || type == SNodeType::dynamic;
254}
255
256void SNode::lazy_grad() {
257 make_lazy_place(
258 this, snode_to_fields_,
259 [this](std::unique_ptr<SNode> &c, std::vector<Expr> &new_grads) {
260 if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) &&
261 !c->has_adjoint()) {
262 new_grads.push_back(snode_to_fields_->at(c.get())->adjoint);
263 }
264 });
265}
266
267void SNode::lazy_dual() {
268 make_lazy_place(
269 this, snode_to_fields_,
270 [this](std::unique_ptr<SNode> &c, std::vector<Expr> &new_duals) {
271 if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) &&
272 !c->has_dual()) {
273 new_duals.push_back(snode_to_fields_->at(c.get())->dual);
274 }
275 });
276}
277
278void SNode::allocate_adjoint_checkbit() {
279 make_lazy_place(this, snode_to_fields_,
280 [this](std::unique_ptr<SNode> &c,
281 std::vector<Expr> &new_adjoint_checkbits) {
282 if (c->type == SNodeType::place && c->is_primal() &&
283 is_real(c->dt) && c->has_adjoint()) {
284 new_adjoint_checkbits.push_back(
285 snode_to_fields_->at(c.get())->adjoint_checkbit);
286 }
287 });
288}
289
290bool SNode::is_primal() const {
291 return grad_info && grad_info->is_primal();
292}
293
294SNodeGradType SNode::get_snode_grad_type() const {
295 TI_ASSERT(grad_info);
296 return grad_info->get_snode_grad_type();
297}
298
299bool SNode::has_adjoint() const {
300 return is_primal() && (grad_info->adjoint_snode() != nullptr);
301}
302
303bool SNode::has_adjoint_checkbit() const {
304 return is_primal() && (grad_info->adjoint_checkbit_snode() != nullptr);
305}
306
307bool SNode::has_dual() const {
308 return is_primal() && (grad_info->dual_snode() != nullptr);
309}
310
311SNode *SNode::get_adjoint() const {
312 TI_ASSERT(has_adjoint());
313 return grad_info->adjoint_snode();
314}
315
316SNode *SNode::get_adjoint_checkbit() const {
317 // TI_ASSERT(has_adjoint());
318 return grad_info->adjoint_checkbit_snode();
319}
320
321SNode *SNode::get_dual() const {
322 TI_ASSERT(has_dual());
323 return grad_info->dual_snode();
324}
325
326void SNode::set_snode_tree_id(int id) {
327 snode_tree_id_ = id;
328 for (auto &child : ch) {
329 child->set_snode_tree_id(id);
330 }
331}
332
333int SNode::get_snode_tree_id() const {
334 return snode_tree_id_;
335}
336
337const SNode *SNode::get_root() const {
338 if (!parent) { // root->parent == nullptr
339 return this;
340 }
341 return parent->get_root();
342}
343
344} // namespace taichi::lang
345