1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/transforms/make_mesh_thread_local.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | const PassID MakeMeshThreadLocal::id = "MakeMeshThreadLocal" ; |
10 | |
11 | namespace irpass { |
12 | |
13 | void make_mesh_thread_local_offload(OffloadedStmt *offload, |
14 | const CompileConfig &config, |
15 | const std::string &kernel_name) { |
16 | if (offload->task_type != OffloadedStmt::TaskType::mesh_for) { |
17 | return; |
18 | } |
19 | |
20 | std::pair</* owned= */ std::unordered_set<mesh::MeshElementType>, |
21 | /* total= */ std::unordered_set<mesh::MeshElementType>> |
22 | accessed = analysis::gather_mesh_thread_local(offload, config); |
23 | |
24 | std::size_t tls_offset = offload->tls_size; |
25 | |
26 | auto data_type = PrimitiveType::u32; // unt32_t type address |
27 | auto dtype_size = data_type_size(data_type); |
28 | |
29 | if (offload->tls_prologue == nullptr) { |
30 | offload->tls_prologue = std::make_unique<Block>(); |
31 | offload->tls_prologue->parent_stmt = offload; |
32 | } |
33 | |
34 | if (offload->mesh_prologue == nullptr) { |
35 | offload->mesh_prologue = std::make_unique<Block>(); |
36 | offload->mesh_prologue->parent_stmt = offload; |
37 | } |
38 | |
39 | auto patch_idx = |
40 | offload->tls_prologue->insert(std::make_unique<MeshPatchIndexStmt>(), -1); |
41 | auto one = offload->tls_prologue->insert( |
42 | std::make_unique<ConstStmt>(TypedConstant(PrimitiveType::i32, 1)), -1); |
43 | auto patch_idx_1 = offload->tls_prologue->insert( |
44 | std::make_unique<BinaryOpStmt>(BinaryOpType::add, patch_idx, one), -1); |
45 | |
46 | auto make_thread_local_store = |
47 | [&](mesh::MeshElementType element_type, |
48 | const std::unordered_map<mesh::MeshElementType, SNode *> &offset_, |
49 | std::unordered_map<mesh::MeshElementType, Stmt *> &offset_local, |
50 | std::unordered_map<mesh::MeshElementType, Stmt *> &num_local) { |
51 | const auto offset_tls_offset = |
52 | (tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size); |
53 | tls_offset += dtype_size; // allocate storage for the TLS variable |
54 | |
55 | const auto num_tls_offset = |
56 | (tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size); |
57 | tls_offset += dtype_size; |
58 | |
59 | // Step 1: |
60 | // Create thread local storage |
61 | { |
62 | auto offset_ptr = |
63 | offload->tls_prologue->push_back<ThreadLocalPtrStmt>( |
64 | offset_tls_offset, |
65 | TypeFactory::get_instance().get_pointer_type(data_type)); |
66 | auto num_ptr = offload->tls_prologue->push_back<ThreadLocalPtrStmt>( |
67 | num_tls_offset, |
68 | TypeFactory::get_instance().get_pointer_type(data_type)); |
69 | |
70 | const auto offset_snode = offset_.find(element_type); |
71 | TI_ASSERT(offset_snode != offset_.end()); |
72 | auto offset_globalptr = offload->tls_prologue->insert( |
73 | std::make_unique<GlobalPtrStmt>(offset_snode->second, |
74 | std::vector<Stmt *>{patch_idx}), |
75 | -1); |
76 | auto offset_load = offload->tls_prologue->insert( |
77 | std::make_unique<GlobalLoadStmt>(offset_globalptr), -1); |
78 | auto offset_1_globalptr = offload->tls_prologue->insert( |
79 | std::make_unique<GlobalPtrStmt>(offset_snode->second, |
80 | std::vector<Stmt *>{patch_idx_1}), |
81 | -1); |
82 | auto offset_1_load = offload->tls_prologue->insert( |
83 | std::make_unique<GlobalLoadStmt>(offset_1_globalptr), -1); |
84 | auto num_load = offload->tls_prologue->insert( |
85 | std::make_unique<BinaryOpStmt>(BinaryOpType::sub, offset_1_load, |
86 | offset_load), |
87 | -1); |
88 | |
89 | // TODO: do not use GlobalStore for TLS ptr. |
90 | offload->tls_prologue->push_back<GlobalStoreStmt>(offset_ptr, |
91 | offset_load); |
92 | offload->tls_prologue->push_back<GlobalStoreStmt>(num_ptr, num_load); |
93 | } |
94 | |
95 | // Step 2: |
96 | // Store TLS mesh_prologue ptr to the offloaded statement |
97 | { |
98 | auto offset_ptr = |
99 | offload->mesh_prologue->push_back<ThreadLocalPtrStmt>( |
100 | offset_tls_offset, |
101 | TypeFactory::get_instance().get_pointer_type(data_type)); |
102 | auto _offset_val = |
103 | offload->mesh_prologue->push_back<GlobalLoadStmt>(offset_ptr); |
104 | auto offset_val = offload->mesh_prologue->push_back<UnaryOpStmt>( |
105 | UnaryOpType::cast_value, _offset_val); |
106 | offset_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
107 | auto num_ptr = offload->mesh_prologue->push_back<ThreadLocalPtrStmt>( |
108 | num_tls_offset, |
109 | TypeFactory::get_instance().get_pointer_type(data_type)); |
110 | auto _num_val = |
111 | offload->mesh_prologue->push_back<GlobalLoadStmt>(num_ptr); |
112 | auto num_val = offload->mesh_prologue->push_back<UnaryOpStmt>( |
113 | UnaryOpType::cast_value, _num_val); |
114 | num_val->as<UnaryOpStmt>()->cast_type = PrimitiveType::i32; |
115 | |
116 | offset_local.insert(std::pair(element_type, offset_val)); |
117 | num_local.insert(std::pair(element_type, num_val)); |
118 | } |
119 | }; |
120 | |
121 | for (auto element_type : accessed.first) { |
122 | make_thread_local_store(element_type, offload->mesh->owned_offset, |
123 | offload->owned_offset_local, |
124 | offload->owned_num_local); |
125 | } |
126 | |
127 | for (auto element_type : accessed.second) { |
128 | make_thread_local_store(element_type, offload->mesh->total_offset, |
129 | offload->total_offset_local, |
130 | offload->total_num_local); |
131 | } |
132 | offload->tls_size = std::max(std::size_t(1), tls_offset); |
133 | } |
134 | |
135 | // This pass should happen after offloading but before lower_access |
136 | void make_mesh_thread_local(IRNode *root, |
137 | const CompileConfig &config, |
138 | const MakeBlockLocalPass::Args &args) { |
139 | TI_AUTO_PROF; |
140 | |
141 | // ========================================================================================= |
142 | // This pass generates code like this: |
143 | // uint32_t total_vertices_offset = _total_vertices_offset[blockIdx.x]; |
144 | // uint32_t total_vertices = _total_vertices_offset[blockIdx.x + 1] - |
145 | // total_vertices_offset; |
146 | |
147 | // uint32_t total_cells_offset = _total_cells_offset[blockIdx.x]; |
148 | // uint32_t total_cells = _total_cells_offset[blockIdx.x + 1] - |
149 | // total_cells_offset; |
150 | |
151 | // uint32_t owned_cells_offset = _owned_cells_offset[blockIdx.x]; |
152 | // uint32_t owned_cells = _owned_cells_offset[blockIdx.x + 1] - |
153 | // owned_cells_offset; |
154 | // ========================================================================================= |
155 | |
156 | if (auto root_block = root->cast<Block>()) { |
157 | for (auto &offload : root_block->statements) { |
158 | make_mesh_thread_local_offload(offload->cast<OffloadedStmt>(), config, |
159 | args.kernel_name); |
160 | } |
161 | } else { |
162 | make_mesh_thread_local_offload(root->as<OffloadedStmt>(), config, |
163 | args.kernel_name); |
164 | } |
165 | type_check(root, config); |
166 | } |
167 | |
168 | } // namespace irpass |
169 | } // namespace taichi::lang |
170 | |