1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/ir/scratch_pad.h"
6#include "taichi/transforms/make_block_local.h"
7
8namespace taichi::lang {
9
10namespace {
11
12void make_block_local_offload(OffloadedStmt *offload,
13 const CompileConfig &config,
14 const std::string &kernel_name) {
15 if (offload->task_type != OffloadedStmt::TaskType::struct_for)
16 return;
17
18 bool debug = config.debug;
19
20 auto pads = irpass::initialize_scratch_pad(offload);
21
22 std::size_t bls_offset_in_bytes = 0;
23
24 for (auto &pad : pads->pads) {
25 auto snode = pad.first;
26 auto data_type = snode->dt.ptr_removed();
27 auto dtype_size = data_type_size(data_type);
28
29 bool bls_has_read = pad.second.total_flags & AccessFlag::read;
30 bool bls_has_write = pad.second.total_flags & AccessFlag::write;
31 bool bls_has_accumulate = pad.second.total_flags & AccessFlag::accumulate;
32
33 TI_ASSERT_INFO(!bls_has_write, "BLS with write accesses is not supported.")
34 TI_ASSERT_INFO(!(bls_has_accumulate && bls_has_read),
35 "BLS with both read and accumulation is not supported.")
36
37 // dim = Dimensionality of the BLS buffer and the block
38 const auto dim = (int)pad.second.pad_size.size();
39 TI_ASSERT(dim == snode->num_active_indices);
40
41 const auto bls_num_elements = pad.second.pad_size_linear();
42
43 std::vector<int> block_strides(dim);
44 std::vector<int> bls_strides(dim);
45 block_strides[dim - 1] = 1;
46 bls_strides[dim - 1] = 1;
47 for (int i = dim - 2; i >= 0; i--) {
48 // TODO: fix the virtual/physical index correspondence here
49 // TODO: rename "pad"
50 // "pad" is the BLS buffer ("scratch pad")
51 block_strides[i] = block_strides[i + 1] * pad.second.block_size[i + 1];
52 bls_strides[i] = bls_strides[i + 1] * pad.second.pad_size[i + 1];
53 }
54
55 // TODO: improve IR builder to make this part easier to read
56
57 // Ensure BLS alignment
58 bls_offset_in_bytes +=
59 (dtype_size - bls_offset_in_bytes % dtype_size) % dtype_size;
60
61 // This lambda is used for both BLS prologue and epilogue creation
62 auto create_xlogue =
63 [&](std::unique_ptr<Block> &block,
64 const std::function<void(
65 Block * element_block, std::vector<Stmt *> global_indices,
66 Stmt * bls_element_offset_bytes)> &operation) {
67 if (block == nullptr) {
68 block = std::make_unique<Block>();
69 block->parent_stmt = offload;
70 }
71 // Equivalent to CUDA threadIdx
72 Stmt *thread_idx_stmt =
73 block->push_back<LoopLinearIndexStmt>(offload);
74
75 /*
76 Note that since there are fewer elements in the block than in BLS,
77 each thread may have to fetch more than one element to BLS.
78 Therefore on CUDA we need something like
79
80 auto bls_element_id = thread_idx_stmt;
81 while (bls_element_id < bls_num_elements) {
82 i, j, k = bls_to_global(bls_element_id)
83 bls[bls_element_id] = x[i, j, k]
84 // or x[i, j, k] = bls[bls_element_id]
85 bls_element_id += block_dim;
86 }
87
88 func bls_to_global(bls_element_id):
89 partial = bls_element_id
90 global_indices = [] // "i, j, k"
91 for i in reversed(range(0, dim)):
92 pad_size = pad.pad_size[i] // a.k.a. bounds[i].range()
93 bls_coord = partial % pad_size
94 partial = partial / pad_size
95 global_index_at_i = BlockCorner[i] + bls_coord
96 global_index_at_i += pad.bounds[i].low
97 global_indices[i] = global_index_at_i
98
99 Since we know block_dim and bls_size at compile time and there's
100 usually not too many iterations, we directly unroll this while loop
101 for performance when constructing prologues/epilogues.
102 */
103
104 // Unroll the while-loop
105 int loop_offset = 0;
106 const int block_dim = offload->block_dim;
107 while (loop_offset < bls_num_elements) {
108 Block *element_block = nullptr;
109 auto loop_offset_stmt =
110 block->push_back<ConstStmt>(TypedConstant(loop_offset));
111
112 auto bls_element_id_this_iteration = block->push_back<BinaryOpStmt>(
113 BinaryOpType::add, loop_offset_stmt, thread_idx_stmt);
114
115 auto bls_element_offset_bytes = block->push_back<BinaryOpStmt>(
116 BinaryOpType::mul, bls_element_id_this_iteration,
117 block->push_back<ConstStmt>(TypedConstant(dtype_size)));
118
119 bls_element_offset_bytes = block->push_back<BinaryOpStmt>(
120 BinaryOpType::add, bls_element_offset_bytes,
121 block->push_back<ConstStmt>(
122 TypedConstant((int32)bls_offset_in_bytes)));
123
124 if (loop_offset + block_dim > bls_num_elements) {
125 // Need to create an IfStmt to safeguard since bls size may not be
126 // a multiple of block_size, and this iteration some threads may
127 // go over bls_num_elements ("block-stride" loop)
128 auto cond = block->push_back<BinaryOpStmt>(
129 BinaryOpType::cmp_lt, bls_element_id_this_iteration,
130 block->push_back<ConstStmt>(TypedConstant(bls_num_elements)));
131 auto if_stmt =
132 dynamic_cast<IfStmt *>(block->push_back<IfStmt>(cond));
133 if_stmt->set_true_statements(std::make_unique<Block>());
134 element_block = if_stmt->true_statements.get();
135 } else {
136 // No need to create an if since every thread is within
137 // bls_num_elements.
138 element_block = block.get();
139 }
140
141 std::vector<Stmt *> global_indices(dim);
142
143 // Convert bls_element_id to global indices
144 // via a series of % and /.
145 auto bls_element_id_partial = bls_element_id_this_iteration;
146 for (int i = dim - 1; i >= 0; i--) {
147 auto pad_size_stmt = element_block->push_back<ConstStmt>(
148 TypedConstant(pad.second.pad_size[i]));
149
150 auto bls_coord = element_block->push_back<BinaryOpStmt>(
151 BinaryOpType::mod, bls_element_id_partial, pad_size_stmt);
152 bls_element_id_partial = element_block->push_back<BinaryOpStmt>(
153 BinaryOpType::div, bls_element_id_partial, pad_size_stmt);
154
155 auto global_index_this_dim =
156 element_block->push_back<BinaryOpStmt>(
157 BinaryOpType::add, bls_coord,
158 element_block->push_back<ConstStmt>(
159 TypedConstant(pad.second.bounds[i].low)));
160
161 auto block_corner =
162 element_block->push_back<BlockCornerIndexStmt>(offload, i);
163 if (pad.second.coefficients[i] > 1) {
164 block_corner = element_block->push_back<BinaryOpStmt>(
165 BinaryOpType::mul, block_corner,
166 element_block->push_back<ConstStmt>(
167 TypedConstant(pad.second.coefficients[i])));
168 }
169
170 global_index_this_dim = element_block->push_back<BinaryOpStmt>(
171 BinaryOpType::add, global_index_this_dim, block_corner);
172
173 global_indices[i] = global_index_this_dim;
174 }
175
176 operation(element_block, global_indices, bls_element_offset_bytes);
177 // TODO: do not use GlobalStore for BLS ptr.
178
179 loop_offset += block_dim;
180 }
181 };
182
183 // Step 1:
184 // Fetch to BLS
185 {
186 create_xlogue(
187 offload->bls_prologue,
188 [&](Block *element_block, std::vector<Stmt *> global_indices,
189 Stmt *bls_element_offset_bytes) {
190 Stmt *value;
191 if (bls_has_read) {
192 // Read access
193 // Fetch from global to BLS
194
195 auto global_pointer = element_block->push_back<GlobalPtrStmt>(
196 snode, global_indices);
197 value = element_block->push_back<GlobalLoadStmt>(global_pointer);
198 } else {
199 // Accumulation access
200 // Zero-fill
201 value = element_block->push_back<ConstStmt>(
202 TypedConstant(data_type, 0));
203 }
204 auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
205 bls_element_offset_bytes,
206 TypeFactory::get_instance().get_pointer_type(data_type));
207 element_block->push_back<GlobalStoreStmt>(bls_ptr, value);
208 });
209 }
210
211 // Step 2:
212 // Make loop body load from BLS instead of global fields
213 {
214 std::vector<GlobalPtrStmt *> global_ptrs;
215
216 // TODO: no more abuse of gather_statements...
217 irpass::analysis::gather_statements(offload->body.get(), [&](Stmt *stmt) {
218 if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
219 if (global_ptr->snode == snode) {
220 global_ptrs.push_back(global_ptr);
221 }
222 }
223 return false;
224 });
225
226 for (auto global_ptr : global_ptrs) {
227 VecStatement bls;
228 Stmt *bls_element_offset = nullptr;
229 auto global_indices = global_ptr->indices;
230 for (int i = 0; i < dim; i++) {
231 // BLS index = sum_i inc_i
232 // where inc_i =
233 // bls_stride_i * (gbl_idx_i - block_corner_i - bls_lower_bound_i)
234 // Note that when index offsets are used, the offset contributions are
235 // already included in bls_lower_bound_i.
236
237 Stmt *block_corner = bls.push_back<BlockCornerIndexStmt>(offload, i);
238 if (pad.second.coefficients[i] > 1) {
239 block_corner = bls.push_back<BinaryOpStmt>(
240 BinaryOpType::mul, block_corner,
241 bls.push_back<ConstStmt>(
242 TypedConstant(pad.second.coefficients[i])));
243 }
244
245 auto inc = bls.push_back<BinaryOpStmt>(
246 BinaryOpType::sub, global_indices[i], block_corner);
247 inc = bls.push_back<BinaryOpStmt>(
248 BinaryOpType::sub, inc,
249 bls.push_back<ConstStmt>(
250 TypedConstant(pad.second.bounds[i].low)));
251
252 if (debug) {
253 // This part insert an assertion to make sure BLS access is within
254 // the bound.
255 auto bls_axis_size =
256 pad.second.bounds[i].high - pad.second.bounds[i].low;
257 std::string msg = fmt::format(
258 "(kernel={}, body) Access out of bound: BLS buffer axis {} "
259 "(size {}) with "
260 "index %d.",
261 kernel_name, i, bls_axis_size);
262
263 auto lower_bound = bls.push_back<ConstStmt>(TypedConstant(0));
264 auto check_lower_bound = bls.push_back<BinaryOpStmt>(
265 BinaryOpType::cmp_ge, inc, lower_bound);
266
267 auto upper_bound =
268 bls.push_back<ConstStmt>(TypedConstant(bls_axis_size));
269 auto check_upper_bound = bls.push_back<BinaryOpStmt>(
270 BinaryOpType::cmp_lt, inc, upper_bound);
271
272 auto check_i = bls.push_back<BinaryOpStmt>(
273 BinaryOpType::bit_and, check_lower_bound, check_upper_bound);
274
275 bls.push_back<AssertStmt>(check_i, msg, std::vector<Stmt *>{inc});
276 }
277
278 inc = bls.push_back<BinaryOpStmt>(
279 BinaryOpType::mul, inc,
280 bls.push_back<ConstStmt>(TypedConstant(bls_strides[i])));
281
282 if (!bls_element_offset) {
283 bls_element_offset = inc;
284 } else {
285 bls_element_offset = bls.push_back<BinaryOpStmt>(
286 BinaryOpType::add, bls_element_offset, inc);
287 }
288 }
289
290 // convert to bytes
291 bls_element_offset = bls.push_back<BinaryOpStmt>(
292 BinaryOpType::mul, bls_element_offset,
293 bls.push_back<ConstStmt>(TypedConstant(dtype_size)));
294
295 // add array offset
296 bls_element_offset = bls.push_back<BinaryOpStmt>(
297 BinaryOpType::add, bls_element_offset,
298 bls.push_back<ConstStmt>(
299 TypedConstant((int32)bls_offset_in_bytes)));
300
301 bls.push_back<BlockLocalPtrStmt>(
302 bls_element_offset,
303 TypeFactory::get_instance().get_pointer_type(data_type));
304 global_ptr->replace_with(std::move(bls));
305 }
306 }
307
308 // Step 3:
309 // Atomic-add BLS contribution to its global version if necessary
310 if (bls_has_accumulate) {
311 create_xlogue(
312 offload->bls_epilogue,
313 [&](Block *element_block, std::vector<Stmt *> global_indices,
314 Stmt *bls_element_offset_bytes) {
315 // Store/accumulate from BLS to global
316 auto bls_ptr = element_block->push_back<BlockLocalPtrStmt>(
317 bls_element_offset_bytes,
318 TypeFactory::get_instance().get_pointer_type(data_type));
319 auto bls_val = element_block->push_back<GlobalLoadStmt>(bls_ptr);
320
321 auto global_pointer =
322 element_block->push_back<GlobalPtrStmt>(snode, global_indices);
323 element_block->push_back<AtomicOpStmt>(AtomicOpType::add,
324 global_pointer, bls_val);
325 });
326 }
327
328 // allocate storage for the BLS variable
329 bls_offset_in_bytes += dtype_size * bls_num_elements;
330 } // for (auto &pad : pads->pads)
331
332 offload->bls_size = std::max(std::size_t(1), bls_offset_in_bytes);
333}
334
335} // namespace
336
337const PassID MakeBlockLocalPass::id = "MakeBlockLocalPass";
338
339namespace irpass {
340
341// This pass should happen after offloading but before lower_access
342void make_block_local(IRNode *root,
343 const CompileConfig &config,
344 const MakeBlockLocalPass::Args &args) {
345 TI_AUTO_PROF;
346
347 if (auto root_block = root->cast<Block>()) {
348 for (auto &offload : root_block->statements) {
349 make_block_local_offload(offload->cast<OffloadedStmt>(), config,
350 args.kernel_name);
351 }
352 } else {
353 make_block_local_offload(root->as<OffloadedStmt>(), config,
354 args.kernel_name);
355 }
356 type_check(root, config);
357}
358
359} // namespace irpass
360
361} // namespace taichi::lang
362