1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/ir/scratch_pad.h" |
6 | #include "taichi/transforms/make_block_local.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | namespace { |
11 | |
12 | void make_block_local_offload(OffloadedStmt *offload, |
13 | const CompileConfig &config, |
14 | const std::string &kernel_name) { |
15 | if (offload->task_type != OffloadedStmt::TaskType::struct_for) |
16 | return; |
17 | |
18 | bool debug = config.debug; |
19 | |
20 | auto pads = irpass::initialize_scratch_pad(offload); |
21 | |
22 | std::size_t bls_offset_in_bytes = 0; |
23 | |
24 | for (auto &pad : pads->pads) { |
25 | auto snode = pad.first; |
26 | auto data_type = snode->dt.ptr_removed(); |
27 | auto dtype_size = data_type_size(data_type); |
28 | |
29 | bool bls_has_read = pad.second.total_flags & AccessFlag::read; |
30 | bool bls_has_write = pad.second.total_flags & AccessFlag::write; |
31 | bool bls_has_accumulate = pad.second.total_flags & AccessFlag::accumulate; |
32 | |
33 | TI_ASSERT_INFO(!bls_has_write, "BLS with write accesses is not supported." ) |
34 | TI_ASSERT_INFO(!(bls_has_accumulate && bls_has_read), |
35 | "BLS with both read and accumulation is not supported." ) |
36 | |
37 | // dim = Dimensionality of the BLS buffer and the block |
38 | const auto dim = (int)pad.second.pad_size.size(); |
39 | TI_ASSERT(dim == snode->num_active_indices); |
40 | |
41 | const auto bls_num_elements = pad.second.pad_size_linear(); |
42 | |
43 | std::vector<int> block_strides(dim); |
44 | std::vector<int> bls_strides(dim); |
45 | block_strides[dim - 1] = 1; |
46 | bls_strides[dim - 1] = 1; |
47 | for (int i = dim - 2; i >= 0; i--) { |
48 | // TODO: fix the virtual/physical index correspondence here |
49 | // TODO: rename "pad" |
50 | // "pad" is the BLS buffer ("scratch pad") |
51 | block_strides[i] = block_strides[i + 1] * pad.second.block_size[i + 1]; |
52 | bls_strides[i] = bls_strides[i + 1] * pad.second.pad_size[i + 1]; |
53 | } |
54 | |
55 | // TODO: improve IR builder to make this part easier to read |
56 | |
57 | // Ensure BLS alignment |
58 | bls_offset_in_bytes += |
59 | (dtype_size - bls_offset_in_bytes % dtype_size) % dtype_size; |
60 | |
61 | // This lambda is used for both BLS prologue and epilogue creation |
62 | auto create_xlogue = |
63 | [&](std::unique_ptr<Block> &block, |
64 | const std::function<void( |
65 | Block * element_block, std::vector<Stmt *> global_indices, |
66 | Stmt * bls_element_offset_bytes)> &operation) { |
67 | if (block == nullptr) { |
68 | block = std::make_unique<Block>(); |
69 | block->parent_stmt = offload; |
70 | } |
71 | // Equivalent to CUDA threadIdx |
72 | Stmt *thread_idx_stmt = |
73 | block->push_back<LoopLinearIndexStmt>(offload); |
74 | |
75 | /* |
76 | Note that since there are fewer elements in the block than in BLS, |
77 | each thread may have to fetch more than one element to BLS. |
78 | Therefore on CUDA we need something like |
79 | |
80 | auto bls_element_id = thread_idx_stmt; |
81 | while (bls_element_id < bls_num_elements) { |
82 | i, j, k = bls_to_global(bls_element_id) |
83 | bls[bls_element_id] = x[i, j, k] |
84 | // or x[i, j, k] = bls[bls_element_id] |
85 | bls_element_id += block_dim; |
86 | } |
87 | |
88 | func bls_to_global(bls_element_id): |
89 | partial = bls_element_id |
90 | global_indices = [] // "i, j, k" |
91 | for i in reversed(range(0, dim)): |
92 | pad_size = pad.pad_size[i] // a.k.a. bounds[i].range() |
93 | bls_coord = partial % pad_size |
94 | partial = partial / pad_size |
95 | global_index_at_i = BlockCorner[i] + bls_coord |
96 | global_index_at_i += pad.bounds[i].low |
97 | global_indices[i] = global_index_at_i |
98 | |
99 | Since we know block_dim and bls_size at compile time and there's |
100 | usually not too many iterations, we directly unroll this while loop |
101 | for performance when constructing prologues/epilogues. |
102 | */ |
103 | |
104 | // Unroll the while-loop |
105 | int loop_offset = 0; |
106 | const int block_dim = offload->block_dim; |
107 | while (loop_offset < bls_num_elements) { |
108 | Block *element_block = nullptr; |
109 | auto loop_offset_stmt = |
110 | block->push_back<ConstStmt>(TypedConstant(loop_offset)); |
111 | |
112 | auto bls_element_id_this_iteration = block->push_back<BinaryOpStmt>( |
113 | BinaryOpType::add, loop_offset_stmt, thread_idx_stmt); |
114 | |
115 | auto bls_element_offset_bytes = block->push_back<BinaryOpStmt>( |
116 | BinaryOpType::mul, bls_element_id_this_iteration, |
117 | block->push_back<ConstStmt>(TypedConstant(dtype_size))); |
118 | |
119 | bls_element_offset_bytes = block->push_back<BinaryOpStmt>( |
120 | BinaryOpType::add, bls_element_offset_bytes, |
121 | block->push_back<ConstStmt>( |
122 | TypedConstant((int32)bls_offset_in_bytes))); |
123 | |
124 | if (loop_offset + block_dim > bls_num_elements) { |
125 | // Need to create an IfStmt to safeguard since bls size may not be |
126 | // a multiple of block_size, and this iteration some threads may |
127 | // go over bls_num_elements ("block-stride" loop) |
128 | auto cond = block->push_back<BinaryOpStmt>( |
129 | BinaryOpType::cmp_lt, bls_element_id_this_iteration, |
130 | block->push_back<ConstStmt>(TypedConstant(bls_num_elements))); |
131 | auto if_stmt = |
132 | dynamic_cast<IfStmt *>(block->push_back<IfStmt>(cond)); |
133 | if_stmt->set_true_statements(std::make_unique<Block>()); |
134 | element_block = if_stmt->true_statements.get(); |
135 | } else { |
136 | // No need to create an if since every thread is within |
137 | // bls_num_elements. |
138 | element_block = block.get(); |
139 | } |
140 | |
141 | std::vector<Stmt *> global_indices(dim); |
142 | |
143 | // Convert bls_element_id to global indices |
144 | // via a series of % and /. |
145 | auto bls_element_id_partial = bls_element_id_this_iteration; |
146 | for (int i = dim - 1; i >= 0; i--) { |
147 | auto pad_size_stmt = element_block->push_back<ConstStmt>( |
148 | TypedConstant(pad.second.pad_size[i])); |
149 | |
150 | auto bls_coord = element_block->push_back<BinaryOpStmt>( |
151 | BinaryOpType::mod, bls_element_id_partial, pad_size_stmt); |
152 | bls_element_id_partial = element_block->push_back<BinaryOpStmt>( |
153 | BinaryOpType::div, bls_element_id_partial, pad_size_stmt); |
154 | |
155 | auto global_index_this_dim = |
156 | element_block->push_back<BinaryOpStmt>( |
157 | BinaryOpType::add, bls_coord, |
158 | element_block->push_back<ConstStmt>( |
159 | TypedConstant(pad.second.bounds[i].low))); |
160 | |
161 | auto block_corner = |
162 | element_block->push_back<BlockCornerIndexStmt>(offload, i); |
163 | if (pad.second.coefficients[i] > 1) { |
164 | block_corner = element_block->push_back<BinaryOpStmt>( |
165 | BinaryOpType::mul, block_corner, |
166 | element_block->push_back<ConstStmt>( |
167 | TypedConstant(pad.second.coefficients[i]))); |
168 | } |
169 | |
170 | global_index_this_dim = element_block->push_back<BinaryOpStmt>( |
171 | BinaryOpType::add, global_index_this_dim, block_corner); |
172 | |
173 | global_indices[i] = global_index_this_dim; |
174 | } |
175 | |
176 | operation(element_block, global_indices, bls_element_offset_bytes); |
177 | // TODO: do not use GlobalStore for BLS ptr. |
178 | |
179 | loop_offset += block_dim; |
180 | } |
181 | }; |
182 | |
183 | // Step 1: |
184 | // Fetch to BLS |
185 | { |
186 | create_xlogue( |
187 | offload->bls_prologue, |
188 | [&](Block *element_block, std::vector<Stmt *> global_indices, |
189 | Stmt *bls_element_offset_bytes) { |
190 | Stmt *value; |
191 | if (bls_has_read) { |
192 | // Read access |
193 | // Fetch from global to BLS |
194 | |
195 | auto global_pointer = element_block->push_back<GlobalPtrStmt>( |
196 | snode, global_indices); |
197 | value = element_block->push_back<GlobalLoadStmt>(global_pointer); |
198 | } else { |
199 | // Accumulation access |
200 | // Zero-fill |
201 | value = element_block->push_back<ConstStmt>( |
202 | TypedConstant(data_type, 0)); |
203 | } |
204 | auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>( |
205 | bls_element_offset_bytes, |
206 | TypeFactory::get_instance().get_pointer_type(data_type)); |
207 | element_block->push_back<GlobalStoreStmt>(bls_ptr, value); |
208 | }); |
209 | } |
210 | |
211 | // Step 2: |
212 | // Make loop body load from BLS instead of global fields |
213 | { |
214 | std::vector<GlobalPtrStmt *> global_ptrs; |
215 | |
216 | // TODO: no more abuse of gather_statements... |
217 | irpass::analysis::gather_statements(offload->body.get(), [&](Stmt *stmt) { |
218 | if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) { |
219 | if (global_ptr->snode == snode) { |
220 | global_ptrs.push_back(global_ptr); |
221 | } |
222 | } |
223 | return false; |
224 | }); |
225 | |
226 | for (auto global_ptr : global_ptrs) { |
227 | VecStatement bls; |
228 | Stmt *bls_element_offset = nullptr; |
229 | auto global_indices = global_ptr->indices; |
230 | for (int i = 0; i < dim; i++) { |
231 | // BLS index = sum_i inc_i |
232 | // where inc_i = |
233 | // bls_stride_i * (gbl_idx_i - block_corner_i - bls_lower_bound_i) |
234 | // Note that when index offsets are used, the offset contributions are |
235 | // already included in bls_lower_bound_i. |
236 | |
237 | Stmt *block_corner = bls.push_back<BlockCornerIndexStmt>(offload, i); |
238 | if (pad.second.coefficients[i] > 1) { |
239 | block_corner = bls.push_back<BinaryOpStmt>( |
240 | BinaryOpType::mul, block_corner, |
241 | bls.push_back<ConstStmt>( |
242 | TypedConstant(pad.second.coefficients[i]))); |
243 | } |
244 | |
245 | auto inc = bls.push_back<BinaryOpStmt>( |
246 | BinaryOpType::sub, global_indices[i], block_corner); |
247 | inc = bls.push_back<BinaryOpStmt>( |
248 | BinaryOpType::sub, inc, |
249 | bls.push_back<ConstStmt>( |
250 | TypedConstant(pad.second.bounds[i].low))); |
251 | |
252 | if (debug) { |
253 | // This part insert an assertion to make sure BLS access is within |
254 | // the bound. |
255 | auto bls_axis_size = |
256 | pad.second.bounds[i].high - pad.second.bounds[i].low; |
257 | std::string msg = fmt::format( |
258 | "(kernel={}, body) Access out of bound: BLS buffer axis {} " |
259 | "(size {}) with " |
260 | "index %d." , |
261 | kernel_name, i, bls_axis_size); |
262 | |
263 | auto lower_bound = bls.push_back<ConstStmt>(TypedConstant(0)); |
264 | auto check_lower_bound = bls.push_back<BinaryOpStmt>( |
265 | BinaryOpType::cmp_ge, inc, lower_bound); |
266 | |
267 | auto upper_bound = |
268 | bls.push_back<ConstStmt>(TypedConstant(bls_axis_size)); |
269 | auto check_upper_bound = bls.push_back<BinaryOpStmt>( |
270 | BinaryOpType::cmp_lt, inc, upper_bound); |
271 | |
272 | auto check_i = bls.push_back<BinaryOpStmt>( |
273 | BinaryOpType::bit_and, check_lower_bound, check_upper_bound); |
274 | |
275 | bls.push_back<AssertStmt>(check_i, msg, std::vector<Stmt *>{inc}); |
276 | } |
277 | |
278 | inc = bls.push_back<BinaryOpStmt>( |
279 | BinaryOpType::mul, inc, |
280 | bls.push_back<ConstStmt>(TypedConstant(bls_strides[i]))); |
281 | |
282 | if (!bls_element_offset) { |
283 | bls_element_offset = inc; |
284 | } else { |
285 | bls_element_offset = bls.push_back<BinaryOpStmt>( |
286 | BinaryOpType::add, bls_element_offset, inc); |
287 | } |
288 | } |
289 | |
290 | // convert to bytes |
291 | bls_element_offset = bls.push_back<BinaryOpStmt>( |
292 | BinaryOpType::mul, bls_element_offset, |
293 | bls.push_back<ConstStmt>(TypedConstant(dtype_size))); |
294 | |
295 | // add array offset |
296 | bls_element_offset = bls.push_back<BinaryOpStmt>( |
297 | BinaryOpType::add, bls_element_offset, |
298 | bls.push_back<ConstStmt>( |
299 | TypedConstant((int32)bls_offset_in_bytes))); |
300 | |
301 | bls.push_back<BlockLocalPtrStmt>( |
302 | bls_element_offset, |
303 | TypeFactory::get_instance().get_pointer_type(data_type)); |
304 | global_ptr->replace_with(std::move(bls)); |
305 | } |
306 | } |
307 | |
308 | // Step 3: |
309 | // Atomic-add BLS contribution to its global version if necessary |
310 | if (bls_has_accumulate) { |
311 | create_xlogue( |
312 | offload->bls_epilogue, |
313 | [&](Block *element_block, std::vector<Stmt *> global_indices, |
314 | Stmt *bls_element_offset_bytes) { |
315 | // Store/accumulate from BLS to global |
316 | auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>( |
317 | bls_element_offset_bytes, |
318 | TypeFactory::get_instance().get_pointer_type(data_type)); |
319 | auto bls_val = element_block->push_back<GlobalLoadStmt>(bls_ptr); |
320 | |
321 | auto global_pointer = |
322 | element_block->push_back<GlobalPtrStmt>(snode, global_indices); |
323 | element_block->push_back<AtomicOpStmt>(AtomicOpType::add, |
324 | global_pointer, bls_val); |
325 | }); |
326 | } |
327 | |
328 | // allocate storage for the BLS variable |
329 | bls_offset_in_bytes += dtype_size * bls_num_elements; |
330 | } // for (auto &pad : pads->pads) |
331 | |
332 | offload->bls_size = std::max(std::size_t(1), bls_offset_in_bytes); |
333 | } |
334 | |
335 | } // namespace |
336 | |
337 | const PassID MakeBlockLocalPass::id = "MakeBlockLocalPass" ; |
338 | |
339 | namespace irpass { |
340 | |
341 | // This pass should happen after offloading but before lower_access |
342 | void make_block_local(IRNode *root, |
343 | const CompileConfig &config, |
344 | const MakeBlockLocalPass::Args &args) { |
345 | TI_AUTO_PROF; |
346 | |
347 | if (auto root_block = root->cast<Block>()) { |
348 | for (auto &offload : root_block->statements) { |
349 | make_block_local_offload(offload->cast<OffloadedStmt>(), config, |
350 | args.kernel_name); |
351 | } |
352 | } else { |
353 | make_block_local_offload(root->as<OffloadedStmt>(), config, |
354 | args.kernel_name); |
355 | } |
356 | type_check(root, config); |
357 | } |
358 | |
359 | } // namespace irpass |
360 | |
361 | } // namespace taichi::lang |
362 | |