1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/******** Pattern Matcher ********/
25
26/*!
27 * \brief PrimExpr pattern matcher.
28 *
29 * It is different from the pattern matcher in arith/pattern_match.h, which is dedicated
30 * for compile-time constant patterns. This pattern matcher can work on dynamic user-specific
31 * patterns.
32 *
33 * The code below shows how to use the pattern matcher.
34 *
35 * \code
36 *
37 * Var x("x"), y("y");
38 * // use PrimExpr to declare patterns, x, y are holes that can be filled with
39 * PatternMatcher pattern_matcher(x + y);
40 * // expr = C[i, j] + A[i, k] * B[k, j], which is the expr we want to match
41 * pattern_matcher.Match(expr);
42 *
43 * if (pattern_matcher.Success()) {
44 * pattern_matcher.Eval(x) // C[i, j]
45 * pattern_matcher.Eval(y) // A[i, k] * B[k, j]
46 * }
47 *
48 * \endcode
49 */
50class PatternMatcher : public ExprVisitor {
51 public:
52 explicit PatternMatcher(Array<PrimExpr> pattern) : pattern_(std::move(pattern)) {}
53
54 void VisitExpr_(const VarNode* op) final {
55 auto it = filled_map_.find(op);
56 if (it == filled_map_.end()) {
57 filled_map_[op] = expr_to_match_;
58 } else {
59 ExprDeepEqual equal;
60 if (it->second.same_as(expr_to_match_) || equal(it->second, expr_to_match_)) return;
61 match_success_ = false;
62 }
63 }
64
65 void VisitExpr_(const LoadNode* op) final {
66 const auto* ptr = expr_to_match_.as<LoadNode>();
67 if (ptr == nullptr) {
68 match_success_ = false;
69 } else {
70 if (!op->buffer_var.same_as(ptr->buffer_var)) {
71 match_success_ = false;
72 } else {
73 PrimExpr tmp = expr_to_match_;
74 expr_to_match_ = ptr->predicate;
75 VisitExpr(op->predicate);
76 expr_to_match_ = ptr->index;
77 VisitExpr(op->index);
78 std::swap(expr_to_match_, tmp);
79 }
80 }
81 }
82
83 void VisitExpr_(const LetNode* op) final {
84 const auto* ptr = expr_to_match_.as<LetNode>();
85 if (ptr == nullptr) {
86 match_success_ = false;
87 } else {
88 PrimExpr tmp = expr_to_match_;
89 expr_to_match_ = ptr->var;
90 VisitExpr(op->var);
91 expr_to_match_ = ptr->value;
92 VisitExpr(op->value);
93 expr_to_match_ = ptr->body;
94 VisitExpr(op->body);
95 std::swap(expr_to_match_, tmp);
96 }
97 }
98
99 void VisitExpr_(const CallNode* op) final {
100 const auto* ptr = expr_to_match_.as<CallNode>();
101 if (ptr == nullptr) {
102 match_success_ = false;
103 } else {
104 if (!op->op.same_as(ptr->op)) {
105 match_success_ = false;
106 } else {
107 PrimExpr tmp = expr_to_match_;
108 for (size_t i = 0; i < op->args.size(); ++i) {
109 expr_to_match_ = ptr->args[i];
110 VisitExpr(op->args[i]);
111 }
112 std::swap(expr_to_match_, tmp);
113 }
114 }
115 }
116
117#define TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OpName) \
118 void VisitExpr_(const OpName* op) { \
119 const auto* ptr = expr_to_match_.as<OpName>(); \
120 if (ptr == nullptr) { \
121 match_success_ = false; \
122 } else { \
123 PrimExpr current = expr_to_match_; \
124 expr_to_match_ = ptr->a; \
125 VisitExpr(op->a); \
126 expr_to_match_ = ptr->b; \
127 VisitExpr(op->b); \
128 std::swap(expr_to_match_, current); \
129 } \
130 }
131
132 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AddNode);
133 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(SubNode);
134 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MulNode);
135 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(DivNode);
136 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(ModNode);
137 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorDivNode);
138 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorModNode);
139 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MinNode);
140 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MaxNode);
141 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(EQNode);
142 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(NENode);
143 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LTNode);
144 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LENode);
145 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GTNode);
146 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GENode);
147 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AndNode);
148 TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OrNode);
149
150 void VisitExpr_(const CastNode* op) final {
151 const auto* ptr = expr_to_match_.as<CastNode>();
152 if (ptr == nullptr) {
153 match_success_ = false;
154 } else {
155 if (!runtime::TypeEqual(op->dtype, ptr->dtype)) {
156 match_success_ = false;
157 } else {
158 PrimExpr tmp = expr_to_match_;
159 expr_to_match_ = ptr->value;
160 VisitExpr(op->value);
161 std::swap(expr_to_match_, tmp);
162 }
163 }
164 }
165
166 void VisitExpr_(const NotNode* op) final {
167 const auto* ptr = expr_to_match_.as<NotNode>();
168 if (ptr == nullptr) {
169 match_success_ = false;
170 } else {
171 PrimExpr tmp = expr_to_match_;
172 expr_to_match_ = ptr->a;
173 VisitExpr(op->a);
174 std::swap(expr_to_match_, tmp);
175 }
176 }
177
178 void VisitExpr_(const SelectNode* op) final {
179 const auto* ptr = expr_to_match_.as<SelectNode>();
180 if (ptr == nullptr) {
181 match_success_ = false;
182 } else {
183 PrimExpr tmp = expr_to_match_;
184 expr_to_match_ = ptr->condition;
185 VisitExpr(op->condition);
186 expr_to_match_ = ptr->true_value;
187 VisitExpr(op->true_value);
188 expr_to_match_ = ptr->false_value;
189 VisitExpr(op->false_value);
190 std::swap(expr_to_match_, tmp);
191 }
192 }
193
194 void VisitExpr_(const RampNode* op) final {
195 const auto* ptr = expr_to_match_.as<RampNode>();
196 if (ptr == nullptr) {
197 match_success_ = false;
198 } else {
199 if (op->lanes != ptr->lanes) {
200 match_success_ = false;
201 } else {
202 PrimExpr tmp = expr_to_match_;
203 expr_to_match_ = ptr->base;
204 VisitExpr(op->base);
205 expr_to_match_ = ptr->stride;
206 VisitExpr(op->stride);
207 std::swap(expr_to_match_, tmp);
208 }
209 }
210 }
211
212 void VisitExpr_(const BroadcastNode* op) final {
213 const auto* ptr = expr_to_match_.as<BroadcastNode>();
214 if (ptr == nullptr) {
215 match_success_ = false;
216 } else {
217 if (op->lanes != ptr->lanes) {
218 match_success_ = false;
219 } else {
220 PrimExpr tmp = expr_to_match_;
221 expr_to_match_ = ptr->value;
222 VisitExpr(op->value);
223 std::swap(expr_to_match_, tmp);
224 }
225 }
226 }
227
228 void VisitExpr_(const ShuffleNode* op) final {
229 const auto* ptr = expr_to_match_.as<ShuffleNode>();
230 if (ptr == nullptr) {
231 match_success_ = false;
232 } else {
233 if (op->vectors.size() != ptr->vectors.size() || op->indices.size() != ptr->indices.size()) {
234 match_success_ = false;
235 } else {
236 PrimExpr tmp = expr_to_match_;
237 for (size_t i = 0; i < op->indices.size(); ++i) {
238 expr_to_match_ = ptr->indices[i];
239 VisitExpr(op->indices[i]);
240 }
241 for (size_t i = 0; i < op->vectors.size(); ++i) {
242 expr_to_match_ = ptr->vectors[i];
243 VisitExpr(op->vectors[i]);
244 }
245 std::swap(expr_to_match_, tmp);
246 }
247 }
248 }
249
250 void VisitExpr_(const IntImmNode* op) final {
251 const auto* ptr = expr_to_match_.as<IntImmNode>();
252 match_success_ = ptr != nullptr && op->value == ptr->value;
253 }
254
255 void VisitExpr_(const FloatImmNode* op) final {
256 const auto* ptr = expr_to_match_.as<FloatImmNode>();
257 match_success_ = ptr != nullptr && op->value == ptr->value;
258 }
259
260 void VisitExpr_(const StringImmNode* op) final {
261 const auto* ptr = expr_to_match_.as<StringImmNode>();
262 match_success_ = ptr != nullptr && op->value == ptr->value;
263 }
264
265 void VisitExpr_(const BufferLoadNode* op) final {
266 const auto* ptr = expr_to_match_.as<BufferLoadNode>();
267 if (ptr == nullptr) {
268 match_success_ = false;
269 } else {
270 if (!op->buffer.same_as(ptr->buffer) || op->indices.size() != ptr->indices.size()) {
271 match_success_ = false;
272 } else {
273 PrimExpr tmp = expr_to_match_;
274 for (size_t i = 0; i < op->indices.size(); ++i) {
275 expr_to_match_ = ptr->indices[i];
276 VisitExpr(op->indices[i]);
277 }
278 std::swap(expr_to_match_, tmp);
279 }
280 }
281 }
282
283 void Match(const Array<PrimExpr>& exprs_to_match) {
284 this->match_success_ = true;
285 this->filled_map_.clear();
286
287 ICHECK_EQ(pattern_.size(), exprs_to_match.size());
288 int n_buffers = pattern_.size();
289 for (int i = 0; i < n_buffers; ++i) {
290 this->expr_to_match_ = exprs_to_match[i];
291 this->operator()(pattern_[i]);
292 }
293 }
294
295 PrimExpr Eval(const Var& var) {
296 auto it = filled_map_.find(var.operator->());
297 ICHECK(it != filled_map_.end()) << "Unknown pattern variable";
298 ICHECK(match_success_) << "Match failed";
299 return it->second;
300 }
301
302 bool Success() const { return match_success_; }
303
304 private:
305 bool match_success_{true};
306 Array<PrimExpr> pattern_;
307 PrimExpr expr_to_match_;
308 std::unordered_map<const VarNode*, PrimExpr> filled_map_;
309};
310
311/******** Reduction Block Related ********/
312
313static const char* kRFactorCrossThreadReductionApplicableBlockDef =
314 R"(Definition of a reduction block that is applicable by RFactor and Cross-Thread Reduction:
3151) The block init should be a single BufferStore or a SeqStmt of BufferStores
3162) The buffers initialized in the block init should be all different
3173) The number of consecutive LetStmts in the block body (if any) should equal the number of BufferStores in the block init
3184) The variables of the LetStmts in the block body should be all different
3195) The body of the innermost LetStmt should be a single BufferStore or a SeqStmt of BufferStores
3206) The number of BufferStores under the block body should equal the number of BufferStores in the block init, and thereby equal the number of LetStmts above
3217) The variables bound by the LetStmts in the block body must all directly serve as values of the BufferStores inside, and the stored values of the BufferStores can only be those variables
3228) The variables stored by the BufferStores in the block body should be all different
3239) The buffers written by the BufferStores in the block body should be all different
32410) The buffers initialized in the block init and written in the block body should match
32511) The buffers written by the block should have same shape
32612) The indices of all BufferStores in the reduction block should be the same)";
327
328void ErrorRFactorCrossThreadReductionNotApplicable(const Optional<ScheduleState>& self, Block block,
329 int violated_cond) {
330 class RFactorNotApplicableError : public ScheduleError {
331 public:
332 explicit RFactorNotApplicableError(IRModule mod, Block block, int violated_cond)
333 : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {}
334
335 String FastErrorString() const final {
336 return "ScheduleError: RFactor cannot be applied to the block since the block does not meet "
337 "the requirements";
338 }
339
340 String DetailRenderTemplate() const final {
341 std::ostringstream os;
342 os << "RFactor cannot be applied to block {0}, because the block violates condition #"
343 << violated_cond_ << ".\n"
344 << kRFactorCrossThreadReductionApplicableBlockDef;
345 return os.str();
346 }
347
348 IRModule mod() const final { return mod_; }
349 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
350
351 IRModule mod_;
352 Block block_;
353 int violated_cond_;
354 };
355
356 if (self.defined()) {
357 throw RFactorNotApplicableError(self.value()->mod, std::move(block), violated_cond);
358 } else {
359 LOG(FATAL) << "ValueError: Cross-thread reduction cannot be applied to the block "
360 << block->name_hint << " because the block violates the condition #" << violated_cond
361 << ".\n"
362 << kRFactorCrossThreadReductionApplicableBlockDef;
363 }
364}
365
366/*!
367 * \brief Extract the BufferStores, which serve as the reduction updates, from the given LetStmt and
368 * the BufferStores inside. And meanwhile set the buffer order of the reduction
369 * \param self The schedule state, used for error reporting
370 * \param block The reduction block, used for error reporting
371 * \param let The LetStmt from which the reduction updates are extracted
372 * \param n_buffers The number of buffers participating in the reduction
373 * \param updates The extracted reduction updates
374 * \param buf2index A mapping from reduction buffers to their indices of the reduction order
375 * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block
376 */
377void ExtractReductionUpdates(const Optional<ScheduleState>& self, Block block,
378 const LetStmtNode* let, int n_buffers, Array<BufferStore>* updates,
379 std::unordered_map<const BufferNode*, int>* buf2index) {
380 std::unordered_map<const VarNode*, int> var2index;
381 Array<PrimExpr> let_values;
382 let_values.reserve(n_buffers);
383 updates->resize(n_buffers);
384
385 // Step 1.
386 // - Extract the BufferStore values from the LetStmts.
387 // - Construct the mapping from let variables to the index.
388 for (int i = 0; i < n_buffers; ++i) {
389 if (let == nullptr) {
390 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3);
391 }
392
393 let_values.push_back(let->value);
394 auto insert_result = var2index.insert(std::make_pair(let->var.get(), i));
395 if (!insert_result.second) {
396 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/4);
397 }
398 if (i != n_buffers - 1) {
399 let = let->body.as<LetStmtNode>();
400 }
401 }
402
403 // There should be no more LetStmt.
404 if (let->body->IsInstance<LetStmtNode>()) {
405 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3);
406 }
407
408 // Now `let` is expected to be the innermost LetStmt, whose body should either be a SeqStmt or
409 // a BufferStore
410 const auto* p_seq = let->body.as<SeqStmtNode>();
411 const auto* p_buf_store = let->body.as<BufferStoreNode>();
412 if (p_seq == nullptr && p_buf_store == nullptr) {
413 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5);
414 }
415 SeqStmt seq =
416 p_seq != nullptr ? GetRef<SeqStmt>(p_seq) : SeqStmt({GetRef<BufferStore>(p_buf_store)});
417 if (static_cast<int>(seq->seq.size()) != n_buffers) {
418 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6);
419 }
420
421 // Step 2.
422 // - Create BufferStores according to the variables being stored.
423 // - Construct the mapping from reduction buffers to the index.
424 for (const Stmt& stmt : seq->seq) {
425 const auto* buf_store = stmt.as<BufferStoreNode>();
426 if (buf_store == nullptr) {
427 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5);
428 }
429 const auto* var = buf_store->value.as<VarNode>();
430 if (var == nullptr) {
431 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/7);
432 }
433 auto it = var2index.find(var);
434 if (it == var2index.end()) {
435 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/7);
436 }
437 int idx = it->second;
438 if ((*updates)[idx].defined()) {
439 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/8);
440 }
441 updates->Set(idx, BufferStore(buf_store->buffer, let_values[idx], buf_store->indices));
442 auto insert_result = buf2index->insert(std::make_pair(buf_store->buffer.get(), idx));
443 if (!insert_result.second) {
444 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/9);
445 }
446 }
447 for (int i = 0; i < n_buffers; ++i) {
448 ICHECK((*updates)[i].defined());
449 }
450}
451
452std::pair<Array<PrimExpr>, Array<BufferStore>> GetInitValuesAndUpdatesFromReductionBlock(
453 const Optional<ScheduleState>& self, Block block) {
454 Array<BufferStore> inits;
455 Array<BufferStore> updates;
456
457 // Step 1. Extract the BufferStores serving as block inits.
458 if (const auto* init = block->init.as<BufferStoreNode>()) {
459 inits.push_back(GetRef<BufferStore>(init));
460 } else if (const auto* seq_init = block->init.as<SeqStmtNode>()) {
461 std::unordered_set<const BufferNode*> init_buffers;
462 for (const Stmt& stmt : seq_init->seq) {
463 init = stmt.as<BufferStoreNode>();
464 if (init == nullptr) {
465 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1);
466 }
467 auto insert_result = init_buffers.insert(init->buffer.get());
468 if (!insert_result.second) {
469 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/2);
470 }
471 inits.push_back(GetRef<BufferStore>(init));
472 }
473 } else {
474 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1);
475 }
476
477 // Step 2. Extract the block updates, in the form of BufferStores.
478 int n_buffers = inits.size();
479 std::unordered_map<const BufferNode*, int> buf2index;
480 if (const auto* update = block->body.as<BufferStoreNode>()) {
481 updates.push_back(GetRef<BufferStore>(update));
482 buf2index[update->buffer.get()] = 0;
483 } else {
484 const auto* let = block->body.as<LetStmtNode>();
485 ExtractReductionUpdates(self, block, let, n_buffers, &updates, &buf2index);
486 }
487 ICHECK_EQ(updates.size(), n_buffers);
488
489 // Step 3. Set the init values according to the buffer order in `updates`, with the help of the
490 // mapping `buf2index`.
491 Array<PrimExpr> init_values;
492 init_values.resize(n_buffers);
493
494 // - Check all buffers have the same shape
495 // - Check all indices of the BufferStores are the same
496 // - Check buffers written in the block init and the block body can match
497 // - Check buffers do not duplicate
498 const Array<PrimExpr>& expected_shape = updates[0]->buffer->shape;
499 const Array<PrimExpr>& expected_indices = updates[0]->indices;
500 ICHECK_EQ(expected_shape.size(), expected_indices.size());
501 int n_dim = expected_indices.size();
502 arith::Analyzer ana;
503 for (int i = 0; i < n_buffers; ++i) {
504 if (static_cast<int>(updates[i]->buffer->shape.size()) != n_dim) {
505 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/11);
506 }
507 if (static_cast<int>(inits[i]->indices.size()) != n_dim ||
508 static_cast<int>(updates[i]->indices.size()) != n_dim) {
509 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12);
510 }
511 for (int d = 0; d < n_dim; ++d) {
512 if (!ana.CanProveEqual(updates[i]->buffer->shape[d], expected_shape[d])) {
513 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/11);
514 }
515 if (!ana.CanProveEqual(inits[i]->indices[d], expected_indices[d]) ||
516 !ana.CanProveEqual(updates[i]->indices[d], expected_indices[d])) {
517 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12);
518 }
519 }
520
521 auto it = buf2index.find(inits[i]->buffer.get());
522 if (it == buf2index.end()) {
523 ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/10);
524 }
525 int idx = it->second;
526 ICHECK(updates[idx]->buffer.same_as(inits[i]->buffer));
527 ICHECK(!init_values[idx].defined());
528 init_values.Set(idx, inits[i]->value);
529 }
530 for (int i = 0; i < n_buffers; ++i) {
531 ICHECK(init_values[i].defined());
532 }
533
534 return std::make_pair(init_values, updates);
535}
536
537bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters) {
538 for (const IterVar& iter_var : iters) {
539 if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) {
540 return false;
541 }
542 }
543 return true;
544}
545
546bool ReductionIterNotIndexOutputBuffer(const Block& block) {
547 // Step 1. Collect the reduction block iters.
548 std::unordered_set<const VarNode*> reduction_block_iters;
549 reduction_block_iters.reserve(block->iter_vars.size());
550 for (const IterVar& iter_var : block->iter_vars) {
551 if (iter_var->iter_type == kCommReduce) {
552 reduction_block_iters.insert(iter_var->var.get());
553 }
554 }
555 // Step 2. Check if the reduction block iters are used to index the output buffer.
556 std::unordered_set<const BufferNode*> buffer_written;
557 buffer_written.reserve(block->writes.size());
558 for (const BufferRegion& write_region : block->writes) {
559 buffer_written.insert(write_region->buffer.get());
560 }
561 auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool {
562 return UsesVar(expr, [&](const VarNode* var) { //
563 return reduction_block_iters.count(var);
564 });
565 };
566
567 std::unordered_map<const BufferNode*, const BufferNode*> match_buffer_sources;
568 for (const MatchBufferRegion& region : block->match_buffers) {
569 match_buffer_sources[region->buffer.get()] = region->source->buffer.get();
570 }
571 bool affected = false;
572 PreOrderVisit(block->body, [&](const ObjectRef& obj) {
573 if (affected) {
574 return false;
575 }
576 const auto* block_node = obj.as<BlockNode>();
577 if (block_node) {
578 for (const MatchBufferRegion& region : block_node->match_buffers) {
579 match_buffer_sources[region->buffer.get()] = region->source->buffer.get();
580 }
581 }
582 const auto* store = obj.as<BufferStoreNode>();
583 if (!store) {
584 return true;
585 }
586
587 bool write_is_covered_by_match_buffer =
588 match_buffer_sources.count(store->buffer.get()) &&
589 buffer_written.count(match_buffer_sources.find(store->buffer.get())->second);
590 ICHECK(buffer_written.count(store->buffer.get()) || write_is_covered_by_match_buffer)
591 << "ValueError: The buffer \"" << store->buffer
592 << "\" is written in the block but is not in the block's signature nor is it covered by "
593 "a match_buffer";
594 for (const PrimExpr& index : store->indices) {
595 if (f_uses_reduction_block_var(index)) {
596 affected = true;
597 return false;
598 }
599 }
600 return false;
601 });
602 return !affected;
603}
604
605class NoMatchedReducerError : public ScheduleError {
606 public:
607 explicit NoMatchedReducerError(IRModule mod, Array<PrimExpr> identities,
608 Array<BufferStore> combiners)
609 : mod_(std::move(mod)),
610 identities_(std::move(identities)),
611 combiners_(std::move(combiners)) {}
612
613 String FastErrorString() const final {
614 return "ScheduleError: No matched reducer for the identity and the combiner of this reduction "
615 "block. So rfactor and cross-thread reduction cannot be applied.";
616 }
617
618 String DetailRenderTemplate() const final {
619 std::ostringstream os;
620 os << "No matched reducer for identity " << identities_ << " and combiner " << combiners_
621 << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for "
622 "default reducers or registering new reducers.";
623 return os.str();
624 }
625
626 IRModule mod() const final { return mod_; }
627 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
628
629 IRModule mod_;
630 Array<PrimExpr> identities_;
631 Array<BufferStore> combiners_;
632};
633
634std::tuple<CommReducer, Array<PrimExpr>, Array<PrimExpr>> GetReducerAndCombinerLhsRhs(
635 const Optional<ScheduleState>& self, const Array<PrimExpr>& identities,
636 const Array<BufferStore>& combiners) {
637 CommReducer reducer{nullptr};
638 Array<PrimExpr> combiner_lhs, combiner_rhs;
639 bool matched =
640 FromIdentityCombiner(identities, combiners, &reducer, &combiner_lhs, &combiner_rhs);
641 if (!matched) {
642 if (self.defined()) {
643 throw NoMatchedReducerError(self.value()->mod, identities, combiners);
644 } else {
645 LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the "
646 "reduction block. So rfactor and cross-thread reduction cannot be applied.";
647 }
648 }
649 return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs));
650}
651
652/******** Commutative Reducer ********/
653
654bool MatchReducer(const CommReducer& reducer, const Array<PrimExpr>& identities,
655 const Array<PrimExpr>& combined_values, const Array<BufferLoad>& buf_loads,
656 Array<PrimExpr>* lhs, Array<PrimExpr>* rhs) {
657 ExprDeepEqual equal;
658 ICHECK_EQ(identities.size(), combined_values.size());
659 int n_buffers = identities.size();
660 for (int i = 0; i < n_buffers; ++i) {
661 if (!equal(reducer->identity_element[i], identities[i])) {
662 return false;
663 }
664 }
665
666 PatternMatcher pattern_matcher(reducer->result);
667 pattern_matcher.Match(combined_values);
668 Array<PrimExpr> lhs_tmp, rhs_tmp;
669 lhs_tmp.reserve(n_buffers);
670 rhs_tmp.reserve(n_buffers);
671 if (!pattern_matcher.Success()) {
672 return false;
673 }
674
675 for (int i = 0; i < n_buffers; ++i) {
676 PrimExpr l = pattern_matcher.Eval(reducer->lhs[i]);
677 PrimExpr r = pattern_matcher.Eval(reducer->rhs[i]);
678 if (!equal(buf_loads[i], l)) {
679 return false;
680 }
681 lhs_tmp.push_back(l);
682 rhs_tmp.push_back(r);
683 }
684 *lhs = std::move(lhs_tmp);
685 *rhs = std::move(rhs_tmp);
686 return true;
687}
688
689bool FromIdentityCombiner(const Array<PrimExpr>& identities, const Array<BufferStore>& combiners,
690 CommReducer* result_reducer, Array<PrimExpr>* lhs, Array<PrimExpr>* rhs) {
691 int n = identities.size();
692 Array<BufferLoad> buf_loads;
693 Array<PrimExpr> stored_values;
694 buf_loads.reserve(n);
695 stored_values.reserve(n);
696
697 for (int i = 0; i < n; ++i) {
698 buf_loads.push_back(BufferLoad(combiners[i]->buffer, combiners[i]->indices));
699 stored_values.push_back(combiners[i]->value);
700 }
701
702 // Check reduction patterns.
703 for (const TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>& reducer_getter :
704 GetReducerGetters()) {
705 Optional<CommReducer> reducer = reducer_getter(identities);
706 if (!reducer.defined()) {
707 continue;
708 }
709 if (MatchReducer(reducer.value(), identities, stored_values, buf_loads, lhs, rhs)) {
710 *result_reducer = reducer.value();
711 return true;
712 }
713 }
714 return false;
715}
716
717} // namespace tir
718} // namespace tvm
719