1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /******** Pattern Matcher ********/ |
25 | |
26 | /*! |
27 | * \brief PrimExpr pattern matcher. |
28 | * |
29 | * It is different from the pattern matcher in arith/pattern_match.h, which is dedicated |
30 | * for compile-time constant patterns. This pattern matcher can work on dynamic user-specific |
31 | * patterns. |
32 | * |
33 | * The code below shows how to use the pattern matcher. |
34 | * |
35 | * \code |
36 | * |
37 | * Var x("x"), y("y"); |
38 | * // use PrimExpr to declare patterns, x, y are holes that can be filled with |
39 | * PatternMatcher pattern_matcher(x + y); |
40 | * // expr = C[i, j] + A[i, k] * B[k, j], which is the expr we want to match |
41 | * pattern_matcher.Match(expr); |
42 | * |
43 | * if (pattern_matcher.Success()) { |
44 | * pattern_matcher.Eval(x) // C[i, j] |
45 | * pattern_matcher.Eval(y) // A[i, k] * B[k, j] |
46 | * } |
47 | * |
48 | * \endcode |
49 | */ |
50 | class PatternMatcher : public ExprVisitor { |
51 | public: |
52 | explicit PatternMatcher(Array<PrimExpr> pattern) : pattern_(std::move(pattern)) {} |
53 | |
54 | void VisitExpr_(const VarNode* op) final { |
55 | auto it = filled_map_.find(op); |
56 | if (it == filled_map_.end()) { |
57 | filled_map_[op] = expr_to_match_; |
58 | } else { |
59 | ExprDeepEqual equal; |
60 | if (it->second.same_as(expr_to_match_) || equal(it->second, expr_to_match_)) return; |
61 | match_success_ = false; |
62 | } |
63 | } |
64 | |
65 | void VisitExpr_(const LoadNode* op) final { |
66 | const auto* ptr = expr_to_match_.as<LoadNode>(); |
67 | if (ptr == nullptr) { |
68 | match_success_ = false; |
69 | } else { |
70 | if (!op->buffer_var.same_as(ptr->buffer_var)) { |
71 | match_success_ = false; |
72 | } else { |
73 | PrimExpr tmp = expr_to_match_; |
74 | expr_to_match_ = ptr->predicate; |
75 | VisitExpr(op->predicate); |
76 | expr_to_match_ = ptr->index; |
77 | VisitExpr(op->index); |
78 | std::swap(expr_to_match_, tmp); |
79 | } |
80 | } |
81 | } |
82 | |
83 | void VisitExpr_(const LetNode* op) final { |
84 | const auto* ptr = expr_to_match_.as<LetNode>(); |
85 | if (ptr == nullptr) { |
86 | match_success_ = false; |
87 | } else { |
88 | PrimExpr tmp = expr_to_match_; |
89 | expr_to_match_ = ptr->var; |
90 | VisitExpr(op->var); |
91 | expr_to_match_ = ptr->value; |
92 | VisitExpr(op->value); |
93 | expr_to_match_ = ptr->body; |
94 | VisitExpr(op->body); |
95 | std::swap(expr_to_match_, tmp); |
96 | } |
97 | } |
98 | |
99 | void VisitExpr_(const CallNode* op) final { |
100 | const auto* ptr = expr_to_match_.as<CallNode>(); |
101 | if (ptr == nullptr) { |
102 | match_success_ = false; |
103 | } else { |
104 | if (!op->op.same_as(ptr->op)) { |
105 | match_success_ = false; |
106 | } else { |
107 | PrimExpr tmp = expr_to_match_; |
108 | for (size_t i = 0; i < op->args.size(); ++i) { |
109 | expr_to_match_ = ptr->args[i]; |
110 | VisitExpr(op->args[i]); |
111 | } |
112 | std::swap(expr_to_match_, tmp); |
113 | } |
114 | } |
115 | } |
116 | |
117 | #define TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OpName) \ |
118 | void VisitExpr_(const OpName* op) { \ |
119 | const auto* ptr = expr_to_match_.as<OpName>(); \ |
120 | if (ptr == nullptr) { \ |
121 | match_success_ = false; \ |
122 | } else { \ |
123 | PrimExpr current = expr_to_match_; \ |
124 | expr_to_match_ = ptr->a; \ |
125 | VisitExpr(op->a); \ |
126 | expr_to_match_ = ptr->b; \ |
127 | VisitExpr(op->b); \ |
128 | std::swap(expr_to_match_, current); \ |
129 | } \ |
130 | } |
131 | |
132 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AddNode); |
133 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(SubNode); |
134 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MulNode); |
135 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(DivNode); |
136 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(ModNode); |
137 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorDivNode); |
138 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorModNode); |
139 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MinNode); |
140 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MaxNode); |
141 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(EQNode); |
142 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(NENode); |
143 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LTNode); |
144 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LENode); |
145 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GTNode); |
146 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GENode); |
147 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AndNode); |
148 | TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OrNode); |
149 | |
150 | void VisitExpr_(const CastNode* op) final { |
151 | const auto* ptr = expr_to_match_.as<CastNode>(); |
152 | if (ptr == nullptr) { |
153 | match_success_ = false; |
154 | } else { |
155 | if (!runtime::TypeEqual(op->dtype, ptr->dtype)) { |
156 | match_success_ = false; |
157 | } else { |
158 | PrimExpr tmp = expr_to_match_; |
159 | expr_to_match_ = ptr->value; |
160 | VisitExpr(op->value); |
161 | std::swap(expr_to_match_, tmp); |
162 | } |
163 | } |
164 | } |
165 | |
166 | void VisitExpr_(const NotNode* op) final { |
167 | const auto* ptr = expr_to_match_.as<NotNode>(); |
168 | if (ptr == nullptr) { |
169 | match_success_ = false; |
170 | } else { |
171 | PrimExpr tmp = expr_to_match_; |
172 | expr_to_match_ = ptr->a; |
173 | VisitExpr(op->a); |
174 | std::swap(expr_to_match_, tmp); |
175 | } |
176 | } |
177 | |
178 | void VisitExpr_(const SelectNode* op) final { |
179 | const auto* ptr = expr_to_match_.as<SelectNode>(); |
180 | if (ptr == nullptr) { |
181 | match_success_ = false; |
182 | } else { |
183 | PrimExpr tmp = expr_to_match_; |
184 | expr_to_match_ = ptr->condition; |
185 | VisitExpr(op->condition); |
186 | expr_to_match_ = ptr->true_value; |
187 | VisitExpr(op->true_value); |
188 | expr_to_match_ = ptr->false_value; |
189 | VisitExpr(op->false_value); |
190 | std::swap(expr_to_match_, tmp); |
191 | } |
192 | } |
193 | |
194 | void VisitExpr_(const RampNode* op) final { |
195 | const auto* ptr = expr_to_match_.as<RampNode>(); |
196 | if (ptr == nullptr) { |
197 | match_success_ = false; |
198 | } else { |
199 | if (op->lanes != ptr->lanes) { |
200 | match_success_ = false; |
201 | } else { |
202 | PrimExpr tmp = expr_to_match_; |
203 | expr_to_match_ = ptr->base; |
204 | VisitExpr(op->base); |
205 | expr_to_match_ = ptr->stride; |
206 | VisitExpr(op->stride); |
207 | std::swap(expr_to_match_, tmp); |
208 | } |
209 | } |
210 | } |
211 | |
212 | void VisitExpr_(const BroadcastNode* op) final { |
213 | const auto* ptr = expr_to_match_.as<BroadcastNode>(); |
214 | if (ptr == nullptr) { |
215 | match_success_ = false; |
216 | } else { |
217 | if (op->lanes != ptr->lanes) { |
218 | match_success_ = false; |
219 | } else { |
220 | PrimExpr tmp = expr_to_match_; |
221 | expr_to_match_ = ptr->value; |
222 | VisitExpr(op->value); |
223 | std::swap(expr_to_match_, tmp); |
224 | } |
225 | } |
226 | } |
227 | |
228 | void VisitExpr_(const ShuffleNode* op) final { |
229 | const auto* ptr = expr_to_match_.as<ShuffleNode>(); |
230 | if (ptr == nullptr) { |
231 | match_success_ = false; |
232 | } else { |
233 | if (op->vectors.size() != ptr->vectors.size() || op->indices.size() != ptr->indices.size()) { |
234 | match_success_ = false; |
235 | } else { |
236 | PrimExpr tmp = expr_to_match_; |
237 | for (size_t i = 0; i < op->indices.size(); ++i) { |
238 | expr_to_match_ = ptr->indices[i]; |
239 | VisitExpr(op->indices[i]); |
240 | } |
241 | for (size_t i = 0; i < op->vectors.size(); ++i) { |
242 | expr_to_match_ = ptr->vectors[i]; |
243 | VisitExpr(op->vectors[i]); |
244 | } |
245 | std::swap(expr_to_match_, tmp); |
246 | } |
247 | } |
248 | } |
249 | |
250 | void VisitExpr_(const IntImmNode* op) final { |
251 | const auto* ptr = expr_to_match_.as<IntImmNode>(); |
252 | match_success_ = ptr != nullptr && op->value == ptr->value; |
253 | } |
254 | |
255 | void VisitExpr_(const FloatImmNode* op) final { |
256 | const auto* ptr = expr_to_match_.as<FloatImmNode>(); |
257 | match_success_ = ptr != nullptr && op->value == ptr->value; |
258 | } |
259 | |
260 | void VisitExpr_(const StringImmNode* op) final { |
261 | const auto* ptr = expr_to_match_.as<StringImmNode>(); |
262 | match_success_ = ptr != nullptr && op->value == ptr->value; |
263 | } |
264 | |
265 | void VisitExpr_(const BufferLoadNode* op) final { |
266 | const auto* ptr = expr_to_match_.as<BufferLoadNode>(); |
267 | if (ptr == nullptr) { |
268 | match_success_ = false; |
269 | } else { |
270 | if (!op->buffer.same_as(ptr->buffer) || op->indices.size() != ptr->indices.size()) { |
271 | match_success_ = false; |
272 | } else { |
273 | PrimExpr tmp = expr_to_match_; |
274 | for (size_t i = 0; i < op->indices.size(); ++i) { |
275 | expr_to_match_ = ptr->indices[i]; |
276 | VisitExpr(op->indices[i]); |
277 | } |
278 | std::swap(expr_to_match_, tmp); |
279 | } |
280 | } |
281 | } |
282 | |
283 | void Match(const Array<PrimExpr>& exprs_to_match) { |
284 | this->match_success_ = true; |
285 | this->filled_map_.clear(); |
286 | |
287 | ICHECK_EQ(pattern_.size(), exprs_to_match.size()); |
288 | int n_buffers = pattern_.size(); |
289 | for (int i = 0; i < n_buffers; ++i) { |
290 | this->expr_to_match_ = exprs_to_match[i]; |
291 | this->operator()(pattern_[i]); |
292 | } |
293 | } |
294 | |
295 | PrimExpr Eval(const Var& var) { |
296 | auto it = filled_map_.find(var.operator->()); |
297 | ICHECK(it != filled_map_.end()) << "Unknown pattern variable" ; |
298 | ICHECK(match_success_) << "Match failed" ; |
299 | return it->second; |
300 | } |
301 | |
302 | bool Success() const { return match_success_; } |
303 | |
304 | private: |
305 | bool match_success_{true}; |
306 | Array<PrimExpr> pattern_; |
307 | PrimExpr expr_to_match_; |
308 | std::unordered_map<const VarNode*, PrimExpr> filled_map_; |
309 | }; |
310 | |
311 | /******** Reduction Block Related ********/ |
312 | |
313 | static const char* kRFactorCrossThreadReductionApplicableBlockDef = |
314 | R"(Definition of a reduction block that is applicable by RFactor and Cross-Thread Reduction: |
315 | 1) The block init should be a single BufferStore or a SeqStmt of BufferStores |
316 | 2) The buffers initialized in the block init should be all different |
317 | 3) The number of consecutive LetStmts in the block body (if any) should equal the number of BufferStores in the block init |
318 | 4) The variables of the LetStmts in the block body should be all different |
319 | 5) The body of the innermost LetStmt should be a single BufferStore or a SeqStmt of BufferStores |
320 | 6) The number of BufferStores under the block body should equal the number of BufferStores in the block init, and thereby equal the number of LetStmts above |
321 | 7) The variables bound by the LetStmts in the block body must all directly serve as values of the BufferStores inside, and the stored values of the BufferStores can only be those variables |
322 | 8) The variables stored by the BufferStores in the block body should be all different |
323 | 9) The buffers written by the BufferStores in the block body should be all different |
324 | 10) The buffers initialized in the block init and written in the block body should match |
325 | 11) The buffers written by the block should have same shape |
326 | 12) The indices of all BufferStores in the reduction block should be the same)" ; |
327 | |
328 | void ErrorRFactorCrossThreadReductionNotApplicable(const Optional<ScheduleState>& self, Block block, |
329 | int violated_cond) { |
330 | class RFactorNotApplicableError : public ScheduleError { |
331 | public: |
332 | explicit RFactorNotApplicableError(IRModule mod, Block block, int violated_cond) |
333 | : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} |
334 | |
335 | String FastErrorString() const final { |
336 | return "ScheduleError: RFactor cannot be applied to the block since the block does not meet " |
337 | "the requirements" ; |
338 | } |
339 | |
340 | String DetailRenderTemplate() const final { |
341 | std::ostringstream os; |
342 | os << "RFactor cannot be applied to block {0}, because the block violates condition #" |
343 | << violated_cond_ << ".\n" |
344 | << kRFactorCrossThreadReductionApplicableBlockDef; |
345 | return os.str(); |
346 | } |
347 | |
348 | IRModule mod() const final { return mod_; } |
349 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
350 | |
351 | IRModule mod_; |
352 | Block block_; |
353 | int violated_cond_; |
354 | }; |
355 | |
356 | if (self.defined()) { |
357 | throw RFactorNotApplicableError(self.value()->mod, std::move(block), violated_cond); |
358 | } else { |
359 | LOG(FATAL) << "ValueError: Cross-thread reduction cannot be applied to the block " |
360 | << block->name_hint << " because the block violates the condition #" << violated_cond |
361 | << ".\n" |
362 | << kRFactorCrossThreadReductionApplicableBlockDef; |
363 | } |
364 | } |
365 | |
366 | /*! |
367 | * \brief Extract the BufferStores, which serve as the reduction updates, from the given LetStmt and |
368 | * the BufferStores inside. And meanwhile set the buffer order of the reduction |
369 | * \param self The schedule state, used for error reporting |
370 | * \param block The reduction block, used for error reporting |
371 | * \param let The LetStmt from which the reduction updates are extracted |
372 | * \param n_buffers The number of buffers participating in the reduction |
373 | * \param updates The extracted reduction updates |
374 | * \param buf2index A mapping from reduction buffers to their indices of the reduction order |
375 | * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block |
376 | */ |
377 | void (const Optional<ScheduleState>& self, Block block, |
378 | const LetStmtNode* let, int n_buffers, Array<BufferStore>* updates, |
379 | std::unordered_map<const BufferNode*, int>* buf2index) { |
380 | std::unordered_map<const VarNode*, int> var2index; |
381 | Array<PrimExpr> let_values; |
382 | let_values.reserve(n_buffers); |
383 | updates->resize(n_buffers); |
384 | |
385 | // Step 1. |
386 | // - Extract the BufferStore values from the LetStmts. |
387 | // - Construct the mapping from let variables to the index. |
388 | for (int i = 0; i < n_buffers; ++i) { |
389 | if (let == nullptr) { |
390 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); |
391 | } |
392 | |
393 | let_values.push_back(let->value); |
394 | auto insert_result = var2index.insert(std::make_pair(let->var.get(), i)); |
395 | if (!insert_result.second) { |
396 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/4); |
397 | } |
398 | if (i != n_buffers - 1) { |
399 | let = let->body.as<LetStmtNode>(); |
400 | } |
401 | } |
402 | |
403 | // There should be no more LetStmt. |
404 | if (let->body->IsInstance<LetStmtNode>()) { |
405 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); |
406 | } |
407 | |
408 | // Now `let` is expected to be the innermost LetStmt, whose body should either be a SeqStmt or |
409 | // a BufferStore |
410 | const auto* p_seq = let->body.as<SeqStmtNode>(); |
411 | const auto* p_buf_store = let->body.as<BufferStoreNode>(); |
412 | if (p_seq == nullptr && p_buf_store == nullptr) { |
413 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); |
414 | } |
415 | SeqStmt seq = |
416 | p_seq != nullptr ? GetRef<SeqStmt>(p_seq) : SeqStmt({GetRef<BufferStore>(p_buf_store)}); |
417 | if (static_cast<int>(seq->seq.size()) != n_buffers) { |
418 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6); |
419 | } |
420 | |
421 | // Step 2. |
422 | // - Create BufferStores according to the variables being stored. |
423 | // - Construct the mapping from reduction buffers to the index. |
424 | for (const Stmt& stmt : seq->seq) { |
425 | const auto* buf_store = stmt.as<BufferStoreNode>(); |
426 | if (buf_store == nullptr) { |
427 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); |
428 | } |
429 | const auto* var = buf_store->value.as<VarNode>(); |
430 | if (var == nullptr) { |
431 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/7); |
432 | } |
433 | auto it = var2index.find(var); |
434 | if (it == var2index.end()) { |
435 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/7); |
436 | } |
437 | int idx = it->second; |
438 | if ((*updates)[idx].defined()) { |
439 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/8); |
440 | } |
441 | updates->Set(idx, BufferStore(buf_store->buffer, let_values[idx], buf_store->indices)); |
442 | auto insert_result = buf2index->insert(std::make_pair(buf_store->buffer.get(), idx)); |
443 | if (!insert_result.second) { |
444 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/9); |
445 | } |
446 | } |
447 | for (int i = 0; i < n_buffers; ++i) { |
448 | ICHECK((*updates)[i].defined()); |
449 | } |
450 | } |
451 | |
452 | std::pair<Array<PrimExpr>, Array<BufferStore>> GetInitValuesAndUpdatesFromReductionBlock( |
453 | const Optional<ScheduleState>& self, Block block) { |
454 | Array<BufferStore> inits; |
455 | Array<BufferStore> updates; |
456 | |
457 | // Step 1. Extract the BufferStores serving as block inits. |
458 | if (const auto* init = block->init.as<BufferStoreNode>()) { |
459 | inits.push_back(GetRef<BufferStore>(init)); |
460 | } else if (const auto* seq_init = block->init.as<SeqStmtNode>()) { |
461 | std::unordered_set<const BufferNode*> init_buffers; |
462 | for (const Stmt& stmt : seq_init->seq) { |
463 | init = stmt.as<BufferStoreNode>(); |
464 | if (init == nullptr) { |
465 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1); |
466 | } |
467 | auto insert_result = init_buffers.insert(init->buffer.get()); |
468 | if (!insert_result.second) { |
469 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/2); |
470 | } |
471 | inits.push_back(GetRef<BufferStore>(init)); |
472 | } |
473 | } else { |
474 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1); |
475 | } |
476 | |
477 | // Step 2. Extract the block updates, in the form of BufferStores. |
478 | int n_buffers = inits.size(); |
479 | std::unordered_map<const BufferNode*, int> buf2index; |
480 | if (const auto* update = block->body.as<BufferStoreNode>()) { |
481 | updates.push_back(GetRef<BufferStore>(update)); |
482 | buf2index[update->buffer.get()] = 0; |
483 | } else { |
484 | const auto* let = block->body.as<LetStmtNode>(); |
485 | ExtractReductionUpdates(self, block, let, n_buffers, &updates, &buf2index); |
486 | } |
487 | ICHECK_EQ(updates.size(), n_buffers); |
488 | |
489 | // Step 3. Set the init values according to the buffer order in `updates`, with the help of the |
490 | // mapping `buf2index`. |
491 | Array<PrimExpr> init_values; |
492 | init_values.resize(n_buffers); |
493 | |
494 | // - Check all buffers have the same shape |
495 | // - Check all indices of the BufferStores are the same |
496 | // - Check buffers written in the block init and the block body can match |
497 | // - Check buffers do not duplicate |
498 | const Array<PrimExpr>& expected_shape = updates[0]->buffer->shape; |
499 | const Array<PrimExpr>& expected_indices = updates[0]->indices; |
500 | ICHECK_EQ(expected_shape.size(), expected_indices.size()); |
501 | int n_dim = expected_indices.size(); |
502 | arith::Analyzer ana; |
503 | for (int i = 0; i < n_buffers; ++i) { |
504 | if (static_cast<int>(updates[i]->buffer->shape.size()) != n_dim) { |
505 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/11); |
506 | } |
507 | if (static_cast<int>(inits[i]->indices.size()) != n_dim || |
508 | static_cast<int>(updates[i]->indices.size()) != n_dim) { |
509 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12); |
510 | } |
511 | for (int d = 0; d < n_dim; ++d) { |
512 | if (!ana.CanProveEqual(updates[i]->buffer->shape[d], expected_shape[d])) { |
513 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/11); |
514 | } |
515 | if (!ana.CanProveEqual(inits[i]->indices[d], expected_indices[d]) || |
516 | !ana.CanProveEqual(updates[i]->indices[d], expected_indices[d])) { |
517 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12); |
518 | } |
519 | } |
520 | |
521 | auto it = buf2index.find(inits[i]->buffer.get()); |
522 | if (it == buf2index.end()) { |
523 | ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/10); |
524 | } |
525 | int idx = it->second; |
526 | ICHECK(updates[idx]->buffer.same_as(inits[i]->buffer)); |
527 | ICHECK(!init_values[idx].defined()); |
528 | init_values.Set(idx, inits[i]->value); |
529 | } |
530 | for (int i = 0; i < n_buffers; ++i) { |
531 | ICHECK(init_values[i].defined()); |
532 | } |
533 | |
534 | return std::make_pair(init_values, updates); |
535 | } |
536 | |
537 | bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters) { |
538 | for (const IterVar& iter_var : iters) { |
539 | if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { |
540 | return false; |
541 | } |
542 | } |
543 | return true; |
544 | } |
545 | |
546 | bool ReductionIterNotIndexOutputBuffer(const Block& block) { |
547 | // Step 1. Collect the reduction block iters. |
548 | std::unordered_set<const VarNode*> reduction_block_iters; |
549 | reduction_block_iters.reserve(block->iter_vars.size()); |
550 | for (const IterVar& iter_var : block->iter_vars) { |
551 | if (iter_var->iter_type == kCommReduce) { |
552 | reduction_block_iters.insert(iter_var->var.get()); |
553 | } |
554 | } |
555 | // Step 2. Check if the reduction block iters are used to index the output buffer. |
556 | std::unordered_set<const BufferNode*> buffer_written; |
557 | buffer_written.reserve(block->writes.size()); |
558 | for (const BufferRegion& write_region : block->writes) { |
559 | buffer_written.insert(write_region->buffer.get()); |
560 | } |
561 | auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool { |
562 | return UsesVar(expr, [&](const VarNode* var) { // |
563 | return reduction_block_iters.count(var); |
564 | }); |
565 | }; |
566 | |
567 | std::unordered_map<const BufferNode*, const BufferNode*> match_buffer_sources; |
568 | for (const MatchBufferRegion& region : block->match_buffers) { |
569 | match_buffer_sources[region->buffer.get()] = region->source->buffer.get(); |
570 | } |
571 | bool affected = false; |
572 | PreOrderVisit(block->body, [&](const ObjectRef& obj) { |
573 | if (affected) { |
574 | return false; |
575 | } |
576 | const auto* block_node = obj.as<BlockNode>(); |
577 | if (block_node) { |
578 | for (const MatchBufferRegion& region : block_node->match_buffers) { |
579 | match_buffer_sources[region->buffer.get()] = region->source->buffer.get(); |
580 | } |
581 | } |
582 | const auto* store = obj.as<BufferStoreNode>(); |
583 | if (!store) { |
584 | return true; |
585 | } |
586 | |
587 | bool write_is_covered_by_match_buffer = |
588 | match_buffer_sources.count(store->buffer.get()) && |
589 | buffer_written.count(match_buffer_sources.find(store->buffer.get())->second); |
590 | ICHECK(buffer_written.count(store->buffer.get()) || write_is_covered_by_match_buffer) |
591 | << "ValueError: The buffer \"" << store->buffer |
592 | << "\" is written in the block but is not in the block's signature nor is it covered by " |
593 | "a match_buffer" ; |
594 | for (const PrimExpr& index : store->indices) { |
595 | if (f_uses_reduction_block_var(index)) { |
596 | affected = true; |
597 | return false; |
598 | } |
599 | } |
600 | return false; |
601 | }); |
602 | return !affected; |
603 | } |
604 | |
605 | class NoMatchedReducerError : public ScheduleError { |
606 | public: |
607 | explicit NoMatchedReducerError(IRModule mod, Array<PrimExpr> identities, |
608 | Array<BufferStore> combiners) |
609 | : mod_(std::move(mod)), |
610 | identities_(std::move(identities)), |
611 | combiners_(std::move(combiners)) {} |
612 | |
613 | String FastErrorString() const final { |
614 | return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " |
615 | "block. So rfactor and cross-thread reduction cannot be applied." ; |
616 | } |
617 | |
618 | String DetailRenderTemplate() const final { |
619 | std::ostringstream os; |
620 | os << "No matched reducer for identity " << identities_ << " and combiner " << combiners_ |
621 | << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " |
622 | "default reducers or registering new reducers." ; |
623 | return os.str(); |
624 | } |
625 | |
626 | IRModule mod() const final { return mod_; } |
627 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
628 | |
629 | IRModule mod_; |
630 | Array<PrimExpr> identities_; |
631 | Array<BufferStore> combiners_; |
632 | }; |
633 | |
634 | std::tuple<CommReducer, Array<PrimExpr>, Array<PrimExpr>> GetReducerAndCombinerLhsRhs( |
635 | const Optional<ScheduleState>& self, const Array<PrimExpr>& identities, |
636 | const Array<BufferStore>& combiners) { |
637 | CommReducer reducer{nullptr}; |
638 | Array<PrimExpr> combiner_lhs, combiner_rhs; |
639 | bool matched = |
640 | FromIdentityCombiner(identities, combiners, &reducer, &combiner_lhs, &combiner_rhs); |
641 | if (!matched) { |
642 | if (self.defined()) { |
643 | throw NoMatchedReducerError(self.value()->mod, identities, combiners); |
644 | } else { |
645 | LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " |
646 | "reduction block. So rfactor and cross-thread reduction cannot be applied." ; |
647 | } |
648 | } |
649 | return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); |
650 | } |
651 | |
652 | /******** Commutative Reducer ********/ |
653 | |
654 | bool MatchReducer(const CommReducer& reducer, const Array<PrimExpr>& identities, |
655 | const Array<PrimExpr>& combined_values, const Array<BufferLoad>& buf_loads, |
656 | Array<PrimExpr>* lhs, Array<PrimExpr>* rhs) { |
657 | ExprDeepEqual equal; |
658 | ICHECK_EQ(identities.size(), combined_values.size()); |
659 | int n_buffers = identities.size(); |
660 | for (int i = 0; i < n_buffers; ++i) { |
661 | if (!equal(reducer->identity_element[i], identities[i])) { |
662 | return false; |
663 | } |
664 | } |
665 | |
666 | PatternMatcher pattern_matcher(reducer->result); |
667 | pattern_matcher.Match(combined_values); |
668 | Array<PrimExpr> lhs_tmp, rhs_tmp; |
669 | lhs_tmp.reserve(n_buffers); |
670 | rhs_tmp.reserve(n_buffers); |
671 | if (!pattern_matcher.Success()) { |
672 | return false; |
673 | } |
674 | |
675 | for (int i = 0; i < n_buffers; ++i) { |
676 | PrimExpr l = pattern_matcher.Eval(reducer->lhs[i]); |
677 | PrimExpr r = pattern_matcher.Eval(reducer->rhs[i]); |
678 | if (!equal(buf_loads[i], l)) { |
679 | return false; |
680 | } |
681 | lhs_tmp.push_back(l); |
682 | rhs_tmp.push_back(r); |
683 | } |
684 | *lhs = std::move(lhs_tmp); |
685 | *rhs = std::move(rhs_tmp); |
686 | return true; |
687 | } |
688 | |
689 | bool FromIdentityCombiner(const Array<PrimExpr>& identities, const Array<BufferStore>& combiners, |
690 | CommReducer* result_reducer, Array<PrimExpr>* lhs, Array<PrimExpr>* rhs) { |
691 | int n = identities.size(); |
692 | Array<BufferLoad> buf_loads; |
693 | Array<PrimExpr> stored_values; |
694 | buf_loads.reserve(n); |
695 | stored_values.reserve(n); |
696 | |
697 | for (int i = 0; i < n; ++i) { |
698 | buf_loads.push_back(BufferLoad(combiners[i]->buffer, combiners[i]->indices)); |
699 | stored_values.push_back(combiners[i]->value); |
700 | } |
701 | |
702 | // Check reduction patterns. |
703 | for (const TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>& reducer_getter : |
704 | GetReducerGetters()) { |
705 | Optional<CommReducer> reducer = reducer_getter(identities); |
706 | if (!reducer.defined()) { |
707 | continue; |
708 | } |
709 | if (MatchReducer(reducer.value(), identities, stored_values, buf_loads, lhs, rhs)) { |
710 | *result_reducer = reducer.value(); |
711 | return true; |
712 | } |
713 | } |
714 | return false; |
715 | } |
716 | |
717 | } // namespace tir |
718 | } // namespace tvm |
719 | |