1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Lower allreduce to device implementable ir.
22 * \file lower_thread_allreduce.cc
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/target/target.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/transform.h>
31
32#include <unordered_set>
33
34#include "../../runtime/thread_storage_scope.h"
35#include "ir_utils.h"
36#include "update_pointer_storage_scope.h"
37
38namespace tvm {
39namespace tir {
40
41class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope {
42 public:
43 explicit UpdatePointerStorageScopeAllReduce(
44 const std::unordered_map<const VarNode*, String>& new_storage_scopes)
45 : UpdatePointerStorageScope(new_storage_scopes) {}
46
47 Stmt VisitStmt_(const AllocateNode* op) final {
48 auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
49 auto new_scope = GetPtrStorageScope(remapped);
50 if (new_scope != GetPtrStorageScope(op->buffer_var)) {
51 Stmt body = StmtExprMutator::VisitStmt(op->body);
52 if (new_scope == "shared") {
53 // use volatile access to shared buffer.
54 body = AttrStmt(remapped, attr::volatile_scope, 1, body);
55 }
56 return Allocate(remapped, op->dtype, op->extents, op->condition, body);
57 }
58 return StmtExprMutator::VisitStmt_(op);
59 }
60};
61
62class ThreadAllreduceBuilder final : public StmtExprMutator {
63 public:
64 explicit ThreadAllreduceBuilder(const TargetNode* target)
65 : target_(target),
66 warp_size_(target->GetAttr<Integer>("thread_warp_size", 1).value().IntValue()) {}
67
68 Stmt VisitStmt_(const AttrStmtNode* op) final {
69 if (op->attr_key == attr::thread_extent) {
70 thread_extents_.push_back(op);
71 Stmt ret = StmtExprMutator::VisitStmt_(op);
72 thread_extents_.pop_back();
73 return ret;
74 } else if (op->attr_key == attr::reduce_scope) {
75 const CommReducerNode* combiner = op->node.as<CommReducerNode>();
76 ICHECK(combiner);
77 reduce_combiner_.push_back(combiner);
78 Stmt ret = StmtExprMutator::VisitStmt_(op);
79 reduce_combiner_.pop_back();
80 return ret;
81 } else {
82 return StmtExprMutator::VisitStmt_(op);
83 }
84 }
85 Stmt VisitStmt_(const EvaluateNode* op) final {
86 Stmt stmt = StmtExprMutator::VisitStmt_(op);
87 op = stmt.as<EvaluateNode>();
88 const CallNode* call = op->value.as<CallNode>();
89 if (call && call->op.same_as(builtin::tvm_thread_allreduce())) {
90 return MakeAllreduce(call);
91 } else {
92 return stmt;
93 }
94 }
95 Stmt VisitStmt_(const AllocateNode* op) final {
96 Stmt stmt = StmtExprMutator::VisitStmt_(op);
97 op = stmt.as<AllocateNode>();
98 auto it = alloc_remap_.find(op->buffer_var.get());
99 if (it != alloc_remap_.end()) {
100 const AllocateNode* repl = it->second.as<AllocateNode>();
101 if (warp_allocs_.count(repl)) {
102 new_storage_scopes_[repl->buffer_var.get()] = "local";
103 } else {
104 new_storage_scopes_[repl->buffer_var.get()] = "shared";
105 }
106 return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
107 } else {
108 return stmt;
109 }
110 }
111
112 PrimExpr VisitExpr_(const LoadNode* op) final {
113 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
114 }
115
116 Stmt VisitStmt_(const StoreNode* op) final {
117 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
118 }
119
120 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
121 {
122 auto it = load_remap_.find(op->buffer->data.get());
123 if (it != load_remap_.end()) {
124 for (const auto& index : op->indices) {
125 ICHECK(is_zero(index));
126 }
127 return it->second;
128 }
129 }
130
131 BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
132 op = load.get();
133
134 {
135 auto it = buf_remap_.find(op->buffer.get());
136 if (it != buf_remap_.end()) {
137 return BufferLoad(it->second, op->indices, op->span);
138 }
139 }
140
141 {
142 auto it = var_remap_.find(op->buffer->data.get());
143 if (it != var_remap_.end()) {
144 Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape,
145 op->buffer->strides, op->buffer->elem_offset, op->buffer->name,
146 op->buffer->data_alignment, op->buffer->offset_factor,
147 op->buffer->buffer_type, op->buffer->axis_separators,
148 op->buffer->span);
149 buf_remap_[op->buffer.get()] = remapped_buffer;
150 return BufferLoad(remapped_buffer, op->indices, op->span);
151 }
152 }
153 return StmtExprMutator::VisitExpr_(op);
154 }
155
156 Stmt VisitStmt_(const BufferStoreNode* op) final {
157 BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
158
159 auto it = store_remap_.find(store->buffer.get());
160 if (it != store_remap_.end()) {
161 for (const auto& index : op->indices) {
162 ICHECK(is_zero(index));
163 }
164
165 auto writer = store.CopyOnWrite();
166 writer->buffer = it->second;
167 return std::move(store);
168 }
169
170 {
171 auto it = buf_remap_.find(store->buffer.get());
172 if (it != buf_remap_.end()) {
173 return BufferStore(it->second, store->value, store->indices, store->span);
174 }
175 }
176
177 {
178 auto it = var_remap_.find(store->buffer->data.get());
179 if (it != var_remap_.end()) {
180 Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape,
181 store->buffer->strides, store->buffer->elem_offset,
182 store->buffer->name, store->buffer->data_alignment,
183 store->buffer->offset_factor, store->buffer->buffer_type,
184 store->buffer->axis_separators, store->buffer->span);
185 buf_remap_[store->buffer.get()] = remapped_buffer;
186 return BufferStore(remapped_buffer, store->value, store->indices, store->span);
187 }
188 }
189
190 return std::move(store);
191 }
192
193 std::unordered_map<const VarNode*, String> new_storage_scopes_;
194
195 private:
196 // Thread entry
197 struct ThreadEntry {
198 runtime::ThreadScope scope;
199 IterVar iv;
200 int extent;
201 // comparator
202 bool operator<(const ThreadEntry& other) const {
203 return scope.dim_index < other.scope.dim_index;
204 }
205 };
206
207 // make allreduce.
208 Stmt MakeAllreduce(const CallNode* call) {
209 ICHECK(!reduce_combiner_.empty());
210 const CommReducerNode* combiner = reduce_combiner_.back();
211 size_t size = combiner->result.size();
212
213 const IntImmNode* size_of_args = call->args[0].as<IntImmNode>();
214 ICHECK(size_of_args) << call->args[0]->GetTypeKey();
215 ICHECK_EQ(size, size_of_args->value);
216 Array<PrimExpr> inits = combiner->identity_element;
217 std::vector<PrimExpr> values(size);
218 std::vector<DataType> types(size);
219 PrimExpr cond = call->args[size + 1];
220 for (size_t idx = 0; idx < size; ++idx) {
221 values[idx] = call->args[1 + idx];
222 if (!is_one(cond)) {
223 values[idx] = Select(cond, values[idx], inits[idx]);
224 }
225 types[idx] = values[idx].dtype();
226 }
227 std::vector<Buffer> buffers(size);
228 for (size_t idx = 0; idx < size; ++idx) {
229 PrimExpr arg = call->args[2 + size + idx];
230 // Loads from boolean buffers may have cast nodes inserted by
231 // earlier passes.
232 if (auto cast = arg.as<CastNode>()) {
233 arg = cast->value;
234 }
235 buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
236 }
237
238 std::unordered_set<const VarNode*> reduce_set;
239 for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
240 const VarNode* v = call->args[i].as<VarNode>();
241 // The simply optimization replace a iteration variable with a constant
242 // when extent of the iteration is 1. As threaded IterVar always started from 0,
243 // we can just ignore this variable in this case.
244 if (v) {
245 reduce_set.insert(v);
246 } else {
247 ICHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0)
248 << "arg" << i << "should be a VarNode or IntImmNode";
249 }
250 }
251
252 size_t nmatch = 0;
253 std::vector<ThreadEntry> vred, vpar;
254 for (const AttrStmtNode* attr : thread_extents_) {
255 ThreadEntry e;
256 IterVar iv = Downcast<IterVar>(attr->node);
257 e.scope = runtime::ThreadScope::Create(iv->thread_tag);
258 e.iv = iv;
259 ICHECK_LE(e.scope.rank, 1);
260 ICHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction";
261 if (e.scope.rank == 1) {
262 const auto* ptr = attr->value.as<IntImmNode>();
263 ICHECK(ptr) << "Need constant extent for reduce set " << iv;
264 e.extent = static_cast<int>(ptr->value);
265 // ignore variables equal to 0
266 if (e.extent == 1) {
267 continue;
268 }
269
270 if (reduce_set.count(iv->var.get())) {
271 vred.push_back(e);
272 ++nmatch;
273 } else {
274 vpar.push_back(e);
275 }
276 }
277 }
278 ICHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context";
279 std::sort(vred.begin(), vred.end());
280 std::sort(vpar.begin(), vpar.end());
281 // the size of each index.
282 int reduce_extent, group_extent;
283 PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
284 PrimExpr group_index = FlattenThread(vpar, &group_extent);
285
286 // the longest contiguous reduce extent after flattening
287 int contiguous_reduce_extent = 1;
288 std::vector<std::tuple<int, int, bool>> block_threads; // tuple(dim_index, extent, is_reduce)
289 for (const ThreadEntry& thr : vred) {
290 if (thr.scope.rank == 1) { // threadIdx
291 block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
292 }
293 }
294 for (const ThreadEntry& thr : vpar) {
295 if (thr.scope.rank == 1) { // threadIdx
296 block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
297 }
298 }
299 // sort according to dim_index
300 std::sort(block_threads.begin(), block_threads.end());
301 for (auto&& thr_attr : block_threads) {
302 auto [dim_index, extent, is_reduce] = thr_attr;
303 (void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
304 if (is_reduce) {
305 contiguous_reduce_extent *= extent;
306 } else {
307 break;
308 }
309 }
310
311 std::vector<Stmt> seq;
312 std::vector<Var> shared_buffer_vars(size);
313 std::vector<Buffer> shared_bufs(size);
314 std::vector<Buffer> local_bufs;
315 //
316 // This is an optimization. For small reduction sizes, it may be beneficial
317 // for a single warp to performance the entire reduction. No trips to shared
318 // memory and no cross warp synchronizations are required.
319 // The following code emits the reduction as follows:
320 //
321 // Allocate reduction vars v[i], i = 0..size-1
322 //
323 // for offset from WARP_SIZE to 1 by 2
324 //
325 // a <- load(v[i])
326 // b <- shuffle_down(load(v[i], offset))
327 // v[i] <- reduction(a, b)
328 //
329 // broadcast results from lane 0 to all other lanes and store
330 // the final reduction result to the proper location.
331 //
332 if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
333 ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
334 //
335 // This is the index to the reduction variable, one reduction
336 // variable per warp. Local scope seems easier to reason without
337 // relying on a pattern match pass to fix it later.
338 Array<PrimExpr> zero_indices = {0};
339
340 for (size_t idx = 0; idx < size; ++idx) {
341 Array<PrimExpr> shape = {1};
342
343 Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
344 Var buffer_var = buffer->data;
345
346 shared_buffer_vars[idx] = buffer_var;
347 shared_bufs[idx] = buffer;
348
349 PrimExpr pred = const_true(types[idx].lanes());
350 seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
351
352 // Uses a local variable to store the shuffled data. Later
353 // on, an allocation will be built for this local variable.
354 local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
355 }
356
357 // The mask for this reducer, as this reducer may sit inside
358 // a divergent control flow. Here it uses a variable to cache the current
359 // active channels.
360 //
361 DataType mask_dtype = DataType::UInt(32);
362 Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
363 {
364 PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
365 if (group_extent > 1) {
366 mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * group_index));
367 }
368 seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
369 // Push the buffer description. Later this will have an
370 // allocation built for it.
371 local_bufs.push_back(mask_buffer);
372 }
373
374 // Emit reductions within a warp.
375 int start_offset = 1;
376 while (start_offset * 2 < reduce_extent) {
377 start_offset *= 2;
378 }
379 for (int offset = start_offset; offset > 0; offset /= 2) {
380 // Load reduction values, no synchronization needed.
381 Array<PrimExpr> a, b;
382 for (size_t i = 0; i < size; ++i) {
383 Buffer shared_buf = shared_bufs[i];
384 BufferLoad val(shared_buf, zero_indices);
385 ICHECK_EQ(val->dtype, types[i]);
386 a.push_back(val);
387
388 // __shfl_*sync calls shall not appear in if_then_else expressions
389 // as this is causing extra divergency. E.g.
390 //
391 // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
392 //
393 // behaves differently from
394 //
395 // int t = __shfl_sync(mask, v1, 0);
396 // v1 = (v2 < v3) ? v3 : t;
397 //
398 // The former may cause dead lock as there is a divergent
399 // branch with a warp sync call inside.
400 //
401 PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
402 Buffer local_buf = local_bufs[i];
403 Stmt s = BufferStore(local_buf, other, zero_indices);
404 seq.push_back(s);
405
406 BufferLoad load = BufferLoad(local_buf, zero_indices);
407 ICHECK_EQ(load->dtype, types[i]);
408 b.push_back(load);
409 }
410
411 // Do reductions.
412 Array<PrimExpr> ret = (*combiner)(a, b);
413
414 // Store the reduction result to itself.
415 std::vector<Stmt> stores(size);
416 for (size_t i = 0; i < size; ++i) {
417 Buffer buf = shared_bufs[i];
418 stores[i] = BufferStore(buf, ret[i], zero_indices);
419 }
420
421 // During the sub-warp reduction, values from inactive threads could be read,
422 // which is an undefined behavior according to the cuda document.
423 //
424 // In practise, the return value are usually 0, which does no harm to sum reduction.
425 // However, the result can be incorrect in max or prod reduction.
426 // Therefore an additional range check has to be performed to ensure the correctness.
427 if (offset * 2 > reduce_extent) {
428 PrimExpr cond = reduce_index + offset < reduce_extent;
429 seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
430 } else {
431 seq.push_back(SeqStmt::Flatten(stores));
432 }
433 }
434
435 // Broadcast the reduction result from lane 0 to all other lanes.
436 // This avoids to emit predicated stores, as all threads are
437 // uniformly writting the same result.
438 //
439 for (size_t i = 0; i < size; ++i) {
440 Buffer buf = shared_bufs[i];
441 PrimExpr val = BufferLoad(buf, zero_indices);
442 ICHECK_EQ(val->dtype, types[i]);
443 PrimExpr splat =
444 WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
445 seq.push_back(BufferStore(buf, splat, zero_indices));
446 }
447
448 // Update existing allocations.
449 for (size_t i = 0; i < size; ++i) {
450 ICHECK(!load_remap_.count(buffers[i]->data.get()));
451 PrimExpr pred = const_true(types[i].lanes());
452 Buffer buf = shared_bufs[i];
453 PrimExpr val = BufferLoad(buf, zero_indices);
454 ICHECK_EQ(val->dtype, types[i]);
455 load_remap_[buffers[i]->data.get()] = val;
456 store_remap_[buffers[i].get()] = buf;
457 Array<PrimExpr> extents{PrimExpr(1)};
458 auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
459 alloc_remap_[buffers[i]->data.get()] = node;
460 var_remap_[buffers[i]->data.get()] = buf->data;
461 warp_allocs_.insert(node.get());
462 }
463 } else {
464 if (reduce_extent == 1) {
465 // special case, no reduction is needed.
466 std::vector<Stmt> stores;
467 for (size_t i = 0; i < size; ++i) {
468 stores.push_back(BufferStore(buffers[i], values[i], {0}));
469 }
470 return SeqStmt::Flatten(stores);
471 }
472 // This sync is necessary because there might be incomplete read of
473 // previous iteration on the same buffer.
474 seq.emplace_back(SyncThread("shared"));
475 for (size_t idx = 0; idx < size; ++idx) {
476 Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx));
477
478 shared_bufs[idx] = buffer;
479 shared_buffer_vars[idx] = buffer->data;
480
481 PrimExpr pred = const_true(types[idx].lanes());
482 seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
483 {BufIndex(reduce_index, group_index, reduce_extent)}));
484 }
485 seq.emplace_back(SyncThread("shared"));
486 seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
487 reduce_extent, group_extent, contiguous_reduce_extent));
488 for (size_t idx = 0; idx < size; ++idx) {
489 ICHECK(!load_remap_.count(buffers[idx]->data.get()));
490 PrimExpr pred = const_true(types[idx].lanes());
491 BufferLoad load(shared_bufs[idx],
492 {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
493 ICHECK_EQ(load->dtype, types[idx]);
494 load_remap_[buffers[idx]->data.get()] = load;
495 alloc_remap_[buffers[idx]->data.get()] =
496 Allocate(shared_bufs[idx]->data, types[idx],
497 {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
498 var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
499 store_remap_[buffers[idx].get()] = shared_bufs[idx];
500 }
501 }
502
503 // Fix all local allocations as all statements are built.
504 Stmt body = SeqStmt::Flatten(seq);
505 for (Buffer buf : local_bufs) {
506 body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
507 new_storage_scopes_[buf->data.get()] = "local";
508 }
509
510 return body;
511 }
512
513 // make allreduce.
514 Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types,
515 const Array<Buffer>& shared_bufs, PrimExpr reduce_index,
516 PrimExpr group_index, int reduce_extent, int group_extent,
517 int contiguous_reduce_extent) {
518 // Get next power of two
519 int reduce_align = 1;
520 while (reduce_extent > reduce_align) {
521 reduce_align = reduce_align << 1;
522 }
523 ICHECK_GT(reduce_align, 1);
524 std::vector<Stmt> seq;
525
526 size_t size = shared_bufs.size();
527 PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
528 // make reduction
529 auto fload = [&](int offset) {
530 Array<PrimExpr> a, b;
531 for (size_t i = 0; i < size; ++i) {
532 BufferLoad b_load(shared_bufs[i],
533 {BufIndex(reduce_index + offset, group_index, reduce_extent)});
534 ICHECK_EQ(b_load->dtype, types[i]);
535 b.push_back(b_load);
536
537 BufferLoad a_load(shared_bufs[i], {buf_index});
538 ICHECK_EQ(a_load->dtype, types[i]);
539 a.push_back(a_load);
540 }
541 Array<PrimExpr> ret = (*combiner)(a, b);
542 return ret;
543 };
544 auto fstore = [&](const Array<PrimExpr>& ret) {
545 std::vector<Stmt> stores(size);
546 for (size_t i = 0; i < size; ++i) {
547 stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index});
548 }
549 return SeqStmt::Flatten(stores);
550 };
551 auto freduce = [&](int offset) {
552 auto ret = fload(offset);
553 return fstore(ret);
554 };
555 // Step one, check for
556 if (reduce_align > reduce_extent) {
557 // reduction with the boundary condition
558 reduce_align = reduce_align >> 1;
559 PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
560 seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
561 seq.emplace_back(SyncThread("shared"));
562 }
563
564 // normal synchronization
565 bool warp_align = group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0;
566 while (reduce_align > contiguous_reduce_extent || reduce_align > warp_size_ || !warp_align) {
567 if (reduce_align == 1) {
568 break;
569 }
570 reduce_align = reduce_align >> 1;
571 PrimExpr cond = reduce_index < reduce_align;
572 seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
573 seq.emplace_back(SyncThread("shared"));
574 }
575 // in warp synchronization.
576 if (reduce_align > 1) {
577 PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
578
579 std::vector<Stmt> in_warp_seq;
580
581 while (reduce_align > 1) {
582 reduce_align = reduce_align >> 1;
583
584 // freduce can read/write to the same memory location. For
585 // example, with reduce_align of 4, threadIdx 3 reads from
586 // memory location 7 as threadIdx 7 is writing to it.
587 // Therefore, we need to separate out the load from the store
588 // with a memory barrier in-between. This isn't necessary for
589 // the earlier normal synchronization, because those are each
590 // protected by an if-statement. The if-statement is avoided
591 // here to reduce thread divergence.
592 auto loads = fload(reduce_align);
593
594 Array<Var> in_warp_local_vars;
595 for (auto expr : loads) {
596 Var var(
597 "w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()),
598 expr->dtype);
599 in_warp_local_vars.push_back(var);
600 }
601
602 std::vector<Stmt> in_let_statement;
603 in_let_statement.emplace_back(SyncThread("warp"));
604 in_let_statement.emplace_back(
605 fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()}));
606 in_let_statement.emplace_back(SyncThread("warp"));
607
608 Stmt body = SeqStmt::Flatten(in_let_statement);
609 for (size_t i = 0; i < size; i++) {
610 body = LetStmt(in_warp_local_vars[i], loads[i], body);
611 }
612 in_warp_seq.push_back(body);
613 }
614
615 Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
616
617 seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
618 seq.emplace_back(SyncThread("shared"));
619 }
620 return SeqStmt::Flatten(seq);
621 }
622 // Flatten the thread index.
623 // Also return a warp number,
624 PrimExpr FlattenThread(const std::vector<ThreadEntry>& tvec, int* out_total_extent) {
625 int& total_extent = *out_total_extent;
626 total_extent = 1;
627 if (tvec.size() == 0) {
628 return make_zero(DataType::Int(32));
629 }
630
631 PrimExpr ret;
632 for (const ThreadEntry& e : tvec) {
633 if (ret.defined()) {
634 ret = ret + e.iv->var * total_extent;
635 } else {
636 ICHECK_EQ(total_extent, 1);
637 ret = e.iv->var;
638 }
639 total_extent *= e.extent;
640 }
641 return ret;
642 }
643 // The local buffer index.
644 PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) {
645 if (!is_zero(group_index)) {
646 return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
647 } else {
648 return reduce_index;
649 }
650 }
651 // sync thread op.
652 static Stmt SyncThread(const std::string& sync) {
653 return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}));
654 }
655
656 // Emit warp shuffle calls.
657 PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) {
658 Array<PrimExpr> indices = {0};
659 PrimExpr mask = BufferLoad(mask_buffer, indices);
660 PrimExpr width = IntImm(DataType::Int(32), warp_size_);
661 Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
662 return Call(val.dtype(), op, args);
663 }
664
665 // Check if we can use warp level reduction.
666 //
667 // Note: The ROCm backend will only have warp reductions for now.
668 // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
669 bool is_warp_reduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
670 int contiguous_reduce_extent) const {
671 // Only cuda target supports warp reductions.
672 if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
673
674 // rocm only supports 32 bit operands for shuffling at the moment
675 if ((target_->kind->name == "rocm") &&
676 (std::any_of(types.begin(), types.end(), [](DataType ty) {
677 if (ty.is_vector()) return true;
678 return ty.bits() != 32;
679 }))) {
680 return false;
681 }
682
683 // Supported types:
684 // {u}int, {u}long, {u}long long, float, double, half/half2
685 if (std::any_of(types.begin(), types.end(), [](DataType ty) {
686 if (ty.is_float16()) return ty.lanes() > 2;
687 if (ty.is_vector()) return true;
688 return ty.bytes() < 4 || ty.bytes() > 8;
689 })) {
690 return false;
691 }
692 if (thread_extents_.empty()) {
693 return false;
694 }
695
696 // reduce region must be contiguous.
697 if (contiguous_reduce_extent != reduce_extent) {
698 return false;
699 }
700
701 // whether reduce_extent and group_extent are vaild for warp reduction.
702 if (target_->kind->name == "rocm") {
703 return reduce_extent == warp_size_;
704 } else { // target_->kind->name == "cuda"
705 if (reduce_extent == 1) {
706 return false; // no need to warp reduce
707 } else {
708 if (warp_size_ % reduce_extent == 0) {
709 return true; // warp size is multiple of reduce extent
710 } else {
711 return group_extent == 1 && reduce_extent <= warp_size_;
712 }
713 }
714 }
715 }
716
717 // The target.
718 const TargetNode* target_ = nullptr;
719
720 // The warp size of the device.
721 int warp_size_{1};
722
723 // surrounding scope of thread extent.
724 std::vector<const AttrStmtNode*> thread_extents_;
725 std::vector<const CommReducerNode*> reduce_combiner_;
726 // The load remap
727 std::unordered_map<const VarNode*, PrimExpr> load_remap_;
728 // The store remap
729 std::unordered_map<const BufferNode*, Buffer> store_remap_;
730 // Allocate remap
731 std::unordered_map<const VarNode*, Stmt> alloc_remap_;
732 // BufferVar remap
733 std::unordered_map<const VarNode*, Var> var_remap_;
734 // Buffer remap
735 std::unordered_map<const BufferNode*, Buffer> buf_remap_;
736 // Allocate from warp reductions
737 std::unordered_set<const void*> warp_allocs_;
738 // Internal analyzer
739 arith::Analyzer analyzer_;
740};
741
742namespace transform {
743
744Pass LowerThreadAllreduce() {
745 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
746 auto* n = f.CopyOnWrite();
747 auto target = f->GetAttr<Target>(tvm::attr::kTarget);
748 ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute";
749 const TargetNode* target_node = target.as<TargetNode>();
750 ThreadAllreduceBuilder thread_all_reduce(target_node);
751 auto reduce_body = thread_all_reduce(n->body);
752 n->body =
753 UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body);
754 return f;
755 };
756 return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
757}
758
759TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce);
760
761} // namespace transform
762} // namespace tir
763} // namespace tvm
764