1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24using support::NDIntSet;
25
26/******** Error Classes ********/
27
28/*!
29 * \brief An error raised when not all required blocks are under the given loop.
30 * \tparam is_consumer Indicates if all the required blocks are consumers or producers
31 */
32template <bool is_consumer>
33class NotAllRequiredBlocksAreVisitedError : public ScheduleError {
34 public:
35 explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited,
36 const Array<StmtSRef>& required)
37 : mod_(mod), num_not_visited_(num_not_visited) {
38 required_.reserve(required.size());
39 for (const StmtSRef& block_sref : required) {
40 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
41 required_.push_back(GetRef<Block>(block));
42 }
43 }
44
45 String FastErrorString() const final {
46 return "ScheduleError: Not all required blocks are under the loop scope";
47 }
48
49 String DetailRenderTemplate() const final {
50 String relation = is_consumer ? "consumer(s)" : "producer(s)";
51 std::ostringstream os;
52 os << "The primitive requires all the " << relation
53 << " of the given block to be present under the target loop. However, there are "
54 << num_not_visited_ << " " << relation << " not satisfying the constraint. List of the "
55 << relation << ":";
56 for (int i = 0, n = required_.size(); i < n; ++i) {
57 os << "{" << i << "}";
58 }
59 return os.str();
60 }
61
62 IRModule mod() const final { return mod_; }
63
64 Array<ObjectRef> LocationsOfInterest() const final {
65 return {required_.begin(), required_.end()};
66 }
67
68 private:
69 IRModule mod_;
70 int num_not_visited_;
71 Array<Block> required_;
72};
73
74/*!
75 * \brief An error raised when the given block is not in the same block scope as the given loop,
76 * or the given loop is the ancestor of the given block.
77 */
78class NotInSameScopeError : public ScheduleError {
79 public:
80 static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref,
81 const StmtSRef& loop_sref, const StmtSRef& scope_root_sref,
82 arith::Analyzer* analyzer) {
83 for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) {
84 if (const ForNode* loop = p->StmtAs<ForNode>()) {
85 analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
86 } else if (p != scope_root_sref.get()) {
87 throw NotInSameScopeError(self->mod, block_sref, loop_sref);
88 } else {
89 break;
90 }
91 }
92 for (const StmtSRefNode* p = block_sref->parent; p != scope_root_sref.get(); p = p->parent) {
93 if (p == loop_sref.get()) {
94 throw NotInSameScopeError(self->mod, block_sref, loop_sref);
95 }
96 }
97 }
98
99 String FastErrorString() const final {
100 return "ScheduleError: Expected the block and loop to be under the same block scope, and loop "
101 "not to be the ancestor of block";
102 }
103 String DetailRenderTemplate() const final {
104 return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, "
105 "and loop not to be the ancestor of block";
106 }
107 IRModule mod() const final { return mod_; }
108 Array<ObjectRef> LocationsOfInterest() const final { return {block_, loop_}; }
109
110 private:
111 explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref)
112 : mod_(mod),
113 block_(GetRef<Block>(block_sref->StmtAs<BlockNode>())),
114 loop_(GetRef<For>(loop_sref->StmtAs<ForNode>())) {}
115
116 IRModule mod_;
117 Block block_;
118 For loop_;
119};
120
121/******** Helper Functions/Classes ********/
122
123/*!
124 * \brief Find a point where the block can be inserted under the loop
125 * \tparam require_all_producers_visited Requires all producer blocks to be present under the loop
126 * \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop
127 * \param self The schedule state
128 * \param subtrees The subtrees under the loop, among which the insertion points are sought
129 * \param producer_srefs The producer blocks
130 * \param consumer_srefs The consumer blocks
131 * \param block2realize A cache that maps a block to its realize
132 * \param index The block index of the loop body subtree blocks:
133 * - `index = -1` means inserted into the last possible insertion point;
134 * - `index = -2` means inserted into the first possible insertion point;
135 * - Otherwise, `index` is a nonnegative number that indicates the insertion point
136 * \return The possible position the new block can be inserted into, and the
137 * producer-consumer-relationship is still satisfied.
138 * \throws ScheduleError if there is no such insertion point found
139 */
140template <bool require_all_producers_visited, bool require_all_consumers_visited>
141int FindInsertionPoint(const ScheduleState& self, const Array<Stmt>& subtrees,
142 const Array<StmtSRef>& producer_srefs, const Array<StmtSRef>& consumer_srefs,
143 std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
144 int index) {
145 ProducerConsumerSplit split =
146 ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
147 // Step 1. Check if all the producers are visited in the subtrees, if required to
148 if (require_all_producers_visited) {
149 int num_producers = producer_srefs.size();
150 if (split.n_producers_visited < num_producers) {
151 throw NotAllRequiredBlocksAreVisitedError<false>(
152 self->mod, num_producers - split.n_producers_visited, producer_srefs);
153 }
154 }
155 // Step 2. Check if all the consumers are visited in the subtrees, if required to
156 if (require_all_consumers_visited) {
157 int num_consumers = consumer_srefs.size();
158 if (split.n_consumers_visited < num_consumers) {
159 throw NotAllRequiredBlocksAreVisitedError<true>(
160 self->mod, num_consumers - split.n_consumers_visited, consumer_srefs);
161 }
162 }
163 // Step 3. Check if there is at least one index of the position can be inserted into
164 // The valid indices are: (last_producer_position, first_consumer_position]
165 ICHECK(split.last_producer_position < split.first_consumer_position);
166 // Step 4. Return the possible insertion point according to index
167 int insert_position;
168 if (index == -1) {
169 insert_position = split.first_consumer_position;
170 } else if (index == -2) {
171 insert_position = split.last_producer_position + 1;
172 } else if (index >= 0 && index >= split.last_producer_position + 1 &&
173 index <= split.first_consumer_position) {
174 insert_position = index;
175 } else {
176 LOG(FATAL) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 << ", "
177 << split.first_consumer_position << "]), "
178 << "current index=" << index;
179 throw;
180 }
181 return insert_position;
182}
183
184/*!
185 * \brief Represent the iteration domain to fully cover the required region of Intersect(dom, bound)
186 * The bound region may not get directly intersected with dom region, instead we try to generate
187 * extra predicates for non-trivial bound. The domain info class can also union with each other.
188 */
189struct BlockVarDomainInfo {
190 arith::IntSet dom{arith::IntSet::Nothing()}; // dom is ensured to be bounded
191 arith::IntSet bound{arith::IntSet::Nothing()};
192
193 /*! \brief Relaxed union operation */
194 void Union(const BlockVarDomainInfo& other) {
195 // just relax (d0 ^ b0) v (d1 ^ b1) to (d0 v d1) ^ (b0 v b1)
196 dom = arith::Union({dom, other.dom});
197 bound = arith::Union({bound, other.bound});
198 }
199
200 /*! \brief Simplify domain info */
201 void Simplify(arith::Analyzer* analyzer) {
202 auto to_simplified = [analyzer](const arith::IntSet& set) {
203 PrimExpr min = set.HasLowerBound() ? analyzer->Simplify(set.min()) : set.min();
204 PrimExpr max = set.HasUpperBound() ? analyzer->Simplify(set.max()) : set.max();
205 return arith::IntSet::Interval(min, max);
206 };
207 // if no dom specified, try use bound as dom
208 if (dom.IsNothing()) {
209 if (bound.HasLowerBound() && bound.HasUpperBound()) {
210 bound = to_simplified(bound);
211 std::swap(dom, bound);
212 }
213 return;
214 }
215 // simplify intset
216 dom = to_simplified(dom);
217 bound = to_simplified(bound);
218 // if can proof the dom is within bound, remove bound
219 auto intersect = to_simplified(arith::Intersect({dom, bound}));
220 if (analyzer->CanProveEqual(dom.min(), intersect.min()) &&
221 analyzer->CanProveEqual(dom.max(), intersect.max())) {
222 bound = arith::IntSet::Nothing();
223 } else if (analyzer->CanProveEqual(bound.min(), intersect.min()) &&
224 analyzer->CanProveEqual(bound.max(), intersect.max())) {
225 dom = bound;
226 bound = arith::IntSet::Nothing();
227 }
228 }
229};
230
231/*!
232 * \brief A helper to reconstruct the block scope where the given block is moved under the given
233 * loop, and the given block's induced loop nest is regenerated to satisfy the required region.
234 */
235class ScopeReconstructor : private StmtMutator {
236 public:
237 explicit ScopeReconstructor(Block scope_root, Block block, For loop)
238 : scope_root_(scope_root), block_(block), loop_(loop) {}
239
240 using StmtMutator::operator();
241
242 /*!
243 * \brief Create the loop nest on top of the block, induced by the given block var's domain
244 * \param insert_position The position among the subtrees where the block and its induced loop
245 * nest is inserted
246 * \param iter_doms The domain of each block var
247 * \param analyzer The arithmetic analyzer
248 * \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1
249 */
250 void MakeNewLoop(int insert_position, std::vector<BlockVarDomainInfo> iter_doms,
251 arith::Analyzer* analyzer, bool preserve_unit_loops) {
252 int n_iters = iter_doms.size();
253 Array<Var> loop_vars;
254 Array<PrimExpr> loop_extents;
255 Array<PrimExpr> iter_values;
256 loop_vars.reserve(n_iters);
257 loop_extents.reserve(n_iters);
258 iter_values.reserve(n_iters);
259 PrimExpr predicate = const_true();
260 for (int i = 0; i < n_iters; ++i) {
261 Range iter_dom = iter_doms[i].dom.CoverRange(block_->iter_vars[i]->dom);
262 if (preserve_unit_loops || !is_one(iter_dom->extent)) {
263 int bits = std::max(iter_dom->min.dtype().bits(), iter_dom->extent.dtype().bits());
264 Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(bits));
265 loop_vars.push_back(var);
266 loop_extents.push_back(analyzer->Simplify(iter_dom->extent));
267 iter_values.push_back(iter_dom->min + var);
268 analyzer->Bind(var, Range::FromMinExtent(IntImm(var.dtype(), 0), iter_dom->extent));
269 } else {
270 iter_values.push_back(iter_dom->min);
271 }
272 const arith::IntSet& pred_bound = iter_doms[i].bound;
273 if (!pred_bound.IsNothing()) {
274 if (pred_bound.HasLowerBound()) {
275 PrimExpr lower_bound = iter_values[i] >= pred_bound.min();
276 predicate = predicate && lower_bound;
277 }
278 if (pred_bound.HasUpperBound()) {
279 PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1;
280 predicate = predicate && upper_bound;
281 }
282 }
283 }
284 this->new_block_realize_ =
285 BlockRealize(std::move(iter_values), analyzer->Simplify(predicate), std::move(block_));
286 Stmt new_subtree = this->new_block_realize_;
287 for (int i = static_cast<int>(loop_vars.size()) - 1; i >= 0; --i) {
288 const Var& loop_var = loop_vars[i];
289 const PrimExpr& loop_extent = loop_extents[i];
290 new_subtree = For(/*loop_var=*/loop_var,
291 /*min=*/Integer(0),
292 /*extent=*/loop_extent,
293 /*ForKind=*/ForKind::kSerial,
294 /*body=*/std::move(new_subtree));
295 }
296 Array<Stmt> subtrees = AsArray(loop_->body);
297 subtrees.insert(subtrees.begin() + insert_position, std::move(new_subtree));
298 ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop_.get());
299 new_loop->body = SeqStmt(std::move(subtrees));
300 this->new_loop_ = For(std::move(new_loop));
301 }
302
303 private:
304 Stmt VisitStmt_(const BlockNode* block) final {
305 if (block != scope_root_.get()) {
306 return GetRef<Block>(block);
307 }
308 if (block == rm_src_stmt_.get()) {
309 block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode);
310 }
311 return StmtMutator::VisitStmt_(block);
312 }
313
314 Stmt VisitStmt_(const ForNode* loop) final {
315 if (loop == rm_src_stmt_.get()) {
316 loop = TVM_TYPE_AS(rm_tgt_stmt_, ForNode);
317 }
318 if (loop == loop_.get()) {
319 return new_loop_;
320 }
321 return StmtMutator::VisitStmt_(loop);
322 }
323
324 public:
325 /*! \brief The root block of the block scope */
326 Block scope_root_;
327 /*! \brief The given block to be moved */
328 Block block_;
329 /*! \brief The given loop the block and its loop nest to be put under */
330 For loop_;
331 /*! \brief The new loop to replace the original loop */
332 For new_loop_{nullptr};
333 /*! \brief The new block realize to the moved block */
334 BlockRealize new_block_realize_{nullptr};
335 /*! \brief The plan to remove the given block by replacing this loop/block in the AST */
336 Stmt rm_src_stmt_{nullptr};
337 /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */
338 Stmt rm_tgt_stmt_{nullptr};
339};
340
341/*!
342 * \brief Calculate a list of accessed buffer regions under a path of loops
343 * \tparam relax_storage_scope Whether to relax beyond the path according to the storage and
344 * execution scope
345 * \param binding The block binding, used to unbind the buffer regions
346 * \param buffer_regions The buffer regions to be calculated
347 * \param relax_path_low_inclusive The lowest point in the loop path, inclusive
348 * \param relax_path_high_exclusive The highest point in the loop path, exclusive
349 * \param relaxed Where the calculation result is stored
350 */
351template <bool relax_storage_scope>
352void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
353 const Array<BufferRegion>& buffer_regions,
354 const StmtSRef& relax_path_low_inclusive,
355 const StmtSRef& relax_path_high_exclusive,
356 std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* relaxed) {
357 runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, ""};
358 // We cache the variable domains
359 runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal;
360 Optional<Map<Var, arith::IntSet>> var_dom = NullOpt;
361 // Enumerate every buffer region
362 for (const BufferRegion& buffer_region : buffer_regions) {
363 const Buffer& buffer = buffer_region->buffer;
364 const Array<Range>& region = buffer_region->region;
365 // Skip the buffer regions we are not interested in
366 auto it = relaxed->find(buffer.get());
367 if (it == relaxed->end()) {
368 continue;
369 }
370 std::vector<NDIntSet>& relaxed_regions = it->second;
371 // Check and update the cached `var_dom`
372 runtime::StorageScope scope =
373 relax_storage_scope ? runtime::StorageScope::Create(buffer.scope()) : global_scope;
374 runtime::StorageRank rank = scope.rank;
375 if (rank != previous_rank || !var_dom.defined()) {
376 previous_rank = rank;
377 var_dom = arith::AsIntSet(LoopDomainOfSRefTreePath(
378 /*low_inclusive=*/relax_path_low_inclusive,
379 /*high_exclusive=*/relax_path_high_exclusive,
380 /*extra_relax_scope=*/scope));
381 }
382 // Relax the region
383 Array<arith::IntSet> relaxed_region =
384 arith::EvalSet(Substitute(region, binding), var_dom.value());
385 relaxed_regions.push_back({relaxed_region.begin(), relaxed_region.end()});
386 }
387}
388
389/*!
390 * \brief Calculate the iteration domain of a provided integer set to fully cover the required
391 * domain
392 * \param provided The provided integer set to cover the required domain
393 * \param required The required domain to be covered
394 * \param analyzer The arithmetic analyzer
395 */
396std::pair<Var, arith::IntSet> SolveBlockVarDomain(const arith::IntSet& provided,
397 const arith::IntSet& required,
398 arith::Analyzer* analyzer) {
399 PrimExpr provided_min = analyzer->Simplify(provided.min());
400 PrimExpr provided_max = analyzer->Simplify(provided.max());
401 PrimExpr required_min = analyzer->Simplify(required.min());
402 PrimExpr required_max = analyzer->Simplify(required.max());
403 PrimExpr dom_min{nullptr}, dom_max{nullptr};
404 Var dom_var{ObjectPtr<VarNode>{nullptr}};
405 arith::PVar<Var> p_v;
406 arith::PVar<PrimExpr> p_e;
407 if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) {
408 PrimExpr e = p_e.Eval();
409 dom_var = p_v.Eval();
410 dom_min = floordiv(required_min, e);
411 dom_max = floordiv(required_max, e);
412 } else if (analyzer->CanProveEqual(provided_min, provided_max)) {
413 if (p_v.Match(provided_min)) {
414 dom_var = p_v.Eval();
415 dom_min = required_min;
416 dom_max = required_max;
417 } else {
418 arith::PVar<PrimExpr> p_f;
419 if ((floordiv(p_v, p_f)).Match(provided_min)) {
420 // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
421 PrimExpr fac = p_f.Eval();
422 if (analyzer->CanProveGreaterEqual(fac, 1)) {
423 dom_var = p_v.Eval();
424 dom_min = required_min * fac;
425 dom_max = analyzer->Simplify(required_max * fac + fac - 1);
426 }
427 } else if ((floormod(p_v, p_f).Match(provided_min))) {
428 // generally domain of (x % fac) enforce no constraints to domain of x
429 dom_var = p_v.Eval();
430 return std::make_pair(dom_var, arith::IntSet::Nothing());
431 }
432 }
433 }
434 ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min;
435 return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max));
436}
437
438/*!
439 * \brief Calculate and update the iteration domain info to fully cover the required domain
440 * \param provided The provided integer set to cover the required domain
441 * \param required The required domain to be covered
442 * \param required_bound The additional region bound of the required domain to be covered
443 * \param iter_doms The result iteration domains to be updated
444 * \param analyzer The arithmetic analyzer
445 */
446void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required,
447 const arith::IntSet& required_bound,
448 std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
449 arith::Analyzer* analyzer) {
450 if (provided.IsSinglePoint() && is_const_int(provided.min())) {
451 ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
452 ICHECK(required_bound.IsSinglePoint() &&
453 analyzer->CanProveEqual(provided.min(), required_bound.min()));
454 return;
455 }
456 auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
457 auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer);
458 const Var& var = var_with_dom.first;
459 const auto& var_dom = var_with_dom.second;
460 const auto& var_bound = var_with_bound.second;
461 ICHECK(var.same_as(var_with_bound.first));
462 auto it = iter_doms->find(var.get());
463 if (it != iter_doms->end()) {
464 it->second.Union({var_dom, var_bound});
465 } else {
466 ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
467 ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
468 }
469}
470
471/*!
472 * \brief Calculate the domain of block vars to cover the required region
473 * \param iter_vars The list of block vars to cover the required region
474 * \param provided_regions The region provided by one iteration instance of the block vars
475 * \param required_regions The region required to be covered
476 * \param analyzer The arithmetic analyzer
477 * \return A list of iteration domain info corresponding to the given list of block vars
478 */
479std::vector<BlockVarDomainInfo> CalculateBlockVarDomain(
480 const Array<IterVar>& iter_vars,
481 std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions,
482 std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions,
483 arith::Analyzer* analyzer) {
484 int n_iters = iter_vars.size();
485 // Step 1. Construct the mapping from block var to their iteration domain (initialized to empty)
486 std::unordered_map<const VarNode*, BlockVarDomainInfo> iter_doms;
487 iter_doms.reserve(n_iters);
488 for (const IterVar& iter_var : iter_vars) {
489 iter_doms[iter_var->var.get()] = BlockVarDomainInfo();
490 }
491 // Step 2. For each buffer, update the domain according to the provided and required regions
492 for (const auto& kv : provided_regions) {
493 const BufferNode* buffer = kv.first;
494 const std::vector<NDIntSet>& many_provided_regions = kv.second;
495 // Calculate `provided_region` and `required_region`
496 auto it = required_regions.find(buffer);
497 if (it == required_regions.end() || it->second.empty()) {
498 continue;
499 }
500 NDIntSet required_region = support::NDIntSetUnion(it->second);
501 NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions);
502 ICHECK_EQ(provided_region.size(), buffer->shape.size());
503 ICHECK_EQ(required_region.size(), buffer->shape.size());
504 // For each dimension, update the iteration domain
505 int ndim = buffer->shape.size();
506 for (int i = 0; i < ndim; ++i) {
507 arith::IntSet provided = provided_region[i];
508 arith::IntSet required = required_region[i];
509 arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i]);
510 UpdateBlockVarDomain(provided, required, required_bound, &iter_doms, analyzer);
511 }
512 }
513 // Union the iter var domains, put them in the same order of block vars, and return
514 std::vector<BlockVarDomainInfo> result;
515 result.reserve(n_iters);
516 for (const IterVar& iter_var : iter_vars) {
517 BlockVarDomainInfo& info = iter_doms.at(iter_var->var.get());
518 if (info.bound.IsNothing()) {
519 info.bound = arith::IntSet::FromRange(iter_var->dom);
520 } else {
521 info.bound = arith::Intersect({info.bound, arith::IntSet::FromRange(iter_var->dom)});
522 }
523 info.Simplify(analyzer);
524 ICHECK(!info.dom.IsNothing());
525 result.push_back(info);
526 }
527 return result;
528}
529
530/*!
531 * \brief Calculate the provided region of the given block by one single of its execution instance,
532 * as well as the required buffer regions relaxed to the given loop
533 * \tparam is_compute_at Indicates if the operation is compute-at or reverse-compute-at
534 * \param block The given block that provides buffer regions
535 * \param loop_sref The given loop under which the block is going to be moved to
536 * \param block2realize Maps a block to its corresponding BlockRealize
537 * \param producer_srefs The producers of the given block
538 * \param consumer_srefs The consumers of the given block
539 * \param provided_regions The calculated regions provided by the block
540 * \param required_regions The calculated regions required by its consumers (in compute-at) or
541 * producers (in reverse-compute-at)
542 */
543template <bool is_compute_at>
544void CalculateProvidedRequiredRegions(
545 const BlockNode* block, const StmtSRef& loop_sref,
546 std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize,
547 Array<StmtSRef> producer_srefs, Array<StmtSRef> consumer_srefs,
548 std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* provided_regions,
549 std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* required_regions) {
550 // Step 1. Calculate the region provided by a single execution instance of `block`
551 const Array<BufferRegion>& provided_buffers = is_compute_at ? block->writes : block->reads;
552 provided_regions->reserve(provided_buffers.size());
553 required_regions->reserve(provided_buffers.size());
554 for (const BufferRegion& provided_buffer_region : provided_buffers) {
555 const BufferNode* buffer = provided_buffer_region->buffer.get();
556 const Array<Range>& region = provided_buffer_region->region;
557 (*provided_regions)[buffer].push_back(support::NDIntSetFromRegion(region));
558 (*required_regions)[buffer].clear();
559 }
560 // Step 2. Calculate the region required by dependent blocks under `loop`
561 for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) {
562 const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref);
563 ICHECK(block2realize.count(required_block));
564 RelaxBufferRegions</*relax_storage_scope=*/is_compute_at>(
565 /*binding=*/GetBindings(GetRef<BlockRealize>(block2realize.at(required_block))),
566 /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes,
567 /*relax_path_low_inclusive=*/GetRef<StmtSRef>(required_block_sref->parent),
568 /*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions);
569 }
570}
571
572/******** Main Implementation ********/
573
574template <bool is_compute_at>
575void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
576 const StmtSRef& loop_sref, bool preserve_unit_loops,
577 arith::Analyzer* analyzer, bool check_only = false,
578 int index = -1) {
579 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
580 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
581 // Step 1. Bunch of checks
582 // Check condition 1) : scope stage pipeline
583 StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
584 /*require_stage_pipeline=*/true);
585 Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
586 BlockScope scope = self->GetBlockScope(scope_root_sref);
587 Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
588 Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
589 // Check condition 2) : `block` is a complete or reduction block
590 CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref);
591 // Check condition 3): `block` and `loop` are under the same scope,
592 // and `loop` is not the ancestor of `block`
593 NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref,
594 analyzer);
595 // Check condition 4): `block` is not an output block
596 if (is_compute_at) {
597 CheckNotOutputBlock(self, block_sref, scope_root_sref);
598 }
599 // Step 2. Plan for the removal of `block`
600 ScopeReconstructor reconstructor(scope_root, GetRef<Block>(block), GetRef<For>(loop));
601 LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_);
602 // Step 3. Find the insertion point under `loop`
603 // Check condition 5): all the required block are under the given loop
604 std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize;
605 block2realize.reserve(self->block_info.size());
606 int insert_position = FindInsertionPoint<!is_compute_at, is_compute_at>(
607 /*self=*/self,
608 /*subtrees=*/AsArray(loop->body),
609 /*producer_srefs=*/producer_srefs,
610 /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
611 /*index=*/index);
612 // Step 4. Calculate the region provided by a single execution instance of `block`,
613 // as well as the region required by dependent blocks under `loop`.
614 // Here is the definition of `provide` and `require`:
615 // - In compute-at, `provide` means `produce`, and `require` means `consume`
616 // - In reverse-compute-at, `provide` means `consume`, and `require` means `produce`
617 std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions;
618 std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions;
619 CalculateProvidedRequiredRegions<is_compute_at>(
620 /*block=*/block, /*loop_sref=*/loop_sref, /*block2realize=*/std::move(block2realize),
621 /*producer_srefs=*/std::move(producer_srefs),
622 /*consumer_srefs=*/std::move(consumer_srefs),
623 /*provided_regions=*/&provided_regions, /*required_regions=*/&required_regions);
624 // Step 5. Calculate the iteration domain for each block var
625 std::vector<BlockVarDomainInfo> iter_doms =
626 CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars,
627 /*provided_regions=*/std::move(provided_regions),
628 /*required_regions=*/std::move(required_regions),
629 /*analyzer=*/analyzer);
630 // Step 6. Create the new scope according to the iteration domain
631 reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms),
632 /*analyzer=*/analyzer, /*preserve_unit_loops=*/preserve_unit_loops);
633 Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
634
635 // Step 7. Do the actual replacement
636 if (check_only) {
637 return;
638 }
639 self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
640 // Step 8. Update the cached flags
641 BlockInfo& block_info = self->block_info[block_sref];
642 block_info.affine_binding = IsAffineBinding(
643 /*realize=*/reconstructor.new_block_realize_,
644 /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)),
645 /*analyzer=*/analyzer);
646}
647
648void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
649 bool preserve_unit_loops, int index) {
650 arith::Analyzer analyzer;
651 ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer,
652 false, index);
653}
654
655void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
656 bool preserve_unit_loops, int index) {
657 arith::Analyzer analyzer;
658 ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
659 &analyzer, false, index);
660}
661
662bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
663 bool preserve_unit_loops) {
664 arith::Analyzer analyzer;
665 try {
666 ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
667 &analyzer, true);
668 } catch (const tvm::runtime::Error& e) {
669 return false;
670 }
671 return true;
672}
673
674bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
675 const StmtSRef& loop_sref, bool preserve_unit_loops) {
676 arith::Analyzer analyzer;
677 try {
678 ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
679 &analyzer, true);
680 } catch (const tvm::runtime::Error& e) {
681 return false;
682 }
683 return true;
684}
685
686/******** InstructionKind Registration ********/
687
688struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {
689 static constexpr const char* kName = "ComputeAt";
690 static constexpr bool kIsPure = false;
691
692 private:
693 static constexpr size_t kNumInputs = 2;
694 static constexpr size_t kNumAttrs = 2;
695 static constexpr size_t kNumDecisions = 0;
696
697 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
698 Bool preserve_unit_loops, IntImm index) {
699 return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value);
700 }
701
702 static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
703 Bool preserve_unit_loops, IntImm index) {
704 PythonAPICall py("compute_at");
705 py.Input("block", block_rv);
706 py.Input("loop", loop_rv);
707 py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
708 py.Input("index", index);
709 return py.Str();
710 }
711
712 template <typename>
713 friend struct ::tvm::tir::UnpackedInstTraits;
714};
715
716struct ReverseComputeAtTraits : public UnpackedInstTraits<ReverseComputeAtTraits> {
717 static constexpr const char* kName = "ReverseComputeAt";
718 static constexpr bool kIsPure = false;
719
720 private:
721 static constexpr size_t kNumInputs = 2;
722 static constexpr size_t kNumAttrs = 2;
723 static constexpr size_t kNumDecisions = 0;
724
725 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
726 Bool preserve_unit_loops, IntImm index) {
727 return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
728 index->value);
729 }
730
731 static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
732 Bool preserve_unit_loops, IntImm index) {
733 PythonAPICall py("reverse_compute_at");
734 py.Input("block", block_rv);
735 py.Input("loop", loop_rv);
736 py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
737 py.Input("index", index);
738 return py.Str();
739 }
740
741 template <typename>
742 friend struct ::tvm::tir::UnpackedInstTraits;
743};
744
745TVM_REGISTER_INST_KIND_TRAITS(ComputeAtTraits);
746TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeAtTraits);
747
748} // namespace tir
749} // namespace tvm
750