1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/transforms/make_mesh_block_local.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | const PassID MakeMeshBlockLocal::id = "MakeMeshBlockLocal" ; |
10 | |
11 | void MakeMeshBlockLocal::simplify_nested_conversion() { |
12 | std::vector<MeshIndexConversionStmt *> stmts; |
13 | std::vector<Stmt *> ori_indices; |
14 | |
15 | irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { |
16 | if (auto conv1 = stmt->cast<MeshIndexConversionStmt>()) { |
17 | if (auto conv2 = conv1->idx->cast<MeshIndexConversionStmt>()) { |
18 | if (conv1->conv_type == mesh::ConvType::g2r && |
19 | conv2->conv_type == mesh::ConvType::l2g && |
20 | conv1->mesh == conv2->mesh && |
21 | conv1->idx_type == conv2->idx_type) { // nested |
22 | stmts.push_back(conv1); |
23 | ori_indices.push_back(conv2->idx); |
24 | } |
25 | } |
26 | } |
27 | return false; |
28 | }); |
29 | |
30 | for (size_t i = 0; i < stmts.size(); ++i) { |
31 | stmts[i]->replace_with(Stmt::make<MeshIndexConversionStmt>( |
32 | stmts[i]->mesh, stmts[i]->idx_type, ori_indices[i], |
33 | mesh::ConvType::l2r)); |
34 | } |
35 | } |
36 | |
37 | void MakeMeshBlockLocal::gather_candidate_mapping() { |
38 | irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { |
39 | if (auto conv = stmt->cast<MeshIndexConversionStmt>()) { |
40 | if (conv->conv_type != mesh::ConvType::g2r) { |
41 | bool is_from_end = (conv->idx_type == offload_->major_from_type); |
42 | bool is_to_end = false; |
43 | for (auto type : offload_->major_to_types) { |
44 | is_to_end |= (conv->idx_type == type); |
45 | } |
46 | for (auto rel : offload_->minor_relation_types) { |
47 | auto from_type = |
48 | mesh::MeshElementType(mesh::from_end_element_order(rel)); |
49 | auto to_type = mesh::MeshElementType(mesh::to_end_element_order(rel)); |
50 | is_from_end |= (conv->idx_type == from_type); |
51 | is_to_end |= (conv->idx_type == to_type); |
52 | } |
53 | if ((is_to_end && config_.mesh_localize_to_end_mapping) || |
54 | (is_from_end && config_.mesh_localize_from_end_mapping)) { |
55 | mappings_.insert(std::make_pair(conv->idx_type, conv->conv_type)); |
56 | } |
57 | } |
58 | } |
59 | return false; |
60 | }); |
61 | } |
62 | |
63 | void MakeMeshBlockLocal::replace_conv_statements() { |
64 | std::vector<MeshIndexConversionStmt *> idx_conv_stmts; |
65 | |
66 | irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { |
67 | if (auto idx_conv = stmt->cast<MeshIndexConversionStmt>()) { |
68 | if (idx_conv->mesh == offload_->mesh && |
69 | idx_conv->conv_type == conv_type_ && |
70 | idx_conv->idx_type == element_type_) { |
71 | idx_conv_stmts.push_back(idx_conv); |
72 | } |
73 | } |
74 | return false; |
75 | }); |
76 | |
77 | for (auto stmt : idx_conv_stmts) { |
78 | VecStatement bls; |
79 | Stmt *bls_element_offset_bytes = bls.push_back<ConstStmt>( |
80 | TypedConstant{(int32)mapping_bls_offset_in_bytes_}); |
81 | Stmt *idx_byte = bls.push_back<BinaryOpStmt>( |
82 | BinaryOpType::mul, stmt->idx, |
83 | bls.push_back<ConstStmt>(TypedConstant(mapping_dtype_size_))); |
84 | Stmt *offset = bls.push_back<BinaryOpStmt>( |
85 | BinaryOpType::add, bls_element_offset_bytes, idx_byte); |
86 | Stmt *bls_ptr = bls.push_back<BlockLocalPtrStmt>( |
87 | offset, |
88 | TypeFactory::get_instance().get_pointer_type(mapping_data_type_)); |
89 | [[maybe_unused]] Stmt *bls_load = bls.push_back<GlobalLoadStmt>(bls_ptr); |
90 | stmt->replace_with(std::move(bls)); |
91 | } |
92 | } |
93 | |
94 | void MakeMeshBlockLocal::replace_global_ptrs(SNode *snode) { |
95 | auto data_type = snode->dt.ptr_removed(); |
96 | auto dtype_size = data_type_size(data_type); |
97 | auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second; |
98 | |
99 | std::vector<GlobalPtrStmt *> global_ptrs; |
100 | irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { |
101 | if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) { |
102 | if (global_ptr->snode == snode && |
103 | global_ptr->indices[0]->is<MeshIndexConversionStmt>()) { |
104 | global_ptrs.push_back(global_ptr); |
105 | } |
106 | } |
107 | return false; |
108 | }); |
109 | |
110 | for (auto global_ptr : global_ptrs) { |
111 | VecStatement bls; |
112 | Stmt *local_idx = |
113 | global_ptr->indices[0]->as<MeshIndexConversionStmt>()->idx; |
114 | Stmt *local_idx_byte = bls.push_back<BinaryOpStmt>( |
115 | BinaryOpType::mul, local_idx, |
116 | bls.push_back<ConstStmt>(TypedConstant(dtype_size))); |
117 | Stmt *offset = |
118 | bls.push_back<ConstStmt>(TypedConstant(int32(offset_in_bytes))); |
119 | Stmt *index = |
120 | bls.push_back<BinaryOpStmt>(BinaryOpType::add, offset, local_idx_byte); |
121 | [[maybe_unused]] Stmt *bls_ptr = bls.push_back<BlockLocalPtrStmt>( |
122 | index, TypeFactory::get_instance().get_pointer_type(data_type)); |
123 | global_ptr->replace_with(std::move(bls)); |
124 | } |
125 | |
126 | // in the cpu backend, atomic op in body block could be demoted to non-atomic |
127 | if (config_.arch != Arch::x64 && config_.arch != Arch::arm64) { |
128 | return; |
129 | } |
130 | std::vector<AtomicOpStmt *> atomic_ops; |
131 | irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { |
132 | if (auto atomic_op = stmt->cast<AtomicOpStmt>()) { |
133 | if (atomic_op->op_type == AtomicOpType::add && |
134 | atomic_op->dest->is<BlockLocalPtrStmt>()) { |
135 | atomic_ops.push_back(atomic_op); |
136 | } |
137 | } |
138 | return false; |
139 | }); |
140 | |
141 | for (auto atomic_op : atomic_ops) { |
142 | VecStatement non_atomic; |
143 | Stmt *dest_val = non_atomic.push_back<GlobalLoadStmt>(atomic_op->dest); |
144 | Stmt *res_val = non_atomic.push_back<BinaryOpStmt>( |
145 | BinaryOpType::add, dest_val, atomic_op->val); |
146 | non_atomic.push_back<GlobalStoreStmt>(atomic_op->dest, res_val); |
147 | atomic_op->replace_with(std::move(non_atomic)); |
148 | } |
149 | } |
150 | |
151 | // This function creates loop like: |
152 | // int i = start_val; |
153 | // while (i < end_val) { |
154 | // body(i); |
155 | // i += blockDim.x; |
156 | // } |
157 | Stmt *MakeMeshBlockLocal::create_xlogue( |
158 | Stmt *start_val, |
159 | Stmt *end_val, |
160 | std::function<void(Block * /*block*/, Stmt * /*idx_val*/)> body_) { |
161 | Stmt *idx = block_->push_back<AllocaStmt>(mapping_data_type_); |
162 | [[maybe_unused]] Stmt *init_val = |
163 | block_->push_back<LocalStoreStmt>(idx, start_val); |
164 | Stmt *block_dim_val; |
165 | if (config_.arch == Arch::x64 || config_.arch == Arch::arm64) { |
166 | block_dim_val = block_->push_back<ConstStmt>(TypedConstant(1)); |
167 | } else { |
168 | block_dim_val = |
169 | block_->push_back<ConstStmt>(TypedConstant{offload_->block_dim}); |
170 | } |
171 | |
172 | std::unique_ptr<Block> body = std::make_unique<Block>(); |
173 | { |
174 | Stmt *idx_val = body->push_back<LocalLoadStmt>(idx); |
175 | Stmt *cond = |
176 | body->push_back<BinaryOpStmt>(BinaryOpType::cmp_lt, idx_val, end_val); |
177 | body->push_back<WhileControlStmt>(nullptr, cond); |
178 | body_(body.get(), idx_val); |
179 | Stmt *idx_val_ = body->push_back<BinaryOpStmt>(BinaryOpType::add, idx_val, |
180 | block_dim_val); |
181 | [[maybe_unused]] Stmt *idx_store = |
182 | body->push_back<LocalStoreStmt>(idx, idx_val_); |
183 | } |
184 | block_->push_back<WhileStmt>(std::move(body)); |
185 | Stmt *idx_val = block_->push_back<LocalLoadStmt>(idx); |
186 | return idx_val; |
187 | } |
188 | |
189 | // This function creates loop like: |
190 | // int i = start_val; |
191 | // while (i < end_val) { |
192 | // mapping_shared[i] = global_val(i); |
193 | // i += blockDim.x; |
194 | // } |
195 | Stmt *MakeMeshBlockLocal::create_cache_mapping( |
196 | Stmt *start_val, |
197 | Stmt *end_val, |
198 | std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)> global_val) { |
199 | Stmt *bls_element_offset_bytes = block_->push_back<ConstStmt>( |
200 | TypedConstant{(int32)mapping_bls_offset_in_bytes_}); |
201 | return create_xlogue(start_val, end_val, [&](Block *body, Stmt *idx_val) { |
202 | Stmt *idx_val_byte = body->push_back<BinaryOpStmt>( |
203 | BinaryOpType::mul, idx_val, |
204 | body->push_back<ConstStmt>(TypedConstant(mapping_dtype_size_))); |
205 | Stmt *offset = body->push_back<BinaryOpStmt>( |
206 | BinaryOpType::add, bls_element_offset_bytes, idx_val_byte); |
207 | Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>( |
208 | offset, |
209 | TypeFactory::get_instance().get_pointer_type(mapping_data_type_)); |
210 | Stmt *casted_val = body->push_back<UnaryOpStmt>(UnaryOpType::cast_value, |
211 | global_val(body, idx_val)); |
212 | casted_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
213 | [[maybe_unused]] Stmt *bls_store = |
214 | body->push_back<GlobalStoreStmt>(bls_ptr, casted_val); |
215 | }); |
216 | } |
217 | |
218 | void MakeMeshBlockLocal::fetch_attr_to_bls(Block *body, |
219 | Stmt *idx_val, |
220 | Stmt *mapping_val) { |
221 | auto attrs = rec_.find(std::make_pair(element_type_, conv_type_)); |
222 | if (attrs == rec_.end()) { |
223 | return; |
224 | } |
225 | for (auto [snode, total_flags] : attrs->second) { |
226 | auto data_type = snode->dt.ptr_removed(); |
227 | auto dtype_size = data_type_size(data_type); |
228 | |
229 | bool bls_has_read = total_flags & AccessFlag::read; |
230 | bool bls_has_write = total_flags & AccessFlag::write; |
231 | bool bls_has_accumulate = total_flags & AccessFlag::accumulate; |
232 | |
233 | TI_ASSERT_INFO(!bls_has_write, "BLS with write accesses is not supported." ); |
234 | TI_ASSERT_INFO(!(bls_has_accumulate && bls_has_read), |
235 | "BLS with both read and accumulation is not supported." ); |
236 | |
237 | bool first_allocate = {false}; |
238 | if (attr_bls_offset_in_bytes_.find(snode) == |
239 | attr_bls_offset_in_bytes_.end()) { |
240 | first_allocate = {true}; |
241 | bls_offset_in_bytes_ += |
242 | (dtype_size - bls_offset_in_bytes_ % dtype_size) % dtype_size; |
243 | attr_bls_offset_in_bytes_.insert( |
244 | std::make_pair(snode, bls_offset_in_bytes_)); |
245 | bls_offset_in_bytes_ += |
246 | dtype_size * |
247 | offload_->mesh->patch_max_element_num.find(element_type_)->second; |
248 | } |
249 | auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second; |
250 | |
251 | Stmt *value{nullptr}; |
252 | if (bls_has_read) { |
253 | // Read access |
254 | // Fetch from global to BLS |
255 | Stmt *global_ptr = body->push_back<GlobalPtrStmt>( |
256 | snode, std::vector<Stmt *>{mapping_val}); |
257 | value = body->push_back<GlobalLoadStmt>(global_ptr); |
258 | } else { |
259 | // Accumulation access |
260 | // Zero-fill |
261 | value = body->push_back<ConstStmt>(TypedConstant(data_type, 0)); |
262 | } |
263 | |
264 | Stmt *offset = |
265 | body->push_back<ConstStmt>(TypedConstant(int32(offset_in_bytes))); |
266 | Stmt *idx_val_byte = body->push_back<BinaryOpStmt>( |
267 | BinaryOpType::mul, idx_val, |
268 | body->push_back<ConstStmt>(TypedConstant(dtype_size))); |
269 | Stmt *index = |
270 | body->push_back<BinaryOpStmt>(BinaryOpType::add, offset, idx_val_byte); |
271 | Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>( |
272 | index, TypeFactory::get_instance().get_pointer_type(data_type)); |
273 | body->push_back<GlobalStoreStmt>(bls_ptr, value); |
274 | |
275 | // Step 3-2-1: |
276 | // Make loop body load from BLS instead of global fields |
277 | // NOTE that first_allocate ensures this step only do ONCE |
278 | if (first_allocate) { |
279 | replace_global_ptrs(snode); |
280 | } |
281 | } |
282 | } |
283 | |
284 | void MakeMeshBlockLocal::push_attr_to_global(Block *body, |
285 | Stmt *idx_val, |
286 | Stmt *mapping_val) { |
287 | auto attrs = rec_.find(std::make_pair(element_type_, conv_type_)); |
288 | if (attrs == rec_.end()) { |
289 | return; |
290 | } |
291 | for (auto [snode, total_flags] : attrs->second) { |
292 | bool bls_has_accumulate = total_flags & AccessFlag::accumulate; |
293 | if (!bls_has_accumulate) { |
294 | continue; |
295 | } |
296 | auto data_type = snode->dt.ptr_removed(); |
297 | auto dtype_size = data_type_size(data_type); |
298 | auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second; |
299 | |
300 | Stmt *offset = |
301 | body->push_back<ConstStmt>(TypedConstant(int32(offset_in_bytes))); |
302 | Stmt *idx_val_byte = body->push_back<BinaryOpStmt>( |
303 | BinaryOpType::mul, idx_val, |
304 | body->push_back<ConstStmt>(TypedConstant(dtype_size))); |
305 | Stmt *index = |
306 | body->push_back<BinaryOpStmt>(BinaryOpType::add, offset, idx_val_byte); |
307 | Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>( |
308 | index, TypeFactory::get_instance().get_pointer_type(data_type)); |
309 | Stmt *bls_val = body->push_back<GlobalLoadStmt>(bls_ptr); |
310 | |
311 | Stmt *global_ptr = |
312 | body->push_back<GlobalPtrStmt>(snode, std::vector<Stmt *>{mapping_val}); |
313 | body->push_back<AtomicOpStmt>(AtomicOpType::add, global_ptr, bls_val); |
314 | } |
315 | } |
316 | |
317 | void MakeMeshBlockLocal::fetch_mapping( |
318 | std::function< |
319 | Stmt *(Stmt * /*start_val*/, |
320 | Stmt * /*end_val*/, |
321 | std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)>)> |
322 | mapping_callback_handler, |
323 | std::function<void(Block *body, Stmt *idx_val, Stmt *mapping_val)> |
324 | attr_callback_handler) { |
325 | Stmt *thread_idx_stmt; |
326 | if (config_.arch == Arch::x64 || config_.arch == Arch::arm64) { |
327 | thread_idx_stmt = block_->push_back<ConstStmt>(TypedConstant(0)); |
328 | } else { |
329 | thread_idx_stmt = block_->push_back<LoopLinearIndexStmt>( |
330 | offload_); // Equivalent to CUDA threadIdx |
331 | } |
332 | Stmt *total_element_num = |
333 | offload_->total_num_local.find(element_type_)->second; |
334 | Stmt *total_element_offset = |
335 | offload_->total_offset_local.find(element_type_)->second; |
336 | |
337 | if (config_.optimize_mesh_reordered_mapping && |
338 | conv_type_ == mesh::ConvType::l2r) { |
339 | // int i = threadIdx.x; |
340 | // while (i < owned_{}_num) { |
341 | // mapping_shared[i] = i + owned_{}_offset; |
342 | // { |
343 | // x0_shared[i] = x0[mapping_shared[i]]; |
344 | // ... |
345 | // } |
346 | // i += blockDim.x; |
347 | // } |
348 | // while (i < total_{}_num) { |
349 | // mapping_shared[i] = mapping[i + total_{}_offset]; |
350 | // { |
351 | // x0_shared[i] = x0[mapping_shared[i]]; |
352 | // ... |
353 | // } |
354 | // i += blockDim.x; |
355 | // } |
356 | Stmt *owned_element_num = |
357 | offload_->owned_num_local.find(element_type_)->second; |
358 | Stmt *owned_element_offset = |
359 | offload_->owned_offset_local.find(element_type_)->second; |
360 | Stmt *pre_idx_val = mapping_callback_handler( |
361 | thread_idx_stmt, owned_element_num, [&](Block *body, Stmt *idx_val) { |
362 | Stmt *global_index = body->push_back<BinaryOpStmt>( |
363 | BinaryOpType::add, idx_val, owned_element_offset); |
364 | attr_callback_handler(body, idx_val, global_index); |
365 | return global_index; |
366 | }); |
367 | mapping_callback_handler( |
368 | pre_idx_val, total_element_num, [&](Block *body, Stmt *idx_val) { |
369 | Stmt *global_offset = body->push_back<BinaryOpStmt>( |
370 | BinaryOpType::add, total_element_offset, idx_val); |
371 | Stmt *global_ptr = body->push_back<GlobalPtrStmt>( |
372 | mapping_snode_, std::vector<Stmt *>{global_offset}); |
373 | Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr); |
374 | Stmt *casted_global_load = body->push_back<UnaryOpStmt>( |
375 | UnaryOpType::cast_value, global_load); |
376 | casted_global_load->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
377 | attr_callback_handler(body, idx_val, casted_global_load); |
378 | return casted_global_load; |
379 | }); |
380 | } else { |
381 | // int i = threadIdx.x; |
382 | // while (i < total_{}_num) { |
383 | // mapping_shared[i] = mapping[i + total_{}_offset]; |
384 | // { |
385 | // x0_shared[i] = x0[mapping_shared[i]]; |
386 | // ... |
387 | // } |
388 | // i += blockDim.x; |
389 | // } |
390 | mapping_callback_handler( |
391 | thread_idx_stmt, total_element_num, [&](Block *body, Stmt *idx_val) { |
392 | Stmt *global_offset = body->push_back<BinaryOpStmt>( |
393 | BinaryOpType::add, total_element_offset, idx_val); |
394 | Stmt *global_ptr = body->push_back<GlobalPtrStmt>( |
395 | mapping_snode_, std::vector<Stmt *>{global_offset}); |
396 | Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr); |
397 | Stmt *casted_global_load = body->push_back<UnaryOpStmt>( |
398 | UnaryOpType::cast_value, global_load); |
399 | casted_global_load->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
400 | attr_callback_handler(body, idx_val, casted_global_load); |
401 | return casted_global_load; |
402 | }); |
403 | } |
404 | } |
405 | |
406 | MakeMeshBlockLocal::MakeMeshBlockLocal(OffloadedStmt *offload, |
407 | const CompileConfig &config) |
408 | : config_(config), offload_(offload) { |
409 | // Step 0: simplify l2g + g2r -> l2r |
410 | simplify_nested_conversion(); |
411 | |
412 | // Step 1: A analyzer to determine which mapping should be localized |
413 | mappings_.clear(); |
414 | gather_candidate_mapping(); |
415 | |
416 | // Step 1: use Mesh BLS analyzer to gather which mesh attributes user declared |
417 | // to cache |
418 | bool auto_mesh_local = config.experimental_auto_mesh_local; |
419 | if (offload->major_to_types.size() != |
420 | 1 || // not support multiple major relations yet |
421 | offload->minor_relation_types.size() > |
422 | 0 || // not support minor relations yet |
423 | offload->mem_access_opt.get_snodes_with_flag(SNodeAccessFlag::mesh_local) |
424 | .size() > 0) { // disable when user determine which attributes to |
425 | // be cached manually |
426 | auto_mesh_local = false; |
427 | } |
428 | auto caches = irpass::analysis::initialize_mesh_local_attribute( |
429 | offload, auto_mesh_local, config); |
430 | |
431 | if (auto_mesh_local && config.arch == Arch::cuda) { |
432 | const auto to_type = *offload->major_to_types.begin(); |
433 | std::size_t shared_mem_size_per_block = |
434 | default_shared_mem_size / config.auto_mesh_local_default_occupacy; |
435 | int available_bytes = |
436 | shared_mem_size_per_block / |
437 | offload->mesh->patch_max_element_num.find(to_type)->second; |
438 | if (mappings_.find(std::make_pair(to_type, mesh::ConvType::l2g)) != |
439 | mappings_.end()) { |
440 | available_bytes -= 4; |
441 | } |
442 | if (mappings_.find(std::make_pair(to_type, mesh::ConvType::l2r)) != |
443 | mappings_.end()) { |
444 | available_bytes -= 4; |
445 | } |
446 | TI_TRACE("available cache attributes bytes = {}" , available_bytes); |
447 | TI_TRACE("caches size = {}" , caches->caches.size()); |
448 | std::vector<MeshBLSCache> priority_caches; |
449 | for (const auto &[snode, cache] : caches->caches) { |
450 | priority_caches.push_back(cache); |
451 | } |
452 | std::sort(priority_caches.begin(), priority_caches.end(), |
453 | [](const MeshBLSCache &a, const MeshBLSCache &b) { |
454 | return a.total_flags > b.total_flags || |
455 | (a.total_flags == b.total_flags && |
456 | a.loop_index > b.loop_index) || |
457 | (a.total_flags == b.total_flags && |
458 | a.loop_index == b.loop_index && |
459 | a.unique_accessed > b.unique_accessed); |
460 | }); |
461 | caches->caches.clear(); |
462 | for (const auto &cache : priority_caches) { |
463 | available_bytes -= data_type_size(cache.snode->dt); |
464 | if (available_bytes < 0) { |
465 | break; // not enough space to ensure occupacy |
466 | } |
467 | TI_TRACE("available = {}, x = {}, loop_index = {}, unique_access = {}" , |
468 | available_bytes, cache.total_flags, int(cache.loop_index), |
469 | cache.unique_accessed); |
470 | caches->caches.insert(std::make_pair(cache.snode, cache)); |
471 | } |
472 | } |
473 | rec_ = caches->finalize(); |
474 | |
475 | // If a mesh attribute is in bls, the config makes its index mapping must also |
476 | // be in bls |
477 | if (config.mesh_localize_all_attr_mappings && |
478 | !config.experimental_auto_mesh_local) { |
479 | for (auto [mapping, attr_set] : rec_) { |
480 | if (mappings_.find(mapping) == mappings_.end()) { |
481 | mappings_.insert(mapping); |
482 | } |
483 | } |
484 | } |
485 | |
486 | auto has_acc = [&](mesh::MeshElementType element_type, |
487 | mesh::ConvType conv_type) { |
488 | auto ptr = rec_.find(std::make_pair(element_type, conv_type)); |
489 | if (ptr == rec_.end()) { |
490 | return false; |
491 | } |
492 | bool has_accumulate = {false}; |
493 | for (auto [snode, total_flags] : ptr->second) { |
494 | has_accumulate |= (total_flags & AccessFlag::accumulate); |
495 | } |
496 | return has_accumulate; |
497 | }; |
498 | |
499 | // Step 3: Cache the mappings and the attributes |
500 | bls_offset_in_bytes_ = offload->bls_size; |
501 | if (offload->bls_prologue == nullptr) { |
502 | offload->bls_prologue = std::make_unique<Block>(); |
503 | offload->bls_prologue->parent_stmt = offload; |
504 | } |
505 | if (offload->bls_epilogue == nullptr) { |
506 | offload->bls_epilogue = std::make_unique<Block>(); |
507 | offload->bls_epilogue->parent_stmt = offload; |
508 | } |
509 | |
510 | // Cache both mappings and mesh attribute |
511 | for (auto [element_type, conv_type] : mappings_) { |
512 | this->element_type_ = element_type; |
513 | this->conv_type_ = conv_type; |
514 | TI_ASSERT(conv_type != mesh::ConvType::g2r); // g2r will not be cached. |
515 | // There is not corresponding mesh element attribute read/write, |
516 | // It's useless to localize this mapping |
517 | if (offload->total_offset_local.find(element_type) == |
518 | offload->total_offset_local.end()) { |
519 | continue; |
520 | } |
521 | |
522 | mapping_snode_ = (offload->mesh->index_mapping |
523 | .find(std::make_pair(element_type, conv_type)) |
524 | ->second); |
525 | // mapping_data_type_ = mapping_snode_->dt.ptr_removed(); |
526 | mapping_data_type_ = PrimitiveType::i32; |
527 | mapping_dtype_size_ = data_type_size(mapping_data_type_); |
528 | |
529 | // Ensure BLS alignment |
530 | bls_offset_in_bytes_ += |
531 | (mapping_dtype_size_ - bls_offset_in_bytes_ % mapping_dtype_size_) % |
532 | mapping_dtype_size_; |
533 | mapping_bls_offset_in_bytes_ = bls_offset_in_bytes_; |
534 | // allocate storage for the BLS variable |
535 | bls_offset_in_bytes_ += |
536 | mapping_dtype_size_ * |
537 | offload->mesh->patch_max_element_num.find(element_type)->second; |
538 | |
539 | // Step 3-1: |
540 | // Fetch index mapping to the BLS block |
541 | // Step 3-2 |
542 | // Fetch mesh attributes to the BLS block at the same time |
543 | // TODO(changyu): better way to use lambda |
544 | block_ = offload->bls_prologue.get(); |
545 | fetch_mapping( |
546 | [&](Stmt *start_val, Stmt *end_val, |
547 | std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)> |
548 | global_val) { |
549 | return create_cache_mapping(start_val, end_val, global_val); |
550 | }, |
551 | [&](Block *body, Stmt *idx_val, Stmt *mapping_val) { |
552 | fetch_attr_to_bls(body, idx_val, mapping_val); |
553 | }); |
554 | |
555 | // Step 3-3: |
556 | // Make mesh index mapping load from BLS instead of global fields |
557 | replace_conv_statements(); |
558 | |
559 | // Step 3-4 |
560 | // Atomic-add BLS contribution to its global version if necessary |
561 | if (!has_acc(element_type, conv_type)) { |
562 | continue; |
563 | } |
564 | block_ = offload->bls_epilogue.get(); |
565 | { |
566 | Stmt *thread_idx_stmt = block_->push_back<LoopLinearIndexStmt>( |
567 | offload); // Equivalent to CUDA threadIdx |
568 | Stmt *total_element_num = |
569 | offload->total_num_local.find(element_type)->second; |
570 | [[maybe_unused]] Stmt *total_element_offset = |
571 | offload->total_offset_local.find(element_type)->second; |
572 | create_xlogue( |
573 | thread_idx_stmt, total_element_num, [&](Block *body, Stmt *idx_val) { |
574 | Stmt *bls_element_offset_bytes = body->push_back<ConstStmt>( |
575 | TypedConstant{(int32)mapping_bls_offset_in_bytes_}); |
576 | Stmt *idx_byte = body->push_back<BinaryOpStmt>( |
577 | BinaryOpType::mul, idx_val, |
578 | body->push_back<ConstStmt>(TypedConstant(mapping_dtype_size_))); |
579 | Stmt *offset = body->push_back<BinaryOpStmt>( |
580 | BinaryOpType::add, bls_element_offset_bytes, idx_byte); |
581 | Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>( |
582 | offset, TypeFactory::get_instance().get_pointer_type( |
583 | mapping_data_type_)); |
584 | Stmt *global_val = body->push_back<GlobalLoadStmt>(bls_ptr); |
585 | this->push_attr_to_global(body, idx_val, global_val); |
586 | }); |
587 | } |
588 | } |
589 | |
590 | // Cache mesh attribute only |
591 | for (auto [mapping, attr_set] : rec_) { |
592 | if (mappings_.find(mapping) != mappings_.end()) { |
593 | continue; |
594 | } |
595 | |
596 | this->element_type_ = mapping.first; |
597 | this->conv_type_ = mapping.second; |
598 | TI_ASSERT(conv_type_ != mesh::ConvType::g2r); // g2r will not be cached. |
599 | |
600 | mapping_snode_ = (offload->mesh->index_mapping |
601 | .find(std::make_pair(element_type_, conv_type_)) |
602 | ->second); |
603 | mapping_data_type_ = mapping_snode_->dt.ptr_removed(); |
604 | mapping_dtype_size_ = data_type_size(mapping_data_type_); |
605 | |
606 | // Step 3-1 |
607 | // Only fetch mesh attributes to the BLS block |
608 | // TODO(changyu): better way to use lambda |
609 | block_ = offload->bls_prologue.get(); |
610 | fetch_mapping( |
611 | [&](Stmt *start_val, Stmt *end_val, |
612 | std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)> |
613 | global_val) { |
614 | return create_xlogue( |
615 | start_val, end_val, |
616 | [&](Block *block, Stmt *idx_val) { global_val(block, idx_val); }); |
617 | }, |
618 | [&](Block *body, Stmt *idx_val, Stmt *mapping_val) { |
619 | fetch_attr_to_bls(body, idx_val, mapping_val); |
620 | }); |
621 | |
622 | // Step 3-2 |
623 | // Atomic-add BLS contribution to its global version if necessary |
624 | if (!has_acc(element_type_, conv_type_)) { |
625 | continue; |
626 | } |
627 | block_ = offload->bls_epilogue.get(); |
628 | fetch_mapping( |
629 | [&](Stmt *start_val, Stmt *end_val, |
630 | std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)> |
631 | global_val) { |
632 | return create_xlogue( |
633 | start_val, end_val, |
634 | [&](Block *block, Stmt *idx_val) { global_val(block, idx_val); }); |
635 | }, |
636 | [&](Block *body, Stmt *idx_val, Stmt *mapping_val) { |
637 | push_attr_to_global(body, idx_val, mapping_val); |
638 | }); |
639 | } |
640 | |
641 | offload->bls_size = std::max(std::size_t(1), bls_offset_in_bytes_); |
642 | } |
643 | |
644 | void MakeMeshBlockLocal::run(OffloadedStmt *offload, |
645 | const CompileConfig &config, |
646 | const std::string &kernel_name) { |
647 | if (offload->task_type != OffloadedStmt::TaskType::mesh_for) { |
648 | return; |
649 | } |
650 | |
651 | MakeMeshBlockLocal(offload, config); |
652 | } |
653 | |
654 | namespace irpass { |
655 | |
656 | // This pass should happen after offloading but before lower_access |
657 | void make_mesh_block_local(IRNode *root, |
658 | const CompileConfig &config, |
659 | const MakeMeshBlockLocal::Args &args) { |
660 | TI_AUTO_PROF; |
661 | |
662 | // ========================================================================================= |
663 | // This pass generates code like this: |
664 | // // Load V_l2g |
665 | // for (int i = threadIdx.x; i < total_vertices; i += blockDim.x) { |
666 | // V_l2g[i] = _V_l2g[i + total_vertices_offset]; |
667 | // sx[i] = x[V_l2g[i]]; |
668 | // sJ[i] = 0.0f; |
669 | // } |
670 | |
671 | if (auto root_block = root->cast<Block>()) { |
672 | for (auto &offload : root_block->statements) { |
673 | MakeMeshBlockLocal::run(offload->cast<OffloadedStmt>(), config, |
674 | args.kernel_name); |
675 | } |
676 | } else { |
677 | MakeMeshBlockLocal::run(root->as<OffloadedStmt>(), config, |
678 | args.kernel_name); |
679 | } |
680 | |
681 | type_check(root, config); |
682 | } |
683 | |
684 | } // namespace irpass |
685 | } // namespace taichi::lang |
686 | |