1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file lower_cross_thread_reduction.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/tir/analysis.h>
25#include <tvm/tir/stmt_functor.h>
26#include <tvm/tir/transform.h>
27
28#include "../schedule/analysis.h"
29#include "./ir_utils.h"
30
31namespace tvm {
32namespace tir {
33
34/*!
35 * \brief Checks if a loop is bound to threadIdx.x/y/z
36 * \brief loop The loop to be checked
37 * \return True if the loop is bound to threadIdx.x/y/z
38 */
39bool IsBoundToThreadIdx(const ForNode* loop) {
40 if (!loop->thread_binding.defined()) {
41 return false;
42 }
43 runtime::ThreadScope scope =
44 runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
45 return scope.rank == 1 && scope.dim_index >= 0;
46}
47
48/*!
49 * \brief Check the dominant property of a block:
50 * the block is the only writer of its output, dominating the reader of its output buffers
51 * \param scope_block The scope block of the block to be checked
52 * \param block The block whose dominant property is to be checked
53 * \return A boolean indicating if the block is a dominant block
54 */
55bool IsDominantBlock(const Block& scope_block, const Block& block) {
56 // Step 1. Count the number of writers for each buffer written by the scope block.
57 std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
58 PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
59 if (const auto* block = obj.as<BlockNode>()) {
60 for (const BufferRegion& buffer_region : block->writes) {
61 ++buffer_writer_cnt[buffer_region->buffer.get()];
62 }
63 return false;
64 }
65 return true;
66 });
67 // Step 2. Check whether `block` is the only writer of its outputs.
68 for (const BufferRegion& buffer_region : block->writes) {
69 ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
70 if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
71 return false;
72 }
73 }
74 return true;
75}
76
77/*!
78 * \brief Check whether the input block is a reduction block.
79 * \param realize The block to be checked
80 * \param loop_range_map The mapping from the loop variables outside the input block to their ranges
81 * \param scope_block The scope block of the input block
82 * \param analyzer The analyzer
83 * \return A boolean indicating whether the input block is a reduction block.
84 * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is
85 * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the
86 * check again.
87 */
88bool IsReductionBlock(const BlockRealize& realize, const Map<Var, Range>& loop_range_map,
89 const Block& scope_block, arith::Analyzer* analyzer) {
90 const auto* block = realize->block.as<BlockNode>();
91 // Cond 1. The block has the `init` statement.
92 if (!block->init.defined()) {
93 return false;
94 }
95 // Cond 2. All the block bindings are quasi-affine expressions.
96 if (!IsAffineBinding(realize, loop_range_map, analyzer)) {
97 return false;
98 }
99 // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile,
100 // we collect all the reduction block vars.
101 if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
102 return false;
103 }
104 // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
105 // output buffers.
106 if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
107 return false;
108 }
109 // Cond 5. The reduction block vars are not used to index the output buffers.
110 return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block));
111}
112
113/*!
114 * \brief Create intermediate buffers according to the input buffers and buffer kind
115 * \param reduction_buffers The old reduction buffers which provide the buffer names and data types
116 * \param is_cross_thread_buffer A boolean indicating whether to create buffers for the cross-thread
117 * computation results or not, which is used for determine the buffer name prefix
118 * \return The created buffers
119 */
120Array<Buffer> MakeScratchpads(const Array<Buffer>& reduction_buffers, bool is_cross_thread_buffer) {
121 Array<Buffer> new_buffers;
122 new_buffers.reserve(reduction_buffers.size());
123 for (const Buffer& buffer : reduction_buffers) {
124 String name = is_cross_thread_buffer ? "cross" : "in";
125 name = name + "_thread_" + buffer->name;
126 new_buffers.push_back(Buffer(/*ptr=*/Var(name, PointerType(PrimType(buffer->dtype), "local")),
127 /*dtype=*/buffer->dtype,
128 /*shape=*/{Integer(1)},
129 /*strides=*/{Integer(1)},
130 /*elem_offset=*/PrimExpr{nullptr},
131 /*name=*/name,
132 /*data_alignment=*/0,
133 /*offset_factor=*/0,
134 /*buffer_type=*/kDefault));
135 }
136 return new_buffers;
137}
138
139/*!
140 * \brief Substitute given source buffers with given target buffers respectively in the input
141 * statement
142 */
143class BufferReplacer : private StmtExprMutator {
144 public:
145 static Stmt Run(Array<Buffer> src_buffers, Array<Buffer> tgt_buffers, Stmt stmt) {
146 Map<Buffer, Buffer> buffer_map;
147 ICHECK_EQ(src_buffers.size(), tgt_buffers.size());
148 int n_buffers = src_buffers.size();
149 for (int i = 0; i < n_buffers; ++i) {
150 buffer_map.Set(src_buffers[i], tgt_buffers[i]);
151 }
152 return BufferReplacer(buffer_map)(std::move(stmt));
153 }
154
155 private:
156 explicit BufferReplacer(Map<Buffer, Buffer> buffer_map) : buffer_map_(std::move(buffer_map)) {}
157
158 PrimExpr VisitExpr_(const BufferLoadNode* load) final {
159 auto it = buffer_map_.find(load->buffer);
160 return it != buffer_map_.end() ? BufferLoad((*it).second, {0}) : GetRef<BufferLoad>(load);
161 }
162
163 Stmt VisitStmt_(const BufferStoreNode* store) final {
164 auto it = buffer_map_.find(store->buffer);
165 if (it != buffer_map_.end()) {
166 PrimExpr value = StmtExprMutator::VisitExpr(store->value);
167 return BufferStore((*it).second, std::move(value), {0});
168 } else {
169 return StmtMutator::VisitStmt_(store);
170 }
171 }
172
173 Map<Buffer, Buffer> buffer_map_;
174};
175
176/*!
177 * \brief Substitute a given source block with a given target block, or remove the source block
178 * branch from the AST if the target block is undefined
179 */
180class InThreadReducerMaker : private StmtMutator {
181 public:
182 static Optional<Stmt> Make(const BlockRealizeNode* src_realize,
183 Optional<BlockRealize> tgt_realize, Stmt stmt) {
184 return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt));
185 }
186
187 private:
188 explicit InThreadReducerMaker(const BlockRealizeNode* src_realize,
189 Optional<BlockRealize> tgt_realize)
190 : src_realize_(src_realize), tgt_realize_(tgt_realize) {}
191 Stmt VisitStmt_(const BlockRealizeNode* realize) final {
192 if (realize == src_realize_) {
193 return tgt_realize_.defined() //
194 ? tgt_realize_.value()
195 : Stmt{nullptr};
196 }
197 return GetRef<BlockRealize>(realize);
198 }
199
200 Stmt VisitStmt_(const ForNode* loop) final {
201 if (Optional<For> opt_res = Downcast<Optional<For>>(StmtMutator::VisitStmt_(loop))) {
202 For res = opt_res.value();
203 if (res->thread_binding.defined()) {
204 return res->body;
205 } else {
206 return std::move(res);
207 }
208 } else {
209 return Stmt{nullptr};
210 }
211 }
212
213 Stmt VisitStmt_(const SeqStmtNode* seq) final {
214 Array<Stmt> stmts;
215 stmts.reserve(seq->size());
216 for (const Stmt& stmt : seq->seq) {
217 if (Optional<Stmt> opt_res = VisitStmt(stmt)) {
218 stmts.push_back(opt_res.value());
219 }
220 }
221 return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts);
222 }
223
224 const BlockRealizeNode* src_realize_;
225 Optional<BlockRealize> tgt_realize_;
226};
227
228/*!
229 * \brief Create the lowered allreduce block transformed from the input reduction block
230 * \param realize The block-realize which contains the old reduction block
231 * \param it_buffers The buffers to store in-thread reduction results
232 * \param ct_buffers The buffers to store cross-thread reduction results
233 * \param wb_buffers The buffers to store the final reduction results
234 * \param old_wb_indices The indices used to access the write-back buffers when storing the final
235 * reduction results into the write-back buffers
236 * \param reducer The reduction function
237 * \param combiner_rhs The RHS values of the combiner
238 * \param reduction_loops The reduction loops
239 */
240Stmt TransformReductionBlock(const BlockRealizeNode* realize, //
241 const Optional<Array<Buffer>>& it_buffers, //
242 const Array<Buffer>& ct_buffers, //
243 const Array<Buffer>& wb_buffers, //
244 const Array<PrimExpr>& old_wb_indices, //
245 const CommReducer& reducer, //
246 const Array<PrimExpr>& combiner_rhs, //
247 const std::vector<const ForNode*>& reduction_loops) {
248 int n_buffers = wb_buffers.size();
249 const BlockNode* block = realize->block.get();
250
251 auto f_create_buffer_regions = [](Array<Buffer> buffers) {
252 Array<BufferRegion> regions;
253 regions.reserve(buffers.size());
254 for (const Buffer& buffer : buffers) {
255 regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
256 }
257 return regions;
258 };
259
260 Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
261 Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
262 if (it_buffers.defined()) {
263 it_buffer_regions = f_create_buffer_regions(it_buffers.value());
264 }
265 // In total, the block is transformed into at most 4 statements
266 // - Stmt 1: initialize the buffer for in-thread reduction
267 // - Stmt 2: do in-thread reduction
268 // - Stmt 3: do cross-thread reduction
269 // - Stmt 4: write cross-thread reduction result to the original buffer
270 Array<Stmt> stmts;
271 stmts.reserve(4);
272 // Stmt 1: initialize the buffer for in-thread reduction
273 if (it_buffers.defined()) {
274 Array<Stmt> inits;
275 inits.reserve(n_buffers);
276 for (int i = 0; i < n_buffers; ++i) {
277 inits.push_back(
278 BufferStore(it_buffers.value()[i], reducer->identity_element[i], {Integer(0)}));
279 }
280 stmts.push_back(BlockRealize(/*iter_values=*/{},
281 /*predicate=*/const_true(),
282 /*block=*/
283 Block(/*iter_vars=*/{},
284 /*reads=*/{},
285 /*writes=*/it_buffer_regions.value(),
286 /*name_hint=*/block->name_hint + "_in_thread_init",
287 /*body=*/n_buffers > 1 ? SeqStmt(inits) : inits[0])));
288 }
289 // Stmt 2: do in-thread reduction
290 {
291 Optional<BlockRealize> new_realize = NullOpt;
292 // If need to generate in-thread reduction,
293 // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize
294 // otherwise, directly remove given BlockRealize
295 if (it_buffers.defined()) {
296 ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
297 new_block->reads = std::move(new_block->reads);
298 new_block->writes = it_buffer_regions.value();
299 new_block->name_hint = new_block->name_hint + "_in_thread";
300 new_block->body =
301 BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body));
302 new_block->init = NullOpt;
303 ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
304 n->block = Block(new_block);
305 new_realize = BlockRealize(n);
306 }
307 For loop = GetRef<For>(reduction_loops[0]);
308 if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) {
309 stmts.push_back(stmt.value());
310 }
311 }
312 // Stmt 3: do cross-thread reduction
313 {
314 // Step 3.1. Create the parameters to the intrinsic
315 Array<PrimExpr> parameters;
316 parameters.reserve(reduction_loops.size() + 4);
317 // 1-st argument: number of buffers
318 parameters.push_back(make_const(DataType::UInt(32), n_buffers));
319 // Next `n_buffers` arguments: sources
320 if (it_buffers.defined()) {
321 for (int i = 0; i < n_buffers; ++i) {
322 parameters.push_back(BufferLoad(it_buffers.value()[i], {Integer(0)}));
323 }
324 } else {
325 parameters.insert(parameters.end(), combiner_rhs.begin(), combiner_rhs.end());
326 }
327 // Next argument: predicate
328 parameters.push_back(const_true());
329 // Next `n_buffers` arguments: destinations
330 for (int i = 0; i < n_buffers; ++i) {
331 parameters.push_back(BufferLoad(ct_buffers[i], {0}));
332 }
333 // Next arguments: all the reduction threads
334 for (const ForNode* reduction_loop : reduction_loops) {
335 if (reduction_loop->thread_binding.defined()) {
336 parameters.push_back(reduction_loop->loop_var);
337 }
338 }
339 // Step 3.2. Create the block and the block-realize.
340 Array<IterVar> iter_vars{nullptr};
341 Array<PrimExpr> bindings{nullptr};
342 Array<BufferRegion> reads{nullptr};
343 if (it_buffers.defined()) {
344 iter_vars = Array<IterVar>{};
345 bindings = Array<PrimExpr>{};
346 reads = it_buffer_regions.value();
347 } else {
348 iter_vars = block->iter_vars;
349 bindings = realize->iter_values;
350 reads = block->reads;
351 }
352 stmts.push_back(BlockRealize(
353 /*iter_values=*/std::move(bindings),
354 /*predicate=*/const_true(),
355 /*block=*/
356 Block(/*iter_vars=*/std::move(iter_vars),
357 /*reads=*/std::move(reads),
358 /*writes=*/ct_buffer_regions,
359 /*name_hint=*/block->name_hint + "_cross_thread",
360 /*body=*/
361 AttrStmt(/*node=*/reducer,
362 /*attr_key=*/tir::attr::reduce_scope,
363 /*value=*/make_zero(DataType::Handle()),
364 /*body=*/
365 Evaluate(Call(/*dtype=*/DataType::Handle(),
366 /*op=*/tir::builtin::tvm_thread_allreduce(),
367 /*args=*/std::move(parameters)))))));
368 }
369 // Stmt 4: write cross-thread reduction result to the original buffer
370 {
371 ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
372 int n_iter = static_cast<int>(block->iter_vars.size());
373 Array<IterVar> iter_vars;
374 Array<PrimExpr> bindings;
375 Map<Var, PrimExpr> var_map;
376 iter_vars.reserve(n_iter);
377 bindings.reserve(n_iter);
378 for (int i = 0; i < n_iter; ++i) {
379 const IterVar& iter_var = block->iter_vars[i];
380 const PrimExpr& binding = realize->iter_values[i];
381 if (iter_var->iter_type != kCommReduce) {
382 IterVar new_iter_var{nullptr};
383 {
384 ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
385 ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
386 n->var = Var(v);
387 new_iter_var = IterVar(n);
388 }
389 iter_vars.push_back(new_iter_var);
390 bindings.push_back(binding);
391 var_map.Set(iter_var->var, new_iter_var->var);
392 }
393 }
394 Array<Stmt> wb_updates;
395 Array<BufferRegion> wb_regions;
396 wb_updates.reserve(n_buffers);
397 wb_regions.reserve(n_buffers);
398 int n_dim = static_cast<int>(old_wb_indices.size());
399 Array<Range> region = Substitute(block->writes[0]->region, var_map);
400 Array<PrimExpr> wb_indices;
401 wb_indices.reserve(n_dim);
402 for (int d = 0; d < n_dim; ++d) {
403 wb_indices.push_back(Substitute(old_wb_indices[d], var_map));
404 }
405 for (int i = 0; i < n_buffers; ++i) {
406 wb_updates.push_back(
407 BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices));
408 wb_regions.push_back(BufferRegion(wb_buffers[i], region));
409 }
410 stmts.push_back(BlockRealize(
411 /*iter_values=*/std::move(bindings),
412 /*predicate=*/const_true(),
413 /*block=*/
414 Block(/*iter_vars=*/std::move(iter_vars),
415 /*reads=*/std::move(ct_buffer_regions),
416 /*writes=*/std::move(wb_regions),
417 /*name_hint=*/block->name_hint + "_write_back",
418 /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0])));
419 }
420 // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx
421 Stmt new_stmt = SeqStmt::Flatten(std::move(stmts));
422 for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) {
423 const ForNode* loop = *rit;
424 if (loop->thread_binding.defined()) {
425 ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
426 n->body = std::move(new_stmt);
427 new_stmt = For(n);
428 }
429 }
430 return new_stmt;
431}
432
433/*!
434 * \brief Detect cross-thread reduction pattern and then transform
435 */
436class CrossThreadReductionTransformer : public StmtMutator {
437 private:
438 // Check if the input block needs cross-thread reduction.
439 std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode* realize) {
440 // Step 0. If the block is the root block, just return.
441 if (block_stack_.empty()) {
442 return {};
443 }
444
445 // Step 1. If the block is not a reduction block, cross-thread reduction is not needed.
446 if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
447 GetRef<Block>(block_stack_.back()), &analyzer_)) {
448 return {};
449 }
450
451 // Step 2. Collect all the vars that appear in the bindings of reduction block iters.
452 std::unordered_set<const VarNode*> reduction_vars;
453 GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr, &reduction_vars);
454
455 // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters.
456 // We call these loops "reduction-related".
457 // Step 4. See whether at least one reduction-related loop is bound to thread axis in GPU - if
458 // so, cross-thread reduction is needed. If none of the reduction-related loops is bound to
459 // thread axis, cross-thread reduction is not needed for the input block.
460 bool need = false;
461 std::vector<const ForNode*> reduction_loops;
462 for (const ForNode* loop : loop_stack_) {
463 if (reduction_vars.count(loop->loop_var.get())) {
464 // Step 3. Collect the loop.
465 reduction_loops.push_back(loop);
466 // Step 4. See whether the loop is bound to some thread axis.
467 if (loop->thread_binding.defined()) {
468 need = true;
469 }
470 }
471 }
472 return need ? reduction_loops : std::vector<const ForNode*>{};
473 }
474
475 /*!
476 * \brief Given that the input block needs cross-thread reduction, check if cross-thread reduction
477 * can be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread
478 * reduction)
479 * \param block The block to be checked
480 * \param reduction_loops The reduction loops above the block
481 * \return A tuple consisting of five elements:
482 * - an integer which indicates the number of reduction loops that are bound to thread axes,
483 * - the detected commutative reducer of the reduction,
484 * - the reduction buffers which store the reduction results,
485 * - the RHS values of the reduction updates,
486 * - the indices which is used to access the reduction buffers when storing the reduction results
487 */
488 std::tuple<int, CommReducer, Array<Buffer>, Array<PrimExpr>, Array<PrimExpr>>
489 CheckCanApplyCrossThreadReduction(const BlockNode* block,
490 const std::vector<const ForNode*>& reduction_loops) const {
491 // Condition 1. All the reduction-related loops should be the deepest among all statements
492 // outside the block (ignoring SeqStmt here).
493 int n_deepest_reduction_loops = 0;
494 for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) {
495 const StmtNode* stmt = *rit;
496 if ((*rit)->IsInstance<SeqStmtNode>()) {
497 // Skip SeqStmt.
498 continue;
499 }
500 if (std::find(reduction_loops.begin(), reduction_loops.end(),
501 reinterpret_cast<const ForNode*>(stmt)) == reduction_loops.end()) {
502 break;
503 }
504 ++n_deepest_reduction_loops;
505 }
506 CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size())
507 << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the "
508 "deepest among all statements outside the desired block. However, block "
509 << block->name_hint
510 << " needs cross-thread reduction, while the reduction-related loops outside of it are not "
511 "the deepest statements, which violates the condition.";
512
513 // Condition 2. All the reduction-related loops that are bound to thread axes should only be
514 // bound to `threadIdx.x/y/z`.
515 int n_bound_reduction_loops = 0;
516 for (const ForNode* reduction_loop : reduction_loops) {
517 if (reduction_loop->thread_binding.defined()) {
518 ++n_bound_reduction_loops;
519 CHECK(IsBoundToThreadIdx(reduction_loop))
520 << "ValueError: Cross-thread reduction requires all the reduction-related loops that "
521 "are bound to GPU thread axes to only be bound `threadIdx.x/y/z`. However, loop "
522 << reduction_loop->loop_var->name_hint << " violates the condition.";
523 }
524 }
525
526 // Condition 3. Get the identity values of the block init and the BufferStore block combiner
527 // updates of the reduction. Extract the commutative reducer, combiner lhs and combiner rhs from
528 // the reduction identities and the reduction combiner.
529 Array<PrimExpr> init_values{nullptr};
530 Array<BufferStore> updates{nullptr};
531 CommReducer reducer{nullptr};
532 Array<PrimExpr> combiner_lhs{nullptr};
533 Array<PrimExpr> combiner_rhs{nullptr};
534 std::tie(init_values, updates) =
535 GetInitValuesAndUpdatesFromReductionBlock(NullOpt, GetRef<Block>(block));
536 std::tie(reducer, combiner_lhs, combiner_rhs) =
537 GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates);
538
539 Array<Buffer> reduction_buffers;
540 reduction_buffers.reserve(updates.size());
541 for (const BufferStore& buf_store : updates) {
542 reduction_buffers.push_back(buf_store->buffer);
543 }
544
545 // Condition 4. The block should be the last block under the first reduction-related loop.
546 bool visit = false;
547 PreOrderVisit(GetRef<For>(reduction_loops[0]), [block, &visit](const ObjectRef& obj) {
548 if (const auto* realize = obj.as<BlockRealizeNode>()) {
549 CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction "
550 "block isn't the last block under its first reduction-related loop";
551 if (realize->block.get() == block) {
552 visit = true;
553 }
554 return false;
555 }
556 return true;
557 });
558 return std::make_tuple(n_bound_reduction_loops, //
559 std::move(reducer), //
560 std::move(reduction_buffers), //
561 std::move(combiner_rhs), //
562 updates[0]->indices);
563 }
564
565 Stmt VisitStmt(const Stmt& stmt) final {
566 statement_stack_.push_back(stmt.get());
567 Stmt result = StmtMutator::VisitStmt(stmt);
568 statement_stack_.pop_back();
569 return result;
570 }
571
572 Stmt VisitStmt_(const ForNode* loop) final {
573 loop_stack_.push_back(loop);
574 loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
575 Stmt result = StmtMutator::VisitStmt_(loop);
576 loop_stack_.pop_back();
577 loop_range_map_.erase(loop->loop_var);
578
579 // Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`.
580 auto it = loop2new_stmt_.find(loop);
581 if (it != loop2new_stmt_.end()) {
582 return it->second;
583 } else {
584 return result;
585 }
586 }
587
588 Stmt VisitStmt_(const BlockNode* block) final {
589 Map<Var, Range> old_loop_range_map;
590
591 block_stack_.push_back(block);
592 std::swap(old_loop_range_map, loop_range_map_);
593 Block new_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
594 block_stack_.pop_back();
595 std::swap(old_loop_range_map, loop_range_map_);
596
597 // Insert the new allocated buffers into the block's `alloc_buffers` field.
598 auto it = block2new_buffers_.find(block);
599 if (it != block2new_buffers_.end()) {
600 BlockNode* p_new_block = new_block.CopyOnWrite();
601 for (const Buffer& new_buffer : it->second) {
602 if (new_buffer.defined()) {
603 p_new_block->alloc_buffers.push_back(new_buffer);
604 }
605 }
606 }
607 return std::move(new_block);
608 }
609
610 Stmt VisitStmt_(const BlockRealizeNode* realize) final {
611 const BlockNode* block = realize->block.get();
612 // Step 1. Check whether cross-thread reduction is needed. If no, skip this block.
613 std::vector<const ForNode*> reduction_loops = NeedCrossThreadReduction(realize);
614 if (reduction_loops.empty()) {
615 return StmtMutator::VisitStmt_(realize);
616 }
617 // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on
618 // which condition the block violates.
619 int n_bound_reduction_loops = 0;
620 CommReducer reducer{nullptr};
621 Array<Buffer> reduction_buffers{nullptr};
622 Array<PrimExpr> combiner_rhs{nullptr};
623 Array<PrimExpr> wb_indices{nullptr};
624 std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) =
625 CheckCanApplyCrossThreadReduction(block, reduction_loops);
626 // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when
627 // - not all the reduction-related loops are bound to thread axes, or
628 // - the block-realize has a non-constant-true predicate.
629 bool need_in_thread_reduction =
630 n_bound_reduction_loops < static_cast<int>(reduction_loops.size()) ||
631 !is_one(realize->predicate);
632 // Step 4. Create intermediate buffers, storing them in `ct_buffers` and
633 // `it_buffers`. Let the scope block allocate these new buffers.
634 Array<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()];
635 Array<Buffer> ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true);
636 new_buffers.insert(new_buffers.end(), ct_buffers.begin(), ct_buffers.end());
637 Optional<Array<Buffer>> it_buffers = NullOpt;
638 if (need_in_thread_reduction) {
639 it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false);
640 new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end());
641 }
642 // Step 5. Transform.
643 loop2new_stmt_[reduction_loops[0]] =
644 TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
645 reducer, combiner_rhs, reduction_loops);
646 // Step 6. Return an empty statement, because the transformation result will be inserted when
647 // returning to the first reduction-related loop.
648 return Stmt{nullptr};
649 }
650
651 private:
652 std::vector<const StmtNode*> statement_stack_;
653 std::vector<const ForNode*> loop_stack_;
654 std::vector<const BlockNode*> block_stack_;
655 std::unordered_map<const BlockNode*, Array<Buffer>> block2new_buffers_;
656 std::unordered_map<const ForNode*, Stmt> loop2new_stmt_;
657 Map<Var, Range> loop_range_map_;
658 arith::Analyzer analyzer_;
659};
660
661PrimFunc LowerCrossThreadReduction(PrimFunc f) {
662 // Only apply this pass to TIR that is not from TE schedules
663 if (!IsFromLegacyTESchedule(f)) {
664 PrimFuncNode* fptr = f.CopyOnWrite();
665 fptr->body = CrossThreadReductionTransformer()(f->body);
666 return f;
667 } else {
668 return f;
669 }
670}
671
672namespace transform {
673
674Pass LowerCrossThreadReduction() {
675 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
676 return LowerCrossThreadReduction(std::move(f));
677 };
678 return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {});
679}
680
681TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction")
682 .set_body_typed(LowerCrossThreadReduction);
683
684} // namespace transform
685
686} // namespace tir
687} // namespace tvm
688