1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/transforms/make_mesh_thread_local.h"
6
7namespace taichi::lang {
8
9const PassID MakeMeshThreadLocal::id = "MakeMeshThreadLocal";
10
11namespace irpass {
12
13void make_mesh_thread_local_offload(OffloadedStmt *offload,
14 const CompileConfig &config,
15 const std::string &kernel_name) {
16 if (offload->task_type != OffloadedStmt::TaskType::mesh_for) {
17 return;
18 }
19
20 std::pair</* owned= */ std::unordered_set<mesh::MeshElementType>,
21 /* total= */ std::unordered_set<mesh::MeshElementType>>
22 accessed = analysis::gather_mesh_thread_local(offload, config);
23
24 std::size_t tls_offset = offload->tls_size;
25
26 auto data_type = PrimitiveType::u32; // unt32_t type address
27 auto dtype_size = data_type_size(data_type);
28
29 if (offload->tls_prologue == nullptr) {
30 offload->tls_prologue = std::make_unique<Block>();
31 offload->tls_prologue->parent_stmt = offload;
32 }
33
34 if (offload->mesh_prologue == nullptr) {
35 offload->mesh_prologue = std::make_unique<Block>();
36 offload->mesh_prologue->parent_stmt = offload;
37 }
38
39 auto patch_idx =
40 offload->tls_prologue->insert(std::make_unique<MeshPatchIndexStmt>(), -1);
41 auto one = offload->tls_prologue->insert(
42 std::make_unique<ConstStmt>(TypedConstant(PrimitiveType::i32, 1)), -1);
43 auto patch_idx_1 = offload->tls_prologue->insert(
44 std::make_unique<BinaryOpStmt>(BinaryOpType::add, patch_idx, one), -1);
45
46 auto make_thread_local_store =
47 [&](mesh::MeshElementType element_type,
48 const std::unordered_map<mesh::MeshElementType, SNode *> &offset_,
49 std::unordered_map<mesh::MeshElementType, Stmt *> &offset_local,
50 std::unordered_map<mesh::MeshElementType, Stmt *> &num_local) {
51 const auto offset_tls_offset =
52 (tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size);
53 tls_offset += dtype_size; // allocate storage for the TLS variable
54
55 const auto num_tls_offset =
56 (tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size);
57 tls_offset += dtype_size;
58
59 // Step 1:
60 // Create thread local storage
61 {
62 auto offset_ptr =
63 offload->tls_prologue->push_back<ThreadLocalPtrStmt>(
64 offset_tls_offset,
65 TypeFactory::get_instance().get_pointer_type(data_type));
66 auto num_ptr = offload->tls_prologue->push_back<ThreadLocalPtrStmt>(
67 num_tls_offset,
68 TypeFactory::get_instance().get_pointer_type(data_type));
69
70 const auto offset_snode = offset_.find(element_type);
71 TI_ASSERT(offset_snode != offset_.end());
72 auto offset_globalptr = offload->tls_prologue->insert(
73 std::make_unique<GlobalPtrStmt>(offset_snode->second,
74 std::vector<Stmt *>{patch_idx}),
75 -1);
76 auto offset_load = offload->tls_prologue->insert(
77 std::make_unique<GlobalLoadStmt>(offset_globalptr), -1);
78 auto offset_1_globalptr = offload->tls_prologue->insert(
79 std::make_unique<GlobalPtrStmt>(offset_snode->second,
80 std::vector<Stmt *>{patch_idx_1}),
81 -1);
82 auto offset_1_load = offload->tls_prologue->insert(
83 std::make_unique<GlobalLoadStmt>(offset_1_globalptr), -1);
84 auto num_load = offload->tls_prologue->insert(
85 std::make_unique<BinaryOpStmt>(BinaryOpType::sub, offset_1_load,
86 offset_load),
87 -1);
88
89 // TODO: do not use GlobalStore for TLS ptr.
90 offload->tls_prologue->push_back<GlobalStoreStmt>(offset_ptr,
91 offset_load);
92 offload->tls_prologue->push_back<GlobalStoreStmt>(num_ptr, num_load);
93 }
94
95 // Step 2:
96 // Store TLS mesh_prologue ptr to the offloaded statement
97 {
98 auto offset_ptr =
99 offload->mesh_prologue->push_back<ThreadLocalPtrStmt>(
100 offset_tls_offset,
101 TypeFactory::get_instance().get_pointer_type(data_type));
102 auto _offset_val =
103 offload->mesh_prologue->push_back<GlobalLoadStmt>(offset_ptr);
104 auto offset_val = offload->mesh_prologue->push_back<UnaryOpStmt>(
105 UnaryOpType::cast_value, _offset_val);
106 offset_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
107 auto num_ptr = offload->mesh_prologue->push_back<ThreadLocalPtrStmt>(
108 num_tls_offset,
109 TypeFactory::get_instance().get_pointer_type(data_type));
110 auto _num_val =
111 offload->mesh_prologue->push_back<GlobalLoadStmt>(num_ptr);
112 auto num_val = offload->mesh_prologue->push_back<UnaryOpStmt>(
113 UnaryOpType::cast_value, _num_val);
114 num_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32;
115
116 offset_local.insert(std::pair(element_type, offset_val));
117 num_local.insert(std::pair(element_type, num_val));
118 }
119 };
120
121 for (auto element_type : accessed.first) {
122 make_thread_local_store(element_type, offload->mesh->owned_offset,
123 offload->owned_offset_local,
124 offload->owned_num_local);
125 }
126
127 for (auto element_type : accessed.second) {
128 make_thread_local_store(element_type, offload->mesh->total_offset,
129 offload->total_offset_local,
130 offload->total_num_local);
131 }
132 offload->tls_size = std::max(std::size_t(1), tls_offset);
133}
134
135// This pass should happen after offloading but before lower_access
136void make_mesh_thread_local(IRNode *root,
137 const CompileConfig &config,
138 const MakeBlockLocalPass::Args &args) {
139 TI_AUTO_PROF;
140
141 // =========================================================================================
142 // This pass generates code like this:
143 // uint32_t total_vertices_offset = _total_vertices_offset[blockIdx.x];
144 // uint32_t total_vertices = _total_vertices_offset[blockIdx.x + 1] -
145 // total_vertices_offset;
146
147 // uint32_t total_cells_offset = _total_cells_offset[blockIdx.x];
148 // uint32_t total_cells = _total_cells_offset[blockIdx.x + 1] -
149 // total_cells_offset;
150
151 // uint32_t owned_cells_offset = _owned_cells_offset[blockIdx.x];
152 // uint32_t owned_cells = _owned_cells_offset[blockIdx.x + 1] -
153 // owned_cells_offset;
154 // =========================================================================================
155
156 if (auto root_block = root->cast<Block>()) {
157 for (auto &offload : root_block->statements) {
158 make_mesh_thread_local_offload(offload->cast<OffloadedStmt>(), config,
159 args.kernel_name);
160 }
161 } else {
162 make_mesh_thread_local_offload(root->as<OffloadedStmt>(), config,
163 args.kernel_name);
164 }
165 type_check(root, config);
166}
167
168} // namespace irpass
169} // namespace taichi::lang
170