1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Lower allreduce to device implementable ir. |
22 | * \file lower_thread_allreduce.cc |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/target/target.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | #include <unordered_set> |
33 | |
34 | #include "../../runtime/thread_storage_scope.h" |
35 | #include "ir_utils.h" |
36 | #include "update_pointer_storage_scope.h" |
37 | |
38 | namespace tvm { |
39 | namespace tir { |
40 | |
41 | class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope { |
42 | public: |
43 | explicit UpdatePointerStorageScopeAllReduce( |
44 | const std::unordered_map<const VarNode*, String>& new_storage_scopes) |
45 | : UpdatePointerStorageScope(new_storage_scopes) {} |
46 | |
47 | Stmt VisitStmt_(const AllocateNode* op) final { |
48 | auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var)); |
49 | auto new_scope = GetPtrStorageScope(remapped); |
50 | if (new_scope != GetPtrStorageScope(op->buffer_var)) { |
51 | Stmt body = StmtExprMutator::VisitStmt(op->body); |
52 | if (new_scope == "shared" ) { |
53 | // use volatile access to shared buffer. |
54 | body = AttrStmt(remapped, attr::volatile_scope, 1, body); |
55 | } |
56 | return Allocate(remapped, op->dtype, op->extents, op->condition, body); |
57 | } |
58 | return StmtExprMutator::VisitStmt_(op); |
59 | } |
60 | }; |
61 | |
62 | class ThreadAllreduceBuilder final : public StmtExprMutator { |
63 | public: |
64 | explicit ThreadAllreduceBuilder(const TargetNode* target) |
65 | : target_(target), |
66 | warp_size_(target->GetAttr<Integer>("thread_warp_size" , 1).value().IntValue()) {} |
67 | |
68 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
69 | if (op->attr_key == attr::thread_extent) { |
70 | thread_extents_.push_back(op); |
71 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
72 | thread_extents_.pop_back(); |
73 | return ret; |
74 | } else if (op->attr_key == attr::reduce_scope) { |
75 | const CommReducerNode* combiner = op->node.as<CommReducerNode>(); |
76 | ICHECK(combiner); |
77 | reduce_combiner_.push_back(combiner); |
78 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
79 | reduce_combiner_.pop_back(); |
80 | return ret; |
81 | } else { |
82 | return StmtExprMutator::VisitStmt_(op); |
83 | } |
84 | } |
85 | Stmt VisitStmt_(const EvaluateNode* op) final { |
86 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
87 | op = stmt.as<EvaluateNode>(); |
88 | const CallNode* call = op->value.as<CallNode>(); |
89 | if (call && call->op.same_as(builtin::tvm_thread_allreduce())) { |
90 | return MakeAllreduce(call); |
91 | } else { |
92 | return stmt; |
93 | } |
94 | } |
95 | Stmt VisitStmt_(const AllocateNode* op) final { |
96 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
97 | op = stmt.as<AllocateNode>(); |
98 | auto it = alloc_remap_.find(op->buffer_var.get()); |
99 | if (it != alloc_remap_.end()) { |
100 | const AllocateNode* repl = it->second.as<AllocateNode>(); |
101 | if (warp_allocs_.count(repl)) { |
102 | new_storage_scopes_[repl->buffer_var.get()] = "local" ; |
103 | } else { |
104 | new_storage_scopes_[repl->buffer_var.get()] = "shared" ; |
105 | } |
106 | return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); |
107 | } else { |
108 | return stmt; |
109 | } |
110 | } |
111 | |
112 | PrimExpr VisitExpr_(const LoadNode* op) final { |
113 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
114 | } |
115 | |
116 | Stmt VisitStmt_(const StoreNode* op) final { |
117 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
118 | } |
119 | |
120 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
121 | { |
122 | auto it = load_remap_.find(op->buffer->data.get()); |
123 | if (it != load_remap_.end()) { |
124 | for (const auto& index : op->indices) { |
125 | ICHECK(is_zero(index)); |
126 | } |
127 | return it->second; |
128 | } |
129 | } |
130 | |
131 | BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
132 | op = load.get(); |
133 | |
134 | { |
135 | auto it = buf_remap_.find(op->buffer.get()); |
136 | if (it != buf_remap_.end()) { |
137 | return BufferLoad(it->second, op->indices, op->span); |
138 | } |
139 | } |
140 | |
141 | { |
142 | auto it = var_remap_.find(op->buffer->data.get()); |
143 | if (it != var_remap_.end()) { |
144 | Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape, |
145 | op->buffer->strides, op->buffer->elem_offset, op->buffer->name, |
146 | op->buffer->data_alignment, op->buffer->offset_factor, |
147 | op->buffer->buffer_type, op->buffer->axis_separators, |
148 | op->buffer->span); |
149 | buf_remap_[op->buffer.get()] = remapped_buffer; |
150 | return BufferLoad(remapped_buffer, op->indices, op->span); |
151 | } |
152 | } |
153 | return StmtExprMutator::VisitExpr_(op); |
154 | } |
155 | |
156 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
157 | BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
158 | |
159 | auto it = store_remap_.find(store->buffer.get()); |
160 | if (it != store_remap_.end()) { |
161 | for (const auto& index : op->indices) { |
162 | ICHECK(is_zero(index)); |
163 | } |
164 | |
165 | auto writer = store.CopyOnWrite(); |
166 | writer->buffer = it->second; |
167 | return std::move(store); |
168 | } |
169 | |
170 | { |
171 | auto it = buf_remap_.find(store->buffer.get()); |
172 | if (it != buf_remap_.end()) { |
173 | return BufferStore(it->second, store->value, store->indices, store->span); |
174 | } |
175 | } |
176 | |
177 | { |
178 | auto it = var_remap_.find(store->buffer->data.get()); |
179 | if (it != var_remap_.end()) { |
180 | Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape, |
181 | store->buffer->strides, store->buffer->elem_offset, |
182 | store->buffer->name, store->buffer->data_alignment, |
183 | store->buffer->offset_factor, store->buffer->buffer_type, |
184 | store->buffer->axis_separators, store->buffer->span); |
185 | buf_remap_[store->buffer.get()] = remapped_buffer; |
186 | return BufferStore(remapped_buffer, store->value, store->indices, store->span); |
187 | } |
188 | } |
189 | |
190 | return std::move(store); |
191 | } |
192 | |
193 | std::unordered_map<const VarNode*, String> new_storage_scopes_; |
194 | |
195 | private: |
196 | // Thread entry |
197 | struct ThreadEntry { |
198 | runtime::ThreadScope scope; |
199 | IterVar iv; |
200 | int extent; |
201 | // comparator |
202 | bool operator<(const ThreadEntry& other) const { |
203 | return scope.dim_index < other.scope.dim_index; |
204 | } |
205 | }; |
206 | |
207 | // make allreduce. |
208 | Stmt MakeAllreduce(const CallNode* call) { |
209 | ICHECK(!reduce_combiner_.empty()); |
210 | const CommReducerNode* combiner = reduce_combiner_.back(); |
211 | size_t size = combiner->result.size(); |
212 | |
213 | const IntImmNode* size_of_args = call->args[0].as<IntImmNode>(); |
214 | ICHECK(size_of_args) << call->args[0]->GetTypeKey(); |
215 | ICHECK_EQ(size, size_of_args->value); |
216 | Array<PrimExpr> inits = combiner->identity_element; |
217 | std::vector<PrimExpr> values(size); |
218 | std::vector<DataType> types(size); |
219 | PrimExpr cond = call->args[size + 1]; |
220 | for (size_t idx = 0; idx < size; ++idx) { |
221 | values[idx] = call->args[1 + idx]; |
222 | if (!is_one(cond)) { |
223 | values[idx] = Select(cond, values[idx], inits[idx]); |
224 | } |
225 | types[idx] = values[idx].dtype(); |
226 | } |
227 | std::vector<Buffer> buffers(size); |
228 | for (size_t idx = 0; idx < size; ++idx) { |
229 | PrimExpr arg = call->args[2 + size + idx]; |
230 | // Loads from boolean buffers may have cast nodes inserted by |
231 | // earlier passes. |
232 | if (auto cast = arg.as<CastNode>()) { |
233 | arg = cast->value; |
234 | } |
235 | buffers[idx] = Downcast<BufferLoad>(arg)->buffer; |
236 | } |
237 | |
238 | std::unordered_set<const VarNode*> reduce_set; |
239 | for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { |
240 | const VarNode* v = call->args[i].as<VarNode>(); |
241 | // The simply optimization replace a iteration variable with a constant |
242 | // when extent of the iteration is 1. As threaded IterVar always started from 0, |
243 | // we can just ignore this variable in this case. |
244 | if (v) { |
245 | reduce_set.insert(v); |
246 | } else { |
247 | ICHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0) |
248 | << "arg" << i << "should be a VarNode or IntImmNode" ; |
249 | } |
250 | } |
251 | |
252 | size_t nmatch = 0; |
253 | std::vector<ThreadEntry> vred, vpar; |
254 | for (const AttrStmtNode* attr : thread_extents_) { |
255 | ThreadEntry e; |
256 | IterVar iv = Downcast<IterVar>(attr->node); |
257 | e.scope = runtime::ThreadScope::Create(iv->thread_tag); |
258 | e.iv = iv; |
259 | ICHECK_LE(e.scope.rank, 1); |
260 | ICHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction" ; |
261 | if (e.scope.rank == 1) { |
262 | const auto* ptr = attr->value.as<IntImmNode>(); |
263 | ICHECK(ptr) << "Need constant extent for reduce set " << iv; |
264 | e.extent = static_cast<int>(ptr->value); |
265 | // ignore variables equal to 0 |
266 | if (e.extent == 1) { |
267 | continue; |
268 | } |
269 | |
270 | if (reduce_set.count(iv->var.get())) { |
271 | vred.push_back(e); |
272 | ++nmatch; |
273 | } else { |
274 | vpar.push_back(e); |
275 | } |
276 | } |
277 | } |
278 | ICHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context" ; |
279 | std::sort(vred.begin(), vred.end()); |
280 | std::sort(vpar.begin(), vpar.end()); |
281 | // the size of each index. |
282 | int reduce_extent, group_extent; |
283 | PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); |
284 | PrimExpr group_index = FlattenThread(vpar, &group_extent); |
285 | |
286 | // the longest contiguous reduce extent after flattening |
287 | int contiguous_reduce_extent = 1; |
288 | std::vector<std::tuple<int, int, bool>> block_threads; // tuple(dim_index, extent, is_reduce) |
289 | for (const ThreadEntry& thr : vred) { |
290 | if (thr.scope.rank == 1) { // threadIdx |
291 | block_threads.emplace_back(thr.scope.dim_index, thr.extent, true); |
292 | } |
293 | } |
294 | for (const ThreadEntry& thr : vpar) { |
295 | if (thr.scope.rank == 1) { // threadIdx |
296 | block_threads.emplace_back(thr.scope.dim_index, thr.extent, false); |
297 | } |
298 | } |
299 | // sort according to dim_index |
300 | std::sort(block_threads.begin(), block_threads.end()); |
301 | for (auto&& thr_attr : block_threads) { |
302 | auto [dim_index, extent, is_reduce] = thr_attr; |
303 | (void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
304 | if (is_reduce) { |
305 | contiguous_reduce_extent *= extent; |
306 | } else { |
307 | break; |
308 | } |
309 | } |
310 | |
311 | std::vector<Stmt> seq; |
312 | std::vector<Var> shared_buffer_vars(size); |
313 | std::vector<Buffer> shared_bufs(size); |
314 | std::vector<Buffer> local_bufs; |
315 | // |
316 | // This is an optimization. For small reduction sizes, it may be beneficial |
317 | // for a single warp to performance the entire reduction. No trips to shared |
318 | // memory and no cross warp synchronizations are required. |
319 | // The following code emits the reduction as follows: |
320 | // |
321 | // Allocate reduction vars v[i], i = 0..size-1 |
322 | // |
323 | // for offset from WARP_SIZE to 1 by 2 |
324 | // |
325 | // a <- load(v[i]) |
326 | // b <- shuffle_down(load(v[i], offset)) |
327 | // v[i] <- reduction(a, b) |
328 | // |
329 | // broadcast results from lane 0 to all other lanes and store |
330 | // the final reduction result to the proper location. |
331 | // |
332 | if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) { |
333 | ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction" ; |
334 | // |
335 | // This is the index to the reduction variable, one reduction |
336 | // variable per warp. Local scope seems easier to reason without |
337 | // relying on a pattern match pass to fix it later. |
338 | Array<PrimExpr> zero_indices = {0}; |
339 | |
340 | for (size_t idx = 0; idx < size; ++idx) { |
341 | Array<PrimExpr> shape = {1}; |
342 | |
343 | Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx)); |
344 | Var buffer_var = buffer->data; |
345 | |
346 | shared_buffer_vars[idx] = buffer_var; |
347 | shared_bufs[idx] = buffer; |
348 | |
349 | PrimExpr pred = const_true(types[idx].lanes()); |
350 | seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices)); |
351 | |
352 | // Uses a local variable to store the shuffled data. Later |
353 | // on, an allocation will be built for this local variable. |
354 | local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx))); |
355 | } |
356 | |
357 | // The mask for this reducer, as this reducer may sit inside |
358 | // a divergent control flow. Here it uses a variable to cache the current |
359 | // active channels. |
360 | // |
361 | DataType mask_dtype = DataType::UInt(32); |
362 | Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask" ); |
363 | { |
364 | PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); |
365 | if (group_extent > 1) { |
366 | mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * group_index)); |
367 | } |
368 | seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices)); |
369 | // Push the buffer description. Later this will have an |
370 | // allocation built for it. |
371 | local_bufs.push_back(mask_buffer); |
372 | } |
373 | |
374 | // Emit reductions within a warp. |
375 | int start_offset = 1; |
376 | while (start_offset * 2 < reduce_extent) { |
377 | start_offset *= 2; |
378 | } |
379 | for (int offset = start_offset; offset > 0; offset /= 2) { |
380 | // Load reduction values, no synchronization needed. |
381 | Array<PrimExpr> a, b; |
382 | for (size_t i = 0; i < size; ++i) { |
383 | Buffer shared_buf = shared_bufs[i]; |
384 | BufferLoad val(shared_buf, zero_indices); |
385 | ICHECK_EQ(val->dtype, types[i]); |
386 | a.push_back(val); |
387 | |
388 | // __shfl_*sync calls shall not appear in if_then_else expressions |
389 | // as this is causing extra divergency. E.g. |
390 | // |
391 | // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0); |
392 | // |
393 | // behaves differently from |
394 | // |
395 | // int t = __shfl_sync(mask, v1, 0); |
396 | // v1 = (v2 < v3) ? v3 : t; |
397 | // |
398 | // The former may cause dead lock as there is a divergent |
399 | // branch with a warp sync call inside. |
400 | // |
401 | PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset); |
402 | Buffer local_buf = local_bufs[i]; |
403 | Stmt s = BufferStore(local_buf, other, zero_indices); |
404 | seq.push_back(s); |
405 | |
406 | BufferLoad load = BufferLoad(local_buf, zero_indices); |
407 | ICHECK_EQ(load->dtype, types[i]); |
408 | b.push_back(load); |
409 | } |
410 | |
411 | // Do reductions. |
412 | Array<PrimExpr> ret = (*combiner)(a, b); |
413 | |
414 | // Store the reduction result to itself. |
415 | std::vector<Stmt> stores(size); |
416 | for (size_t i = 0; i < size; ++i) { |
417 | Buffer buf = shared_bufs[i]; |
418 | stores[i] = BufferStore(buf, ret[i], zero_indices); |
419 | } |
420 | |
421 | // During the sub-warp reduction, values from inactive threads could be read, |
422 | // which is an undefined behavior according to the cuda document. |
423 | // |
424 | // In practise, the return value are usually 0, which does no harm to sum reduction. |
425 | // However, the result can be incorrect in max or prod reduction. |
426 | // Therefore an additional range check has to be performed to ensure the correctness. |
427 | if (offset * 2 > reduce_extent) { |
428 | PrimExpr cond = reduce_index + offset < reduce_extent; |
429 | seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores))); |
430 | } else { |
431 | seq.push_back(SeqStmt::Flatten(stores)); |
432 | } |
433 | } |
434 | |
435 | // Broadcast the reduction result from lane 0 to all other lanes. |
436 | // This avoids to emit predicated stores, as all threads are |
437 | // uniformly writting the same result. |
438 | // |
439 | for (size_t i = 0; i < size; ++i) { |
440 | Buffer buf = shared_bufs[i]; |
441 | PrimExpr val = BufferLoad(buf, zero_indices); |
442 | ICHECK_EQ(val->dtype, types[i]); |
443 | PrimExpr splat = |
444 | WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index); |
445 | seq.push_back(BufferStore(buf, splat, zero_indices)); |
446 | } |
447 | |
448 | // Update existing allocations. |
449 | for (size_t i = 0; i < size; ++i) { |
450 | ICHECK(!load_remap_.count(buffers[i]->data.get())); |
451 | PrimExpr pred = const_true(types[i].lanes()); |
452 | Buffer buf = shared_bufs[i]; |
453 | PrimExpr val = BufferLoad(buf, zero_indices); |
454 | ICHECK_EQ(val->dtype, types[i]); |
455 | load_remap_[buffers[i]->data.get()] = val; |
456 | store_remap_[buffers[i].get()] = buf; |
457 | Array<PrimExpr> extents{PrimExpr(1)}; |
458 | auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0)); |
459 | alloc_remap_[buffers[i]->data.get()] = node; |
460 | var_remap_[buffers[i]->data.get()] = buf->data; |
461 | warp_allocs_.insert(node.get()); |
462 | } |
463 | } else { |
464 | if (reduce_extent == 1) { |
465 | // special case, no reduction is needed. |
466 | std::vector<Stmt> stores; |
467 | for (size_t i = 0; i < size; ++i) { |
468 | stores.push_back(BufferStore(buffers[i], values[i], {0})); |
469 | } |
470 | return SeqStmt::Flatten(stores); |
471 | } |
472 | // This sync is necessary because there might be incomplete read of |
473 | // previous iteration on the same buffer. |
474 | seq.emplace_back(SyncThread("shared" )); |
475 | for (size_t idx = 0; idx < size; ++idx) { |
476 | Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx)); |
477 | |
478 | shared_bufs[idx] = buffer; |
479 | shared_buffer_vars[idx] = buffer->data; |
480 | |
481 | PrimExpr pred = const_true(types[idx].lanes()); |
482 | seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], |
483 | {BufIndex(reduce_index, group_index, reduce_extent)})); |
484 | } |
485 | seq.emplace_back(SyncThread("shared" )); |
486 | seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, |
487 | reduce_extent, group_extent, contiguous_reduce_extent)); |
488 | for (size_t idx = 0; idx < size; ++idx) { |
489 | ICHECK(!load_remap_.count(buffers[idx]->data.get())); |
490 | PrimExpr pred = const_true(types[idx].lanes()); |
491 | BufferLoad load(shared_bufs[idx], |
492 | {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)}); |
493 | ICHECK_EQ(load->dtype, types[idx]); |
494 | load_remap_[buffers[idx]->data.get()] = load; |
495 | alloc_remap_[buffers[idx]->data.get()] = |
496 | Allocate(shared_bufs[idx]->data, types[idx], |
497 | {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); |
498 | var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data; |
499 | store_remap_[buffers[idx].get()] = shared_bufs[idx]; |
500 | } |
501 | } |
502 | |
503 | // Fix all local allocations as all statements are built. |
504 | Stmt body = SeqStmt::Flatten(seq); |
505 | for (Buffer buf : local_bufs) { |
506 | body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body); |
507 | new_storage_scopes_[buf->data.get()] = "local" ; |
508 | } |
509 | |
510 | return body; |
511 | } |
512 | |
513 | // make allreduce. |
514 | Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types, |
515 | const Array<Buffer>& shared_bufs, PrimExpr reduce_index, |
516 | PrimExpr group_index, int reduce_extent, int group_extent, |
517 | int contiguous_reduce_extent) { |
518 | // Get next power of two |
519 | int reduce_align = 1; |
520 | while (reduce_extent > reduce_align) { |
521 | reduce_align = reduce_align << 1; |
522 | } |
523 | ICHECK_GT(reduce_align, 1); |
524 | std::vector<Stmt> seq; |
525 | |
526 | size_t size = shared_bufs.size(); |
527 | PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent); |
528 | // make reduction |
529 | auto fload = [&](int offset) { |
530 | Array<PrimExpr> a, b; |
531 | for (size_t i = 0; i < size; ++i) { |
532 | BufferLoad b_load(shared_bufs[i], |
533 | {BufIndex(reduce_index + offset, group_index, reduce_extent)}); |
534 | ICHECK_EQ(b_load->dtype, types[i]); |
535 | b.push_back(b_load); |
536 | |
537 | BufferLoad a_load(shared_bufs[i], {buf_index}); |
538 | ICHECK_EQ(a_load->dtype, types[i]); |
539 | a.push_back(a_load); |
540 | } |
541 | Array<PrimExpr> ret = (*combiner)(a, b); |
542 | return ret; |
543 | }; |
544 | auto fstore = [&](const Array<PrimExpr>& ret) { |
545 | std::vector<Stmt> stores(size); |
546 | for (size_t i = 0; i < size; ++i) { |
547 | stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index}); |
548 | } |
549 | return SeqStmt::Flatten(stores); |
550 | }; |
551 | auto freduce = [&](int offset) { |
552 | auto ret = fload(offset); |
553 | return fstore(ret); |
554 | }; |
555 | // Step one, check for |
556 | if (reduce_align > reduce_extent) { |
557 | // reduction with the boundary condition |
558 | reduce_align = reduce_align >> 1; |
559 | PrimExpr cond = reduce_index < (reduce_extent - reduce_align); |
560 | seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); |
561 | seq.emplace_back(SyncThread("shared" )); |
562 | } |
563 | |
564 | // normal synchronization |
565 | bool warp_align = group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0; |
566 | while (reduce_align > contiguous_reduce_extent || reduce_align > warp_size_ || !warp_align) { |
567 | if (reduce_align == 1) { |
568 | break; |
569 | } |
570 | reduce_align = reduce_align >> 1; |
571 | PrimExpr cond = reduce_index < reduce_align; |
572 | seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); |
573 | seq.emplace_back(SyncThread("shared" )); |
574 | } |
575 | // in warp synchronization. |
576 | if (reduce_align > 1) { |
577 | PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1); |
578 | |
579 | std::vector<Stmt> in_warp_seq; |
580 | |
581 | while (reduce_align > 1) { |
582 | reduce_align = reduce_align >> 1; |
583 | |
584 | // freduce can read/write to the same memory location. For |
585 | // example, with reduce_align of 4, threadIdx 3 reads from |
586 | // memory location 7 as threadIdx 7 is writing to it. |
587 | // Therefore, we need to separate out the load from the store |
588 | // with a memory barrier in-between. This isn't necessary for |
589 | // the earlier normal synchronization, because those are each |
590 | // protected by an if-statement. The if-statement is avoided |
591 | // here to reduce thread divergence. |
592 | auto loads = fload(reduce_align); |
593 | |
594 | Array<Var> in_warp_local_vars; |
595 | for (auto expr : loads) { |
596 | Var var( |
597 | "w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()), |
598 | expr->dtype); |
599 | in_warp_local_vars.push_back(var); |
600 | } |
601 | |
602 | std::vector<Stmt> in_let_statement; |
603 | in_let_statement.emplace_back(SyncThread("warp" )); |
604 | in_let_statement.emplace_back( |
605 | fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()})); |
606 | in_let_statement.emplace_back(SyncThread("warp" )); |
607 | |
608 | Stmt body = SeqStmt::Flatten(in_let_statement); |
609 | for (size_t i = 0; i < size; i++) { |
610 | body = LetStmt(in_warp_local_vars[i], loads[i], body); |
611 | } |
612 | in_warp_seq.push_back(body); |
613 | } |
614 | |
615 | Stmt warp_body = SeqStmt::Flatten(in_warp_seq); |
616 | |
617 | seq.emplace_back(IfThenElse(in_warp_cond, warp_body)); |
618 | seq.emplace_back(SyncThread("shared" )); |
619 | } |
620 | return SeqStmt::Flatten(seq); |
621 | } |
622 | // Flatten the thread index. |
623 | // Also return a warp number, |
624 | PrimExpr FlattenThread(const std::vector<ThreadEntry>& tvec, int* out_total_extent) { |
625 | int& total_extent = *out_total_extent; |
626 | total_extent = 1; |
627 | if (tvec.size() == 0) { |
628 | return make_zero(DataType::Int(32)); |
629 | } |
630 | |
631 | PrimExpr ret; |
632 | for (const ThreadEntry& e : tvec) { |
633 | if (ret.defined()) { |
634 | ret = ret + e.iv->var * total_extent; |
635 | } else { |
636 | ICHECK_EQ(total_extent, 1); |
637 | ret = e.iv->var; |
638 | } |
639 | total_extent *= e.extent; |
640 | } |
641 | return ret; |
642 | } |
643 | // The local buffer index. |
644 | PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { |
645 | if (!is_zero(group_index)) { |
646 | return analyzer_.Simplify(group_index * reduce_extent + reduce_index); |
647 | } else { |
648 | return reduce_index; |
649 | } |
650 | } |
651 | // sync thread op. |
652 | static Stmt SyncThread(const std::string& sync) { |
653 | return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)})); |
654 | } |
655 | |
656 | // Emit warp shuffle calls. |
657 | PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) { |
658 | Array<PrimExpr> indices = {0}; |
659 | PrimExpr mask = BufferLoad(mask_buffer, indices); |
660 | PrimExpr width = IntImm(DataType::Int(32), warp_size_); |
661 | Array<PrimExpr> args{mask, val, delta_or_lane, width, width}; |
662 | return Call(val.dtype(), op, args); |
663 | } |
664 | |
665 | // Check if we can use warp level reduction. |
666 | // |
667 | // Note: The ROCm backend will only have warp reductions for now. |
668 | // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). |
669 | bool is_warp_reduction(const std::vector<DataType>& types, int group_extent, int reduce_extent, |
670 | int contiguous_reduce_extent) const { |
671 | // Only cuda target supports warp reductions. |
672 | if ((target_->kind->name != "cuda" ) && (target_->kind->name != "rocm" )) return false; |
673 | |
674 | // rocm only supports 32 bit operands for shuffling at the moment |
675 | if ((target_->kind->name == "rocm" ) && |
676 | (std::any_of(types.begin(), types.end(), [](DataType ty) { |
677 | if (ty.is_vector()) return true; |
678 | return ty.bits() != 32; |
679 | }))) { |
680 | return false; |
681 | } |
682 | |
683 | // Supported types: |
684 | // {u}int, {u}long, {u}long long, float, double, half/half2 |
685 | if (std::any_of(types.begin(), types.end(), [](DataType ty) { |
686 | if (ty.is_float16()) return ty.lanes() > 2; |
687 | if (ty.is_vector()) return true; |
688 | return ty.bytes() < 4 || ty.bytes() > 8; |
689 | })) { |
690 | return false; |
691 | } |
692 | if (thread_extents_.empty()) { |
693 | return false; |
694 | } |
695 | |
696 | // reduce region must be contiguous. |
697 | if (contiguous_reduce_extent != reduce_extent) { |
698 | return false; |
699 | } |
700 | |
701 | // whether reduce_extent and group_extent are vaild for warp reduction. |
702 | if (target_->kind->name == "rocm" ) { |
703 | return reduce_extent == warp_size_; |
704 | } else { // target_->kind->name == "cuda" |
705 | if (reduce_extent == 1) { |
706 | return false; // no need to warp reduce |
707 | } else { |
708 | if (warp_size_ % reduce_extent == 0) { |
709 | return true; // warp size is multiple of reduce extent |
710 | } else { |
711 | return group_extent == 1 && reduce_extent <= warp_size_; |
712 | } |
713 | } |
714 | } |
715 | } |
716 | |
717 | // The target. |
718 | const TargetNode* target_ = nullptr; |
719 | |
720 | // The warp size of the device. |
721 | int warp_size_{1}; |
722 | |
723 | // surrounding scope of thread extent. |
724 | std::vector<const AttrStmtNode*> thread_extents_; |
725 | std::vector<const CommReducerNode*> reduce_combiner_; |
726 | // The load remap |
727 | std::unordered_map<const VarNode*, PrimExpr> load_remap_; |
728 | // The store remap |
729 | std::unordered_map<const BufferNode*, Buffer> store_remap_; |
730 | // Allocate remap |
731 | std::unordered_map<const VarNode*, Stmt> alloc_remap_; |
732 | // BufferVar remap |
733 | std::unordered_map<const VarNode*, Var> var_remap_; |
734 | // Buffer remap |
735 | std::unordered_map<const BufferNode*, Buffer> buf_remap_; |
736 | // Allocate from warp reductions |
737 | std::unordered_set<const void*> warp_allocs_; |
738 | // Internal analyzer |
739 | arith::Analyzer analyzer_; |
740 | }; |
741 | |
742 | namespace transform { |
743 | |
744 | Pass LowerThreadAllreduce() { |
745 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
746 | auto* n = f.CopyOnWrite(); |
747 | auto target = f->GetAttr<Target>(tvm::attr::kTarget); |
748 | ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute" ; |
749 | const TargetNode* target_node = target.as<TargetNode>(); |
750 | ThreadAllreduceBuilder thread_all_reduce(target_node); |
751 | auto reduce_body = thread_all_reduce(n->body); |
752 | n->body = |
753 | UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); |
754 | return f; |
755 | }; |
756 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce" , {}); |
757 | } |
758 | |
759 | TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce" ).set_body_typed(LowerThreadAllreduce); |
760 | |
761 | } // namespace transform |
762 | } // namespace tir |
763 | } // namespace tvm |
764 | |