1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*!
25 * \brief A helper class to create a new scope that contains decomposed init body
26 * and replaced old reduction block.
27 */
28class DecomposeReductionBlockReplacer : public StmtMutator {
29 public:
30 /*!
31 * \brief The open interface to users to call the helper class
32 * \param old_scope_root The original block scope before decomposition
33 * \param target_loop The loop we insert the decomposed init body before
34 * \param decompose_body The decomposed init body
35 * \param old_reduction_block The reduction block we want to decompose
36 * \return The new block scope and the updated reduction block
37 */
38 static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
39 Stmt decomposed_body, Block old_reduction_block) {
40 DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
41 std::move(old_reduction_block));
42 return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
43 replacer.new_reduction_block_);
44 }
45
46 private:
47 explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
48 Block old_reduction_block)
49 : target_loop_(std::move(target_loop)),
50 decomposed_body_(std::move(decomposed_body)),
51 old_reduction_block_(std::move(old_reduction_block)) {}
52
53 Stmt VisitStmt_(const ForNode* loop) final {
54 Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
55 if (loop == target_loop_.get()) {
56 return SeqStmt({decomposed_body_, mutated_stmt});
57 } else {
58 return mutated_stmt;
59 }
60 }
61
62 Stmt VisitStmt_(const BlockNode* block) final {
63 if (block == old_reduction_block_.get()) {
64 ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
65 p_new_block->name_hint = p_new_block->name_hint + "_update";
66 p_new_block->init = NullOpt;
67 // Add write regions back to read regions in update block.
68 Array<BufferRegion> new_reads;
69 std::unordered_set<const BufferNode*> read_bufs;
70 for (const BufferRegion& read_access : block->reads) {
71 read_bufs.insert(read_access->buffer.get());
72 }
73 for (const BufferRegion& write_access : block->writes) {
74 if (read_bufs.find(write_access->buffer.get()) == read_bufs.end()) {
75 new_reads.push_back(write_access);
76 }
77 }
78 for (const BufferRegion& read_access : block->reads) {
79 new_reads.push_back(read_access);
80 }
81 p_new_block->reads = new_reads;
82 new_reduction_block_ = Block(p_new_block);
83 return new_reduction_block_;
84 } else {
85 return StmtMutator::VisitStmt_(block);
86 }
87 }
88
89 Stmt VisitStmt_(const SeqStmtNode* seq) final {
90 Array<Stmt> new_stmts;
91 new_stmts.reserve(seq->seq.size());
92 for (const Stmt& old_stmt : seq->seq) {
93 new_stmts.push_back(VisitStmt(old_stmt));
94 }
95 return SeqStmt::Flatten(new_stmts);
96 }
97
98 private:
99 For target_loop_;
100 Stmt decomposed_body_;
101 Block old_reduction_block_;
102 Block new_reduction_block_;
103};
104
105class LoopHeightError : public ScheduleError {
106 public:
107 static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
108 const BlockRealizeNode* realize,
109 const Array<StmtSRef>& loops,
110 const StmtSRef& loop_sref) {
111 for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
112 // For each block var of type kCommReduce, check its binding
113 const IterVar& iter_var = block->iter_vars[i];
114 const PrimExpr& binding = realize->iter_values[i];
115 if (iter_var->iter_type != IterVarType::kCommReduce) {
116 continue;
117 }
118 for (const StmtSRef& higher_loop : loops) {
119 // Only check loops not lower than the target loop
120 if (higher_loop.same_as(loop_sref)) {
121 break;
122 }
123 // loop_var of a higher loop shouldn't contain loop var
124 const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
125 if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
126 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
127 throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
128 }
129 }
130 }
131 }
132
133 explicit LoopHeightError(IRModule mod, For loop, Block block)
134 : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
135
136 String FastErrorString() const final {
137 return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
138 "related to reduce block var";
139 }
140
141 String DetailRenderTemplate() const final {
142 std::ostringstream os;
143 os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
144 "related to reduce block var of block {1}";
145 return os.str();
146 }
147
148 IRModule mod() const final { return mod_; }
149 Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
150
151 IRModule mod_;
152 For loop_;
153 Block block_;
154};
155
156PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
157 if (is_one(pred)) return Bool(true);
158 PrimExpr new_pred = Bool(true);
159 auto f = [&](const VarNode* var) { return discarded_loops.count(var); };
160 arith::PVar<PrimExpr> lhs, rhs, rest;
161 for (;;) {
162 if ((rest && (lhs < rhs)).Match(pred)) {
163 if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
164 pred = rest.Eval();
165 } else if ((lhs < rhs).Match(pred)) {
166 if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
167 break;
168 } else {
169 ICHECK(false) << "Unexpected predicate for reduction block";
170 }
171 }
172 return new_pred;
173}
174
175StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
176 const StmtSRef& loop_sref) {
177 /*!
178 * Check
179 * - block is a reduction block
180 * - loop is not lower than all the loops related to reduce block var
181 * Mutate
182 * - generate loops related to data par block vars
183 * - generate corresponding init block and update block
184 */
185 // Condition Checks and Information Collection
186 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
187 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
188 // Get the outer loops from high to low
189 Array<StmtSRef> loops = GetLoops(block_sref);
190 const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
191 // Cond 0. Check loop_sref is an ancestor of block_sref
192 if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
193 throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
194 "decompose_reduction");
195 }
196 // Cond 1. Check block is reduction
197 StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
198 /*require_stage_pipeline=*/false);
199 CheckReductionBlock(self, block_sref, scope_root_sref);
200 // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
201 LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
202 // IR Manipulation
203 ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
204 ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
205 init_block->name_hint = block->name_hint + "_init";
206 init_block->annotations = block->annotations;
207 init_realize->iter_values = {};
208 init_realize->block = Block(init_block);
209 // Step 1. Create new block vars and their bindings
210 // Maps an old block var to the new corresponding block var
211 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
212 block_var_map.reserve(block->iter_vars.size());
213 for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
214 const IterVar& iter_var = block->iter_vars[i];
215 const PrimExpr& binding = realize->iter_values[i];
216 // Only process data parallel block vars
217 if (iter_var->iter_type != IterVarType::kDataPar) {
218 continue;
219 }
220 // Create a new block var
221 IterVar new_iter_var(/*dom=*/iter_var->dom,
222 /*var=*/iter_var->var.copy_with_suffix(""),
223 /*iter_type=*/iter_var->iter_type,
224 /*thread_tag=*/iter_var->thread_tag);
225 // Add a block var and its binding
226 init_block->iter_vars.push_back(new_iter_var);
227 init_realize->iter_values.push_back(binding);
228 // Add a mapping from old block vars to new block vars
229 block_var_map[iter_var->var] = new_iter_var->var;
230 }
231 // Step 2. After copying block vars, substitute them in init block
232 init_block->body = Substitute(block->init.value(), block_var_map);
233 for (const BufferRegion& write : block->writes) {
234 init_block->writes.push_back(
235 BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
236 }
237 // Step 3. Scan loops not higher than the specified loop above the reduction block.
238 // If the loop is used in the init block binding, then it is chosen.
239 // Otherwise, it is discarded.
240 std::unordered_set<const VarNode*> discarded_loops;
241 std::vector<int> chosen_loops;
242 for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
243 const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
244 bool discarded = true;
245 for (const PrimExpr& expr : init_realize->iter_values) {
246 if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
247 continue;
248 }
249 // The loop is related to init block bindings;
250 chosen_loops.push_back(i);
251 discarded = false;
252 break;
253 }
254 if (discarded) discarded_loops.insert(loop_var);
255 // Only scan loops not higher than the given loop
256 if (loops[i].same_as(loop_sref)) {
257 break;
258 }
259 }
260 // Step 4. After scanning loops, make a new predicate in the init block realize
261 // We discard predicate that is related to discarded loops
262 init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops);
263 // Step 5. Create new loops above init block
264 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
265 Stmt body = BlockRealize(init_realize);
266 for (int i : chosen_loops) {
267 const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]);
268 // Create a new equivalent to the chosen loop
269 Var old_loop_var = old_loop->loop_var;
270 Var new_loop_var = old_loop_var.copy_with_suffix("_init");
271 loop_var_map[old_loop_var] = new_loop_var;
272 body = For(/*loop_var=*/new_loop_var,
273 /*min=*/old_loop->min,
274 /*extent=*/old_loop->extent,
275 /*kind=*/old_loop->kind,
276 /*body=*/body);
277 }
278 body = Substitute(body, loop_var_map);
279 // Step 6. Mutate IR
280 const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref);
281 auto [new_scope_root, new_reduction_block] = DecomposeReductionBlockReplacer::Replace(
282 GetRef<Block>(old_scope_root), GetRef<For>(loop), body, GetRef<Block>(block));
283 self->Replace(scope_root_sref, new_scope_root,
284 {{GetRef<Block>(old_scope_root), new_scope_root},
285 {GetRef<Block>(block), new_reduction_block}});
286 self->UpdateScopeBlockInfo(new_scope_root);
287 return self->stmt2ref.at(init_block.get());
288}
289
290/******** Commutative Reducer ********/
291
292/*!
293 * \brief A structure used for registering new commutative reducers, and store all the registered
294 * reducers. The reducers are preserved in a list, in the form of "reducer-getter function". When
295 * invoking a reducer-getter function with a specific datatype, the reducer-getter will return the
296 * CommReducer of the corresponding reduction pattern and the specific datatype
297 */
298struct ReducerRegistry {
299 ReducerRegistry()
300 : reducer_getters{
301 CreateReducerGetter(
302 /*n_buffers=*/1,
303 [](const Array<Var>& x, const Array<Var>& y) {
304 return Array<PrimExpr>{x[0] + y[0]};
305 },
306 [](const Array<PrimExpr>& values) {
307 return Array<PrimExpr>{make_const(values[0]->dtype, 0)};
308 }),
309 CreateReducerGetter(
310 /*n_buffers=*/1,
311 [](const Array<Var>& x, const Array<Var>& y) {
312 return Array<PrimExpr>{x[0] * y[0]};
313 },
314 [](const Array<PrimExpr>& values) {
315 return Array<PrimExpr>{make_const(values[0]->dtype, 1)};
316 }),
317 CreateReducerGetter(
318 /*n_buffers=*/1,
319 [](const Array<Var>& x, const Array<Var>& y) {
320 return Array<PrimExpr>{min(x[0], y[0])};
321 },
322 [](const Array<PrimExpr>& values) {
323 return Array<PrimExpr>{max_value(values[0]->dtype)};
324 }),
325 CreateReducerGetter(
326 /*n_buffers=*/1,
327 [](const Array<Var>& x, const Array<Var>& y) {
328 return Array<PrimExpr>{max(x[0], y[0])};
329 },
330 [](const Array<PrimExpr>& values) {
331 return Array<PrimExpr>{min_value(values[0]->dtype)};
332 }),
333 CreateReducerGetter(
334 /*n_buffers=*/2,
335 [](const Array<Var>& x, const Array<Var>& y) {
336 return Array<PrimExpr>{x[0] + y[0], x[1] + y[1]};
337 },
338 [](const Array<PrimExpr>& values) {
339 return Array<PrimExpr>{make_const(values[0]->dtype, 0),
340 make_const(values[1]->dtype, 0)};
341 }),
342 CreateReducerGetter(
343 /*n_buffers=*/2,
344 [](const Array<Var>& x, const Array<Var>& y) {
345 PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
346 PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
347 return Array<PrimExpr>{idx, val};
348 },
349 [](const Array<PrimExpr>& values) {
350 return Array<PrimExpr>{make_const(values[0]->dtype, -1),
351 min_value(values[1]->dtype)};
352 }),
353 CreateReducerGetter(
354 /*n_buffers=*/2,
355 [](const Array<Var>& x, const Array<Var>& y) {
356 PrimExpr idx =
357 Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))),
358 x[0], y[0]);
359 PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]);
360 return Array<PrimExpr>{idx, val};
361 },
362 [](const Array<PrimExpr>& values) {
363 return Array<PrimExpr>{make_const(values[0]->dtype, -1),
364 min_value(values[1]->dtype)};
365 }),
366 CreateReducerGetter(
367 /*n_buffers=*/2,
368 [](const Array<Var>& x, const Array<Var>& y) {
369 PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
370 PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
371 return Array<PrimExpr>{idx, val};
372 },
373 [](const Array<PrimExpr>& values) {
374 return Array<PrimExpr>{make_const(values[0]->dtype, -1),
375 max_value(values[1]->dtype)};
376 }),
377 CreateReducerGetter(
378 /*n_buffers=*/2,
379 [](const Array<Var>& x, const Array<Var>& y) {
380 PrimExpr idx = Select(
381 Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]);
382 PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]);
383 return Array<PrimExpr>{idx, val};
384 },
385 [](const Array<PrimExpr>& values) {
386 return Array<PrimExpr>{make_const(values[0]->dtype, -1),
387 max_value(values[1]->dtype)};
388 })} {}
389
390 static void RegisterReducer(
391 int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)> combiner_getter,
392 TypedPackedFunc<Array<PrimExpr>(Array<PrimExpr>)> identity_getter) {
393 ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter(
394 n_buffers, std::move(combiner_getter), std::move(identity_getter)));
395 }
396
397 static TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)> CreateReducerGetter(
398 int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)> combiner_getter,
399 TypedPackedFunc<Array<PrimExpr>(Array<PrimExpr>)> identity_getter) {
400 return [n_buffers, //
401 combiner_getter = std::move(combiner_getter), //
402 identity_getter = std::move(identity_getter) //
403 ](Array<PrimExpr> values) -> Optional<CommReducer> {
404 if (static_cast<int>(values.size()) != n_buffers) {
405 return NullOpt;
406 }
407 Array<Var> lhs;
408 Array<Var> rhs;
409 for (int i = 0; i < n_buffers; ++i) {
410 lhs.push_back(Var("x" + std::to_string(i), values[i]->dtype));
411 rhs.push_back(Var("y" + std::to_string(i), values[i]->dtype));
412 }
413 return CommReducer(lhs, rhs, combiner_getter(lhs, rhs), identity_getter(values));
414 };
415 }
416
417 static ReducerRegistry* Global() {
418 static ReducerRegistry instance;
419 return &instance;
420 }
421
422 std::vector<TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>> reducer_getters;
423};
424
425std::vector<TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>> GetReducerGetters() {
426 return ReducerRegistry::Global()->reducer_getters;
427}
428
429class NotSerialLoopKindError : public ScheduleError {
430 public:
431 explicit NotSerialLoopKindError(IRModule mod, For loop)
432 : mod_(std::move(mod)), loop_(std::move(loop)) {}
433
434 String FastErrorString() const final {
435 return "ScheduleError: The input loop of rfactor is required to be `kSerial`";
436 }
437
438 String DetailRenderTemplate() const final {
439 String str_kind = ForKind2String(loop_->kind);
440 std::ostringstream os;
441 os << "ScheduleError: The input loop {0} of rfactor is required to be `Serial`. However, the "
442 "kind of {0} is `"
443 << str_kind << "`";
444 return os.str();
445 }
446
447 IRModule mod() const final { return mod_; }
448 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
449
450 IRModule mod_;
451 For loop_;
452};
453
454class FactorAxisOutOfRangeError : public ScheduleError {
455 public:
456 explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis)
457 : mod_(std::move(mod)), buffer_(std::move(buffer)), factor_axis_(factor_axis) {}
458
459 String FastErrorString() const final {
460 return "ScheduleError: The input `factor_axis` is out of range. It is required to be in range "
461 "[-(ndim + 1), ndim] where `ndim` is the number of dimensions of the write buffer";
462 }
463
464 String DetailRenderTemplate() const final {
465 std::ostringstream os;
466 int ndim = static_cast<int>(buffer_->shape.size());
467 os << "The write buffer " << buffer_->name << " has " << ndim
468 << " dimension(s), so `factor_axis` is required to be in [" << -(ndim + 1) << ", " << ndim
469 << "] for rfactor. However, the input `factor_axis` is " << factor_axis_
470 << ", which is out of the expected range";
471 return os.str();
472 }
473
474 IRModule mod() const final { return mod_; }
475 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
476
477 static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int factor_axis) {
478 int ndim = static_cast<int>(buffer->shape.size());
479 if (factor_axis < -(ndim + 1) || factor_axis > ndim) {
480 throw FactorAxisOutOfRangeError(mod, buffer, factor_axis);
481 }
482 // If factor_axis is negative, convert it to a non-negative one.
483 if (factor_axis < 0) {
484 factor_axis += ndim + 1;
485 }
486 return factor_axis;
487 }
488
489 IRModule mod_;
490 Buffer buffer_;
491 int factor_axis_;
492};
493
494class LoopPropertyError : public ScheduleError {
495 public:
496 enum ErrorType {
497 kDataParIterTouchRFactorLoop = 0,
498 kLoopTouchedByBothKindsOfBlockIters = 1,
499 kNotFirstChildBlockOfOutermostLoop = 2,
500 kUnboundLoopUnderReductionLoop = 3
501 };
502
503 explicit LoopPropertyError(IRModule mod, For loop, ErrorType error_type)
504 : mod_(std::move(mod)), loop_(std::move(loop)), error_type_(error_type) {}
505
506 String FastErrorString() const final {
507 switch (error_type_) {
508 case kDataParIterTouchRFactorLoop:
509 return "ScheduleError: The loop to be applied rfactor is required not to be touched by any "
510 "data parallel block iter of the block";
511 case kLoopTouchedByBothKindsOfBlockIters:
512 return "ScheduleError: The loops outside of the reduction block are required not to be "
513 "touched by both data parallel block iters and reduction block iters";
514 case kNotFirstChildBlockOfOutermostLoop:
515 return "ScheduleError: The reduction block should be the first child block of the "
516 "outermost loop outside of it";
517 case kUnboundLoopUnderReductionLoop:
518 return "ScheduleError: A loop who has extent greater than one and is not bound to any "
519 "block iter should not appear under a reduction loop";
520 }
521 ICHECK(false) << "Unreachable";
522 throw;
523 }
524
525 String DetailRenderTemplate() const final {
526 switch (error_type_) {
527 case kDataParIterTouchRFactorLoop:
528 return "The loop to be applied rfactor is {0}, which is required not to be touched by any "
529 "data parallel block iter of the block below. However, some of the block's data "
530 "parallel block iters touch this loop";
531 case kLoopTouchedByBothKindsOfBlockIters:
532 return "It is not allowed that the loop {0} is touched by both some data parallel block "
533 "iters and some reduction block iters";
534 case kNotFirstChildBlockOfOutermostLoop:
535 return "The first child block of the outermost loop {0} is not the reduction block.";
536 case kUnboundLoopUnderReductionLoop:
537 return "The loop {0} has extent greater than one, and is not bound to any block iter. "
538 "Therefore it shouldn't appear under a reduction loop";
539 }
540 ICHECK(false) << "Unreachable";
541 throw;
542 }
543
544 IRModule mod() const final { return mod_; }
545 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
546
547 static void CheckLoopProperty(const ScheduleState& self, const Array<For>& loops,
548 const ForNode* rf_loop, const Block& block,
549 const std::unordered_set<const VarNode*>& data_par_loop_vars,
550 const std::unordered_set<const VarNode*>& reduce_loop_vars) {
551 Array<BlockRealize> children_of_outermost_loop =
552 GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get()));
553 if (!children_of_outermost_loop[0]->block.same_as(block)) {
554 throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop);
555 }
556
557 bool meet_reduction_loop = false;
558 for (const For& loop : loops) {
559 bool data_par_touched = data_par_loop_vars.count(loop->loop_var.get());
560 bool reduction_touched = reduce_loop_vars.count(loop->loop_var.get());
561
562 if (data_par_touched && reduction_touched) {
563 throw LoopPropertyError(self->mod, loop, kLoopTouchedByBothKindsOfBlockIters);
564 } else if (data_par_touched) {
565 if (loop.get() == rf_loop) {
566 throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop);
567 }
568 continue;
569 } else if (reduction_touched) {
570 if (!meet_reduction_loop) {
571 CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get()));
572 meet_reduction_loop = true;
573 }
574 continue;
575 } else if (meet_reduction_loop && !is_one(loop->extent)) {
576 throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop);
577 }
578 }
579 }
580
581 IRModule mod_;
582 For loop_;
583 ErrorType error_type_;
584};
585
586/*!
587 * \brief For each loop in the given array of loop, associate its loop var with the loop itself
588 * using a mapping
589 * \param loops The loops to be analyzed
590 * \return A mapping from loops to their corresponding loop vars
591 */
592std::unordered_map<const VarNode*, For> GetLoopVar2LoopMap(const Array<For>& loops) {
593 std::unordered_map<const VarNode*, For> loop_vars2loop;
594 loop_vars2loop.reserve(loops.size());
595 for (const For& loop : loops) {
596 loop_vars2loop[loop->loop_var.get()] = loop;
597 }
598 return loop_vars2loop;
599}
600
601/*!
602 * \brief Create the intermediate rfactor buffers, which the rfactor block writes to and the
603 * write-back block reads from
604 * \param buf_stores The BufferStores of the original block, where the rfactor buffers will be
605 * created from
606 * \param factor_axis The `factor_axis` parameter of rfactor
607 * \param rf_loop The rfactor loop
608 * \return The new created intermediate rfactor buffer
609 */
610Array<Buffer> CreateRFactorBuffers(const Array<BufferStore>& buf_stores, int factor_axis,
611 const ForNode* rf_loop) {
612 Array<Buffer> rf_buffers;
613 rf_buffers.reserve(buf_stores.size());
614 for (const BufferStore& buf_store : buf_stores) {
615 Buffer buffer = buf_store->buffer;
616 Array<PrimExpr> rf_shape = buffer->shape;
617 rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent);
618
619 ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
620 n->shape = rf_shape;
621 n->name = buffer->name + ".rf";
622 n->data = buffer->data.copy_with_suffix(".rf");
623 rf_buffers.push_back(Buffer(n));
624 }
625 return rf_buffers;
626}
627
628/*!
629 * \brief The base class of the rfactor/write-back block creator, which creates the blocks in four
630 * steps:
631 * 1) Create the new block iters and the their iter bindings
632 * 2) Create the body and init of the new block
633 * 3) Create the read/write regions of the new block
634 * 4) Create the new block and the new block-realize
635 */
636class BaseBlockCreator {
637 public:
638 explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop,
639 Array<BufferStore> old_reduction_updates, CommReducer reducer,
640 Array<Buffer> rf_buffers, bool is_rf_block)
641 : old_block_realize_(std::move(old_block_realize)),
642 rf_loop_(std::move(rf_loop)),
643 old_reduction_updates_(std::move(old_reduction_updates)),
644 reducer_(std::move(reducer)),
645 rf_buffers_(std::move(rf_buffers)),
646 n_buffers_(static_cast<int>(rf_buffers_.size())),
647 is_rf_block_(is_rf_block) {
648 n_block_iters_ = static_cast<int>(old_block_realize_->iter_values.size());
649 update_buffers_.reserve(n_buffers_);
650 update_indices_.reserve(n_buffers_);
651 update_lhs_.reserve(n_buffers_);
652 update_rhs_.reserve(n_buffers_);
653 }
654
655 void CreateBlock() {
656 CreateAdditionalIter();
657 for (int i = 0; i < n_block_iters_; ++i) {
658 CreateNormalIters(i);
659 }
660 bool has_reduce_iter = false;
661 for (const IterVar& iter_var : iter_vars_) {
662 if (iter_var->iter_type == IterVarType::kCommReduce) {
663 has_reduce_iter = true;
664 break;
665 }
666 }
667
668 // The pre-processing finds out the buffers written in the block, the indices of the buffer
669 // accesses, and the reduction LHS and RHS of the stored values.
670 PreProcess();
671 Stmt block_body = Substitute(CreateBlockBody(has_reduce_iter), var_map_);
672 Optional<Stmt> block_init = CreateBlockInit(has_reduce_iter);
673 if (block_init.defined()) {
674 block_init = Substitute(block_init.value(), var_map_);
675 }
676 CreateReadWriteRegions();
677
678 String new_block_name = old_block_realize_->block->name_hint;
679 PrimExpr predicate = const_true();
680 if (is_rf_block_) {
681 new_block_name = new_block_name + "_rf";
682 predicate = old_block_realize_->predicate;
683 }
684 new_block_ = Block(
685 /*iter_vars=*/iter_vars_,
686 /*reads=*/read_regions_,
687 /*writes=*/write_regions_,
688 /*name_hint=*/new_block_name,
689 /*body=*/std::move(block_body),
690 /*init=*/std::move(block_init),
691 /*alloc_buffers=*/{},
692 /*match_buffers=*/{},
693 /*annotations=*/old_block_realize_->block->annotations);
694 new_block_realize_ = BlockRealize(iter_values_, predicate, new_block_);
695 }
696
697 private:
698 virtual void CreateAdditionalIter() = 0;
699 virtual void CreateNormalIters(int idx) = 0;
700 virtual void PreProcess() = 0;
701 virtual void CreateReadWriteRegions() = 0;
702
703 Stmt CreateBlockBody(bool has_reduce_iter) {
704 Array<Stmt> buf_stores;
705 buf_stores.reserve(n_buffers_);
706
707 // Case 1. If the block has no reduction iterator, we just store the RHS values into the
708 // buffers.
709 if (!has_reduce_iter) {
710 for (int i = 0; i < n_buffers_; ++i) {
711 buf_stores.push_back(BufferStore(update_buffers_[i], update_rhs_[i], update_indices_[i]));
712 }
713 return n_buffers_ > 1 ? SeqStmt(buf_stores) : buf_stores[0];
714 }
715
716 // Case 2. If the reduction is for single buffer, the block body is a single BufferStore.
717 Array<PrimExpr> stored_values = (*reducer_.get())(update_lhs_, update_rhs_);
718 if (n_buffers_ == 1) {
719 return BufferStore(update_buffers_[0], stored_values[0], update_indices_[0]);
720 }
721
722 // Case 3. In case the reduction is for multiple buffers, we should create the reduction with
723 // LetStmt so that the reduction execution generates correct results.
724 Array<Var> let_vars;
725 let_vars.reserve(n_buffers_);
726 for (int i = 0; i < n_buffers_; ++i) {
727 Var var("v_" + update_buffers_[i]->name, PrimType(stored_values[i]->dtype));
728 let_vars.push_back(var);
729 buf_stores.push_back(BufferStore(update_buffers_[i], var, update_indices_[i]));
730 }
731 Stmt body = SeqStmt(buf_stores);
732 for (int i = n_buffers_ - 1; i >= 0; --i) {
733 body = LetStmt(let_vars[i], stored_values[i], std::move(body));
734 }
735 return body;
736 }
737
738 Optional<Stmt> CreateBlockInit(bool has_reduce_iter) {
739 if (!has_reduce_iter) {
740 return NullOpt;
741 }
742
743 Array<Stmt> inits;
744 inits.reserve(n_buffers_);
745 for (int i = 0; i < n_buffers_; ++i) {
746 inits.push_back(
747 BufferStore(update_buffers_[i], reducer_->identity_element[i], update_indices_[i]));
748 }
749 return n_buffers_ > 1 ? SeqStmt(inits) : inits[0];
750 }
751
752 public:
753 /*! \brief The new created block */
754 Block new_block_;
755 /*! \brief The new created block-realize */
756 BlockRealize new_block_realize_;
757 /*! \brief The indices used to access the intermediate rfactor buffer */
758 Array<PrimExpr> rf_buf_access_indices_;
759
760 protected:
761 /*! \brief The old block-realize */
762 BlockRealize old_block_realize_;
763 /*! \brief The number of block iters in the old block */
764 int n_block_iters_;
765 /*! \brief The rfactor loop */
766 For rf_loop_;
767 /*! \brief The update BufferStores of the old block */
768 Array<BufferStore> old_reduction_updates_;
769 /*! \brief The matched commutative reducer */
770 CommReducer reducer_;
771 /*! \brief The intermediate rfactor buffers */
772 Array<Buffer> rf_buffers_;
773 /*! \brief The number of rfactor buffers. */
774 const int n_buffers_;
775 /*!
776 * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced
777 * by the expressions in future substitution for the two blocks
778 */
779 Map<Var, PrimExpr> var_map_;
780
781 /*! \brief Whether we are creating the rfactor block or the write-back block */
782 bool is_rf_block_;
783 /*! \brief The new block iters of the new created block */
784 std::vector<IterVar> iter_vars_;
785 /*! \brief The new block iter bindings of the new created block-realize */
786 std::vector<PrimExpr> iter_values_;
787 /*! \brief The buffers updated in this block */
788 Array<Buffer> update_buffers_;
789 /*! \brief The indices of the buffers updated in this block, respectively */
790 Array<Array<PrimExpr>> update_indices_;
791 /*! \brief The LHS values of the reduction in this block */
792 Array<PrimExpr> update_lhs_;
793 /*! \brief THe RHS values of the reduction in this block */
794 Array<PrimExpr> update_rhs_;
795 /*! \brief The read regions of the new created block */
796 Array<BufferRegion> read_regions_;
797 /*! \brief The write regions of the new created block */
798 Array<BufferRegion> write_regions_;
799};
800
801/*!
802 * \brief The derived class of the rfactor block creator, which implements all virtual methods in
803 * the base creator
804 * \details Start constructing the rfactor block. The main difficulty to construct the rfactor block
805 * is to create its block iters. So here we introduce the algorithm to create the block iters.
806 * 1. Create a block iter for the rfactor loop. The block binding of this iter is the loop var, and
807 * the block iter is data parallel.
808 * 2. For all the old block's block iters, there are two cases:
809 * (a) If it is data parallel block iter, or a reduction block iter which doesn't touch the
810 * rfactor loop, we keep it and its block binding in the rfactor block.
811 * (b) Otherwise it is a reduction block iter which touches the rfactor loop. In this case, we
812 * "split" the block iter into one or more new block iters and do not keep the old block
813 * var. More specifically, we create a new reduction block iter for each loop var that
814 * appears in the reduction block iter's binding (except for the rfactor loop), and the
815 * binding of the new block iter is exactly the loop var. (Note that for each loop var, we
816 * create at most one block iter, even if there are multiple old block iters which touch
817 * both this loop and the rfactor loop).
818 * Then we substitute the appearances of the old block iter with the new created block
819 * iters by recording two mappings: one maps loops vars to new created block iters which
820 * is used for binding substitution, and another maps old block iters to new expressions
821 * which is used for substitutions of the old block iters.
822 */
823class RFactorBlockCreator : public BaseBlockCreator {
824 public:
825 explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop,
826 Array<BufferStore> old_reduction_updates, CommReducer reducer,
827 Array<Buffer> rf_buffers,
828 std::unordered_map<const VarNode*, For> loop_vars2loop,
829 int factor_axis, Array<PrimExpr> combiner_rhs)
830 : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop),
831 std::move(old_reduction_updates), std::move(reducer),
832 std::move(rf_buffers), true),
833 loop_vars2loop_(std::move(loop_vars2loop)),
834 factor_axis_(factor_axis),
835 combiner_rhs_(std::move(combiner_rhs)) {}
836
837 private:
838 void CreateAdditionalIter() final {
839 // Create a new data parallel block iter for the rfactor loop.
840 additional_iter_ =
841 IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar);
842 loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var;
843 iter_vars_.push_back(additional_iter_);
844 iter_values_.push_back(rf_loop_->loop_var);
845 }
846
847 void CreateNormalIters(int idx) final {
848 IterVar old_iter = old_block_realize_->block->iter_vars[idx];
849 PrimExpr old_binding = old_block_realize_->iter_values[idx];
850 if (old_iter->iter_type == IterVarType::kDataPar ||
851 !UsesVar(old_binding,
852 [v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) {
853 // The old block iter is either a data parallel block iter, or a reduction block iter that
854 // doesn't touch the rfactor loop. In this case reuse the old reduction block iter and its
855 // corresponding binding.
856 iter_vars_.push_back(old_iter);
857 iter_values_.push_back(old_binding);
858 return;
859 }
860 ICHECK(old_iter->iter_type == kCommReduce);
861 // This block iter is a reduction block iter that touches the rfactor loop. So next we try to
862 // create a new block iter for all loop vars that appear in the old binding.
863 Array<Var> vars_in_old_binding = UndefinedVars(old_binding);
864 for (const Var& var : vars_in_old_binding) {
865 auto it = loop_vars2loop_.find(var.get());
866 if (it == loop_vars2loop_.end()) {
867 // `var` is not a loop var. So skip.
868 continue;
869 }
870 const For& loop = it->second;
871 if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) {
872 // We haven't created the new block iter for `var`. So here we create it, append it
873 // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively.
874 IterVar new_iter_var =
875 IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce);
876 loop_var2block_binding_[var.get()] = new_iter_var->var;
877 iter_vars_.push_back(new_iter_var);
878 iter_values_.push_back(var);
879 }
880 }
881 // Substitute the original binding with new block iters. Store the result expression
882 // in `rf_var_map` for future substitution.
883 var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_));
884 }
885
886 void PreProcess() final {
887 // The accessed indices for all reduction buffers are the same.
888 rf_buf_access_indices_ = old_reduction_updates_[0]->indices;
889 rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_,
890 additional_iter_->var);
891 for (int i = 0; i < n_buffers_; ++i) {
892 update_buffers_.push_back(rf_buffers_[i]);
893 update_indices_.push_back(rf_buf_access_indices_);
894 update_lhs_.push_back(BufferLoad(update_buffers_[i], rf_buf_access_indices_));
895 update_rhs_.push_back(combiner_rhs_[i]);
896 }
897 }
898
899 void CreateReadWriteRegions() final {
900 Map<Buffer, Buffer> buffer_map;
901 for (int i = 0; i < n_buffers_; ++i) {
902 buffer_map.Set(old_reduction_updates_[i]->buffer, rf_buffers_[i]);
903 }
904 const Block& old_block = old_block_realize_->block;
905 read_regions_.reserve(old_block->reads.size());
906 for (const BufferRegion& read_region : old_block->reads) {
907 read_regions_.push_back(
908 BufferRegion(read_region->buffer, Substitute(read_region->region, var_map_)));
909 }
910 write_regions_.reserve(old_block->writes.size());
911 for (const BufferRegion& write_region : old_block->writes) {
912 Array<Range> region = write_region->region;
913 region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, 1));
914 Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer);
915 ICHECK(rf_buffer.defined());
916 write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_)));
917 }
918 }
919
920 public:
921 /*! \brief The generated additional block iter in rfactor block for the rfactor loop */
922 IterVar additional_iter_;
923
924 private:
925 /*!
926 * \brief A mapping which maps a loop var to its corresponding For loop for all the reduction
927 * block's outer loops
928 */
929 std::unordered_map<const VarNode*, For> loop_vars2loop_;
930 /*! \brief The factor_axis specified for rfactor */
931 int factor_axis_;
932 /*! \brief The RHS values of the reduction in the old block */
933 Array<PrimExpr> combiner_rhs_;
934 /*!
935 * \brief A mapping which maps loop vars to new created block iters. This map is used to
936 * substitute the loop vars which appear in the bindings of some old block iters with the new
937 * created block iters
938 */
939 std::unordered_map<const VarNode*, PrimExpr> loop_var2block_binding_;
940};
941
942/*!
943 * \brief The derived class of the write-back block creator, which implements all virtual methods in
944 * the base creator
945 */
946class WriteBackBlockCreator : public BaseBlockCreator {
947 public:
948 explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop,
949 Array<BufferStore> old_reduction_updates, CommReducer reducer,
950 Array<Buffer> rf_buffers, IterVar rf_additional_iter,
951 Array<PrimExpr> combiner_lhs,
952 Array<PrimExpr> rf_buf_access_indices)
953 : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop),
954 std::move(old_reduction_updates), std::move(reducer),
955 std::move(rf_buffers), false),
956 rf_additional_iter_(std::move(rf_additional_iter)),
957 combiner_lhs_(std::move(combiner_lhs)) {
958 iter_vars_.reserve(n_block_iters_);
959 iter_values_.reserve(n_block_iters_);
960 rf_buf_access_indices_ = std::move(rf_buf_access_indices);
961 }
962
963 private:
964 void CreateAdditionalIter() final {
965 // Create a new reduction block iter for the rfactor loop.
966 IterVar wb_new_block_iter =
967 IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce);
968 iter_vars_.push_back(wb_new_block_iter);
969 iter_values_.push_back(rf_loop_->loop_var);
970 var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var);
971 }
972
973 void CreateNormalIters(int idx) final {
974 IterVar old_block_iter = old_block_realize_->block->iter_vars[idx];
975 if (old_block_iter->iter_type == IterVarType::kDataPar) {
976 iter_vars_.emplace_back(old_block_iter->dom, old_block_iter->var.copy_with_suffix(""),
977 kDataPar);
978 iter_values_.push_back(old_block_realize_->iter_values[idx]);
979 var_map_.Set(old_block_iter->var, iter_vars_.back());
980 }
981 }
982
983 void PreProcess() final {
984 for (int i = 0; i < n_buffers_; ++i) {
985 PrimExpr rhs = BufferLoad(rf_buffers_[i], rf_buf_access_indices_);
986 update_buffers_.push_back(old_reduction_updates_[i]->buffer);
987 update_indices_.push_back(old_reduction_updates_[i]->indices);
988 update_lhs_.push_back(Substitute(combiner_lhs_[i], var_map_));
989 update_rhs_.push_back(Substitute(std::move(rhs), var_map_));
990 }
991 }
992
993 void CreateReadWriteRegions() final {
994 CreateRegion(update_rhs_, true);
995 CreateRegion(update_lhs_, false);
996 }
997
998 void CreateRegion(const Array<PrimExpr>& buf_loads, bool is_read) {
999 Array<BufferRegion>& buf_regions = is_read ? read_regions_ : write_regions_;
1000 for (const PrimExpr& expr : buf_loads) {
1001 const auto* buf_load = expr.as<BufferLoadNode>();
1002 ICHECK(buf_load != nullptr);
1003 Array<Range> region;
1004 region.reserve(buf_load->indices.size());
1005 for (const PrimExpr& index : buf_load->indices) {
1006 region.push_back(Range::FromMinExtent(index, 1));
1007 }
1008 buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region)));
1009 }
1010 }
1011
1012 private:
1013 /*! \brief The new created additional block iter of the rfactor block */
1014 IterVar rf_additional_iter_;
1015 /*! \brief The LHS values of the reduction in the old block */
1016 Array<PrimExpr> combiner_lhs_;
1017};
1018
1019/*!
1020 * \brief Create new outer loops for the rfactor block, meanwhile update the rfactor block's iter
1021 * bindings to use the new created loop vars
1022 * \param rf_block_realize The BlockRealize of the rfactor block
1023 * \param loops The loops to be wrapped over the rfactor block
1024 * \return A Stmt which is the wrapping result
1025 */
1026Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array<For>& loops) {
1027 int n_loops = static_cast<int>(loops.size());
1028
1029 // Step 1. Create new loop vars.
1030 Array<For> new_loops;
1031 std::unordered_map<const VarNode*, PrimExpr> new_loop_var_map;
1032 new_loops.reserve(n_loops);
1033 new_loop_var_map.reserve(n_loops);
1034 for (const For& old_loop : loops) {
1035 Var new_loop_var = old_loop->loop_var.copy_with_suffix("");
1036 new_loop_var_map[old_loop->loop_var.get()] = new_loop_var;
1037 }
1038
1039 // Step 2. Update the iter bindings and predicate of the rfactor block.
1040 Array<PrimExpr> new_bindings;
1041 new_bindings.reserve(rf_block_realize->iter_values.size());
1042 for (const PrimExpr& old_binding : rf_block_realize->iter_values) {
1043 new_bindings.push_back(Substitute(old_binding, new_loop_var_map));
1044 }
1045 {
1046 BlockRealizeNode* p_rf_block_realize = rf_block_realize.CopyOnWrite();
1047 p_rf_block_realize->iter_values = new_bindings;
1048 p_rf_block_realize->predicate = Substitute(rf_block_realize->predicate, new_loop_var_map);
1049 }
1050
1051 // Step 3. Wrap `rf_block_realize` with outer loops.
1052 Stmt rf_body = rf_block_realize;
1053 for (int i = n_loops - 1; i >= 0; --i) {
1054 ObjectPtr<ForNode> p_loop = make_object<ForNode>(*loops[i].get());
1055 p_loop->loop_var = Downcast<Var>(new_loop_var_map[loops[i]->loop_var.get()]);
1056 p_loop->body = rf_body;
1057 rf_body = For(std::move(p_loop));
1058 }
1059
1060 return rf_body;
1061}
1062
1063class BlockReplacer : public StmtMutator {
1064 public:
1065 /*!
1066 * \brief The replace takes the old scope root block as input, and does four things:
1067 * 1) replace the reduction block with the write-back block,
1068 * 2) remove loops outside the write-back block that are touched by reduction block iters, except
1069 * for the rfactor loop
1070 * 3) combine the rfactor block (wrapped with outer loops) and the transformed outermost loop
1071 * into a SeqStmt, and
1072 * 4) insert the rfactor buffer into the scope root block's `alloc_buffers`
1073 * After transformation, the function returns the new scope root block
1074 * \param scope_root_block The old scope root block
1075 * \param rf_body The rfactor block, which is already wrapped with outer loops
1076 * \param outermost_loop The loop that is outermost among all loops outside the reduction block
1077 * \param wb_block_realize The new created BlockRealize of the write-back block
1078 * \param old_block_realize The BlockRealize of the reduction block
1079 * \param rf_loop The rfactor loop, which should be kept outside the write-back block
1080 * \param reduce_loop_vars The loops that are touched by reduction block iters, used to remove
1081 * loops outside the write-back block
1082 * \param loop_vars2loop The mapping from loop vars to loops that are outside the reduction block,
1083 * which is used to reduce redundant recursive visits
1084 * \param rf_buffer The rfactor buffer to be added into the scope root's `alloc_buffers`
1085 * \return The transformed new scope root block
1086 */
1087 static Block Replace(Block scope_root_block, Stmt rf_body, For outermost_loop,
1088 BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop,
1089 std::unordered_set<const VarNode*> reduce_loop_vars,
1090 std::unordered_map<const VarNode*, For> loop_vars2loop,
1091 const Array<Buffer>& rf_buffers) {
1092 BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop),
1093 std::move(wb_block_realize), std::move(old_block_realize),
1094 std::move(rf_loop), std::move(reduce_loop_vars),
1095 std::move(loop_vars2loop));
1096 Block new_scope_root = Downcast<Block>(replacer(std::move(scope_root_block)));
1097 BlockNode* p = new_scope_root.CopyOnWrite();
1098 for (const Buffer& rf_buffer : rf_buffers) {
1099 p->alloc_buffers.push_back(rf_buffer);
1100 }
1101 return new_scope_root;
1102 }
1103
1104 private:
1105 explicit BlockReplacer(Stmt rf_body, For outermost_loop, BlockRealize wb_block_realize,
1106 BlockRealize old_block_realize, For rf_loop,
1107 std::unordered_set<const VarNode*> reduce_loop_vars,
1108 std::unordered_map<const VarNode*, For> loop_vars2loop)
1109 : rf_body_(std::move(rf_body)),
1110 outermost_loop_(std::move(outermost_loop)),
1111 wb_block_realize_(std::move(wb_block_realize)),
1112 old_block_realize_(std::move(old_block_realize)),
1113 rf_loop_(std::move(rf_loop)),
1114 reduce_loop_vars_(std::move(reduce_loop_vars)),
1115 loop_vars2loop_(std::move(loop_vars2loop)) {}
1116
1117 Stmt VisitStmt_(const ForNode* loop) final {
1118 // Step 1. Check whether this loop is outside the reduction block. Given that we've made sure
1119 // that the scope root block has stage-pipeline property, if this loop is not outside the
1120 // reduction block, there's no need to recursively mutate.
1121 if (!loop_vars2loop_.count(loop->loop_var.get())) {
1122 return GetRef<For>(loop);
1123 }
1124
1125 // Step 2. Recursively mutate.
1126 Stmt body = StmtMutator::VisitStmt(loop->body);
1127
1128 // Step 3. If this loop is the rfactor loop and isn't touched by any reduction block iter, it
1129 // should be kept outside the write-back block. Otherwise it shouldn't.
1130 if (loop == rf_loop_.get() || !reduce_loop_vars_.count(loop->loop_var.get())) {
1131 ObjectPtr<ForNode> p_loop = CopyOnWrite(loop);
1132 p_loop->body = body;
1133 body = Stmt(p_loop);
1134 }
1135
1136 // Step 4. If this loop is the outermost loop of the reduction block, return the combination of
1137 // `rf_body_` and the mutation result `body`. Otherwise return the mutation result.
1138 return loop == outermost_loop_.get() ? SeqStmt({rf_body_, body}) : body;
1139 }
1140
1141 Stmt VisitStmt_(const BlockRealizeNode* block_realize) final {
1142 // Due to the visitor's behavior on ForNode, this block-realize must be the reduction block's
1143 // block-realize. And we directly return the new `wb_block_realize`.
1144 ICHECK_EQ(block_realize, old_block_realize_.get());
1145 return wb_block_realize_;
1146 }
1147
1148 Stmt VisitStmt_(const SeqStmtNode* seq) final {
1149 Array<Stmt> new_stmts;
1150 new_stmts.reserve(static_cast<int>(seq->seq.size()));
1151
1152 for (const Stmt old_stmt : seq->seq) {
1153 new_stmts.push_back(VisitStmt(old_stmt));
1154 }
1155 return SeqStmt::Flatten(new_stmts);
1156 }
1157
1158 private:
1159 Stmt rf_body_;
1160 For outermost_loop_;
1161 BlockRealize wb_block_realize_;
1162 BlockRealize old_block_realize_;
1163 For rf_loop_;
1164 std::unordered_set<const VarNode*> reduce_loop_vars_;
1165 std::unordered_map<const VarNode*, For> loop_vars2loop_;
1166};
1167
1168StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_axis) {
1169 // *****************************************************
1170 // * Condition Checks and Information Collection *
1171 // *****************************************************
1172
1173 // Step 1. Check some basic conditions for rfactor. Get the block and block-realize.
1174 BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref);
1175 const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get());
1176 const Block& block = block_realize->block;
1177 StmtSRef scope_root = GetScopeRoot(self, block_sref, //
1178 /*require_stage_pipeline=*/true);
1179 CheckReductionBlock(self, block_sref, scope_root);
1180 const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref);
1181 if (rf_loop->kind != ForKind::kSerial) {
1182 throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop));
1183 }
1184
1185 // Step 2. Collect loop vars that are touched by data parallel block iters and reduction block
1186 // iters, respectively.
1187 std::unordered_set<const VarNode*> data_par_loop_vars;
1188 std::unordered_set<const VarNode*> reduce_loop_vars;
1189 GetVarsTouchedByBlockIters(block_realize, &data_par_loop_vars, &reduce_loop_vars);
1190
1191 // Step 3. Collect the loops of the reduction block. Construct a mapping from loops to
1192 // corresponding loop vars.
1193 Array<For> loops = LoopSRefs2Loops(GetLoops(block_sref));
1194 std::unordered_map<const VarNode*, For> loop_vars2loop = GetLoopVar2LoopMap(loops);
1195
1196 // Step 4. Check four properties that the loops should have:
1197 // - the rfactor loop cannot be touched by any data parallel block iter;
1198 // - all the loops cannot be touched by both data parallel block iters and reduction block iters;
1199 // - the outermost loop should have the reduction block as its first child block;
1200 // - the outermost loop that is touched by some reduction block iters can only have one child
1201 // block.
1202 LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars,
1203 reduce_loop_vars);
1204
1205 // Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the
1206 // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the
1207 // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs
1208 // will be used when constructing the rfactor block.
1209 Array<PrimExpr> init_values{nullptr};
1210 Array<BufferStore> updates{nullptr};
1211 CommReducer reducer{nullptr};
1212 Array<PrimExpr> combiner_lhs{nullptr};
1213 Array<PrimExpr> combiner_rhs{nullptr};
1214 std::tie(init_values, updates) = GetInitValuesAndUpdatesFromReductionBlock(self, block);
1215 std::tie(reducer, combiner_lhs, combiner_rhs) =
1216 GetReducerAndCombinerLhsRhs(self, init_values, updates);
1217
1218 // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it
1219 // is negative.
1220 factor_axis =
1221 FactorAxisOutOfRangeError::CheckAndUpdate(self->mod, updates[0]->buffer, factor_axis);
1222
1223 // *****************************************************
1224 // * IR Manipulation *
1225 // *****************************************************
1226 // Since rfactor splits the reduction block into two, we call the first one "rfactor block", and
1227 // the latter one "write-back block", and the intermediate buffer is called "rfactor buffer".
1228
1229 // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional
1230 // dimension that specified by `factor_axis` and `rf_loop`.
1231 Array<Buffer> rf_buffers = CreateRFactorBuffers(updates, factor_axis, rf_loop);
1232
1233 // Step 2. Create the rfactor block.
1234 RFactorBlockCreator rf_block_creator(block_realize, GetRef<For>(rf_loop), updates, reducer,
1235 rf_buffers, loop_vars2loop, factor_axis,
1236 std::move(combiner_rhs));
1237 rf_block_creator.CreateBlock();
1238
1239 // Step 3. Create the write-back block.
1240 WriteBackBlockCreator wb_block_creator(block_realize, GetRef<For>(rf_loop), updates, reducer,
1241 rf_buffers, std::move(rf_block_creator.additional_iter_),
1242 std::move(combiner_lhs),
1243 std::move(rf_block_creator.rf_buf_access_indices_));
1244 wb_block_creator.CreateBlock();
1245
1246 // Step 4. Wrap the rfactor block with loops.
1247 Stmt rf_body = CreateLoopOutsideRfactorBlock(rf_block_creator.new_block_realize_, loops);
1248
1249 // *****************************************************
1250 // * Schedule Replacement & Update *
1251 // *****************************************************
1252
1253 // Step 1. Substitute the old scope root block with the new scope root block.
1254 Block old_scope_root_block = GetRef<Block>(scope_root->StmtAs<BlockNode>());
1255 Block new_scope_root_block = BlockReplacer::Replace(
1256 old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize,
1257 GetRef<For>(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers);
1258 self->Replace(
1259 scope_root, new_scope_root_block,
1260 {{old_scope_root_block, new_scope_root_block}, {block, wb_block_creator.new_block_}});
1261
1262 // Step 2. Update scope information.
1263 std::vector<StmtSRef> new_block_srefs{self->stmt2ref.at(rf_block_creator.new_block_.get()),
1264 self->stmt2ref.at(wb_block_creator.new_block_.get())};
1265 for (const StmtSRef& new_block_sref : new_block_srefs) {
1266 BlockInfo& info = self->block_info[new_block_sref];
1267 info.affine_binding = true;
1268 info.region_cover = true;
1269 info.scope->stage_pipeline = true;
1270 }
1271 return new_block_srefs[0];
1272}
1273
1274/******** InstructionKind Registration ********/
1275
1276struct DecomposeReductionTraits : public UnpackedInstTraits<DecomposeReductionTraits> {
1277 static constexpr const char* kName = "DecomposeReduction";
1278 static constexpr bool kIsPure = false;
1279
1280 private:
1281 static constexpr size_t kNumInputs = 2;
1282 static constexpr size_t kNumAttrs = 0;
1283 static constexpr size_t kNumDecisions = 0;
1284
1285 static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) {
1286 return sch->DecomposeReduction(block_rv, loop_rv);
1287 }
1288
1289 static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv) {
1290 PythonAPICall py("decompose_reduction");
1291 py.Input("block", block_rv);
1292 py.Input("loop", loop_rv);
1293 py.SingleOutput(outputs);
1294 return py.Str();
1295 }
1296
1297 template <typename>
1298 friend struct ::tvm::tir::UnpackedInstTraits;
1299};
1300
1301struct RFactorTraits : public UnpackedInstTraits<RFactorTraits> {
1302 static constexpr const char* kName = "RFactor";
1303 static constexpr bool kIsPure = false;
1304
1305 private:
1306 static constexpr size_t kNumInputs = 1;
1307 static constexpr size_t kNumAttrs = 1;
1308 static constexpr size_t kNumDecisions = 0;
1309
1310 static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer factor_axis) {
1311 return sch->RFactor(loop_rv, factor_axis->value);
1312 }
1313
1314 static String UnpackedAsPython(Array<String> outputs, String loop_rv, Integer factor_axis) {
1315 PythonAPICall py("rfactor");
1316 py.Input("loop", loop_rv);
1317 py.Input("factor_axis", factor_axis->value);
1318 py.SingleOutput(outputs);
1319 return py.Str();
1320 }
1321
1322 template <typename>
1323 friend struct ::tvm::tir::UnpackedInstTraits;
1324};
1325
1326TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits);
1327TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits);
1328
1329/******** FFI ********/
1330
1331TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer")
1332 .set_body_typed([](int n_buffers, PackedFunc combiner_getter, PackedFunc identity_getter) {
1333 ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter),
1334 std::move(identity_getter));
1335 });
1336
1337} // namespace tir
1338} // namespace tvm
1339