1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../ir_comparator.h"
20#include "../utils.h"
21
22namespace tvm {
23namespace tir {
24
25/******** IR Module ********/
26
27const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
28 GlobalVar* result_g_var) {
29 for (const auto& kv : mod->functions) {
30 const GlobalVar& g_var = kv.first;
31 const BaseFunc& base_func = kv.second;
32 if (const auto* func = base_func.as<PrimFuncNode>()) {
33 if (const auto* realize = func->body.as<BlockRealizeNode>()) {
34 if (realize->block.get() == root_block) {
35 if (result_g_var != nullptr) {
36 *result_g_var = g_var;
37 }
38 return func;
39 }
40 }
41 }
42 }
43 LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the "
44 "statement:\n"
45 << GetRef<Stmt>(root_block);
46 throw;
47}
48
49/******** Scope ********/
50
51StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref,
52 bool require_stage_pipeline) {
53 class RootBlockError : public ScheduleError {
54 public:
55 explicit RootBlockError(IRModule mod) : mod_(mod) {}
56 IRModule mod() const final { return mod_; }
57 String FastErrorString() const final {
58 return "ScheduleError: The primitive does not operate on the root block";
59 }
60 String DetailRenderTemplate() const final {
61 return "The primitive does not operate on the root block";
62 }
63 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
64 IRModule mod_;
65 };
66
67 class NotStagePipelineError : public ScheduleError {
68 public:
69 explicit NotStagePipelineError(IRModule mod, Block block) : mod_(mod), block_(block) {}
70 IRModule mod() const final { return mod_; }
71 String FastErrorString() const final {
72 return "ScheduleError: The scope root is not a stage pipeline";
73 }
74 String DetailRenderTemplate() const final {
75 return R"(The scope {0} is not a stage pipeline.
76Definition of a scope that is a stage pipeline:
77- The region cover property holds for every of its child blocks
78- No write-after-read dependency or opaque dependency,
79- only read-after-write and write-after-write are allowed
80- All the statements in the scope are schedulable statements, i.e. Block and For
81)";
82 }
83 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
84 IRModule mod_;
85 Block block_;
86 };
87
88 StmtSRef scope_root_sref{nullptr};
89 StmtSRef scope_root_subtree{nullptr};
90 // Step 1. Find the scope root and the subtree that the given sref is in
91 {
92 const StmtSRefNode* p = sref->parent;
93 const StmtSRefNode* subtree = sref.get();
94 for (; p != nullptr; subtree = p, p = p->parent) {
95 if (p->stmt->IsInstance<BlockNode>()) {
96 scope_root_sref = GetRef<StmtSRef>(p);
97 scope_root_subtree = GetRef<StmtSRef>(subtree);
98 break;
99 }
100 }
101 if (p == nullptr) {
102 throw RootBlockError(self->mod);
103 }
104 }
105 // Step 2. Handle `require_stage_pipeline`
106 if (require_stage_pipeline) {
107 bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
108 if (stage_pipeline == false) {
109 const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
110 throw NotStagePipelineError(self->mod, GetRef<Block>(block));
111 }
112 }
113 return scope_root_sref;
114}
115
116ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) {
117 struct Collector : public StmtVisitor {
118 void VisitStmt_(const BlockRealizeNode* realize) final {
119 result.realizes.push_back(GetRef<BlockRealize>(realize));
120 const Array<IterVar>& iter_vars = realize->block->iter_vars;
121 const Array<PrimExpr>& iter_values = realize->iter_values;
122 ICHECK_EQ(iter_vars.size(), iter_values.size());
123 int n = realize->iter_values.size();
124 for (int i = 0; i < n; ++i) {
125 const IterVar& iter_var = iter_vars[i];
126 const PrimExpr& iter_value = iter_values[i];
127 std::unordered_set<const VarNode*>* vars = nullptr;
128 if (iter_var->iter_type == IterVarType::kDataPar) {
129 vars = &result.spatial_vars;
130 } else {
131 vars = &result.non_spatial_vars;
132 }
133 PostOrderVisit(iter_value, [vars](const ObjectRef& obj) {
134 if (const VarNode* var = obj.as<VarNode>()) {
135 vars->insert(var);
136 }
137 });
138 }
139 }
140
141 ScopeBlockLoopInfo result;
142 } visitor;
143 visitor(scope_block->body);
144 return std::move(visitor.result);
145}
146
147/*!
148 * \brief Check whether the given sref_a is higher than or equal to sref_b.
149 */
150void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) {
151 const StmtSRefNode* p = sref_b.get();
152 for (; p != nullptr; p = p->parent) {
153 if (p == sref_a.get()) {
154 return;
155 }
156 }
157 CHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " << sref_b;
158}
159
160/*!
161 * \brief Check the dominant property of a block:
162 * the block is the only writer of its output, dominating the reader of its output buffers under the
163 * given root scope.
164 * \param self The schedule state.
165 * \param scope_root_sref The StmtSRef corresponding to the root scope.
166 * \param block_sref The block whose dominant property is to be checked.
167 * \return A boolean indicating if the block is a dominant block.
168 */
169bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref,
170 const StmtSRef& block_sref) {
171 std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
172 CheckSRefHigherOrEqual(scope_root_sref, block_sref);
173 const BlockNode* maybe_root_block = scope_root_sref->StmtAs<BlockNode>();
174 if (maybe_root_block) {
175 BlockScope scope = self->GetBlockScope(scope_root_sref);
176 buffer_writers = scope->buffer_writers;
177 } else {
178 // Collect all child blocks of root sub-tree, and merge their buffer writers.
179 Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref);
180 for (const StmtSRef& child_block_sref : child_block_srefs) {
181 BlockScope child_scope = self->GetBlockScope(child_block_sref);
182 for (const auto& it : child_scope->buffer_writers) {
183 buffer_writers.insert(it);
184 }
185 }
186 }
187 // Check whether the input block is the only writer of its outputs
188 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
189 for (const BufferRegion& write_region : block->writes) {
190 if (buffer_writers.count(write_region->buffer)) {
191 if (buffer_writers.at(write_region->buffer).size() != 1) {
192 return false;
193 }
194 }
195 }
196 return true;
197}
198
199/*!
200 * \brief A helper function that checks whether a given block is a complete block under the scope,
201 * or return the condition it violates if it is not a complete block
202 * \param self The schedule state
203 * \param block_sref The block to be checked
204 * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in
205 * \return 0 if the block is a complete block, or a positive integer indicating which condition is
206 * first violated
207 */
208int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
209 const StmtSRef& scope_root_sref) {
210 // Cond 1. All block vars are data parallel
211 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
212 for (const IterVar& iter_var : block->iter_vars) {
213 if (iter_var->iter_type != kDataPar) {
214 return 1;
215 }
216 }
217 // Cond 2. Dominant: the block is the only writer of its output,
218 // dominating the reader of its output buffers
219 if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
220 return 2;
221 }
222 // Cond 3. No overlap between the buffers the block reads and writes
223 std::unordered_set<const BufferNode*> written_buffers;
224 written_buffers.reserve(block->writes.size());
225 for (const BufferRegion& write : block->writes) {
226 written_buffers.insert(write->buffer.get());
227 }
228 for (const BufferRegion& read : block->reads) {
229 if (written_buffers.count(read->buffer.get())) {
230 return 3;
231 }
232 }
233 return 0;
234}
235
236static const char* kCompleteBlockDefinition = R"(Definition of a complete block:
2371) All block vars are data parallel
2382) Dominant: the block is the only writer of its output, dominating the reader of its output buffers
2393) No overlap between the buffers the block reads and writes)";
240
241static const char* kReductionBlockDefinition = R"(Definition of a reduction block:
2421) The block has the `init` statement
2432) All the block bindings are quasi-affine expressions
2443) All block vars are either data parallel block vars or reduction block vars
2454) Dominant: the block is the only writer of its output, dominating the reader of its output buffers
2465) The reduction block vars are not used to index the output buffers)";
247
248static const char* kLocalCompleteBlockDefinition = R"(Definition of a local complete block:
2491) All block vars are data parallel
2502) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
2513) No overlap between the buffers the block reads and writes)";
252
253static const char* kLocalReductionBlockDefinition = R"(Definition of a reduction block:
2541) The block has the `init` statement
2552) All the block bindings are quasi-affine expressions
2563) All block vars are either data parallel block vars or reduction block vars
2574) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
2585) The reduction block vars are not used to index the output buffers)";
259
260bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
261 const StmtSRef& scope_root_sref) {
262 return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0;
263}
264
265void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
266 const StmtSRef& scope_root_sref) {
267 class IncompleteBlockError : public ScheduleError {
268 public:
269 explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond)
270 : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {}
271 String FastErrorString() const final { return "ScheduleError: Incomplete block"; }
272 String DetailRenderTemplate() const final {
273 std::ostringstream os;
274 os << "The block {0} is not a complete block - it violates condition #" << violated_cond_;
275 os << ".\n" << kCompleteBlockDefinition;
276 return os.str();
277 }
278 IRModule mod() const final { return mod_; }
279 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
280 IRModule mod_;
281 Block block_;
282 int violated_cond_;
283 };
284
285 int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref);
286 if (error_code != 0) {
287 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
288 throw IncompleteBlockError(self->mod, GetRef<Block>(block), error_code);
289 }
290}
291
292/*!
293 * \brief A helper function that checks whether a given block is a reduction block under the scope,
294 * or return the condition it violates if it is not a reduction block
295 * \param self The schedule state
296 * \param block_sref The block to be checked
297 * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in
298 * \return 0 if the block is a reduction block, or a positive integer indicating which condition is
299 * first violated
300 */
301int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
302 const StmtSRef& scope_root_sref) {
303 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
304 // Cond 1. The block has the `init` statement.
305 if (!block->init.defined()) {
306 return 1;
307 }
308 // Cond 2. All the block bindings are quasi-affine expressions.
309 if (!self->IsAffineBlockBinding(block_sref)) {
310 return 2;
311 }
312 // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile,
313 // we collect all the reduction block vars.
314 if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
315 return 3;
316 }
317 // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
318 // output buffers.
319 if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
320 return 4;
321 }
322 // Cond 5. The reduction block vars are not used to index the output buffers.
323 return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block)) ? 0 : 5;
324}
325
326bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
327 const StmtSRef& scope_root_sref) {
328 return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0;
329}
330
331void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
332 const StmtSRef& scope_root_sref) {
333 class NotReductionBlockError : public ScheduleError {
334 public:
335 explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond)
336 : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {}
337 String FastErrorString() const final { return "ScheduleError: Not a reduction block"; }
338 String DetailRenderTemplate() const final {
339 std::ostringstream os;
340 os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_;
341 os << ".\n" << kReductionBlockDefinition;
342 return os.str();
343 }
344 IRModule mod() const final { return mod_; }
345 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
346 IRModule mod_;
347 Block block_;
348 int violated_cond_;
349 };
350
351 int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref);
352 if (error_code != 0) {
353 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
354 throw NotReductionBlockError(self->mod, GetRef<Block>(block), error_code);
355 }
356}
357
358void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
359 const StmtSRef& scope_root_sref) {
360 class NotCompleteOrReductionBlockError : public ScheduleError {
361 public:
362 explicit NotCompleteOrReductionBlockError(IRModule mod, Block block,
363 int complete_block_error_code,
364 int reduction_block_error_code)
365 : mod_(mod),
366 block_(block),
367 complete_block_error_code_(complete_block_error_code),
368 reduction_block_error_code_(reduction_block_error_code) {}
369
370 String FastErrorString() const final {
371 return "ScheduleError: Not a complete or reduction block";
372 }
373 String DetailRenderTemplate() const final {
374 std::ostringstream os;
375 os << "The block {0} is not a complete block - it violates condition #"
376 << complete_block_error_code_;
377 os << ".\n" << kCompleteBlockDefinition;
378 os << "\nThe block is not a reduction block either - it violates condition #"
379 << reduction_block_error_code_;
380 os << ".\n" << kReductionBlockDefinition;
381 return os.str();
382 }
383 IRModule mod() const final { return mod_; }
384 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
385
386 IRModule mod_;
387 Block block_;
388 int complete_block_error_code_;
389 int reduction_block_error_code_;
390 };
391
392 int complete_block_error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref);
393 if (complete_block_error_code == 0) {
394 return;
395 }
396 int reduction_block_error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref);
397 if (reduction_block_error_code == 0) {
398 return;
399 }
400 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
401 throw NotCompleteOrReductionBlockError(self->mod, GetRef<Block>(block), complete_block_error_code,
402 reduction_block_error_code);
403}
404
405void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) {
406 class NotCompactDataFlowError : public ScheduleError {
407 public:
408 explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block,
409 int local_complete_block_code, int local_reduction_block_code)
410 : mod_(std::move(mod)),
411 subtree_root_(std::move(subtree_root)),
412 violate_block_(std::move(violate_block)),
413 local_complete_block_code_(local_complete_block_code),
414 local_reduction_block_code_(local_reduction_block_code) {
415 ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>());
416 }
417 String FastErrorString() const final {
418 return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, "
419 "because some of its child block on SRef tree is neither a local complete block nor a "
420 "local reduction block.";
421 }
422 String DetailRenderTemplate() const final {
423 std::ostringstream os;
424 os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
425 "its child block {1} on SRef tree is neither a local complete block nor a local "
426 "reduction block.\n";
427 os << "It violates condition #" << local_complete_block_code_
428 << " as a local complete block.\n";
429 os << kLocalCompleteBlockDefinition << "\n";
430 os << "It violates condition #" << local_reduction_block_code_
431 << " as a local reduction block.\n";
432 os << kLocalReductionBlockDefinition << "\n";
433 return os.str();
434 }
435 IRModule mod() const final { return mod_; }
436 Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; }
437
438 IRModule mod_;
439 Stmt subtree_root_;
440 Block violate_block_;
441 int local_complete_block_code_;
442 int local_reduction_block_code_;
443 };
444
445 Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root);
446 for (const StmtSRef& block_sref : child_block_srefs) {
447 int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root),
448 local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root);
449 if (local_complete_block_code != 0 && local_reduction_block_code != 0) {
450 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
451 throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt),
452 GetRef<Block>(block), local_complete_block_code,
453 local_reduction_block_code);
454 }
455 }
456}
457
458bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
459 const StmtSRef& scope_root_sref) {
460 const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref);
461 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
462 std::unordered_set<const BufferNode*> scope_allocated;
463 scope_allocated.reserve(scope_root->alloc_buffers.size());
464 for (const Buffer& buffer : scope_root->alloc_buffers) {
465 scope_allocated.insert(buffer.get());
466 }
467 for (const BufferRegion& buffer_region : block->writes) {
468 if (!scope_allocated.count(buffer_region->buffer.get())) {
469 return true;
470 }
471 }
472 return false;
473}
474
475void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
476 const StmtSRef& scope_root_sref) {
477 class OutputBlockError : public ScheduleError {
478 public:
479 explicit OutputBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {}
480 String FastErrorString() const final {
481 return "ScheduleError: Cannot operate on an output block";
482 }
483 String DetailRenderTemplate() const final { return "The block {0} is an output block"; }
484 IRModule mod() const final { return mod_; }
485 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
486
487 IRModule mod_;
488 Block block_;
489 };
490 if (IsOutputBlock(self, block_sref, scope_root_sref)) {
491 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
492 throw OutputBlockError(self->mod, GetRef<Block>(block));
493 }
494}
495
496std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) {
497 std::vector<IterVarType> results;
498 results.reserve(block->iter_vars.size());
499 for (const IterVar& iter_var : block->iter_vars) {
500 results.push_back(iter_var->iter_type);
501 }
502 return results;
503}
504
505std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
506 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
507 return GetBlockVarTypes(block);
508}
509
510bool IsWriteCache(const StmtSRef& block_sref) {
511 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
512 if (block->writes.size() != 1) {
513 return false;
514 }
515 const BufferRegion& write_region = block->writes[0];
516 for (const BufferRegion& read_region : block->reads) {
517 auto [exists, surjective, injective, ordered, no_const_read, no_shift_read] =
518 AnalyzeReadWritePattern(read_region, write_region);
519 // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
520 (void)exists;
521 (void)surjective;
522 (void)no_const_read;
523 (void)no_shift_read;
524 if (!(injective && ordered)) {
525 return false;
526 }
527 }
528 return true;
529}
530
531/******** Binding ********/
532
533bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges,
534 arith::Analyzer* analyzer) {
535 if (loop_var_ranges.empty()) {
536 return true;
537 }
538 auto res = arith::DetectIterMap(
539 /*indices=*/realize->iter_values,
540 /*input_iters=*/loop_var_ranges,
541 /*predicate=*/realize->predicate,
542 /*check_level=*/arith::IterMapLevel::Surjective,
543 /*analyzer=*/analyzer,
544 /*simplify_trivial_iterators=*/false);
545 if (res->indices.empty()) {
546 return false;
547 }
548 for (const arith::IterSumExpr& sum_expr : res->indices) {
549 const Array<arith::IterSplitExpr>& args = sum_expr->args;
550 if (!args.empty() && !is_one(args[0]->scale)) {
551 return false;
552 }
553 }
554 return true;
555}
556
557void CheckPartialAffineBinding(const ScheduleState& self, Block block,
558 const Optional<StmtSRef>& high_exclusive) {
559 class NotAffineBindingError : public ScheduleError {
560 public:
561 explicit NotAffineBindingError(IRModule mod, Block block, Optional<StmtSRef> high_exclusive)
562 : mod_(std::move(mod)), block_(std::move(block)) {
563 if (high_exclusive.defined()) {
564 high_exclusive_loop_ = high_exclusive.value()->StmtAs<ForNode>();
565 }
566 }
567 String FastErrorString() const final {
568 std::ostringstream ss;
569 if (high_exclusive_loop_) {
570 ss << "ScheduleError: The block is required to have an partial affine binding under "
571 << high_exclusive_loop_->loop_var;
572 } else {
573 ss << "ScheduleError: The block is required to have an affine binding";
574 }
575 return ss.str();
576 }
577 String DetailRenderTemplate() const final {
578 std::ostringstream ss;
579 if (high_exclusive_loop_) {
580 ss << "The block {0} is required to have an partial affine binding under "
581 << high_exclusive_loop_->loop_var;
582 } else {
583 ss << "The block {0} is required to have an affine binding";
584 }
585 return ss.str();
586 }
587 IRModule mod() const final { return mod_; }
588 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
589 IRModule mod_;
590 Block block_;
591 const ForNode* high_exclusive_loop_{nullptr};
592 };
593
594 StmtSRef block_sref = self->stmt2ref.at(block.get());
595 if (self->IsAffineBlockBinding(block_sref)) {
596 // check block cached state for global affineness
597 return;
598 }
599 if (block_sref->parent && high_exclusive.defined()) {
600 // if it is not of global affine binding, check affineness under high_exclusive,
601 arith::Analyzer analyzer;
602 Map<Var, Range> dom_map =
603 LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent), high_exclusive);
604 if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) {
605 return;
606 }
607 }
608 throw NotAffineBindingError(self->mod, std::move(block), high_exclusive);
609}
610
611void CheckAffineBinding(const ScheduleState& self, Block block) {
612 CheckPartialAffineBinding(self, std::move(block), NullOpt);
613}
614
615void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
616 class NotTrivialBindingError : public ScheduleError {
617 public:
618 explicit NotTrivialBindingError(IRModule mod, Block block)
619 : mod_(std::move(mod)), block_(std::move(block)) {}
620
621 String FastErrorString() const final {
622 return "ScheduleError: The binding values of the block are not variables of outer loops.";
623 }
624
625 String DetailRenderTemplate() const final {
626 std::ostringstream os;
627 os << "The binding values of the {0} are not variables of outer loops.";
628 return os.str();
629 }
630
631 IRModule mod() const final { return mod_; }
632 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
633
634 private:
635 IRModule mod_;
636 Block block_;
637 };
638
639 if (!IsTrivialBinding(self, block_sref)) {
640 throw NotTrivialBindingError(self->mod, GetRef<Block>(block_sref->StmtAs<BlockNode>()));
641 }
642}
643
644Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
645 const Optional<StmtSRef>& high_exclusive,
646 const runtime::StorageScope& extra_relax_scope) {
647 Map<Var, Range> result;
648 const StmtSRefNode* p = low_inclusive.get();
649 const StmtSRefNode* limit = static_cast<const StmtSRefNode*>(high_exclusive.get());
650 for (; p != limit; p = p->parent) {
651 const ForNode* loop = p->StmtAs<ForNode>();
652 if (loop == nullptr) {
653 break;
654 }
655 result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
656 }
657 if (extra_relax_scope.rank != runtime::StorageRank::kGlobal) {
658 for (; p; p = p->parent) {
659 if (const ForNode* loop = p->StmtAs<ForNode>()) {
660 if (loop->kind == ForKind::kThreadBinding) {
661 const String& thread_tag = loop->thread_binding.value()->thread_tag;
662 if (CanRelaxStorageUnderThread(extra_relax_scope,
663 runtime::ThreadScope::Create(thread_tag))) {
664 result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
665 }
666 }
667 }
668 }
669 }
670 return result;
671}
672
673Map<Var, PrimExpr> GetBindings(const BlockRealize& realize) {
674 const BlockNode* block = realize->block.get();
675 const Array<IterVar>& all_lhs = block->iter_vars;
676 const Array<PrimExpr>& all_rhs = realize->iter_values;
677 ICHECK_EQ(all_lhs.size(), all_rhs.size());
678 Map<Var, PrimExpr> result;
679 for (int i = 0, n = all_lhs.size(); i < n; ++i) {
680 const IterVar& lhs = all_lhs[i];
681 const PrimExpr& rhs = all_rhs[i];
682 result.Set(lhs->var, rhs);
683 }
684 return result;
685}
686
687bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
688 std::unordered_set<const VarNode*>* data_par_vars,
689 std::unordered_set<const VarNode*>* reduce_vars) {
690 Block block = block_realize->block;
691 ICHECK(block_realize->block.same_as(block))
692 << "ValueError: The input `block_realize` is required to be the exact BlockRealize of the "
693 "input block";
694
695 bool has_block_vars_of_other_types = false;
696 ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size());
697 int n = static_cast<int>(block->iter_vars.size());
698 for (int i = 0; i < n; ++i) {
699 const IterVar& iter_var = block->iter_vars[i];
700 const PrimExpr& iter_value = block_realize->iter_values[i];
701 std::unordered_set<const VarNode*>* set = nullptr;
702 if (iter_var->iter_type == IterVarType::kDataPar) {
703 set = data_par_vars;
704 } else if (iter_var->iter_type == IterVarType::kCommReduce) {
705 set = reduce_vars;
706 } else {
707 has_block_vars_of_other_types = true;
708 }
709 if (set == nullptr) {
710 continue;
711 }
712 Array<Var> vars_in_binding = UndefinedVars(iter_value);
713 for (const Var& var : vars_in_binding) {
714 set->insert(var.get());
715 }
716 }
717
718 return has_block_vars_of_other_types;
719}
720
721/******** Loop properties ********/
722
723void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
724 arith::Analyzer* analyzer) {
725 class LoopNotStartWithZeroError : public ScheduleError {
726 public:
727 explicit LoopNotStartWithZeroError(IRModule mod, For loop)
728 : mod_(mod), loop_(std::move(loop)) {}
729
730 String FastErrorString() const final {
731 return "ScheduleError: The primitive only supports loop starting with 0";
732 }
733
734 String DetailRenderTemplate() const final {
735 return "The loop {0} does not start with 0, which is not supported";
736 }
737
738 IRModule mod() const final { return mod_; }
739 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
740
741 IRModule mod_;
742 For loop_;
743 };
744 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
745 if (!analyzer->CanProve(loop->min == 0)) {
746 throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
747 }
748}
749
750/******** Block-loop relation ********/
751
752Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self,
753 const StmtSRef& parent_sref) {
754 Array<BlockRealize> child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref);
755 Array<StmtSRef> child_block_srefs;
756 child_block_srefs.reserve(child_block_realize.size());
757
758 for (BlockRealize realize : child_block_realize) {
759 child_block_srefs.push_back(self->stmt2ref.at(realize->block.get()));
760 }
761 return child_block_srefs;
762}
763
764Array<BlockRealize> GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) {
765 struct Collector : public StmtVisitor {
766 static Array<BlockRealize> Collect(const Stmt& stmt) {
767 Collector collector;
768 collector(stmt);
769 return std::move(collector.result_);
770 }
771
772 void VisitStmt_(const BlockRealizeNode* block_realize) final {
773 result_.push_back(GetRef<BlockRealize>(block_realize));
774 }
775
776 Array<BlockRealize> result_;
777 };
778
779 if (parent_sref->stmt->IsInstance<ForNode>()) {
780 const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
781 return Collector::Collect(loop->body);
782 } else if (parent_sref->stmt->IsInstance<BlockNode>()) {
783 const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
784 return Collector::Collect(block->body);
785 }
786 ICHECK(false) << "Unreachable";
787 throw;
788}
789
790BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self,
791 const StmtSRef& parent_sref) {
792 class NonSingleChildBlockError : public ScheduleError {
793 public:
794 explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref)
795 : mod_(std::move(mod)), stmt_(GetRef<Stmt>(sref->stmt)) {
796 sref_type_ = stmt_.as<BlockNode>() != nullptr ? "block" : "loop";
797 }
798
799 String FastErrorString() const final {
800 std::ostringstream os;
801 os << "ScheduleError: The " << sref_type_ << " is required to have only one child block";
802 return os.str();
803 }
804
805 String DetailRenderTemplate() const final {
806 std::ostringstream os;
807 os << "The " << sref_type_ << " {0} is required to have only one child block";
808 return os.str();
809 }
810
811 IRModule mod() const final { return mod_; }
812 Array<ObjectRef> LocationsOfInterest() const final { return {stmt_}; }
813
814 IRModule mod_;
815 Stmt stmt_;
816 String sref_type_;
817 };
818
819 Array<BlockRealize> child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref);
820 if (child_block_realize.size() != 1) {
821 throw NonSingleChildBlockError(self->mod, parent_sref);
822 }
823 return child_block_realize[0];
824}
825
826BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref) {
827 struct BlockRealizeFinder : public StmtVisitor {
828 explicit BlockRealizeFinder(const BlockNode* target_block)
829 : target_block(target_block), result(nullptr) {}
830
831 void VisitStmt(const Stmt& stmt) final {
832 if (result != nullptr) {
833 return;
834 }
835 StmtVisitor::VisitStmt(stmt);
836 }
837
838 void VisitStmt_(const BlockRealizeNode* block_realize) final {
839 if (block_realize->block.get() == target_block) {
840 result = block_realize;
841 }
842 // No need to visit recursively, since the deeper BlockRealizes must not be the result.
843 }
844
845 const BlockNode* target_block;
846 const BlockRealizeNode* result;
847 };
848
849 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
850 if (block_sref->parent == nullptr) {
851 const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr);
852 return Downcast<BlockRealize>(func->body);
853 } else {
854 BlockRealizeFinder finder(block);
855 finder(GetRef<Stmt>(block_sref->parent->stmt));
856 ICHECK(finder.result != nullptr)
857 << "InternalError: Cannot find the BlockRealize of block " << GetRef<Block>(block);
858 return GetRef<BlockRealize>(finder.result);
859 }
860}
861
862IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
863 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
864 const Var& loop_var = loop->loop_var;
865 int n_spatial = 0;
866 int n_reduce = 0;
867 int n_other = 0;
868 auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool {
869 if (const auto* realize = obj.as<BlockRealizeNode>()) {
870 const BlockNode* block = realize->block.get();
871 // Number of block vars and their bindings
872 ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size());
873 size_t n = realize->iter_values.size();
874 for (size_t i = 0; i < n; ++i) {
875 const IterVar& iter_var = block->iter_vars[i];
876 const PrimExpr& binding = realize->iter_values[i];
877 // Categorize the current block var
878 int* ref = nullptr;
879 if (iter_var->iter_type == IterVarType::kDataPar) {
880 ref = &n_spatial;
881 } else if (iter_var->iter_type == IterVarType::kCommReduce) {
882 ref = &n_reduce;
883 } else {
884 ref = &n_other;
885 }
886 // Visit the binding to see if `loop_var` appears
887 PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void {
888 if (obj.same_as(loop_var)) {
889 (*ref) += 1;
890 }
891 });
892 }
893 return false;
894 }
895 return true;
896 };
897 PreOrderVisit(loop->body, f_visit);
898 if (n_other) {
899 return IterVarType::kOpaque;
900 } else if (n_spatial && n_reduce) {
901 return IterVarType::kOpaque;
902 } else if (n_reduce) {
903 return IterVarType::kCommReduce;
904 } else {
905 return IterVarType::kDataPar;
906 }
907}
908
909StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs) {
910 CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref";
911
912 std::unordered_map<const StmtSRefNode*, size_t> sref_visited_cnt;
913 for (const StmtSRef& sref : srefs) {
914 const StmtSRefNode* p = sref.get();
915 while (p != nullptr) {
916 ++sref_visited_cnt[p];
917 p = p->parent;
918 }
919 }
920 size_t n_sref = srefs.size();
921 const StmtSRefNode* p = srefs[0].get();
922 while (p != nullptr && sref_visited_cnt[p] != n_sref) {
923 p = p->parent;
924 }
925 ICHECK(p != nullptr);
926 return GetRef<StmtSRef>(p);
927}
928
929bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) {
930 return tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined();
931}
932
933std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self,
934 const StmtSRef& block_sref) {
935 Array<StmtSRef> location_srefs;
936 std::vector<int> location_indices;
937
938 // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can
939 // be inlined.
940 if (CanComputeInline(self, block_sref)) {
941 location_srefs.push_back(StmtSRef::InlineMark());
942 location_indices.push_back(-2);
943 }
944 location_srefs.push_back(StmtSRef::RootMark());
945 location_indices.push_back(-1);
946
947 // Step 2. If the block has no consumer, there is no more candidate.
948 Array<StmtSRef> consumers = GetConsumers(self, block_sref);
949 if (consumers.empty()) {
950 return std::make_pair(location_srefs, location_indices);
951 }
952
953 // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If
954 // such a loop cannot be found, there is no more candidate and we just return.
955 StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers)
956 : GetRef<StmtSRef>(consumers[0]->parent);
957 if (loop_boundary->StmtAs<ForNode>() == nullptr) {
958 return std::make_pair(location_srefs, location_indices);
959 }
960
961 // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position
962 // of the boundary loop reveals the number of possible additional candidates.
963 Array<StmtSRef> loop_srefs = GetLoops(consumers[0]);
964 size_t lca_pos =
965 std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin();
966 ICHECK_LT(lca_pos, loop_srefs.size());
967 size_t n_candidate = lca_pos + 1;
968
969 // Step 5. Find the position of the deepest data-parallel loop among the candidate loops. This
970 // position is used for removing the unwanted candidates from the perspective of performance.
971 std::vector<IterVarType> loop_iter_types;
972 loop_iter_types.reserve(n_candidate);
973 int i_last_datapar = -1;
974 for (size_t i = 0; i < n_candidate; ++i) {
975 // TODO(siyuan): improve the performance
976 IterVarType iter_type = GetLoopIterType(loop_srefs[i]);
977 loop_iter_types.push_back(iter_type);
978 if (iter_type == IterVarType::kDataPar) {
979 i_last_datapar = i;
980 }
981 }
982 // Step 6. Check and add the candidates in turn according to the following rules:
983 // - skip the unit loops (loops with extent 1);
984 // - do not consider the data-parallel loops after a not-data-parallel loop;
985 // - do not consider the trailing not-data-parallel loops.
986 location_srefs.reserve(n_candidate + 2);
987 location_indices.reserve(n_candidate + 2);
988 bool visited_reduce = false;
989 for (size_t i = 0; i < n_candidate; ++i) {
990 const int64_t* loop_extent = GetLoopIntExtent(loop_srefs[i]);
991 if (loop_extent != nullptr && *loop_extent == 1) {
992 continue;
993 }
994
995 if (loop_iter_types[i] == IterVarType::kDataPar) {
996 if (visited_reduce) {
997 break;
998 }
999 } else {
1000 visited_reduce = true;
1001 if (static_cast<int>(i) > i_last_datapar) {
1002 break;
1003 }
1004 }
1005 if (CanComputeAt(self, block_sref, loop_srefs[i], true)) {
1006 location_srefs.push_back(loop_srefs[i]);
1007 location_indices.push_back(i);
1008 }
1009 }
1010
1011 return std::make_pair(location_srefs, location_indices);
1012}
1013
1014/******** Producer-consumer relation ********/
1015
1016Array<StmtSRef> GetProducers(const StmtSRef& block_sref, const BlockScope& scope) {
1017 Array<Dependency> edges = scope->GetDepsByDst(block_sref);
1018 Array<StmtSRef> results;
1019 std::unordered_set<StmtSRef, ObjectPtrHash, ObjectPtrEqual> result_set;
1020 results.reserve(edges.size());
1021 for (const Dependency& edge : edges) {
1022 if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) &&
1023 !result_set.count(edge->src)) {
1024 results.push_back(edge->src);
1025 result_set.emplace(edge->src);
1026 }
1027 }
1028 return results;
1029}
1030
1031Array<StmtSRef> GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) {
1032 Array<Dependency> edges = scope->GetDepsBySrc(block_sref);
1033 Array<StmtSRef> results;
1034 std::unordered_set<StmtSRef, ObjectPtrHash, ObjectPtrEqual> result_set;
1035 results.reserve(edges.size());
1036 for (const Dependency& edge : edges) {
1037 if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) &&
1038 !result_set.count(edge->dst)) {
1039 results.push_back(edge->dst);
1040 result_set.emplace(edge->dst);
1041 }
1042 }
1043 return results;
1044}
1045
1046ProducerConsumerSplit ProducerConsumerSplit::Find(
1047 const ScheduleState& self, const Array<Stmt>& subtrees,
1048 const Array<StmtSRef>& producer_block_srefs, const Array<StmtSRef>& consumer_block_srefs,
1049 std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
1050 class InsertionPointNotFoundError : public ScheduleError {
1051 public:
1052 explicit InsertionPointNotFoundError(IRModule mod, int last_producer_position,
1053 int first_consumer_position)
1054 : mod_(mod),
1055 last_producer_position_(last_producer_position),
1056 first_consumer_position_(first_consumer_position) {}
1057
1058 String FastErrorString() const final {
1059 return "ScheduleError: Cannot find the insertion point that satisfies the producer-consumer "
1060 "constraint";
1061 }
1062
1063 String DetailRenderTemplate() const final {
1064 return "Cannot find the insertion point that satisfies the producer-consumer constraint. In "
1065 "0-based indexing, the last producer appears in subtree " +
1066 std::to_string(last_producer_position_) +
1067 ", and the first consumer appears in subtree " +
1068 std::to_string(first_consumer_position_);
1069 }
1070
1071 IRModule mod() const final { return mod_; }
1072
1073 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
1074
1075 private:
1076 IRModule mod_;
1077 int last_producer_position_;
1078 int first_consumer_position_;
1079 };
1080
1081 class Finder : public StmtVisitor {
1082 public:
1083 void VisitStmt_(const BlockRealizeNode* realize) final {
1084 const BlockNode* block = realize->block.get();
1085 if (block2realize_) {
1086 block2realize_->emplace(block, realize);
1087 }
1088 if (producer_blocks_.count(block)) {
1089 ++this->n_producers_visited_;
1090 }
1091 if (consumer_blocks_.count(block)) {
1092 ++this->n_consumers_visited_;
1093 }
1094 }
1095
1096 std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize_;
1097 std::unordered_set<const StmtNode*> producer_blocks_;
1098 std::unordered_set<const StmtNode*> consumer_blocks_;
1099 int n_producers_visited_ = 0;
1100 int n_consumers_visited_ = 0;
1101 };
1102
1103 Finder finder;
1104 finder.block2realize_ = block2realize;
1105 // Set up the lookup table for producers
1106 finder.producer_blocks_.reserve(producer_block_srefs.size());
1107 for (const StmtSRef& block_sref : producer_block_srefs) {
1108 finder.producer_blocks_.insert(block_sref->stmt);
1109 }
1110 // Set up the lookup table for consumers
1111 finder.consumer_blocks_.reserve(consumer_block_srefs.size());
1112 for (const StmtSRef& block_sref : consumer_block_srefs) {
1113 finder.consumer_blocks_.insert(block_sref->stmt);
1114 }
1115 // Visit the subtrees
1116 int n = subtrees.size();
1117 int last_producer_position = -1;
1118 int first_consumer_position = n;
1119 for (int i = 0; i < n; ++i) {
1120 int n_producers_visited_before = finder.n_producers_visited_;
1121 int n_consumers_visited_before = finder.n_consumers_visited_;
1122 finder(subtrees[i]);
1123 // Check if the subtree contains at least a producer
1124 if (finder.n_producers_visited_ != n_producers_visited_before) {
1125 last_producer_position = i;
1126 }
1127 // Check if the subtree contains at least a consumer
1128 if (finder.n_consumers_visited_ != n_consumers_visited_before) {
1129 if (first_consumer_position == n) {
1130 first_consumer_position = i;
1131 }
1132 }
1133 }
1134 if (last_producer_position >= first_consumer_position) {
1135 throw InsertionPointNotFoundError(self->mod, last_producer_position, first_consumer_position);
1136 }
1137 return ProducerConsumerSplit{last_producer_position, //
1138 first_consumer_position, //
1139 finder.n_producers_visited_, //
1140 finder.n_consumers_visited_};
1141}
1142
1143/******** Block-buffer relation ********/
1144
1145BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n,
1146 BufferIndexType index_type) {
1147 class BufferIndexOutOfRangeError : public ScheduleError {
1148 public:
1149 explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index,
1150 BufferIndexType index_type)
1151 : mod_(std::move(mod)),
1152 block_(std::move(block)),
1153 buffer_index_(buffer_index),
1154 index_type_(index_type) {}
1155
1156 String FastErrorString() const final {
1157 if (index_type_ == BufferIndexType::kWrite) {
1158 return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
1159 "range "
1160 "[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
1161 "written by the block.";
1162 } else {
1163 return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
1164 "range "
1165 "[0, num_read_regions) where `num_read_regions` is the number of buffer regions "
1166 "read by the block.";
1167 }
1168 }
1169
1170 String DetailRenderTemplate() const final {
1171 std::ostringstream os;
1172 size_t num =
1173 index_type_ == BufferIndexType::kWrite ? block_->writes.size() : block_->reads.size();
1174 os << "The block {0} has " << num << " " << BufferIndexType2Str(index_type_)
1175 << " regions, so `buffer_index` is required to be in [0, " << num
1176 << "). However, the input `buffer_index` is " << buffer_index_
1177 << ", which is out of the expected range.";
1178 return os.str();
1179 }
1180
1181 IRModule mod() const final { return mod_; }
1182 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
1183
1184 private:
1185 IRModule mod_;
1186 Block block_;
1187 int buffer_index_;
1188 BufferIndexType index_type_;
1189 };
1190
1191 const Array<BufferRegion>& access_region =
1192 index_type == BufferIndexType::kWrite ? block->writes : block->reads;
1193
1194 if (n < 0 || static_cast<int>(access_region.size()) <= n) {
1195 throw BufferIndexOutOfRangeError(self->mod, block, n, index_type);
1196 }
1197 return access_region[n];
1198}
1199
1200Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
1201 BufferIndexType index_type) {
1202 return GetNthAccessBufferRegion(self, block, n, index_type)->buffer;
1203}
1204
1205std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
1206 const Buffer& buffer) {
1207 // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or
1208 // match_buffers.
1209 const StmtSRefNode* defining_site_sref = block_sref.get();
1210 while (defining_site_sref != nullptr) {
1211 const auto* block = defining_site_sref->StmtAs<BlockNode>();
1212 // If this sref is not a block sref, skip it.
1213 if (block == nullptr) {
1214 defining_site_sref = defining_site_sref->parent;
1215 continue;
1216 }
1217 // Try to find the buffer in `allloc_buffers`
1218 for (const Buffer& alloc_buffer : block->alloc_buffers) {
1219 if (buffer.same_as(alloc_buffer)) {
1220 return {GetRef<StmtSRef>(defining_site_sref), true};
1221 }
1222 }
1223 // We do not allow the buffer being defined in `match_buffer`.
1224 for (const MatchBufferRegion match_buffer : block->match_buffers) {
1225 if (buffer.same_as(match_buffer)) {
1226 return {GetRef<StmtSRef>(defining_site_sref), false};
1227 }
1228 }
1229 defining_site_sref = defining_site_sref->parent;
1230 }
1231 // If we cannot find the defining site block, it means that the buffer must be in the function's
1232 // buffer_map, which isn't an intermediate buffer.
1233 return {NullOpt, false};
1234}
1235
1236/******** SRef Tree Related ********/
1237
1238StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
1239 const StmtSRefNode* p = sref.get();
1240 for (; p->parent != nullptr; p = p->parent) {
1241 }
1242 return GetRef<StmtSRef>(p);
1243}
1244
1245/******** Misc ********/
1246
1247bool HasOp(const Stmt& stmt, const Array<Op>& ops) {
1248 std::unordered_set<const Object*> op_set;
1249 op_set.reserve(ops.size());
1250 for (const Op& op : ops) {
1251 op_set.insert(op.operator->());
1252 }
1253 bool found = false;
1254 PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool {
1255 if (found) {
1256 return false;
1257 }
1258 if (const auto* call = obj.as<CallNode>()) {
1259 if (op_set.count(call->op.operator->())) {
1260 found = true;
1261 }
1262 }
1263 return !found;
1264 });
1265 return found;
1266}
1267
1268bool HasIfThenElse(const Stmt& stmt) {
1269 bool has_branch = false;
1270 auto f_visit = [&has_branch](const ObjectRef& obj) -> bool {
1271 if (has_branch) {
1272 // stop visiting
1273 return false;
1274 }
1275 if (const auto* realize = obj.as<BlockRealizeNode>()) {
1276 // Case 1: BlockRealize
1277 if (!is_one(realize->predicate)) {
1278 has_branch = true;
1279 }
1280 } else if (obj->IsInstance<IfThenElseNode>() || obj->IsInstance<SelectNode>()) {
1281 // Case 2: IfThenElse / Select
1282 has_branch = true;
1283 } else if (const auto* call = obj.as<CallNode>()) {
1284 // Case 3: Call the `if_then_else` operator
1285 static const Op& op_if_then_else = Op::Get("tir.if_then_else");
1286 if (call->op.same_as(op_if_then_else)) {
1287 has_branch = true;
1288 }
1289 }
1290 return !has_branch;
1291 };
1292 PreOrderVisit(stmt, f_visit);
1293 return has_branch;
1294}
1295
1296std::tuple</*exists=*/bool,
1297 /*surjective=*/bool,
1298 /*injective=*/bool,
1299 /*ordered=*/bool,
1300 /*no_const_read=*/bool,
1301 /*no_shift_read=*/bool>
1302AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region) {
1303 static constexpr const std::tuple<bool, bool, bool, bool, bool, bool> kNotExist =
1304 std::make_tuple(false, false, false, false, false, false);
1305 // Step 1. Extract the write indices
1306 int w_dim = write_region->buffer->shape.size();
1307 std::unordered_map<const VarNode*, int> var2idx;
1308 var2idx.reserve(w_dim);
1309 for (int i = 0; i < w_dim; ++i) {
1310 const Range& dom = write_region->region[i];
1311 if (as_const_int(dom->extent) == nullptr) {
1312 return kNotExist;
1313 }
1314 if (const auto* v = dom->min.as<VarNode>()) {
1315 var2idx.emplace(v, i);
1316 } else {
1317 return kNotExist;
1318 }
1319 }
1320 // Step 2. Map each read index to a write index
1321 bool no_const_read = true;
1322 bool no_shift_read = true;
1323 int r_dim = read_region->buffer->shape.size();
1324 std::vector<int> mapped(r_dim, -1);
1325 for (int i = 0; i < r_dim; ++i) {
1326 const Range& dom = read_region->region[i];
1327 if (as_const_int(dom->extent) == nullptr) {
1328 return kNotExist;
1329 }
1330 // Case 1. Read index is a constant
1331 if (as_const_int(dom->min) != nullptr) {
1332 no_const_read = false;
1333 continue;
1334 }
1335 // Case 2. Read index cannot be recognized as `var +/- const`
1336 // where `var` is a write index and `const` is an optional constant shift
1337 Optional<IntImm> opt_const = NullOpt;
1338 const VarNode* var =
1339 static_cast<const VarNode*>(AnalyzeVarWithShift(dom->min, &opt_const).get());
1340 if (var == nullptr || !var2idx.count(var)) {
1341 return kNotExist;
1342 }
1343 // Case 3. Read index is `var +/- const`
1344 mapped[i] = var2idx.at(var);
1345 if (opt_const.defined()) {
1346 no_shift_read = false;
1347 }
1348 }
1349 // Step 3. Check if the mapping is ordered, and count how many times each var is mapped
1350 std::vector<int> mapped_counter(w_dim, 0);
1351 bool ordered = true;
1352 int last_mapped = -1;
1353 for (int i : mapped) {
1354 if (i != -1) {
1355 ++mapped_counter[i];
1356 if (last_mapped != -1 && last_mapped > i) {
1357 ordered = false;
1358 }
1359 last_mapped = i;
1360 }
1361 }
1362 // Step 4. Check if the mapping is surjective or injective
1363 // Surjective: each write index is mapped at least once
1364 // Injective: each write index is mapped at most once
1365 bool surjective = true;
1366 bool injective = true;
1367 for (int cnt : mapped_counter) {
1368 if (cnt == 0) {
1369 surjective = false;
1370 } else if (cnt >= 2) {
1371 injective = false;
1372 }
1373 }
1374 return std::make_tuple(/*exist=*/true, surjective, injective, ordered, no_const_read,
1375 no_shift_read);
1376}
1377
1378/******** Storage Scope ********/
1379
1380void CheckStorageScope(const ScheduleState& self, String storage_scope) {
1381 class InvalidStorageScopeError : public ScheduleError {
1382 public:
1383 explicit InvalidStorageScopeError(IRModule mod, String storage_scope)
1384 : mod_(std::move(mod)), storage_scope_(std::move(storage_scope)) {}
1385
1386 String FastErrorString() const final {
1387 return "ScheduleError: The input storage scope is invalid";
1388 }
1389
1390 String DetailRenderTemplate() const final {
1391 return "The input storage scope \"" + storage_scope_ + "\" is invalid.";
1392 }
1393
1394 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
1395 IRModule mod() const final { return mod_; }
1396
1397 private:
1398 IRModule mod_;
1399 String storage_scope_;
1400 };
1401
1402 try {
1403 runtime::StorageScope::Create(std::string(storage_scope));
1404 } catch (...) {
1405 throw InvalidStorageScopeError(self->mod, std::move(storage_scope));
1406 }
1407}
1408
1409bool IsSpatial(const StmtSRef& block_sref) {
1410 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
1411 for (const IterVar& iter_var : block->iter_vars) {
1412 if (iter_var->iter_type != IterVarType::kDataPar) {
1413 return false;
1414 }
1415 }
1416 return true;
1417}
1418
1419bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
1420 TVM_SREF_TO_BLOCK(block_sref);
1421 Array<StmtSRef> loops = GetLoops(block_sref);
1422 Array<PrimExpr> binds = GetBlockRealize(self, block_sref)->iter_values;
1423 if (loops.size() != binds.size()) {
1424 return false;
1425 }
1426 for (int i = 0, n = loops.size(); i < n; ++i) {
1427 const ForNode* loop = TVM_SREF_TO_FOR(loops[i]);
1428 if (binds[i].get() != loop->loop_var.get()) {
1429 return false;
1430 }
1431 }
1432 return true;
1433}
1434
1435bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) {
1436 if (HasBeenMultiLevelTiled(block_sref)) {
1437 return false;
1438 }
1439 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
1440 if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) ||
1441 !IsTrivialBinding(self, block_sref)) {
1442 return false;
1443 }
1444 const BufferNode* write_buffer = block->writes[0]->buffer.get();
1445 // Step 1. Sort out spatial block variables. Skip the block iters of domain [0, 1), since such
1446 // block iters distracts the following check of the unused block iters.
1447 std::vector<const VarNode*> spatial_block_vars;
1448 spatial_block_vars.reserve(block->iter_vars.size());
1449 for (const IterVar& block_var : block->iter_vars) {
1450 const int64_t* dom_min = as_const_int(block_var->dom->min);
1451 const int64_t* dom_extent = as_const_int(block_var->dom->extent);
1452 bool has_trivial_dom =
1453 dom_min != nullptr && dom_extent != nullptr && *dom_min == 0 && *dom_extent == 1;
1454 if (block_var->iter_type == IterVarType::kDataPar && !has_trivial_dom) {
1455 spatial_block_vars.push_back(block_var->var.get());
1456 }
1457 }
1458 // Step 2. Enumerate each read region, check the number of block vars that are not used
1459 // to index the read region
1460 int total_unused_block_vars = 0;
1461 std::unordered_set<const BufferNode*> read_buffers;
1462 read_buffers.reserve(block->reads.size());
1463 for (const BufferRegion& buffer_region : block->reads) {
1464 const BufferNode* buffer = buffer_region->buffer.get();
1465 const Array<Range>& regions = buffer_region->region;
1466 // Step 2.1. Duplication of read buffers are not allowed
1467 if (read_buffers.insert(buffer).second == false) {
1468 return false;
1469 }
1470 // Step 2.2. Skip the reduction buffer
1471 if (buffer == write_buffer) {
1472 continue;
1473 }
1474 // Step 2.3. Collect the block vars that are used to index the read region
1475 std::unordered_set<const VarNode*> vars;
1476 for (const Range& range : regions) {
1477 if (as_const_int(range->extent) == nullptr) {
1478 return false;
1479 }
1480 for (const Var& var : UndefinedVars(range->min)) {
1481 vars.insert(var.get());
1482 }
1483 }
1484 // Step 2.4. Check if the block vars are not used to index the read region
1485 int n_unused_block_vars = 0;
1486 for (const VarNode* block_var : spatial_block_vars) {
1487 if (vars.count(block_var) == 0) {
1488 ++n_unused_block_vars;
1489 }
1490 }
1491 total_unused_block_vars += n_unused_block_vars;
1492 }
1493 return total_unused_block_vars >= 1;
1494}
1495
1496bool IsSpatialPrimFunc(const PrimFunc& func) {
1497 bool result = true;
1498 PreOrderVisit(func->body, [&result](const ObjectRef& obj) {
1499 if (result == false) {
1500 return false;
1501 }
1502 if (const auto* block = obj.as<BlockNode>()) {
1503 for (const IterVar& iter_var : block->iter_vars) {
1504 if (iter_var->iter_type != IterVarType::kDataPar) {
1505 result = false;
1506 return false;
1507 }
1508 }
1509 }
1510 return true;
1511 });
1512 return result;
1513}
1514
1515std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self,
1516 const tir::StmtSRef& block_sref) {
1517 Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
1518 int64_t cum_space_len = 1, cum_reduce_len = 1;
1519 /*
1520 * Return (-1, -1) if
1521 * 1. there is some loop with type other than kDataPar and kCommReduce;
1522 * 2. there is some loop which is dynamic.
1523 */
1524 for (const tir::StmtSRef& loop_sref : loops) {
1525 tir::IterVarType type = GetLoopIterType(loop_sref);
1526 if (type == tir::kDataPar) {
1527 const int64_t* extent = GetLoopIntExtent(loop_sref);
1528 if (*extent != -1) {
1529 cum_space_len *= *extent;
1530 } else {
1531 return std::make_pair(-1, -1);
1532 }
1533 } else if (type == tir::kCommReduce) {
1534 const int64_t* extent = GetLoopIntExtent(loop_sref);
1535 if (*extent != -1) {
1536 cum_reduce_len *= *extent;
1537 } else {
1538 return std::make_pair(-1, -1);
1539 }
1540 } else {
1541 return std::make_pair(-1, -1);
1542 }
1543 }
1544 return std::make_pair(cum_space_len, cum_reduce_len);
1545}
1546
1547bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
1548 const tir::StmtSRef& block_sref, //
1549 int64_t max_parallel_extent, //
1550 int64_t max_parallel_basic) {
1551 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
1552 Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
1553
1554 // Cond 1. The block must have at lease one write buffer
1555 if (block->writes.size() == 0) {
1556 return false;
1557 }
1558
1559 // Cond 2. The block is a reduction block and has trivial binding.
1560 const StmtSRef& scope_sref = GetScopeRoot(self, block_sref,
1561 /*require_stage_pipeline=*/false);
1562 if (!IsReductionBlock(self, block_sref, scope_sref) //
1563 || !IsTrivialBinding(self, block_sref) //
1564 || HasBeenMultiLevelTiled(block_sref)) {
1565 return false;
1566 }
1567
1568 // Cond 3. Every the loop axis must be either spatial axis or reduction axis.
1569 for (const tir::StmtSRef& loop_sref : loops) {
1570 const tir::IterVarType& type = GetLoopIterType(loop_sref);
1571 if (type != tir::kDataPar && type != tir::kCommReduce) {
1572 return false;
1573 }
1574 }
1575
1576 // Cond 4. Whether there is at least one reduction loop.
1577 // Cond 5. The loops are continuous, and the body of the innermost loop is exactly the block.
1578 bool has_reduction_loop = false;
1579 for (size_t i = 0; i < loops.size(); ++i) {
1580 // Cond 4.
1581 if (GetLoopIterType(loops[i]) == tir::kCommReduce) {
1582 has_reduction_loop = true;
1583 }
1584
1585 // Cond 5.
1586 const ForNode* loop_i = TVM_SREF_TO_FOR(loops[i]);
1587 if (i < loops.size() - 1) {
1588 const ForNode* loop_i1 = TVM_SREF_TO_FOR(loops[i + 1]);
1589 if (loop_i->body.get() != loop_i1) {
1590 return false;
1591 }
1592 } else {
1593 const auto* block_realize = loop_i->body.as<tir::BlockRealizeNode>();
1594 if (!block_realize || block_realize->block.get() != block) {
1595 return false;
1596 }
1597 }
1598 }
1599 if (!has_reduction_loop) {
1600 return false;
1601 }
1602
1603 // Cond 6. Can successfully calculating the cumulative loop length.
1604 auto [cum_space_len, cum_reduce_len] = GetCumulativeSpaceAndReductionLength(self, block_sref);
1605 if (cum_space_len == -1 || cum_reduce_len == -1) {
1606 return false;
1607 }
1608
1609 // Cond 7.
1610 if (NeedsMultiLevelTiling(self, block_sref)) {
1611 // Do not use rfactor/cross-thread-reduction if we have enough parallelism on spatial loops.
1612 return !(cum_space_len >= cum_reduce_len || cum_space_len > max_parallel_extent);
1613 } else {
1614 // Always try rfactor/cross-thread-reduction for other reduction blocks.
1615 return cum_reduce_len > 1;
1616 }
1617}
1618
1619PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) {
1620 auto simplified = analyzer->Simplify(expr);
1621 if (simplified->IsInstance<IntImmNode>()) {
1622 return expr;
1623 } else {
1624 return simplified;
1625 }
1626}
1627
1628TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
1629
1630/*! \brief Auxiliary data structure of information extracted from tensor intrin description */
1631struct TensorIntrinDescInfo {
1632 /*! \brief The block of the description function, which is the (unique) direct child of the root
1633 * block.
1634 */
1635 const BlockRealizeNode* desc_block = nullptr;
1636 /*! \brief The loops of the description function, in the order from outer loops to inner ones. */
1637 std::vector<const tir::ForNode*> desc_loops;
1638 /*! \brief The loop variables. */
1639 std::unordered_set<const tir::VarNode*> desc_loop_vars;
1640};
1641
1642/*!
1643 * \brief Extract auxilary information from the tensor intrin description.
1644 * \param analyze The arithmetic analyzer
1645 * \param desc_func The description PrimFunc
1646 * \return The auxilary information
1647 */
1648TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer,
1649 const PrimFunc& desc_func) {
1650 TensorIntrinDescInfo info;
1651 const auto* desc_scope_realize = desc_func->body.as<BlockRealizeNode>();
1652 ICHECK(desc_scope_realize);
1653 {
1654 auto f_visit = [&](const ObjectRef& obj) -> bool {
1655 // Extract the block
1656 if (const auto* block = obj.as<BlockRealizeNode>()) {
1657 info.desc_block = block;
1658 return false;
1659 }
1660 // Extract the loops
1661 if (const auto* loop = obj.as<ForNode>()) {
1662 info.desc_loops.push_back(loop);
1663 info.desc_loop_vars.insert(loop->loop_var.get());
1664 if (!analyzer->CanProve(loop->min == 0)) {
1665 return false;
1666 }
1667 }
1668 return true;
1669 };
1670 tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
1671 std::reverse(info.desc_loops.begin(), info.desc_loops.end());
1672 ICHECK(info.desc_block);
1673 }
1674 return info;
1675}
1676
1677Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
1678 const tir::StmtSRef& block_sref,
1679 const tir::PrimFunc& desc_func,
1680 bool allow_padding) {
1681 arith::Analyzer analyzer;
1682 const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
1683 // Step 1. Analyze desc_func, extract its block, loops and loop vars
1684 TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func);
1685 // Step 2. Collect loops from block_sref
1686 const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
1687 TVM_SREF_TO_BLOCK(scope_sref);
1688 std::vector<const tir::ForNode*> block_loops;
1689 std::unordered_set<const tir::VarNode*> block_loop_vars;
1690 {
1691 for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
1692 const auto* loop = loop_sref->StmtAs<tir::ForNode>();
1693 if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
1694 break;
1695 }
1696 block_loops.push_back(loop);
1697 block_loop_vars.insert(loop->loop_var.get());
1698 if (!analyzer.CanProve(loop->min == 0)) {
1699 return NullOpt;
1700 }
1701 }
1702 std::reverse(block_loops.begin(), block_loops.end());
1703 }
1704 // Step 3. Map from block loops to desc block loops
1705 const std::vector<const ForNode*>& desc_loops = desc_info.desc_loops;
1706 const std::unordered_set<const VarNode*>& desc_loop_vars = desc_info.desc_loop_vars;
1707 const BlockRealizeNode* desc_block = desc_info.desc_block;
1708 ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
1709 const int n_block_vars = block->iter_values.size();
1710 const int n_desc_vars = desc_block->iter_values.size();
1711 const int offset = n_block_vars - n_desc_vars;
1712
1713 std::unordered_map<int, int> block_index_to_padding; // padding of each block iter if necessary
1714
1715 if (offset < 0) {
1716 return NullOpt;
1717 }
1718
1719 const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
1720 const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());
1721
1722 ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
1723 ICHECK(block_loops.size() == iter_types_block.size());
1724
1725 // We assume that the orders of iter_vars in the target and the desc block are consistent.
1726 // Based on that assumption, the following logic supports arbitrary permutations of a loop order,
1727 // such as
1728
1729 // for k:
1730 // for i:
1731 // for j:
1732 // C[i, j] += A[i, k] * B[k, j]
1733
1734 // or
1735
1736 // for i:
1737 // for j:
1738 // for k:
1739 // C[i, j] += A[i, k] * B[k, j]
1740
1741 int next_block_ind = block_loops.size() - 1;
1742 for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
1743 // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc
1744 const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
1745 const tir::ForNode* desc_loop = nullptr;
1746 IterVarType iter_type_desc = iter_types_desc[i_desc];
1747 for (int i = 0, n = desc_loops.size(); i < n; ++i) {
1748 // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
1749 PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
1750 if (!UsesVar(residual,
1751 [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) {
1752 desc_loop = desc_loops[i];
1753 iter_type_desc = iter_types_desc[i];
1754 break;
1755 }
1756 }
1757 if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) {
1758 return NullOpt;
1759 }
1760
1761 const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
1762
1763 // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
1764 PrimExpr block_bind;
1765 int current_block_ind = next_block_ind;
1766 for (; current_block_ind >= 0; --current_block_ind) {
1767 if (iter_types_block[current_block_ind] == iter_type_desc) {
1768 next_block_ind = current_block_ind - 1;
1769 block_bind = block->iter_values[current_block_ind];
1770 break;
1771 }
1772 }
1773
1774 if (!block_bind.defined()) return NullOpt;
1775
1776 // Step 3.3. Find the corresponding loop of the target block
1777 for (int i = 0, n = block_loops.size(); i < n; ++i) {
1778 // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
1779 const tir::ForNode* block_loop = block_loops[i];
1780 const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
1781 // Skip i-th loop if it has already been mapped
1782 if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue;
1783
1784 PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
1785 if (UsesVar(residual,
1786 [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) {
1787 continue;
1788 }
1789 // padding is allowed only when the block has trivial bindings
1790 if (allow_padding && !is_zero(residual)) {
1791 allow_padding = false;
1792 }
1793
1794 const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();
1795
1796 // Check divisibility
1797 if (!int_block_extent) {
1798 return NullOpt;
1799 }
1800 int64_t remainder = int_block_extent->value % int_desc_extent->value;
1801 if (remainder != 0) {
1802 if (allow_padding) {
1803 // If the block loop is not divisible by the desc loop, we pad the block loop to make it
1804 // divisible if padding is allowed.
1805 block_index_to_padding[current_block_ind] = int_desc_extent->value - remainder;
1806 } else {
1807 return NullOpt;
1808 }
1809 }
1810
1811 ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
1812 break;
1813 }
1814 }
1815
1816 for (int i = 0, n = desc_loops.size(); i < n; ++i) {
1817 ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
1818 }
1819 if (!block_index_to_padding.empty()) {
1820 if (!allow_padding) {
1821 return NullOpt;
1822 }
1823 Array<Integer> paddings;
1824 for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) {
1825 const IterVar& iter_var = block->block->iter_vars[i];
1826 if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) {
1827 paddings.push_back(IntImm(iter_var->var.dtype(), it->second));
1828 } else {
1829 paddings.push_back(IntImm(iter_var->var.dtype(), 0));
1830 }
1831 }
1832 ret->block_iter_paddings = std::move(paddings);
1833 }
1834
1835 return TensorizeInfo(ret);
1836}
1837
1838TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc);
1839TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
1840 .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) {
1841 return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding);
1842 });
1843
1844/******** Auto Tensorization ********/
1845
1846/*! \brief IndexMap proposer for layout transformation in auto tensorization. */
1847class AutoTensorizeMappingProposer {
1848 public:
1849 static Array<IndexMap> ProposeMappings(const AutoTensorizeComparator* extractor,
1850 arith::Analyzer* analyzer) {
1851 AutoTensorizeMappingProposer proposer(extractor, analyzer);
1852 proposer.CollectFeasibleSet();
1853 return proposer.ProposeAllFuseMapping();
1854 }
1855
1856 private:
1857 explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator* extractor,
1858 arith::Analyzer* analyzer)
1859 : extractor_(extractor), analyzer_(analyzer) {}
1860
1861 using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
1862
1863 void CollectFeasibleSet() {
1864 // Collect the set of potential iter var mapping between the workload and the tensor intrin.
1865 // We analyze the appearance of each variable in the buffer indices of each buffer on LHS and
1866 // RHS. The appearance of a variable in the buffer indices is encoded as bit-masks (BufferMask).
1867 // Variables on the LHS and the RHS with the same bit-mask and the same iter type are potential
1868 // mappings.
1869 //
1870 // For example, consider the conv2d case. We will try to match the workload
1871 // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * W[rh, rw, rc, c]
1872 // against a matmul tensor intrin
1873 // C[m, n] = sum_{k} A[m, k] * B[k, n]
1874 // First we extract the correspondence of the buffers: conv2d <=> C, A <=> X, B <=> W.
1875 // Then for each variable, we extract the buffers where it is used for indexing.
1876 // Take the variable m on the RHS as an example. m is used to index buffer A and C. On the LHS,
1877 // we will find the variables used to index only the exact corresponding buffers conv2d and X
1878 // (the variable is not allowed to index other buffers). In this case, n, h, w is used to index
1879 // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, w} <=> m is a potential
1880 // mapping.
1881
1882 // Note: the mapping is not unique when multiple variables on RHS has the same bit-mask.
1883 // This is currently not supported.
1884
1885 using BufferMask = std::vector<bool>;
1886
1887 // Step 1: Assign an index to each buffer in LHS and RHS
1888 std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual> rhs_buffer_index;
1889 std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual> lhs_buffer_index;
1890 {
1891 int i = 0;
1892 for (const auto& kv : extractor_->rhs_buffer_map_) {
1893 const Buffer& rhs_buffer = kv.first;
1894 const Buffer& lhs_buffer = kv.second;
1895 rhs_buffer_index[rhs_buffer] = i;
1896 lhs_buffer_index[lhs_buffer] = i;
1897 ++i;
1898 }
1899 }
1900
1901 // Step 2: Compute the buffer mask
1902 ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size());
1903 int num_buffers = rhs_buffer_index.size();
1904 std::unordered_map<const VarNode*, std::vector<bool>> rhs_buffer_masks, lhs_buffer_masks;
1905 // helper function to initialize or update the buffer mask
1906 auto update_mask = [&](const VarNode* var,
1907 std::unordered_map<const VarNode*, std::vector<bool>>* masks, int i) {
1908 if (!masks->count(var)) {
1909 (*masks)[var].resize(num_buffers);
1910 }
1911 (*masks)[var][i] = true;
1912 };
1913
1914 for (const auto& it : extractor_->rhs_buffer_indices_map_) {
1915 const Buffer& rhs_buffer = it.first;
1916 for (const PrimExpr& rhs_index : it.second) {
1917 if (const VarNode* var_node = rhs_index.as<VarNode>()) {
1918 update_mask(var_node, &rhs_buffer_masks, rhs_buffer_index.at(rhs_buffer));
1919 } else {
1920 LOG(FATAL) << "ValueError: Buffer index " << rhs_index
1921 << " other that variables in tensor intrinsics is not supported.";
1922 }
1923 }
1924
1925 auto lhs_buffer_it = extractor_->rhs_buffer_map_.find(rhs_buffer);
1926 ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end());
1927 const Buffer& lhs_buffer = lhs_buffer_it->second;
1928 for (const PrimExpr& index : extractor_->lhs_buffer_indices_map_.at(lhs_buffer)) {
1929 PreOrderVisit(index, [&](const ObjectRef& obj) -> bool {
1930 if (const VarNode* var = obj.as<VarNode>()) {
1931 update_mask(var, &lhs_buffer_masks, lhs_buffer_index.at(lhs_buffer));
1932 }
1933 return true;
1934 });
1935 }
1936 }
1937
1938 // Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure LHS and RHS vars
1939 // have the same iter type.
1940 std::unordered_map<BufferMask, VarSet> mask_to_rhs_vars;
1941 for (const auto& kv : rhs_buffer_masks) {
1942 const VarNode* rhs_var = kv.first;
1943 const BufferMask& mask = kv.second;
1944 mask_to_rhs_vars[mask].insert(GetRef<Var>(rhs_var));
1945 }
1946 std::unordered_map<const VarNode*, IterVarType> rhs_var_iter_type;
1947 for (const auto& iter : extractor_->rhs_iters_) {
1948 rhs_var_iter_type.emplace(iter->var.get(), iter->iter_type);
1949 }
1950 for (const auto& iter : extractor_->lhs_iters_) {
1951 auto& potential_mappings = lhs_feasible_vars_[iter->var];
1952 VarSet rhs_candidates = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]];
1953 std::copy_if(
1954 rhs_candidates.begin(), rhs_candidates.end(),
1955 std::inserter(potential_mappings, potential_mappings.begin()),
1956 [&](const Var& var) { return rhs_var_iter_type.at(var.get()) == iter->iter_type; });
1957 }
1958 }
1959
1960 Array<IndexMap> ProposeAllFuseMapping() {
1961 // Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to
1962 // the same iter on RHS, they will be fused in the original order in LHS block iters. We will
1963 // generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped
1964 // to the same iter var on RHS, we will produce index map `lambda n, h, w: fuse(n, h, w)`, where
1965 // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn
1966
1967 // the parameters of the result index map, each parameter corresponds to a LHS iter
1968 Array<Var> index_map_src;
1969 // the outputs of the result index map
1970 Array<PrimExpr> index_map_tgt;
1971
1972 // Step 1: Collect extents of LHS iters and prepare the initial indices of the IndexMap
1973 Map<Var, PrimExpr> lhs_iter_extents;
1974 for (const auto& iter : extractor_->lhs_iters_) {
1975 lhs_iter_extents.Set(iter->var, iter->dom->extent);
1976 index_map_src.push_back(iter->var.copy_with_suffix(""));
1977 }
1978
1979 // Step 2: Each iter on RHS has a group of corresponding iters on LHS. Initialize the fusion
1980 // result for each group of iters on LHS.
1981 Map<Var, PrimExpr> fused_lhs_iters;
1982 for (const auto& iter : extractor_->rhs_iters_) {
1983 fused_lhs_iters.Set(iter->var, 0);
1984 }
1985
1986 // Step 3: Fuse LHS iters mapped to the same RHS iter
1987 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_rhs_vars;
1988 for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) {
1989 const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var;
1990 const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var];
1991 if (rhs_candidates.empty()) {
1992 // put unmapped iters at the beginning
1993 index_map_tgt.push_back(index_map_src[i]);
1994 } else if (rhs_candidates.size() == 1) {
1995 Var rhs_var = *rhs_candidates.begin();
1996 PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var);
1997 PrimExpr updated_fused_lhs =
1998 fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
1999 fused_lhs_iters.Set(rhs_var, updated_fused_lhs);
2000 used_rhs_vars.insert(rhs_var);
2001 } else {
2002 // non-unique mapping is not supported
2003 return {};
2004 }
2005 }
2006 for (const auto& iter : extractor_->rhs_iters_) {
2007 if (!used_rhs_vars.count(iter->var)) {
2008 return {};
2009 }
2010 index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var]));
2011 }
2012 // At most one mapping is supported.
2013 return {IndexMap(index_map_src, index_map_tgt)};
2014 }
2015
2016 private:
2017 // The extractor that has extracted information for auto tensorization from the workload and the
2018 // tensor intrin.
2019 const AutoTensorizeComparator* extractor_;
2020 // The arithmetic analyzer.
2021 arith::Analyzer* analyzer_;
2022 /*! \brief Potential mappings on RHS for each variable on LHS */
2023 std::unordered_map<Var, VarSet, ObjectPtrHash, ObjectPtrEqual> lhs_feasible_vars_;
2024};
2025
2026bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tir::StmtSRef& block_sref,
2027 const tir::PrimFunc& desc_func,
2028 AutoTensorizeComparator* extractor) {
2029 // Step 1. Analyze desc_func, extract its block, loops and loop vars
2030 // Step 2. Check if `desc_block` matches `block`
2031 // Ignore the scope of buffers when comparing, since we can do cache_read/write
2032 const BlockRealize& block = tir::GetBlockRealize(state, block_sref);
2033 arith::Analyzer analyzer;
2034 auto desc_info = tir::ExtractTensorIntrinDescInfo(&analyzer, desc_func);
2035
2036 return extractor->VisitStmt(block->block, desc_info.desc_block->block);
2037}
2038
2039bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& block_rv,
2040 const tir::PrimFunc& desc_func) {
2041 AutoTensorizeComparator extractor(sch->state()->mod);
2042 return CheckAutoTensorizeApplicable(sch->state(), sch->GetSRef(block_rv), desc_func, &extractor);
2043}
2044
2045Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const tir::ScheduleState& self,
2046 const tir::StmtSRef& block_sref,
2047 const tir::PrimFunc& desc_func) {
2048 AutoTensorizeComparator extractor(self->mod);
2049 if (!CheckAutoTensorizeApplicable(self, block_sref, desc_func, &extractor)) {
2050 return NullOpt;
2051 }
2052 arith::Analyzer analyzer;
2053 Array<IndexMap> mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer);
2054 if (mappings.empty()) {
2055 return NullOpt;
2056 }
2057 ObjectPtr<AutoTensorizeMappingInfoNode> ret = make_object<AutoTensorizeMappingInfoNode>();
2058 ret->mappings = std::move(mappings);
2059 ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_);
2060 ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_);
2061 ret->lhs_iters = std::move(extractor.lhs_iters_);
2062 ret->rhs_iters = std::move(extractor.rhs_iters_);
2063 return AutoTensorizeMappingInfo(ret);
2064}
2065
2066TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode);
2067
2068TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo")
2069 .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
2070 return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func);
2071 });
2072
2073TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock);
2074
2075} // namespace tir
2076} // namespace tvm
2077