1 | #include "taichi/ir/snode.h" |
2 | |
3 | #include <limits> |
4 | |
5 | #include "taichi/ir/ir.h" |
6 | #include "taichi/ir/statements.h" |
7 | #include "taichi/program/program.h" |
8 | #include "taichi/program/snode_rw_accessors_bank.h" |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | std::atomic<int> SNode::counter{0}; |
13 | |
14 | SNode &SNode::insert_children(SNodeType t) { |
15 | TI_ASSERT(t != SNodeType::root); |
16 | |
17 | auto new_ch = std::make_unique<SNode>(depth + 1, t, snode_to_fields_, |
18 | snode_rw_accessors_bank_); |
19 | new_ch->parent = this; |
20 | new_ch->is_path_all_dense = (is_path_all_dense && !new_ch->need_activation()); |
21 | for (int i = 0; i < taichi_max_num_indices; i++) { |
22 | new_ch->extractors[i].num_elements_from_root *= |
23 | extractors[i].num_elements_from_root; |
24 | } |
25 | std::memcpy(new_ch->physical_index_position, physical_index_position, |
26 | sizeof(physical_index_position)); |
27 | new_ch->num_active_indices = num_active_indices; |
28 | if (type == SNodeType::bit_struct || type == SNodeType::quant_array) { |
29 | new_ch->is_bit_level = true; |
30 | } else { |
31 | new_ch->is_bit_level = is_bit_level; |
32 | } |
33 | ch.push_back(std::move(new_ch)); |
34 | return *ch.back(); |
35 | } |
36 | |
37 | SNode &SNode::create_node(std::vector<Axis> axes, |
38 | std::vector<int> sizes, |
39 | SNodeType type, |
40 | const std::string &tb) { |
41 | TI_ASSERT(axes.size() == sizes.size() || sizes.size() == 1); |
42 | if (sizes.size() == 1) { |
43 | sizes = std::vector<int>(axes.size(), sizes[0]); |
44 | } |
45 | |
46 | if (type == SNodeType::hash) |
47 | TI_ASSERT_INFO(depth == 0, |
48 | "hashed node must be child of root due to initialization " |
49 | "memset limitation." ); |
50 | |
51 | auto &new_node = insert_children(type); |
52 | for (int i = 0; i < (int)axes.size(); i++) { |
53 | if (sizes[i] <= 0) { |
54 | throw TaichiRuntimeError( |
55 | "Every dimension of a Taichi field should be positive" ); |
56 | } |
57 | int ind = axes[i].value; |
58 | auto end = new_node.physical_index_position + new_node.num_active_indices; |
59 | bool is_first_division = |
60 | std::find(new_node.physical_index_position, end, ind) == end; |
61 | if (is_first_division) { |
62 | new_node.physical_index_position[new_node.num_active_indices++] = ind; |
63 | } else { |
64 | TI_WARN_IF( |
65 | !bit::is_power_of_two(sizes[i]), |
66 | "Shape {} is detected on non-first division of axis {}:\n{} For " |
67 | "best performance, we recommend that you set it to a power of two." , |
68 | sizes[i], char('i' + ind), tb); |
69 | } |
70 | new_node.extractors[ind].active = true; |
71 | new_node.extractors[ind].num_elements_from_root *= sizes[i]; |
72 | new_node.extractors[ind].shape = sizes[i]; |
73 | } |
74 | std::sort(new_node.physical_index_position, |
75 | new_node.physical_index_position + new_node.num_active_indices); |
76 | // infer extractors |
77 | int64 acc_shape = 1; |
78 | for (int i = taichi_max_num_indices - 1; i >= 0; i--) { |
79 | // casting to int32 in extractors. |
80 | new_node.extractors[i].acc_shape = static_cast<int>(acc_shape); |
81 | acc_shape *= new_node.extractors[i].shape; |
82 | } |
83 | if (acc_shape > std::numeric_limits<int>::max()) { |
84 | TI_WARN( |
85 | "SNode index might be out of int32 boundary but int64 indexing is not " |
86 | "supported yet. Struct fors might not work either." ); |
87 | } |
88 | new_node.num_cells_per_container = acc_shape; |
89 | |
90 | if (new_node.type == SNodeType::dynamic) { |
91 | int = 0; |
92 | for (int i = 0; i < taichi_max_num_indices; i++) { |
93 | if (new_node.extractors[i].active) { |
94 | active_extractor_counder += 1; |
95 | SNode *p = new_node.parent; |
96 | while (p) { |
97 | TI_ASSERT_INFO( |
98 | !p->extractors[i].active, |
99 | "Dynamic SNode must have a standalone dimensionality." ); |
100 | p = p->parent; |
101 | } |
102 | } |
103 | } |
104 | TI_ASSERT_INFO(active_extractor_counder == 1, |
105 | "Dynamic SNode can have only one index extractor." ); |
106 | } |
107 | return new_node; |
108 | } |
109 | |
110 | SNode &SNode::dynamic(const Axis &expr, |
111 | int n, |
112 | int chunk_size, |
113 | const std::string &tb) { |
114 | auto &snode = create_node({expr}, {n}, SNodeType::dynamic, tb); |
115 | snode.chunk_size = chunk_size; |
116 | return snode; |
117 | } |
118 | |
119 | SNode &SNode::bit_struct(BitStructType *bit_struct_type, |
120 | const std::string &tb) { |
121 | auto &snode = create_node({}, {}, SNodeType::bit_struct, tb); |
122 | snode.dt = bit_struct_type; |
123 | snode.physical_type = bit_struct_type->get_physical_type(); |
124 | return snode; |
125 | } |
126 | |
127 | SNode &SNode::quant_array(const std::vector<Axis> &axes, |
128 | const std::vector<int> &sizes, |
129 | int bits, |
130 | const std::string &tb) { |
131 | auto &snode = create_node(axes, sizes, SNodeType::quant_array, tb); |
132 | snode.physical_type = |
133 | TypeFactory::get_instance().get_primitive_int_type(bits, false); |
134 | return snode; |
135 | } |
136 | |
137 | bool SNode::is_place() const { |
138 | return type == SNodeType::place; |
139 | } |
140 | |
141 | bool SNode::is_scalar() const { |
142 | return is_place() && (num_active_indices == 0); |
143 | } |
144 | |
145 | SNode *SNode::get_least_sparse_ancestor() const { |
146 | if (is_path_all_dense) { |
147 | return nullptr; |
148 | } |
149 | auto *result = const_cast<SNode *>(this); |
150 | while (!result->need_activation()) { |
151 | result = result->parent; |
152 | TI_ASSERT(result); |
153 | } |
154 | return result; |
155 | } |
156 | |
157 | int SNode::shape_along_axis(int i) const { |
158 | const auto & = extractors[physical_index_position[i]]; |
159 | return extractor.num_elements_from_root; |
160 | } |
161 | |
162 | int64 SNode::read_int(const std::vector<int> &i) { |
163 | return snode_rw_accessors_bank_->get(this).read_int(i); |
164 | } |
165 | |
166 | uint64 SNode::read_uint(const std::vector<int> &i) { |
167 | return snode_rw_accessors_bank_->get(this).read_uint(i); |
168 | } |
169 | |
170 | float64 SNode::read_float(const std::vector<int> &i) { |
171 | return snode_rw_accessors_bank_->get(this).read_float(i); |
172 | } |
173 | |
174 | void SNode::write_int(const std::vector<int> &i, int64 val) { |
175 | snode_rw_accessors_bank_->get(this).write_int(i, val); |
176 | } |
177 | |
178 | void SNode::write_uint(const std::vector<int> &i, uint64 val) { |
179 | snode_rw_accessors_bank_->get(this).write_uint(i, val); |
180 | } |
181 | |
182 | void SNode::write_float(const std::vector<int> &i, float64 val) { |
183 | snode_rw_accessors_bank_->get(this).write_float(i, val); |
184 | } |
185 | |
186 | Expr SNode::get_expr() const { |
187 | return Expr(snode_to_fields_->at(this)); |
188 | } |
189 | |
190 | SNode::SNode(SNodeFieldMap *snode_to_fields, |
191 | SNodeRwAccessorsBank *snode_rw_accessors_bank) |
192 | : SNode(0, SNodeType::undefined, snode_to_fields, snode_rw_accessors_bank) { |
193 | } |
194 | |
195 | SNode::SNode(int depth, |
196 | SNodeType t, |
197 | SNodeFieldMap *snode_to_fields, |
198 | SNodeRwAccessorsBank *snode_rw_accessors_bank) |
199 | : depth(depth), |
200 | type(t), |
201 | snode_to_fields_(snode_to_fields), |
202 | snode_rw_accessors_bank_(snode_rw_accessors_bank) { |
203 | id = counter++; |
204 | node_type_name = get_node_type_name(); |
205 | num_active_indices = 0; |
206 | std::memset(physical_index_position, -1, sizeof(physical_index_position)); |
207 | parent = nullptr; |
208 | has_ambient = false; |
209 | dt = PrimitiveType::gen; |
210 | _morton = false; |
211 | } |
212 | |
213 | SNode::SNode(const SNode &) { |
214 | TI_NOT_IMPLEMENTED; // Copying an SNode is forbidden. However we need the |
215 | // definition here to make pybind11 happy. |
216 | } |
217 | |
218 | std::string SNode::get_node_type_name() const { |
219 | return fmt::format("S{}" , id); |
220 | } |
221 | |
222 | std::string SNode::get_node_type_name_hinted() const { |
223 | std::string suffix; |
224 | if (type == SNodeType::place || type == SNodeType::bit_struct) |
225 | suffix = fmt::format("<{}>" , dt->to_string()); |
226 | if (is_bit_level) |
227 | suffix += "<bit>" ; |
228 | return fmt::format("S{}{}{}" , id, snode_type_name(type), suffix); |
229 | } |
230 | |
231 | void SNode::print() { |
232 | for (int i = 0; i < depth; i++) { |
233 | fmt::print(" " ); |
234 | } |
235 | fmt::print("{}" , get_node_type_name_hinted()); |
236 | fmt::print("\n" ); |
237 | for (auto &c : ch) { |
238 | c->print(); |
239 | } |
240 | } |
241 | |
242 | void SNode::set_index_offsets(std::vector<int> index_offsets_) { |
243 | TI_ASSERT(this->index_offsets.empty()); |
244 | TI_ASSERT(!index_offsets_.empty()); |
245 | TI_ASSERT(type == SNodeType::place); |
246 | TI_ASSERT(index_offsets_.size() == this->num_active_indices); |
247 | this->index_offsets = index_offsets_; |
248 | } |
249 | |
250 | // TODO: rename to is_sparse? |
251 | bool SNode::need_activation() const { |
252 | return type == SNodeType::pointer || type == SNodeType::hash || |
253 | type == SNodeType::bitmasked || type == SNodeType::dynamic; |
254 | } |
255 | |
256 | void SNode::lazy_grad() { |
257 | make_lazy_place( |
258 | this, snode_to_fields_, |
259 | [this](std::unique_ptr<SNode> &c, std::vector<Expr> &new_grads) { |
260 | if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) && |
261 | !c->has_adjoint()) { |
262 | new_grads.push_back(snode_to_fields_->at(c.get())->adjoint); |
263 | } |
264 | }); |
265 | } |
266 | |
267 | void SNode::lazy_dual() { |
268 | make_lazy_place( |
269 | this, snode_to_fields_, |
270 | [this](std::unique_ptr<SNode> &c, std::vector<Expr> &new_duals) { |
271 | if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) && |
272 | !c->has_dual()) { |
273 | new_duals.push_back(snode_to_fields_->at(c.get())->dual); |
274 | } |
275 | }); |
276 | } |
277 | |
278 | void SNode::allocate_adjoint_checkbit() { |
279 | make_lazy_place(this, snode_to_fields_, |
280 | [this](std::unique_ptr<SNode> &c, |
281 | std::vector<Expr> &new_adjoint_checkbits) { |
282 | if (c->type == SNodeType::place && c->is_primal() && |
283 | is_real(c->dt) && c->has_adjoint()) { |
284 | new_adjoint_checkbits.push_back( |
285 | snode_to_fields_->at(c.get())->adjoint_checkbit); |
286 | } |
287 | }); |
288 | } |
289 | |
290 | bool SNode::is_primal() const { |
291 | return grad_info && grad_info->is_primal(); |
292 | } |
293 | |
294 | SNodeGradType SNode::get_snode_grad_type() const { |
295 | TI_ASSERT(grad_info); |
296 | return grad_info->get_snode_grad_type(); |
297 | } |
298 | |
299 | bool SNode::has_adjoint() const { |
300 | return is_primal() && (grad_info->adjoint_snode() != nullptr); |
301 | } |
302 | |
303 | bool SNode::has_adjoint_checkbit() const { |
304 | return is_primal() && (grad_info->adjoint_checkbit_snode() != nullptr); |
305 | } |
306 | |
307 | bool SNode::has_dual() const { |
308 | return is_primal() && (grad_info->dual_snode() != nullptr); |
309 | } |
310 | |
311 | SNode *SNode::get_adjoint() const { |
312 | TI_ASSERT(has_adjoint()); |
313 | return grad_info->adjoint_snode(); |
314 | } |
315 | |
316 | SNode *SNode::get_adjoint_checkbit() const { |
317 | // TI_ASSERT(has_adjoint()); |
318 | return grad_info->adjoint_checkbit_snode(); |
319 | } |
320 | |
321 | SNode *SNode::get_dual() const { |
322 | TI_ASSERT(has_dual()); |
323 | return grad_info->dual_snode(); |
324 | } |
325 | |
326 | void SNode::set_snode_tree_id(int id) { |
327 | snode_tree_id_ = id; |
328 | for (auto &child : ch) { |
329 | child->set_snode_tree_id(id); |
330 | } |
331 | } |
332 | |
333 | int SNode::get_snode_tree_id() const { |
334 | return snode_tree_id_; |
335 | } |
336 | |
337 | const SNode *SNode::get_root() const { |
338 | if (!parent) { // root->parent == nullptr |
339 | return this; |
340 | } |
341 | return parent->get_root(); |
342 | } |
343 | |
344 | } // namespace taichi::lang |
345 | |