1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/transforms/make_mesh_block_local.h"
6
7namespace taichi::lang {
8
9const PassID MakeMeshBlockLocal::id = "MakeMeshBlockLocal";
10
11void MakeMeshBlockLocal::simplify_nested_conversion() {
12 std::vector<MeshIndexConversionStmt *> stmts;
13 std::vector<Stmt *> ori_indices;
14
15 irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) {
16 if (auto conv1 = stmt->cast<MeshIndexConversionStmt>()) {
17 if (auto conv2 = conv1->idx->cast<MeshIndexConversionStmt>()) {
18 if (conv1->conv_type == mesh::ConvType::g2r &&
19 conv2->conv_type == mesh::ConvType::l2g &&
20 conv1->mesh == conv2->mesh &&
21 conv1->idx_type == conv2->idx_type) { // nested
22 stmts.push_back(conv1);
23 ori_indices.push_back(conv2->idx);
24 }
25 }
26 }
27 return false;
28 });
29
30 for (size_t i = 0; i < stmts.size(); ++i) {
31 stmts[i]->replace_with(Stmt::make<MeshIndexConversionStmt>(
32 stmts[i]->mesh, stmts[i]->idx_type, ori_indices[i],
33 mesh::ConvType::l2r));
34 }
35}
36
37void MakeMeshBlockLocal::gather_candidate_mapping() {
38 irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) {
39 if (auto conv = stmt->cast<MeshIndexConversionStmt>()) {
40 if (conv->conv_type != mesh::ConvType::g2r) {
41 bool is_from_end = (conv->idx_type == offload_->major_from_type);
42 bool is_to_end = false;
43 for (auto type : offload_->major_to_types) {
44 is_to_end |= (conv->idx_type == type);
45 }
46 for (auto rel : offload_->minor_relation_types) {
47 auto from_type =
48 mesh::MeshElementType(mesh::from_end_element_order(rel));
49 auto to_type = mesh::MeshElementType(mesh::to_end_element_order(rel));
50 is_from_end |= (conv->idx_type == from_type);
51 is_to_end |= (conv->idx_type == to_type);
52 }
53 if ((is_to_end && config_.mesh_localize_to_end_mapping) ||
54 (is_from_end && config_.mesh_localize_from_end_mapping)) {
55 mappings_.insert(std::make_pair(conv->idx_type, conv->conv_type));
56 }
57 }
58 }
59 return false;
60 });
61}
62
63void MakeMeshBlockLocal::replace_conv_statements() {
64 std::vector<MeshIndexConversionStmt *> idx_conv_stmts;
65
66 irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) {
67 if (auto idx_conv = stmt->cast<MeshIndexConversionStmt>()) {
68 if (idx_conv->mesh == offload_->mesh &&
69 idx_conv->conv_type == conv_type_ &&
70 idx_conv->idx_type == element_type_) {
71 idx_conv_stmts.push_back(idx_conv);
72 }
73 }
74 return false;
75 });
76
77 for (auto stmt : idx_conv_stmts) {
78 VecStatement bls;
79 Stmt *bls_element_offset_bytes = bls.push_back<ConstStmt>(
80 TypedConstant{(int32)mapping_bls_offset_in_bytes_});
81 Stmt *idx_byte = bls.push_back<BinaryOpStmt>(
82 BinaryOpType::mul, stmt->idx,
83 bls.push_back<ConstStmt>(TypedConstant(mapping_dtype_size_)));
84 Stmt *offset = bls.push_back<BinaryOpStmt>(
85 BinaryOpType::add, bls_element_offset_bytes, idx_byte);
86 Stmt *bls_ptr = bls.push_back<BlockLocalPtrStmt>(
87 offset,
88 TypeFactory::get_instance().get_pointer_type(mapping_data_type_));
89 [[maybe_unused]] Stmt *bls_load = bls.push_back<GlobalLoadStmt>(bls_ptr);
90 stmt->replace_with(std::move(bls));
91 }
92}
93
94void MakeMeshBlockLocal::replace_global_ptrs(SNode *snode) {
95 auto data_type = snode->dt.ptr_removed();
96 auto dtype_size = data_type_size(data_type);
97 auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second;
98
99 std::vector<GlobalPtrStmt *> global_ptrs;
100 irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) {
101 if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
102 if (global_ptr->snode == snode &&
103 global_ptr->indices[0]->is<MeshIndexConversionStmt>()) {
104 global_ptrs.push_back(global_ptr);
105 }
106 }
107 return false;
108 });
109
110 for (auto global_ptr : global_ptrs) {
111 VecStatement bls;
112 Stmt *local_idx =
113 global_ptr->indices[0]->as<MeshIndexConversionStmt>()->idx;
114 Stmt *local_idx_byte = bls.push_back<BinaryOpStmt>(
115 BinaryOpType::mul, local_idx,
116 bls.push_back<ConstStmt>(TypedConstant(dtype_size)));
117 Stmt *offset =
118 bls.push_back<ConstStmt>(TypedConstant(int32(offset_in_bytes)));
119 Stmt *index =
120 bls.push_back<BinaryOpStmt>(BinaryOpType::add, offset, local_idx_byte);
121 [[maybe_unused]] Stmt *bls_ptr = bls.push_back<BlockLocalPtrStmt>(
122 index, TypeFactory::get_instance().get_pointer_type(data_type));
123 global_ptr->replace_with(std::move(bls));
124 }
125
126 // in the cpu backend, atomic op in body block could be demoted to non-atomic
127 if (config_.arch != Arch::x64 && config_.arch != Arch::arm64) {
128 return;
129 }
130 std::vector<AtomicOpStmt *> atomic_ops;
131 irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) {
132 if (auto atomic_op = stmt->cast<AtomicOpStmt>()) {
133 if (atomic_op->op_type == AtomicOpType::add &&
134 atomic_op->dest->is<BlockLocalPtrStmt>()) {
135 atomic_ops.push_back(atomic_op);
136 }
137 }
138 return false;
139 });
140
141 for (auto atomic_op : atomic_ops) {
142 VecStatement non_atomic;
143 Stmt *dest_val = non_atomic.push_back<GlobalLoadStmt>(atomic_op->dest);
144 Stmt *res_val = non_atomic.push_back<BinaryOpStmt>(
145 BinaryOpType::add, dest_val, atomic_op->val);
146 non_atomic.push_back<GlobalStoreStmt>(atomic_op->dest, res_val);
147 atomic_op->replace_with(std::move(non_atomic));
148 }
149}
150
151// This function creates loop like:
152// int i = start_val;
153// while (i < end_val) {
154// body(i);
155// i += blockDim.x;
156// }
157Stmt *MakeMeshBlockLocal::create_xlogue(
158 Stmt *start_val,
159 Stmt *end_val,
160 std::function<void(Block * /*block*/, Stmt * /*idx_val*/)> body_) {
161 Stmt *idx = block_->push_back<AllocaStmt>(mapping_data_type_);
162 [[maybe_unused]] Stmt *init_val =
163 block_->push_back<LocalStoreStmt>(idx, start_val);
164 Stmt *block_dim_val;
165 if (config_.arch == Arch::x64 || config_.arch == Arch::arm64) {
166 block_dim_val = block_->push_back<ConstStmt>(TypedConstant(1));
167 } else {
168 block_dim_val =
169 block_->push_back<ConstStmt>(TypedConstant{offload_->block_dim});
170 }
171
172 std::unique_ptr<Block> body = std::make_unique<Block>();
173 {
174 Stmt *idx_val = body->push_back<LocalLoadStmt>(idx);
175 Stmt *cond =
176 body->push_back<BinaryOpStmt>(BinaryOpType::cmp_lt, idx_val, end_val);
177 body->push_back<WhileControlStmt>(nullptr, cond);
178 body_(body.get(), idx_val);
179 Stmt *idx_val_ = body->push_back<BinaryOpStmt>(BinaryOpType::add, idx_val,
180 block_dim_val);
181 [[maybe_unused]] Stmt *idx_store =
182 body->push_back<LocalStoreStmt>(idx, idx_val_);
183 }
184 block_->push_back<WhileStmt>(std::move(body));
185 Stmt *idx_val = block_->push_back<LocalLoadStmt>(idx);
186 return idx_val;
187}
188
189// This function creates loop like:
190// int i = start_val;
191// while (i < end_val) {
192// mapping_shared[i] = global_val(i);
193// i += blockDim.x;
194// }
195Stmt *MakeMeshBlockLocal::create_cache_mapping(
196 Stmt *start_val,
197 Stmt *end_val,
198 std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)> global_val) {
199 Stmt *bls_element_offset_bytes = block_->push_back<ConstStmt>(
200 TypedConstant{(int32)mapping_bls_offset_in_bytes_});
201 return create_xlogue(start_val, end_val, [&](Block *body, Stmt *idx_val) {
202 Stmt *idx_val_byte = body->push_back<BinaryOpStmt>(
203 BinaryOpType::mul, idx_val,
204 body->push_back<ConstStmt>(TypedConstant(mapping_dtype_size_)));
205 Stmt *offset = body->push_back<BinaryOpStmt>(
206 BinaryOpType::add, bls_element_offset_bytes, idx_val_byte);
207 Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>(
208 offset,
209 TypeFactory::get_instance().get_pointer_type(mapping_data_type_));
210 Stmt *casted_val = body->push_back<UnaryOpStmt>(UnaryOpType::cast_value,
211 global_val(body, idx_val));
212 casted_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
213 [[maybe_unused]] Stmt *bls_store =
214 body->push_back<GlobalStoreStmt>(bls_ptr, casted_val);
215 });
216}
217
218void MakeMeshBlockLocal::fetch_attr_to_bls(Block *body,
219 Stmt *idx_val,
220 Stmt *mapping_val) {
221 auto attrs = rec_.find(std::make_pair(element_type_, conv_type_));
222 if (attrs == rec_.end()) {
223 return;
224 }
225 for (auto [snode, total_flags] : attrs->second) {
226 auto data_type = snode->dt.ptr_removed();
227 auto dtype_size = data_type_size(data_type);
228
229 bool bls_has_read = total_flags & AccessFlag::read;
230 bool bls_has_write = total_flags & AccessFlag::write;
231 bool bls_has_accumulate = total_flags & AccessFlag::accumulate;
232
233 TI_ASSERT_INFO(!bls_has_write, "BLS with write accesses is not supported.");
234 TI_ASSERT_INFO(!(bls_has_accumulate && bls_has_read),
235 "BLS with both read and accumulation is not supported.");
236
237 bool first_allocate = {false};
238 if (attr_bls_offset_in_bytes_.find(snode) ==
239 attr_bls_offset_in_bytes_.end()) {
240 first_allocate = {true};
241 bls_offset_in_bytes_ +=
242 (dtype_size - bls_offset_in_bytes_ % dtype_size) % dtype_size;
243 attr_bls_offset_in_bytes_.insert(
244 std::make_pair(snode, bls_offset_in_bytes_));
245 bls_offset_in_bytes_ +=
246 dtype_size *
247 offload_->mesh->patch_max_element_num.find(element_type_)->second;
248 }
249 auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second;
250
251 Stmt *value{nullptr};
252 if (bls_has_read) {
253 // Read access
254 // Fetch from global to BLS
255 Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
256 snode, std::vector<Stmt *>{mapping_val});
257 value = body->push_back<GlobalLoadStmt>(global_ptr);
258 } else {
259 // Accumulation access
260 // Zero-fill
261 value = body->push_back<ConstStmt>(TypedConstant(data_type, 0));
262 }
263
264 Stmt *offset =
265 body->push_back<ConstStmt>(TypedConstant(int32(offset_in_bytes)));
266 Stmt *idx_val_byte = body->push_back<BinaryOpStmt>(
267 BinaryOpType::mul, idx_val,
268 body->push_back<ConstStmt>(TypedConstant(dtype_size)));
269 Stmt *index =
270 body->push_back<BinaryOpStmt>(BinaryOpType::add, offset, idx_val_byte);
271 Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>(
272 index, TypeFactory::get_instance().get_pointer_type(data_type));
273 body->push_back<GlobalStoreStmt>(bls_ptr, value);
274
275 // Step 3-2-1:
276 // Make loop body load from BLS instead of global fields
277 // NOTE that first_allocate ensures this step only do ONCE
278 if (first_allocate) {
279 replace_global_ptrs(snode);
280 }
281 }
282}
283
284void MakeMeshBlockLocal::push_attr_to_global(Block *body,
285 Stmt *idx_val,
286 Stmt *mapping_val) {
287 auto attrs = rec_.find(std::make_pair(element_type_, conv_type_));
288 if (attrs == rec_.end()) {
289 return;
290 }
291 for (auto [snode, total_flags] : attrs->second) {
292 bool bls_has_accumulate = total_flags & AccessFlag::accumulate;
293 if (!bls_has_accumulate) {
294 continue;
295 }
296 auto data_type = snode->dt.ptr_removed();
297 auto dtype_size = data_type_size(data_type);
298 auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second;
299
300 Stmt *offset =
301 body->push_back<ConstStmt>(TypedConstant(int32(offset_in_bytes)));
302 Stmt *idx_val_byte = body->push_back<BinaryOpStmt>(
303 BinaryOpType::mul, idx_val,
304 body->push_back<ConstStmt>(TypedConstant(dtype_size)));
305 Stmt *index =
306 body->push_back<BinaryOpStmt>(BinaryOpType::add, offset, idx_val_byte);
307 Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>(
308 index, TypeFactory::get_instance().get_pointer_type(data_type));
309 Stmt *bls_val = body->push_back<GlobalLoadStmt>(bls_ptr);
310
311 Stmt *global_ptr =
312 body->push_back<GlobalPtrStmt>(snode, std::vector<Stmt *>{mapping_val});
313 body->push_back<AtomicOpStmt>(AtomicOpType::add, global_ptr, bls_val);
314 }
315}
316
317void MakeMeshBlockLocal::fetch_mapping(
318 std::function<
319 Stmt *(Stmt * /*start_val*/,
320 Stmt * /*end_val*/,
321 std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)>)>
322 mapping_callback_handler,
323 std::function<void(Block *body, Stmt *idx_val, Stmt *mapping_val)>
324 attr_callback_handler) {
325 Stmt *thread_idx_stmt;
326 if (config_.arch == Arch::x64 || config_.arch == Arch::arm64) {
327 thread_idx_stmt = block_->push_back<ConstStmt>(TypedConstant(0));
328 } else {
329 thread_idx_stmt = block_->push_back<LoopLinearIndexStmt>(
330 offload_); // Equivalent to CUDA threadIdx
331 }
332 Stmt *total_element_num =
333 offload_->total_num_local.find(element_type_)->second;
334 Stmt *total_element_offset =
335 offload_->total_offset_local.find(element_type_)->second;
336
337 if (config_.optimize_mesh_reordered_mapping &&
338 conv_type_ == mesh::ConvType::l2r) {
339 // int i = threadIdx.x;
340 // while (i < owned_{}_num) {
341 // mapping_shared[i] = i + owned_{}_offset;
342 // {
343 // x0_shared[i] = x0[mapping_shared[i]];
344 // ...
345 // }
346 // i += blockDim.x;
347 // }
348 // while (i < total_{}_num) {
349 // mapping_shared[i] = mapping[i + total_{}_offset];
350 // {
351 // x0_shared[i] = x0[mapping_shared[i]];
352 // ...
353 // }
354 // i += blockDim.x;
355 // }
356 Stmt *owned_element_num =
357 offload_->owned_num_local.find(element_type_)->second;
358 Stmt *owned_element_offset =
359 offload_->owned_offset_local.find(element_type_)->second;
360 Stmt *pre_idx_val = mapping_callback_handler(
361 thread_idx_stmt, owned_element_num, [&](Block *body, Stmt *idx_val) {
362 Stmt *global_index = body->push_back<BinaryOpStmt>(
363 BinaryOpType::add, idx_val, owned_element_offset);
364 attr_callback_handler(body, idx_val, global_index);
365 return global_index;
366 });
367 mapping_callback_handler(
368 pre_idx_val, total_element_num, [&](Block *body, Stmt *idx_val) {
369 Stmt *global_offset = body->push_back<BinaryOpStmt>(
370 BinaryOpType::add, total_element_offset, idx_val);
371 Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
372 mapping_snode_, std::vector<Stmt *>{global_offset});
373 Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr);
374 Stmt *casted_global_load = body->push_back<UnaryOpStmt>(
375 UnaryOpType::cast_value, global_load);
376 casted_global_load->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
377 attr_callback_handler(body, idx_val, casted_global_load);
378 return casted_global_load;
379 });
380 } else {
381 // int i = threadIdx.x;
382 // while (i < total_{}_num) {
383 // mapping_shared[i] = mapping[i + total_{}_offset];
384 // {
385 // x0_shared[i] = x0[mapping_shared[i]];
386 // ...
387 // }
388 // i += blockDim.x;
389 // }
390 mapping_callback_handler(
391 thread_idx_stmt, total_element_num, [&](Block *body, Stmt *idx_val) {
392 Stmt *global_offset = body->push_back<BinaryOpStmt>(
393 BinaryOpType::add, total_element_offset, idx_val);
394 Stmt *global_ptr = body->push_back<GlobalPtrStmt>(
395 mapping_snode_, std::vector<Stmt *>{global_offset});
396 Stmt *global_load = body->push_back<GlobalLoadStmt>(global_ptr);
397 Stmt *casted_global_load = body->push_back<UnaryOpStmt>(
398 UnaryOpType::cast_value, global_load);
399 casted_global_load->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
400 attr_callback_handler(body, idx_val, casted_global_load);
401 return casted_global_load;
402 });
403 }
404}
405
406MakeMeshBlockLocal::MakeMeshBlockLocal(OffloadedStmt *offload,
407 const CompileConfig &config)
408 : config_(config), offload_(offload) {
409 // Step 0: simplify l2g + g2r -> l2r
410 simplify_nested_conversion();
411
412 // Step 1: A analyzer to determine which mapping should be localized
413 mappings_.clear();
414 gather_candidate_mapping();
415
416 // Step 1: use Mesh BLS analyzer to gather which mesh attributes user declared
417 // to cache
418 bool auto_mesh_local = config.experimental_auto_mesh_local;
419 if (offload->major_to_types.size() !=
420 1 || // not support multiple major relations yet
421 offload->minor_relation_types.size() >
422 0 || // not support minor relations yet
423 offload->mem_access_opt.get_snodes_with_flag(SNodeAccessFlag::mesh_local)
424 .size() > 0) { // disable when user determine which attributes to
425 // be cached manually
426 auto_mesh_local = false;
427 }
428 auto caches = irpass::analysis::initialize_mesh_local_attribute(
429 offload, auto_mesh_local, config);
430
431 if (auto_mesh_local && config.arch == Arch::cuda) {
432 const auto to_type = *offload->major_to_types.begin();
433 std::size_t shared_mem_size_per_block =
434 default_shared_mem_size / config.auto_mesh_local_default_occupacy;
435 int available_bytes =
436 shared_mem_size_per_block /
437 offload->mesh->patch_max_element_num.find(to_type)->second;
438 if (mappings_.find(std::make_pair(to_type, mesh::ConvType::l2g)) !=
439 mappings_.end()) {
440 available_bytes -= 4;
441 }
442 if (mappings_.find(std::make_pair(to_type, mesh::ConvType::l2r)) !=
443 mappings_.end()) {
444 available_bytes -= 4;
445 }
446 TI_TRACE("available cache attributes bytes = {}", available_bytes);
447 TI_TRACE("caches size = {}", caches->caches.size());
448 std::vector<MeshBLSCache> priority_caches;
449 for (const auto &[snode, cache] : caches->caches) {
450 priority_caches.push_back(cache);
451 }
452 std::sort(priority_caches.begin(), priority_caches.end(),
453 [](const MeshBLSCache &a, const MeshBLSCache &b) {
454 return a.total_flags > b.total_flags ||
455 (a.total_flags == b.total_flags &&
456 a.loop_index > b.loop_index) ||
457 (a.total_flags == b.total_flags &&
458 a.loop_index == b.loop_index &&
459 a.unique_accessed > b.unique_accessed);
460 });
461 caches->caches.clear();
462 for (const auto &cache : priority_caches) {
463 available_bytes -= data_type_size(cache.snode->dt);
464 if (available_bytes < 0) {
465 break; // not enough space to ensure occupacy
466 }
467 TI_TRACE("available = {}, x = {}, loop_index = {}, unique_access = {}",
468 available_bytes, cache.total_flags, int(cache.loop_index),
469 cache.unique_accessed);
470 caches->caches.insert(std::make_pair(cache.snode, cache));
471 }
472 }
473 rec_ = caches->finalize();
474
475 // If a mesh attribute is in bls, the config makes its index mapping must also
476 // be in bls
477 if (config.mesh_localize_all_attr_mappings &&
478 !config.experimental_auto_mesh_local) {
479 for (auto [mapping, attr_set] : rec_) {
480 if (mappings_.find(mapping) == mappings_.end()) {
481 mappings_.insert(mapping);
482 }
483 }
484 }
485
486 auto has_acc = [&](mesh::MeshElementType element_type,
487 mesh::ConvType conv_type) {
488 auto ptr = rec_.find(std::make_pair(element_type, conv_type));
489 if (ptr == rec_.end()) {
490 return false;
491 }
492 bool has_accumulate = {false};
493 for (auto [snode, total_flags] : ptr->second) {
494 has_accumulate |= (total_flags & AccessFlag::accumulate);
495 }
496 return has_accumulate;
497 };
498
499 // Step 3: Cache the mappings and the attributes
500 bls_offset_in_bytes_ = offload->bls_size;
501 if (offload->bls_prologue == nullptr) {
502 offload->bls_prologue = std::make_unique<Block>();
503 offload->bls_prologue->parent_stmt = offload;
504 }
505 if (offload->bls_epilogue == nullptr) {
506 offload->bls_epilogue = std::make_unique<Block>();
507 offload->bls_epilogue->parent_stmt = offload;
508 }
509
510 // Cache both mappings and mesh attribute
511 for (auto [element_type, conv_type] : mappings_) {
512 this->element_type_ = element_type;
513 this->conv_type_ = conv_type;
514 TI_ASSERT(conv_type != mesh::ConvType::g2r); // g2r will not be cached.
515 // There is not corresponding mesh element attribute read/write,
516 // It's useless to localize this mapping
517 if (offload->total_offset_local.find(element_type) ==
518 offload->total_offset_local.end()) {
519 continue;
520 }
521
522 mapping_snode_ = (offload->mesh->index_mapping
523 .find(std::make_pair(element_type, conv_type))
524 ->second);
525 // mapping_data_type_ = mapping_snode_->dt.ptr_removed();
526 mapping_data_type_ = PrimitiveType::i32;
527 mapping_dtype_size_ = data_type_size(mapping_data_type_);
528
529 // Ensure BLS alignment
530 bls_offset_in_bytes_ +=
531 (mapping_dtype_size_ - bls_offset_in_bytes_ % mapping_dtype_size_) %
532 mapping_dtype_size_;
533 mapping_bls_offset_in_bytes_ = bls_offset_in_bytes_;
534 // allocate storage for the BLS variable
535 bls_offset_in_bytes_ +=
536 mapping_dtype_size_ *
537 offload->mesh->patch_max_element_num.find(element_type)->second;
538
539 // Step 3-1:
540 // Fetch index mapping to the BLS block
541 // Step 3-2
542 // Fetch mesh attributes to the BLS block at the same time
543 // TODO(changyu): better way to use lambda
544 block_ = offload->bls_prologue.get();
545 fetch_mapping(
546 [&](Stmt *start_val, Stmt *end_val,
547 std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)>
548 global_val) {
549 return create_cache_mapping(start_val, end_val, global_val);
550 },
551 [&](Block *body, Stmt *idx_val, Stmt *mapping_val) {
552 fetch_attr_to_bls(body, idx_val, mapping_val);
553 });
554
555 // Step 3-3:
556 // Make mesh index mapping load from BLS instead of global fields
557 replace_conv_statements();
558
559 // Step 3-4
560 // Atomic-add BLS contribution to its global version if necessary
561 if (!has_acc(element_type, conv_type)) {
562 continue;
563 }
564 block_ = offload->bls_epilogue.get();
565 {
566 Stmt *thread_idx_stmt = block_->push_back<LoopLinearIndexStmt>(
567 offload); // Equivalent to CUDA threadIdx
568 Stmt *total_element_num =
569 offload->total_num_local.find(element_type)->second;
570 [[maybe_unused]] Stmt *total_element_offset =
571 offload->total_offset_local.find(element_type)->second;
572 create_xlogue(
573 thread_idx_stmt, total_element_num, [&](Block *body, Stmt *idx_val) {
574 Stmt *bls_element_offset_bytes = body->push_back<ConstStmt>(
575 TypedConstant{(int32)mapping_bls_offset_in_bytes_});
576 Stmt *idx_byte = body->push_back<BinaryOpStmt>(
577 BinaryOpType::mul, idx_val,
578 body->push_back<ConstStmt>(TypedConstant(mapping_dtype_size_)));
579 Stmt *offset = body->push_back<BinaryOpStmt>(
580 BinaryOpType::add, bls_element_offset_bytes, idx_byte);
581 Stmt *bls_ptr = body->push_back<BlockLocalPtrStmt>(
582 offset, TypeFactory::get_instance().get_pointer_type(
583 mapping_data_type_));
584 Stmt *global_val = body->push_back<GlobalLoadStmt>(bls_ptr);
585 this->push_attr_to_global(body, idx_val, global_val);
586 });
587 }
588 }
589
590 // Cache mesh attribute only
591 for (auto [mapping, attr_set] : rec_) {
592 if (mappings_.find(mapping) != mappings_.end()) {
593 continue;
594 }
595
596 this->element_type_ = mapping.first;
597 this->conv_type_ = mapping.second;
598 TI_ASSERT(conv_type_ != mesh::ConvType::g2r); // g2r will not be cached.
599
600 mapping_snode_ = (offload->mesh->index_mapping
601 .find(std::make_pair(element_type_, conv_type_))
602 ->second);
603 mapping_data_type_ = mapping_snode_->dt.ptr_removed();
604 mapping_dtype_size_ = data_type_size(mapping_data_type_);
605
606 // Step 3-1
607 // Only fetch mesh attributes to the BLS block
608 // TODO(changyu): better way to use lambda
609 block_ = offload->bls_prologue.get();
610 fetch_mapping(
611 [&](Stmt *start_val, Stmt *end_val,
612 std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)>
613 global_val) {
614 return create_xlogue(
615 start_val, end_val,
616 [&](Block *block, Stmt *idx_val) { global_val(block, idx_val); });
617 },
618 [&](Block *body, Stmt *idx_val, Stmt *mapping_val) {
619 fetch_attr_to_bls(body, idx_val, mapping_val);
620 });
621
622 // Step 3-2
623 // Atomic-add BLS contribution to its global version if necessary
624 if (!has_acc(element_type_, conv_type_)) {
625 continue;
626 }
627 block_ = offload->bls_epilogue.get();
628 fetch_mapping(
629 [&](Stmt *start_val, Stmt *end_val,
630 std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)>
631 global_val) {
632 return create_xlogue(
633 start_val, end_val,
634 [&](Block *block, Stmt *idx_val) { global_val(block, idx_val); });
635 },
636 [&](Block *body, Stmt *idx_val, Stmt *mapping_val) {
637 push_attr_to_global(body, idx_val, mapping_val);
638 });
639 }
640
641 offload->bls_size = std::max(std::size_t(1), bls_offset_in_bytes_);
642}
643
644void MakeMeshBlockLocal::run(OffloadedStmt *offload,
645 const CompileConfig &config,
646 const std::string &kernel_name) {
647 if (offload->task_type != OffloadedStmt::TaskType::mesh_for) {
648 return;
649 }
650
651 MakeMeshBlockLocal(offload, config);
652}
653
654namespace irpass {
655
656// This pass should happen after offloading but before lower_access
657void make_mesh_block_local(IRNode *root,
658 const CompileConfig &config,
659 const MakeMeshBlockLocal::Args &args) {
660 TI_AUTO_PROF;
661
662 // =========================================================================================
663 // This pass generates code like this:
664 // // Load V_l2g
665 // for (int i = threadIdx.x; i < total_vertices; i += blockDim.x) {
666 // V_l2g[i] = _V_l2g[i + total_vertices_offset];
667 // sx[i] = x[V_l2g[i]];
668 // sJ[i] = 0.0f;
669 // }
670
671 if (auto root_block = root->cast<Block>()) {
672 for (auto &offload : root_block->statements) {
673 MakeMeshBlockLocal::run(offload->cast<OffloadedStmt>(), config,
674 args.kernel_name);
675 }
676 } else {
677 MakeMeshBlockLocal::run(root->as<OffloadedStmt>(), config,
678 args.kernel_name);
679 }
680
681 type_check(root, config);
682}
683
684} // namespace irpass
685} // namespace taichi::lang
686