1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | using support::NDIntSet; |
25 | |
26 | /******** Error Classes ********/ |
27 | |
28 | /*! |
29 | * \brief An error raised when not all required blocks are under the given loop. |
30 | * \tparam is_consumer Indicates if all the required blocks are consumers or producers |
31 | */ |
32 | template <bool is_consumer> |
33 | class NotAllRequiredBlocksAreVisitedError : public ScheduleError { |
34 | public: |
35 | explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited, |
36 | const Array<StmtSRef>& required) |
37 | : mod_(mod), num_not_visited_(num_not_visited) { |
38 | required_.reserve(required.size()); |
39 | for (const StmtSRef& block_sref : required) { |
40 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
41 | required_.push_back(GetRef<Block>(block)); |
42 | } |
43 | } |
44 | |
45 | String FastErrorString() const final { |
46 | return "ScheduleError: Not all required blocks are under the loop scope" ; |
47 | } |
48 | |
49 | String DetailRenderTemplate() const final { |
50 | String relation = is_consumer ? "consumer(s)" : "producer(s)" ; |
51 | std::ostringstream os; |
52 | os << "The primitive requires all the " << relation |
53 | << " of the given block to be present under the target loop. However, there are " |
54 | << num_not_visited_ << " " << relation << " not satisfying the constraint. List of the " |
55 | << relation << ":" ; |
56 | for (int i = 0, n = required_.size(); i < n; ++i) { |
57 | os << "{" << i << "}" ; |
58 | } |
59 | return os.str(); |
60 | } |
61 | |
62 | IRModule mod() const final { return mod_; } |
63 | |
64 | Array<ObjectRef> LocationsOfInterest() const final { |
65 | return {required_.begin(), required_.end()}; |
66 | } |
67 | |
68 | private: |
69 | IRModule mod_; |
70 | int num_not_visited_; |
71 | Array<Block> required_; |
72 | }; |
73 | |
74 | /*! |
75 | * \brief An error raised when the given block is not in the same block scope as the given loop, |
76 | * or the given loop is the ancestor of the given block. |
77 | */ |
78 | class NotInSameScopeError : public ScheduleError { |
79 | public: |
80 | static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref, |
81 | const StmtSRef& loop_sref, const StmtSRef& scope_root_sref, |
82 | arith::Analyzer* analyzer) { |
83 | for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) { |
84 | if (const ForNode* loop = p->StmtAs<ForNode>()) { |
85 | analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
86 | } else if (p != scope_root_sref.get()) { |
87 | throw NotInSameScopeError(self->mod, block_sref, loop_sref); |
88 | } else { |
89 | break; |
90 | } |
91 | } |
92 | for (const StmtSRefNode* p = block_sref->parent; p != scope_root_sref.get(); p = p->parent) { |
93 | if (p == loop_sref.get()) { |
94 | throw NotInSameScopeError(self->mod, block_sref, loop_sref); |
95 | } |
96 | } |
97 | } |
98 | |
99 | String FastErrorString() const final { |
100 | return "ScheduleError: Expected the block and loop to be under the same block scope, and loop " |
101 | "not to be the ancestor of block" ; |
102 | } |
103 | String DetailRenderTemplate() const final { |
104 | return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, " |
105 | "and loop not to be the ancestor of block" ; |
106 | } |
107 | IRModule mod() const final { return mod_; } |
108 | Array<ObjectRef> LocationsOfInterest() const final { return {block_, loop_}; } |
109 | |
110 | private: |
111 | explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) |
112 | : mod_(mod), |
113 | block_(GetRef<Block>(block_sref->StmtAs<BlockNode>())), |
114 | loop_(GetRef<For>(loop_sref->StmtAs<ForNode>())) {} |
115 | |
116 | IRModule mod_; |
117 | Block block_; |
118 | For loop_; |
119 | }; |
120 | |
121 | /******** Helper Functions/Classes ********/ |
122 | |
123 | /*! |
124 | * \brief Find a point where the block can be inserted under the loop |
125 | * \tparam require_all_producers_visited Requires all producer blocks to be present under the loop |
126 | * \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop |
127 | * \param self The schedule state |
128 | * \param subtrees The subtrees under the loop, among which the insertion points are sought |
129 | * \param producer_srefs The producer blocks |
130 | * \param consumer_srefs The consumer blocks |
131 | * \param block2realize A cache that maps a block to its realize |
132 | * \param index The block index of the loop body subtree blocks: |
133 | * - `index = -1` means inserted into the last possible insertion point; |
134 | * - `index = -2` means inserted into the first possible insertion point; |
135 | * - Otherwise, `index` is a nonnegative number that indicates the insertion point |
136 | * \return The possible position the new block can be inserted into, and the |
137 | * producer-consumer-relationship is still satisfied. |
138 | * \throws ScheduleError if there is no such insertion point found |
139 | */ |
140 | template <bool require_all_producers_visited, bool require_all_consumers_visited> |
141 | int FindInsertionPoint(const ScheduleState& self, const Array<Stmt>& subtrees, |
142 | const Array<StmtSRef>& producer_srefs, const Array<StmtSRef>& consumer_srefs, |
143 | std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize, |
144 | int index) { |
145 | ProducerConsumerSplit split = |
146 | ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize); |
147 | // Step 1. Check if all the producers are visited in the subtrees, if required to |
148 | if (require_all_producers_visited) { |
149 | int num_producers = producer_srefs.size(); |
150 | if (split.n_producers_visited < num_producers) { |
151 | throw NotAllRequiredBlocksAreVisitedError<false>( |
152 | self->mod, num_producers - split.n_producers_visited, producer_srefs); |
153 | } |
154 | } |
155 | // Step 2. Check if all the consumers are visited in the subtrees, if required to |
156 | if (require_all_consumers_visited) { |
157 | int num_consumers = consumer_srefs.size(); |
158 | if (split.n_consumers_visited < num_consumers) { |
159 | throw NotAllRequiredBlocksAreVisitedError<true>( |
160 | self->mod, num_consumers - split.n_consumers_visited, consumer_srefs); |
161 | } |
162 | } |
163 | // Step 3. Check if there is at least one index of the position can be inserted into |
164 | // The valid indices are: (last_producer_position, first_consumer_position] |
165 | ICHECK(split.last_producer_position < split.first_consumer_position); |
166 | // Step 4. Return the possible insertion point according to index |
167 | int insert_position; |
168 | if (index == -1) { |
169 | insert_position = split.first_consumer_position; |
170 | } else if (index == -2) { |
171 | insert_position = split.last_producer_position + 1; |
172 | } else if (index >= 0 && index >= split.last_producer_position + 1 && |
173 | index <= split.first_consumer_position) { |
174 | insert_position = index; |
175 | } else { |
176 | LOG(FATAL) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 << ", " |
177 | << split.first_consumer_position << "]), " |
178 | << "current index=" << index; |
179 | throw; |
180 | } |
181 | return insert_position; |
182 | } |
183 | |
184 | /*! |
185 | * \brief Represent the iteration domain to fully cover the required region of Intersect(dom, bound) |
186 | * The bound region may not get directly intersected with dom region, instead we try to generate |
187 | * extra predicates for non-trivial bound. The domain info class can also union with each other. |
188 | */ |
189 | struct BlockVarDomainInfo { |
190 | arith::IntSet dom{arith::IntSet::Nothing()}; // dom is ensured to be bounded |
191 | arith::IntSet bound{arith::IntSet::Nothing()}; |
192 | |
193 | /*! \brief Relaxed union operation */ |
194 | void Union(const BlockVarDomainInfo& other) { |
195 | // just relax (d0 ^ b0) v (d1 ^ b1) to (d0 v d1) ^ (b0 v b1) |
196 | dom = arith::Union({dom, other.dom}); |
197 | bound = arith::Union({bound, other.bound}); |
198 | } |
199 | |
200 | /*! \brief Simplify domain info */ |
201 | void Simplify(arith::Analyzer* analyzer) { |
202 | auto to_simplified = [analyzer](const arith::IntSet& set) { |
203 | PrimExpr min = set.HasLowerBound() ? analyzer->Simplify(set.min()) : set.min(); |
204 | PrimExpr max = set.HasUpperBound() ? analyzer->Simplify(set.max()) : set.max(); |
205 | return arith::IntSet::Interval(min, max); |
206 | }; |
207 | // if no dom specified, try use bound as dom |
208 | if (dom.IsNothing()) { |
209 | if (bound.HasLowerBound() && bound.HasUpperBound()) { |
210 | bound = to_simplified(bound); |
211 | std::swap(dom, bound); |
212 | } |
213 | return; |
214 | } |
215 | // simplify intset |
216 | dom = to_simplified(dom); |
217 | bound = to_simplified(bound); |
218 | // if can proof the dom is within bound, remove bound |
219 | auto intersect = to_simplified(arith::Intersect({dom, bound})); |
220 | if (analyzer->CanProveEqual(dom.min(), intersect.min()) && |
221 | analyzer->CanProveEqual(dom.max(), intersect.max())) { |
222 | bound = arith::IntSet::Nothing(); |
223 | } else if (analyzer->CanProveEqual(bound.min(), intersect.min()) && |
224 | analyzer->CanProveEqual(bound.max(), intersect.max())) { |
225 | dom = bound; |
226 | bound = arith::IntSet::Nothing(); |
227 | } |
228 | } |
229 | }; |
230 | |
231 | /*! |
232 | * \brief A helper to reconstruct the block scope where the given block is moved under the given |
233 | * loop, and the given block's induced loop nest is regenerated to satisfy the required region. |
234 | */ |
235 | class ScopeReconstructor : private StmtMutator { |
236 | public: |
237 | explicit ScopeReconstructor(Block scope_root, Block block, For loop) |
238 | : scope_root_(scope_root), block_(block), loop_(loop) {} |
239 | |
240 | using StmtMutator::operator(); |
241 | |
242 | /*! |
243 | * \brief Create the loop nest on top of the block, induced by the given block var's domain |
244 | * \param insert_position The position among the subtrees where the block and its induced loop |
245 | * nest is inserted |
246 | * \param iter_doms The domain of each block var |
247 | * \param analyzer The arithmetic analyzer |
248 | * \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1 |
249 | */ |
250 | void MakeNewLoop(int insert_position, std::vector<BlockVarDomainInfo> iter_doms, |
251 | arith::Analyzer* analyzer, bool preserve_unit_loops) { |
252 | int n_iters = iter_doms.size(); |
253 | Array<Var> loop_vars; |
254 | Array<PrimExpr> loop_extents; |
255 | Array<PrimExpr> iter_values; |
256 | loop_vars.reserve(n_iters); |
257 | loop_extents.reserve(n_iters); |
258 | iter_values.reserve(n_iters); |
259 | PrimExpr predicate = const_true(); |
260 | for (int i = 0; i < n_iters; ++i) { |
261 | Range iter_dom = iter_doms[i].dom.CoverRange(block_->iter_vars[i]->dom); |
262 | if (preserve_unit_loops || !is_one(iter_dom->extent)) { |
263 | int bits = std::max(iter_dom->min.dtype().bits(), iter_dom->extent.dtype().bits()); |
264 | Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(bits)); |
265 | loop_vars.push_back(var); |
266 | loop_extents.push_back(analyzer->Simplify(iter_dom->extent)); |
267 | iter_values.push_back(iter_dom->min + var); |
268 | analyzer->Bind(var, Range::FromMinExtent(IntImm(var.dtype(), 0), iter_dom->extent)); |
269 | } else { |
270 | iter_values.push_back(iter_dom->min); |
271 | } |
272 | const arith::IntSet& pred_bound = iter_doms[i].bound; |
273 | if (!pred_bound.IsNothing()) { |
274 | if (pred_bound.HasLowerBound()) { |
275 | PrimExpr lower_bound = iter_values[i] >= pred_bound.min(); |
276 | predicate = predicate && lower_bound; |
277 | } |
278 | if (pred_bound.HasUpperBound()) { |
279 | PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1; |
280 | predicate = predicate && upper_bound; |
281 | } |
282 | } |
283 | } |
284 | this->new_block_realize_ = |
285 | BlockRealize(std::move(iter_values), analyzer->Simplify(predicate), std::move(block_)); |
286 | Stmt new_subtree = this->new_block_realize_; |
287 | for (int i = static_cast<int>(loop_vars.size()) - 1; i >= 0; --i) { |
288 | const Var& loop_var = loop_vars[i]; |
289 | const PrimExpr& loop_extent = loop_extents[i]; |
290 | new_subtree = For(/*loop_var=*/loop_var, |
291 | /*min=*/Integer(0), |
292 | /*extent=*/loop_extent, |
293 | /*ForKind=*/ForKind::kSerial, |
294 | /*body=*/std::move(new_subtree)); |
295 | } |
296 | Array<Stmt> subtrees = AsArray(loop_->body); |
297 | subtrees.insert(subtrees.begin() + insert_position, std::move(new_subtree)); |
298 | ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop_.get()); |
299 | new_loop->body = SeqStmt(std::move(subtrees)); |
300 | this->new_loop_ = For(std::move(new_loop)); |
301 | } |
302 | |
303 | private: |
304 | Stmt VisitStmt_(const BlockNode* block) final { |
305 | if (block != scope_root_.get()) { |
306 | return GetRef<Block>(block); |
307 | } |
308 | if (block == rm_src_stmt_.get()) { |
309 | block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode); |
310 | } |
311 | return StmtMutator::VisitStmt_(block); |
312 | } |
313 | |
314 | Stmt VisitStmt_(const ForNode* loop) final { |
315 | if (loop == rm_src_stmt_.get()) { |
316 | loop = TVM_TYPE_AS(rm_tgt_stmt_, ForNode); |
317 | } |
318 | if (loop == loop_.get()) { |
319 | return new_loop_; |
320 | } |
321 | return StmtMutator::VisitStmt_(loop); |
322 | } |
323 | |
324 | public: |
325 | /*! \brief The root block of the block scope */ |
326 | Block scope_root_; |
327 | /*! \brief The given block to be moved */ |
328 | Block block_; |
329 | /*! \brief The given loop the block and its loop nest to be put under */ |
330 | For loop_; |
331 | /*! \brief The new loop to replace the original loop */ |
332 | For new_loop_{nullptr}; |
333 | /*! \brief The new block realize to the moved block */ |
334 | BlockRealize new_block_realize_{nullptr}; |
335 | /*! \brief The plan to remove the given block by replacing this loop/block in the AST */ |
336 | Stmt rm_src_stmt_{nullptr}; |
337 | /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ |
338 | Stmt rm_tgt_stmt_{nullptr}; |
339 | }; |
340 | |
341 | /*! |
342 | * \brief Calculate a list of accessed buffer regions under a path of loops |
343 | * \tparam relax_storage_scope Whether to relax beyond the path according to the storage and |
344 | * execution scope |
345 | * \param binding The block binding, used to unbind the buffer regions |
346 | * \param buffer_regions The buffer regions to be calculated |
347 | * \param relax_path_low_inclusive The lowest point in the loop path, inclusive |
348 | * \param relax_path_high_exclusive The highest point in the loop path, exclusive |
349 | * \param relaxed Where the calculation result is stored |
350 | */ |
351 | template <bool relax_storage_scope> |
352 | void RelaxBufferRegions(const Map<Var, PrimExpr>& binding, |
353 | const Array<BufferRegion>& buffer_regions, |
354 | const StmtSRef& relax_path_low_inclusive, |
355 | const StmtSRef& relax_path_high_exclusive, |
356 | std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* relaxed) { |
357 | runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, "" }; |
358 | // We cache the variable domains |
359 | runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal; |
360 | Optional<Map<Var, arith::IntSet>> var_dom = NullOpt; |
361 | // Enumerate every buffer region |
362 | for (const BufferRegion& buffer_region : buffer_regions) { |
363 | const Buffer& buffer = buffer_region->buffer; |
364 | const Array<Range>& region = buffer_region->region; |
365 | // Skip the buffer regions we are not interested in |
366 | auto it = relaxed->find(buffer.get()); |
367 | if (it == relaxed->end()) { |
368 | continue; |
369 | } |
370 | std::vector<NDIntSet>& relaxed_regions = it->second; |
371 | // Check and update the cached `var_dom` |
372 | runtime::StorageScope scope = |
373 | relax_storage_scope ? runtime::StorageScope::Create(buffer.scope()) : global_scope; |
374 | runtime::StorageRank rank = scope.rank; |
375 | if (rank != previous_rank || !var_dom.defined()) { |
376 | previous_rank = rank; |
377 | var_dom = arith::AsIntSet(LoopDomainOfSRefTreePath( |
378 | /*low_inclusive=*/relax_path_low_inclusive, |
379 | /*high_exclusive=*/relax_path_high_exclusive, |
380 | /*extra_relax_scope=*/scope)); |
381 | } |
382 | // Relax the region |
383 | Array<arith::IntSet> relaxed_region = |
384 | arith::EvalSet(Substitute(region, binding), var_dom.value()); |
385 | relaxed_regions.push_back({relaxed_region.begin(), relaxed_region.end()}); |
386 | } |
387 | } |
388 | |
389 | /*! |
390 | * \brief Calculate the iteration domain of a provided integer set to fully cover the required |
391 | * domain |
392 | * \param provided The provided integer set to cover the required domain |
393 | * \param required The required domain to be covered |
394 | * \param analyzer The arithmetic analyzer |
395 | */ |
396 | std::pair<Var, arith::IntSet> SolveBlockVarDomain(const arith::IntSet& provided, |
397 | const arith::IntSet& required, |
398 | arith::Analyzer* analyzer) { |
399 | PrimExpr provided_min = analyzer->Simplify(provided.min()); |
400 | PrimExpr provided_max = analyzer->Simplify(provided.max()); |
401 | PrimExpr required_min = analyzer->Simplify(required.min()); |
402 | PrimExpr required_max = analyzer->Simplify(required.max()); |
403 | PrimExpr dom_min{nullptr}, dom_max{nullptr}; |
404 | Var dom_var{ObjectPtr<VarNode>{nullptr}}; |
405 | arith::PVar<Var> p_v; |
406 | arith::PVar<PrimExpr> p_e; |
407 | if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) { |
408 | PrimExpr e = p_e.Eval(); |
409 | dom_var = p_v.Eval(); |
410 | dom_min = floordiv(required_min, e); |
411 | dom_max = floordiv(required_max, e); |
412 | } else if (analyzer->CanProveEqual(provided_min, provided_max)) { |
413 | if (p_v.Match(provided_min)) { |
414 | dom_var = p_v.Eval(); |
415 | dom_min = required_min; |
416 | dom_max = required_max; |
417 | } else { |
418 | arith::PVar<PrimExpr> p_f; |
419 | if ((floordiv(p_v, p_f)).Match(provided_min)) { |
420 | // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1) |
421 | PrimExpr fac = p_f.Eval(); |
422 | if (analyzer->CanProveGreaterEqual(fac, 1)) { |
423 | dom_var = p_v.Eval(); |
424 | dom_min = required_min * fac; |
425 | dom_max = analyzer->Simplify(required_max * fac + fac - 1); |
426 | } |
427 | } else if ((floormod(p_v, p_f).Match(provided_min))) { |
428 | // generally domain of (x % fac) enforce no constraints to domain of x |
429 | dom_var = p_v.Eval(); |
430 | return std::make_pair(dom_var, arith::IntSet::Nothing()); |
431 | } |
432 | } |
433 | } |
434 | ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min; |
435 | return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max)); |
436 | } |
437 | |
438 | /*! |
439 | * \brief Calculate and update the iteration domain info to fully cover the required domain |
440 | * \param provided The provided integer set to cover the required domain |
441 | * \param required The required domain to be covered |
442 | * \param required_bound The additional region bound of the required domain to be covered |
443 | * \param iter_doms The result iteration domains to be updated |
444 | * \param analyzer The arithmetic analyzer |
445 | */ |
446 | void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required, |
447 | const arith::IntSet& required_bound, |
448 | std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms, |
449 | arith::Analyzer* analyzer) { |
450 | if (provided.IsSinglePoint() && is_const_int(provided.min())) { |
451 | ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min())); |
452 | ICHECK(required_bound.IsSinglePoint() && |
453 | analyzer->CanProveEqual(provided.min(), required_bound.min())); |
454 | return; |
455 | } |
456 | auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer); |
457 | auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer); |
458 | const Var& var = var_with_dom.first; |
459 | const auto& var_dom = var_with_dom.second; |
460 | const auto& var_bound = var_with_bound.second; |
461 | ICHECK(var.same_as(var_with_bound.first)); |
462 | auto it = iter_doms->find(var.get()); |
463 | if (it != iter_doms->end()) { |
464 | it->second.Union({var_dom, var_bound}); |
465 | } else { |
466 | ICHECK(analyzer->CanProveEqual(provided.min(), required.min())); |
467 | ICHECK(analyzer->CanProveEqual(provided.max(), required.max())); |
468 | } |
469 | } |
470 | |
471 | /*! |
472 | * \brief Calculate the domain of block vars to cover the required region |
473 | * \param iter_vars The list of block vars to cover the required region |
474 | * \param provided_regions The region provided by one iteration instance of the block vars |
475 | * \param required_regions The region required to be covered |
476 | * \param analyzer The arithmetic analyzer |
477 | * \return A list of iteration domain info corresponding to the given list of block vars |
478 | */ |
479 | std::vector<BlockVarDomainInfo> CalculateBlockVarDomain( |
480 | const Array<IterVar>& iter_vars, |
481 | std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions, |
482 | std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions, |
483 | arith::Analyzer* analyzer) { |
484 | int n_iters = iter_vars.size(); |
485 | // Step 1. Construct the mapping from block var to their iteration domain (initialized to empty) |
486 | std::unordered_map<const VarNode*, BlockVarDomainInfo> iter_doms; |
487 | iter_doms.reserve(n_iters); |
488 | for (const IterVar& iter_var : iter_vars) { |
489 | iter_doms[iter_var->var.get()] = BlockVarDomainInfo(); |
490 | } |
491 | // Step 2. For each buffer, update the domain according to the provided and required regions |
492 | for (const auto& kv : provided_regions) { |
493 | const BufferNode* buffer = kv.first; |
494 | const std::vector<NDIntSet>& many_provided_regions = kv.second; |
495 | // Calculate `provided_region` and `required_region` |
496 | auto it = required_regions.find(buffer); |
497 | if (it == required_regions.end() || it->second.empty()) { |
498 | continue; |
499 | } |
500 | NDIntSet required_region = support::NDIntSetUnion(it->second); |
501 | NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions); |
502 | ICHECK_EQ(provided_region.size(), buffer->shape.size()); |
503 | ICHECK_EQ(required_region.size(), buffer->shape.size()); |
504 | // For each dimension, update the iteration domain |
505 | int ndim = buffer->shape.size(); |
506 | for (int i = 0; i < ndim; ++i) { |
507 | arith::IntSet provided = provided_region[i]; |
508 | arith::IntSet required = required_region[i]; |
509 | arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i]); |
510 | UpdateBlockVarDomain(provided, required, required_bound, &iter_doms, analyzer); |
511 | } |
512 | } |
513 | // Union the iter var domains, put them in the same order of block vars, and return |
514 | std::vector<BlockVarDomainInfo> result; |
515 | result.reserve(n_iters); |
516 | for (const IterVar& iter_var : iter_vars) { |
517 | BlockVarDomainInfo& info = iter_doms.at(iter_var->var.get()); |
518 | if (info.bound.IsNothing()) { |
519 | info.bound = arith::IntSet::FromRange(iter_var->dom); |
520 | } else { |
521 | info.bound = arith::Intersect({info.bound, arith::IntSet::FromRange(iter_var->dom)}); |
522 | } |
523 | info.Simplify(analyzer); |
524 | ICHECK(!info.dom.IsNothing()); |
525 | result.push_back(info); |
526 | } |
527 | return result; |
528 | } |
529 | |
530 | /*! |
531 | * \brief Calculate the provided region of the given block by one single of its execution instance, |
532 | * as well as the required buffer regions relaxed to the given loop |
533 | * \tparam is_compute_at Indicates if the operation is compute-at or reverse-compute-at |
534 | * \param block The given block that provides buffer regions |
535 | * \param loop_sref The given loop under which the block is going to be moved to |
536 | * \param block2realize Maps a block to its corresponding BlockRealize |
537 | * \param producer_srefs The producers of the given block |
538 | * \param consumer_srefs The consumers of the given block |
539 | * \param provided_regions The calculated regions provided by the block |
540 | * \param required_regions The calculated regions required by its consumers (in compute-at) or |
541 | * producers (in reverse-compute-at) |
542 | */ |
543 | template <bool is_compute_at> |
544 | void CalculateProvidedRequiredRegions( |
545 | const BlockNode* block, const StmtSRef& loop_sref, |
546 | std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize, |
547 | Array<StmtSRef> producer_srefs, Array<StmtSRef> consumer_srefs, |
548 | std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* provided_regions, |
549 | std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* required_regions) { |
550 | // Step 1. Calculate the region provided by a single execution instance of `block` |
551 | const Array<BufferRegion>& provided_buffers = is_compute_at ? block->writes : block->reads; |
552 | provided_regions->reserve(provided_buffers.size()); |
553 | required_regions->reserve(provided_buffers.size()); |
554 | for (const BufferRegion& provided_buffer_region : provided_buffers) { |
555 | const BufferNode* buffer = provided_buffer_region->buffer.get(); |
556 | const Array<Range>& region = provided_buffer_region->region; |
557 | (*provided_regions)[buffer].push_back(support::NDIntSetFromRegion(region)); |
558 | (*required_regions)[buffer].clear(); |
559 | } |
560 | // Step 2. Calculate the region required by dependent blocks under `loop` |
561 | for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) { |
562 | const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref); |
563 | ICHECK(block2realize.count(required_block)); |
564 | RelaxBufferRegions</*relax_storage_scope=*/is_compute_at>( |
565 | /*binding=*/GetBindings(GetRef<BlockRealize>(block2realize.at(required_block))), |
566 | /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes, |
567 | /*relax_path_low_inclusive=*/GetRef<StmtSRef>(required_block_sref->parent), |
568 | /*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions); |
569 | } |
570 | } |
571 | |
572 | /******** Main Implementation ********/ |
573 | |
574 | template <bool is_compute_at> |
575 | void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, |
576 | const StmtSRef& loop_sref, bool preserve_unit_loops, |
577 | arith::Analyzer* analyzer, bool check_only = false, |
578 | int index = -1) { |
579 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
580 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
581 | // Step 1. Bunch of checks |
582 | // Check condition 1) : scope stage pipeline |
583 | StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, |
584 | /*require_stage_pipeline=*/true); |
585 | Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>()); |
586 | BlockScope scope = self->GetBlockScope(scope_root_sref); |
587 | Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope); |
588 | Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope); |
589 | // Check condition 2) : `block` is a complete or reduction block |
590 | CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref); |
591 | // Check condition 3): `block` and `loop` are under the same scope, |
592 | // and `loop` is not the ancestor of `block` |
593 | NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref, |
594 | analyzer); |
595 | // Check condition 4): `block` is not an output block |
596 | if (is_compute_at) { |
597 | CheckNotOutputBlock(self, block_sref, scope_root_sref); |
598 | } |
599 | // Step 2. Plan for the removal of `block` |
600 | ScopeReconstructor reconstructor(scope_root, GetRef<Block>(block), GetRef<For>(loop)); |
601 | LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_); |
602 | // Step 3. Find the insertion point under `loop` |
603 | // Check condition 5): all the required block are under the given loop |
604 | std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize; |
605 | block2realize.reserve(self->block_info.size()); |
606 | int insert_position = FindInsertionPoint<!is_compute_at, is_compute_at>( |
607 | /*self=*/self, |
608 | /*subtrees=*/AsArray(loop->body), |
609 | /*producer_srefs=*/producer_srefs, |
610 | /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize, |
611 | /*index=*/index); |
612 | // Step 4. Calculate the region provided by a single execution instance of `block`, |
613 | // as well as the region required by dependent blocks under `loop`. |
614 | // Here is the definition of `provide` and `require`: |
615 | // - In compute-at, `provide` means `produce`, and `require` means `consume` |
616 | // - In reverse-compute-at, `provide` means `consume`, and `require` means `produce` |
617 | std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions; |
618 | std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions; |
619 | CalculateProvidedRequiredRegions<is_compute_at>( |
620 | /*block=*/block, /*loop_sref=*/loop_sref, /*block2realize=*/std::move(block2realize), |
621 | /*producer_srefs=*/std::move(producer_srefs), |
622 | /*consumer_srefs=*/std::move(consumer_srefs), |
623 | /*provided_regions=*/&provided_regions, /*required_regions=*/&required_regions); |
624 | // Step 5. Calculate the iteration domain for each block var |
625 | std::vector<BlockVarDomainInfo> iter_doms = |
626 | CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars, |
627 | /*provided_regions=*/std::move(provided_regions), |
628 | /*required_regions=*/std::move(required_regions), |
629 | /*analyzer=*/analyzer); |
630 | // Step 6. Create the new scope according to the iteration domain |
631 | reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), |
632 | /*analyzer=*/analyzer, /*preserve_unit_loops=*/preserve_unit_loops); |
633 | Block new_scope_root = Downcast<Block>(reconstructor(scope_root)); |
634 | |
635 | // Step 7. Do the actual replacement |
636 | if (check_only) { |
637 | return; |
638 | } |
639 | self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); |
640 | // Step 8. Update the cached flags |
641 | BlockInfo& block_info = self->block_info[block_sref]; |
642 | block_info.affine_binding = IsAffineBinding( |
643 | /*realize=*/reconstructor.new_block_realize_, |
644 | /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)), |
645 | /*analyzer=*/analyzer); |
646 | } |
647 | |
648 | void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, |
649 | bool preserve_unit_loops, int index) { |
650 | arith::Analyzer analyzer; |
651 | ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer, |
652 | false, index); |
653 | } |
654 | |
655 | void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, |
656 | bool preserve_unit_loops, int index) { |
657 | arith::Analyzer analyzer; |
658 | ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops, |
659 | &analyzer, false, index); |
660 | } |
661 | |
662 | bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, |
663 | bool preserve_unit_loops) { |
664 | arith::Analyzer analyzer; |
665 | try { |
666 | ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, |
667 | &analyzer, true); |
668 | } catch (const tvm::runtime::Error& e) { |
669 | return false; |
670 | } |
671 | return true; |
672 | } |
673 | |
674 | bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, |
675 | const StmtSRef& loop_sref, bool preserve_unit_loops) { |
676 | arith::Analyzer analyzer; |
677 | try { |
678 | ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops, |
679 | &analyzer, true); |
680 | } catch (const tvm::runtime::Error& e) { |
681 | return false; |
682 | } |
683 | return true; |
684 | } |
685 | |
686 | /******** InstructionKind Registration ********/ |
687 | |
688 | struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> { |
689 | static constexpr const char* kName = "ComputeAt" ; |
690 | static constexpr bool kIsPure = false; |
691 | |
692 | private: |
693 | static constexpr size_t kNumInputs = 2; |
694 | static constexpr size_t kNumAttrs = 2; |
695 | static constexpr size_t kNumDecisions = 0; |
696 | |
697 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, |
698 | Bool preserve_unit_loops, IntImm index) { |
699 | return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value); |
700 | } |
701 | |
702 | static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv, |
703 | Bool preserve_unit_loops, IntImm index) { |
704 | PythonAPICall py("compute_at" ); |
705 | py.Input("block" , block_rv); |
706 | py.Input("loop" , loop_rv); |
707 | py.Input("preserve_unit_loops" , preserve_unit_loops.operator bool()); |
708 | py.Input("index" , index); |
709 | return py.Str(); |
710 | } |
711 | |
712 | template <typename> |
713 | friend struct ::tvm::tir::UnpackedInstTraits; |
714 | }; |
715 | |
716 | struct ReverseComputeAtTraits : public UnpackedInstTraits<ReverseComputeAtTraits> { |
717 | static constexpr const char* kName = "ReverseComputeAt" ; |
718 | static constexpr bool kIsPure = false; |
719 | |
720 | private: |
721 | static constexpr size_t kNumInputs = 2; |
722 | static constexpr size_t kNumAttrs = 2; |
723 | static constexpr size_t kNumDecisions = 0; |
724 | |
725 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, |
726 | Bool preserve_unit_loops, IntImm index) { |
727 | return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), |
728 | index->value); |
729 | } |
730 | |
731 | static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv, |
732 | Bool preserve_unit_loops, IntImm index) { |
733 | PythonAPICall py("reverse_compute_at" ); |
734 | py.Input("block" , block_rv); |
735 | py.Input("loop" , loop_rv); |
736 | py.Input("preserve_unit_loops" , preserve_unit_loops.operator bool()); |
737 | py.Input("index" , index); |
738 | return py.Str(); |
739 | } |
740 | |
741 | template <typename> |
742 | friend struct ::tvm::tir::UnpackedInstTraits; |
743 | }; |
744 | |
745 | TVM_REGISTER_INST_KIND_TRAITS(ComputeAtTraits); |
746 | TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeAtTraits); |
747 | |
748 | } // namespace tir |
749 | } // namespace tvm |
750 | |