1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of
25 'A[i, j, k, ...] = f(i, j, k, ...)',
26where the indices on the left are distinct atomic variables,
27and there should be no variables other than the index variables)";
28
29static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of
30 `B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
31where A is the only buffer the block consumes, whose indices are distinct atomic variables,
32and there should be no variables other than the index variables), and f is a bijective affine
33mapping and there should not be predicates in the inlined block. The iter domains of the inlined
34block should be covered by the producer block.)";
35
36class HasInitBlock : public ScheduleError {
37 public:
38 explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {}
39
40 String FastErrorString() const final { return "ScheduleError: The block has init statement"; }
41
42 String DetailRenderTemplate() const final {
43 return "ScheduleError: The block has init statement: {0}";
44 }
45
46 IRModule mod() const final { return mod_; }
47 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
48
49 static void Check(const IRModule& mod, const Block& block) {
50 if (block->init.defined()) {
51 throw HasInitBlock(mod, block);
52 }
53 }
54
55 private:
56 IRModule mod_;
57 Block block_;
58};
59
60class NotSingleReadWriteBuffer : public ScheduleError {
61 public:
62 explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block)
63 : mod_(mod), is_read_(is_read), block_(std::move(block)) {}
64
65 String FastErrorString() const final {
66 return is_read_ ? "ScheduleError: The block is allowed to read only a single buffer region"
67 : "ScheduleError: The block is allowed to write only a single buffer region";
68 }
69
70 String DetailRenderTemplate() const final {
71 if (is_read_) {
72 int k = block_->reads.size();
73 return "The block is only allowed to read a single buffer region, but it reads " +
74 std::to_string(k) + " region(s): {0}";
75 } else {
76 int k = block_->writes.size();
77 return "The block is only allowed to write a single buffer region, but it writes " +
78 std::to_string(k) + " region(s): {0}";
79 }
80 }
81
82 IRModule mod() const final { return mod_; }
83 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
84
85 IRModule mod_;
86 bool is_read_;
87 Block block_;
88
89 static Buffer GetSingleRead(const ScheduleState& self, const Block& block,
90 const StmtSRef& scope_root_sref) {
91 const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>&
92 buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers;
93 const BufferNode* read_buffer = nullptr;
94 for (const BufferRegion& read_region : block->reads) {
95 const BufferNode* buffer = read_region->buffer.get();
96 if (buffer == read_buffer) {
97 continue;
98 }
99 if (buffer_writers.count(GetRef<Buffer>(buffer)) > 0) {
100 if (read_buffer != nullptr) {
101 throw NotSingleReadWriteBuffer(self->mod, true, block);
102 }
103 read_buffer = buffer;
104 }
105 }
106 if (read_buffer == nullptr) {
107 throw NotSingleReadWriteBuffer(self->mod, true, block);
108 }
109 return GetRef<Buffer>(read_buffer);
110 }
111
112 static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) {
113 if (block->writes.size() != 1) {
114 throw NotSingleReadWriteBuffer(self->mod, false, block);
115 }
116 return block->writes[0]->buffer;
117 }
118};
119
120class BodyAnalysisError : public ScheduleError {
121 public:
122 explicit BodyAnalysisError(bool is_reverse, IRModule mod, Block block)
123 : is_reverse_(is_reverse), mod_(mod), block_(std::move(block)) {}
124
125 String FastErrorString() const final {
126 return "ScheduleError: The block cannot be inlined because its body pattern does not meet the "
127 "condition for inlining";
128 }
129
130 String DetailRenderTemplate() const final {
131 return is_reverse_ ? kErrBodyReverseInline : kErrBodyInline;
132 }
133
134 IRModule mod() const final { return mod_; }
135 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
136
137 bool is_reverse_;
138 IRModule mod_;
139 Block block_;
140};
141
142class NonSingleProducerError : public ScheduleError {
143 public:
144 explicit NonSingleProducerError(IRModule mod, Block block)
145 : mod_(mod), block_(std::move(block)) {}
146
147 String FastErrorString() const final {
148 return "ScheduleError: The consumer block to be inlined is required to have only a single "
149 "producer block, and the producer block should be a complete block who has only a "
150 "single consumer";
151 }
152
153 String DetailRenderTemplate() const final {
154 return "The consumer block {0} to be inlined is required to have only a single "
155 "producer block, and the producer block should be a complete block who has only a "
156 "single consumer";
157 }
158
159 IRModule mod() const final { return mod_; }
160 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
161
162 IRModule mod_;
163 Block block_;
164
165 /*!
166 * \brief Check if the block has a single producer.
167 * \param self The schedule state
168 * \param block_sref The sref of the block to be checked
169 * \param scope_root_sref The sref of the scope root
170 * \return The sref of the producer block if the block has a single producer
171 * \throw ScheduleError if the block does not have a single producer
172 */
173 static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
174 const StmtSRef& scope_root_sref) {
175 BlockScope scope = self->GetBlockScope(scope_root_sref);
176 Array<Dependency> producers = scope->GetDepsByDst(consumer_block_sref);
177 StmtSRef producer_block_sref{nullptr};
178 if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) {
179 producer_block_sref = producers[0]->src;
180 if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) {
181 Array<Dependency> consumers = scope->GetDepsBySrc(producer_block_sref);
182 if (consumers.size() == 1) {
183 return producer_block_sref;
184 }
185 }
186 }
187 const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref);
188 throw NonSingleProducerError(self->mod, GetRef<Block>(block));
189 }
190};
191
192class OpaqueAccessError : public ScheduleError {
193 public:
194 explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref)
195 : mod_(mod), scope_root_(nullptr) {
196 const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref);
197 this->scope_root_ = GetRef<Block>(scope_root);
198 }
199
200 String FastErrorString() const final {
201 return "ScheduleError: The buffer to be inlined has opaque access (e.g. `B.data`), or its "
202 "subregion is matched into other blocks";
203 }
204
205 String DetailRenderTemplate() const final {
206 return "The buffer to be inlined has opaque access (e.g. `B.data`), or its "
207 "subregion is matched into other blocks: {0}";
208 }
209
210 IRModule mod() const final { return mod_; }
211 Array<ObjectRef> LocationsOfInterest() const final { return {scope_root_}; }
212
213 IRModule mod_;
214 Block scope_root_;
215};
216
217class ProducerHasNonTrivialPredicateError : public ScheduleError {
218 public:
219 explicit ProducerHasNonTrivialPredicateError(IRModule mod, BlockRealize producer,
220 PrimExpr new_predicate)
221 : mod_(mod), producer_(producer), new_predicate_(new_predicate) {}
222
223 String FastErrorString() const final {
224 return "ScheduleError: The producer block has a non-trivial predicate.";
225 }
226
227 String DetailRenderTemplate() const final {
228 std::ostringstream os;
229 os << "ScheduleError: The producer block {0} has a non-trivial predicate "
230 << producer_->predicate << " that cannot be implied by the synthesized predicate "
231 << new_predicate_ << " of the new inlined block.";
232 return os.str();
233 }
234
235 IRModule mod() const final { return mod_; }
236 Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; }
237
238 IRModule mod_;
239 BlockRealize producer_;
240 PrimExpr new_predicate_;
241};
242
243/*!
244 * \brief The base class of the inliner, which handles:
245 * 1) Substitute a subtree with the specific block being inlined
246 * 2) Update the block signature to reflect the changes of read/write/allocated buffers
247 * 3) Maintain a list of index variables and their substitution of the buffer being inlined
248 */
249class BaseInliner : public StmtExprMutator {
250 protected:
251 explicit BaseInliner(const Buffer& inlined_buffer, const Block& inlined_block,
252 const StmtSRef& scope_root_sref)
253 : inlined_buffer_(inlined_buffer),
254 inlined_store_(inlined_block->body.as<BufferStoreNode>()),
255 scope_root_sref_(scope_root_sref) {
256 AddBuffersInBlockSignature(inlined_block.get());
257 }
258
259 PrimExpr VisitExpr_(const VarNode* var) final {
260 CheckOpaqueAccess(var);
261 return StmtExprMutator::VisitExpr_(var);
262 }
263
264 PrimExpr VisitExpr_(const LoadNode* op) final {
265 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
266 }
267
268 Stmt VisitStmt_(const StoreNode* op) final {
269 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
270 }
271
272 Stmt VisitStmt_(const ForNode* loop) final {
273 if (src_stmt.get() == loop) {
274 loop = tgt_stmt.as<ForNode>();
275 ICHECK(loop != nullptr);
276 }
277 return StmtExprMutator::VisitStmt_(loop);
278 }
279
280 Stmt VisitStmt_(const BlockNode* block) final {
281 CheckMatchBufferRegion(block);
282 AddBuffersInBlockSignature(block);
283 Block src_block = GetRef<Block>(block);
284 if (src_block.same_as(src_stmt)) {
285 block = tgt_stmt.as<BlockNode>();
286 ICHECK(block != nullptr);
287 }
288 Block tgt_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
289 bool is_scope_root = src_block.get() == scope_root_sref_->stmt;
290 tgt_block = UpdateBuffersInBlockSignature(std::move(tgt_block), is_scope_root);
291 block_reuse.Set(src_block, tgt_block);
292 return std::move(tgt_block);
293 }
294
295 /*!
296 * \brief Count the number of undefined variables that are not used
297 * as buffer objects.
298 *
299 * This is used to determine whether inlining or reverse inlining is
300 * possible. The only undefined variables present should be the
301 * load/store indices, or buffer access based on those indices.
302 *
303 * \param stmt The statement in which to count undefined variables
304 */
305 static int GetNumUndefinedNonpointerVars(const Stmt& stmt) {
306 auto undefined_vars = UndefinedVars(stmt, {});
307 // Buffer pointers and the inlined indices are allowed, but no
308 // other variables may appear in the inlined block.
309 int num_nonpointer_vars = 0;
310 for (const auto& var : undefined_vars) {
311 bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() &&
312 var->type_annotation.as<PointerTypeNode>();
313 if (!is_pointer) {
314 num_nonpointer_vars++;
315 }
316 }
317 return num_nonpointer_vars;
318 }
319
320 private:
321 /*!
322 * \brief Add the buffers in the block signature to the `buffer_var_map_`,
323 * which is used for auto-completion of a block's read/write region
324 * \param block The block whose signature to be added
325 */
326 void AddBuffersInBlockSignature(const BlockNode* block) {
327 for (const BufferRegion& buffer_region : block->reads) {
328 const Buffer& buffer = buffer_region->buffer;
329 buffer_var_map_.Set(buffer->data, buffer);
330 }
331 for (const BufferRegion& buffer_region : block->writes) {
332 const Buffer& buffer = buffer_region->buffer;
333 buffer_var_map_.Set(buffer->data, buffer);
334 }
335 for (const Buffer& buffer : block->alloc_buffers) {
336 buffer_var_map_.Set(buffer->data, buffer);
337 }
338 }
339
340 /*!
341 * \brief Update the following block signature:
342 * 1) T.alloc_buffer, if the block is scope root
343 * 2) T.reads, if the block is not scope root
344 * 3) T.writes, if the block is not scope root
345 * \param block The block to be updated
346 * \param is_scope_root A flag indicating if a block is the scope root of the block to be inlined
347 * \return The updated block
348 */
349 Block UpdateBuffersInBlockSignature(Block block, bool is_scope_root) {
350 // Step 1. Update `BlockNode::alloc_buffers`
351 Array<Buffer> alloc_buffers;
352 if (is_scope_root) {
353 alloc_buffers.reserve(block->alloc_buffers.size());
354 for (const Buffer& alloc_buffer : block->alloc_buffers) {
355 if (!alloc_buffer.same_as(inlined_buffer_)) {
356 alloc_buffers.push_back(alloc_buffer);
357 }
358 }
359 } else {
360 alloc_buffers = std::move(block->alloc_buffers);
361 }
362 // Step 2. Update `BlockNode::reads` and `BlockNode::writes`
363 Array<BufferRegion> reads = std::move(block->reads);
364 Array<BufferRegion> writes = std::move(block->writes);
365 auto f_access_inline_buffer = [this](const BufferRegion& access) {
366 return access->buffer.same_as(this->inlined_buffer_);
367 };
368 if (!is_scope_root && (std::any_of(reads.begin(), reads.end(), f_access_inline_buffer) ||
369 std::any_of(writes.begin(), writes.end(), f_access_inline_buffer))) {
370 Array<Array<BufferRegion>> inspected = GetBlockReadWriteRegion(block, buffer_var_map_);
371 reads = std::move(inspected[0]);
372 writes = std::move(inspected[1]);
373 }
374 // Step 3. Assemble the result
375 BlockNode* n = block.CopyOnWrite();
376 n->reads = std::move(reads);
377 n->writes = std::move(writes);
378 n->alloc_buffers = std::move(alloc_buffers);
379 return block;
380 }
381
382 /*!
383 * \brief Opaque access to the buffer to be inlined is disallowed.
384 * This method checks if a buffer var belongs to the buffer
385 * \param buffer_var The buffer var to be checked
386 */
387 void CheckOpaqueAccess(const VarNode* buffer_var) {
388 if (inlined_buffer_->data.get() == buffer_var) {
389 this->has_opaque_access = true;
390 }
391 }
392
393 /*!
394 * \brief The buffer to be inlined is not allowed to be region matched.
395 * This method checks if a block has the disallowed behavior of buffer region match.
396 * \param block The block to be checked
397 */
398 void CheckMatchBufferRegion(const BlockNode* block) {
399 for (const MatchBufferRegion& match_buffer_region : block->match_buffers) {
400 const Buffer& matched = match_buffer_region->source->buffer;
401 if (matched.same_as(inlined_buffer_)) {
402 this->has_opaque_access = true;
403 }
404 }
405 }
406
407 protected:
408 /*! \brief The buffer to be inlined */
409 Buffer inlined_buffer_{nullptr};
410 /*! \brief The body of the block to be inlined */
411 const BufferStoreNode* inlined_store_{nullptr};
412 /*! \brief The scope root */
413 StmtSRef scope_root_sref_{nullptr};
414 /*! \brief Maps a buffer's data field to itself */
415 Map<Var, Buffer> buffer_var_map_;
416 /*! \brief The indices used for indexing the buffer to be inlined */
417 std::vector<const VarNode*> idx_vars_;
418 /*! \brief The mapping to substitute index variables to PrimExprs */
419 std::unordered_map<const VarNode*, PrimExpr> idx_sub_;
420
421 public:
422 /*!
423 * \brief The Stmt to be replaced when removing the leaf block
424 * \note The pair (src_stmt, tgt_stmt) are produced by LeafBlockRemovalPlan to indicate a
425 * transformation on top of the input AST. We take this approach to avoid changing the AST twice
426 */
427 Stmt src_stmt{nullptr};
428 /*! \brief The Stmt to be replaced to when removing the leaf block */
429 Stmt tgt_stmt{nullptr};
430 /*! \brief The reuse mapping of block srefs */
431 Map<Block, Block> block_reuse;
432 /*! \brief Indicates if there is any opaque access of the inlined buffer */
433 bool has_opaque_access{false};
434};
435
436/*!
437 * \brief Helper to inline the producer block into its consumer(s)
438 * The derived class implements the following functionalities:
439 * 1) Substitute `BufferLoad` on the buffer to be inlined
440 * to its value calculation in the producer block
441 * 2) Analyze the producer block to determine the remapping of index variables
442 */
443class ComputeInliner : public BaseInliner {
444 public:
445 explicit ComputeInliner(const Buffer& inlined_buffer, const Block& producer_block,
446 const StmtSRef& scope_root_sref)
447 : BaseInliner(inlined_buffer, producer_block, scope_root_sref) {}
448
449 bool BodyPatternAllowInline(const Block& producer_block) {
450 if (inlined_store_ == nullptr) {
451 return false;
452 }
453
454 int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
455 if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) {
456 return false;
457 }
458 return true;
459 }
460
461 private:
462 using BaseInliner::VisitExpr_;
463 using BaseInliner::VisitStmt_;
464
465 PrimExpr VisitExpr_(const BufferLoadNode* _load) final {
466 BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_load));
467 if (!load->buffer.same_as(inlined_buffer_)) {
468 return std::move(load);
469 }
470 return ReplaceInlinedBuffer(std::move(load));
471 }
472
473 PrimExpr ReplaceInlinedBuffer(BufferLoad load) {
474 SetIndexSubstitution(load->indices);
475 return Substitute(inlined_store_->value, idx_sub_);
476 }
477
478 /*!
479 * \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
480 * If so, set `self->idx_vars_` properly.
481 * \param indices The indices to be extracted
482 * \param expected_ndim The expected ndim of the access
483 * \return A boolean flag indicating if the check is successful
484 */
485 bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
486 int n = indices.size();
487 if (n != expected_ndim) {
488 // Failure: dimension mismatch
489 return false;
490 }
491 std::vector<const VarNode*> result;
492 result.reserve(n);
493 for (const PrimExpr& i : indices) {
494 if (const auto* var = i.as<VarNode>()) {
495 result.push_back(var);
496 } else {
497 // Failure: indexing expression is not a variable
498 return false;
499 }
500 }
501 using DistinctSet = std::unordered_set<const VarNode*>;
502 int n_distinct = DistinctSet(result.begin(), result.end()).size();
503 if (n != n_distinct) {
504 // Failure: indexing variables are not distinct
505 return false;
506 }
507 if (idx_vars_.empty()) {
508 idx_vars_ = std::move(result);
509 } else if (!support::ArrayWithSameContent(idx_vars_, result)) {
510 // Failure: indexing variables are not consitent in different BufferLoads
511 return false;
512 }
513 return true;
514 }
515
516 /*!
517 * \brief Set the mapping of index substitution `self->idx_sub_`
518 * \param indices The expressions that the corresponding index variables are replaced to
519 */
520 void SetIndexSubstitution(const Array<PrimExpr>& indices) {
521 ICHECK_EQ(indices.size(), idx_vars_.size());
522 int n = idx_vars_.size();
523 idx_sub_.reserve(n);
524 for (int i = 0; i < n; ++i) {
525 idx_sub_[idx_vars_[i]] = indices[i];
526 }
527 }
528};
529
530/*!
531 * \brief Helper to inline the consumer block into its producer
532 * The derived class implements the following functionalities:
533 * 1) Analyze the consumer block to determine the remapping of index variables
534 * 2) Substitute `BufferStore` of the buffer to be inlined,
535 * replacing it with direct writing to the buffer that consumer writes
536 */
537class ReverseComputeInliner : public BaseInliner {
538 class Substituter : public StmtExprMutator {
539 public:
540 explicit Substituter(ReverseComputeInliner* self) : self_(self) {}
541
542 private:
543 PrimExpr VisitExpr_(const VarNode* var) final {
544 auto it = self_->idx_sub_.find(var);
545 ICHECK(it != self_->idx_sub_.end());
546 return (*it).second;
547 }
548
549 PrimExpr VisitExpr_(const BufferLoadNode* _load) final {
550 BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_load));
551 return load->buffer.same_as(self_->inlined_buffer_) ? self_->producer_rhs_ : load;
552 }
553
554 ReverseComputeInliner* self_;
555 };
556
557 public:
558 explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block,
559 const BlockRealize& consumer_block_realize,
560 const StmtSRef& scope_root_sref, const IRModule& mod)
561 : BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref),
562 producer_block_(producer_block),
563 consumer_block_(consumer_block_realize->block.get()),
564 mod_(mod) {
565 // Initialize the predicates to ensure consumer block iters are in-bound
566 consumer_iter_in_bound_ = Bool(true);
567 for (const IterVar& iter : consumer_block_realize->block->iter_vars) {
568 consumer_iter_in_bound_ =
569 consumer_iter_in_bound_ &&
570 (iter->var >= iter->dom->min && iter->var < iter->dom->min + iter->dom->extent);
571 }
572 }
573
574 bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) {
575 const Block& consumer_block = consumer_block_realize->block;
576
577 if (!is_one(consumer_block_realize->predicate)) {
578 // Failure: Predicate is the consumer block is not supported
579 return false;
580 }
581 if (inlined_store_ == nullptr) {
582 // Failure: block body is not BufferStore
583 return false;
584 }
585 std::vector<const BufferLoadNode*> loads = ExtractBufferLoad(inlined_buffer_, inlined_store_);
586 if (loads.size() == 0) {
587 // Failure: no BufferLoad from the `inlined_buffer_`
588 return false;
589 }
590
591 // Collect block iter domains and update the substition map
592 Map<Var, Range> consumer_iter_doms;
593 for (const auto& iter_var : consumer_block->iter_vars) {
594 consumer_iter_doms.Set(iter_var->var, iter_var->dom);
595 // Set default mapping for unit iters
596 if (is_const_int(iter_var->dom->extent, 1) && is_const_int(iter_var->dom->min)) {
597 idx_sub_[iter_var->var.get()] = iter_var->dom->min;
598 }
599 }
600
601 for (const BufferLoadNode* load : loads) {
602 if (!UpdateAndCheckIndexExprs(load->indices)) {
603 return false;
604 }
605 }
606
607 auto res = arith::DetectIterMap(
608 /*indices=*/buffer_load_indices_,
609 /*input_iters=*/consumer_iter_doms,
610 /*predicate=*/true,
611 /*check_level=*/arith::IterMapLevel::Bijective,
612 /*analyzer=*/&analyzer_,
613 /*simplify_trivial_iterators=*/false);
614 buffer_load_iter_map_ = res->indices;
615 if (buffer_load_iter_map_.empty()) {
616 // Failure: indices of BufferLoad are not bijective affine
617 return false;
618 }
619
620 const BufferStoreNode* producer_store = producer_block_->body.as<BufferStoreNode>();
621 if (producer_store == nullptr) {
622 // Failure: producer block body is not BufferStore
623 return false;
624 }
625 CreateInverseMapping(producer_store->indices);
626 if (!CheckConsumerCovered()) {
627 // Failure: consumer block iter domains are not covered by the producer block
628 return false;
629 }
630
631 return true;
632 }
633
634 private:
635 using BaseInliner::VisitExpr_;
636 using BaseInliner::VisitStmt_;
637
638 /*! \brief Generate the predicate after inlining based on the consumer predicate */
639 PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) {
640 // Bind the producer block iter domains for simplification
641 Map<Var, PrimExpr> subst_map;
642 for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) {
643 const IterVar& iter = producer_block_realize->block->iter_vars[i];
644 analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
645 subst_map.Set(iter->var, producer_block_realize->iter_values[i]);
646 }
647 // Substitute the consumer block iters with the corresponding iters in the producer blocks
648 PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
649 // Simplify the predicate using the producer block iter domains
650 predicate = analyzer_.Simplify(predicate);
651 // Substitute the producer block iters with the its bindings since the predicate in BlockRealize
652 // should not contain the block iters
653 predicate = Substitute(predicate, subst_map);
654 return predicate;
655 }
656
657 Stmt VisitStmt_(const BlockRealizeNode* op) final {
658 BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
659 if (op->block.get() == producer_block_) {
660 auto new_predicate = BuildInlinedConsumerPredicate(new_block_realize.get());
661
662 With<arith::ConstraintContext> ctx(&analyzer_, new_predicate);
663 if (!analyzer_.CanProve(op->predicate)) {
664 // We do not allow cases where the new predicate for the inlined block cannot
665 // imply the original predicate in the producer block.
666 throw ProducerHasNonTrivialPredicateError(mod_, GetRef<BlockRealize>(op), new_predicate);
667 }
668 new_block_realize.CopyOnWrite()->predicate = new_predicate;
669 }
670 return std::move(new_block_realize);
671 }
672
673 Stmt VisitStmt_(const BufferStoreNode* _store) final {
674 BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_store));
675 if (!store->buffer.same_as(inlined_buffer_)) {
676 return std::move(store);
677 }
678 return ReplaceInlinedBuffer(std::move(store));
679 }
680
681 /*!
682 * \brief Check the consumer block iter domains are covered by the producer block iter domains
683 * \return Whether the consumer block iter domains are covered
684 */
685 bool CheckConsumerCovered() {
686 Map<IterVar, arith::IntSet> producer_iter_doms;
687 for (const IterVar& iter_var : producer_block_->iter_vars) {
688 producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom));
689 }
690 // For each block iter in the consumer block, find the corresponding expression in the producer
691 for (const IterVar& iter : consumer_block_->iter_vars) {
692 if (auto it = idx_sub_.find(iter->var.get()); it != idx_sub_.end()) {
693 const PrimExpr& producer_iter = it->second;
694 arith::IntSet producer_iter_range = arith::EvalSet(producer_iter, producer_iter_doms);
695 if (analyzer_.CanProve(producer_iter_range.min() > iter->dom->min) ||
696 analyzer_.CanProve(producer_iter_range.max() <
697 iter->dom->min + iter->dom->extent - 1)) {
698 return false;
699 }
700 } else {
701 return false;
702 }
703 }
704 return true;
705 }
706
707 /*!
708 * \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with
709 * the result. It will be later used to transform the BufferStore indices of the producer.
710 * \param producer_indices The BufferStore indices of the producer.
711 */
712 void CreateInverseMapping(const Array<PrimExpr> producer_indices) {
713 auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, producer_indices);
714 for (const auto& pair : inverse_iter_map) {
715 idx_sub_[pair.first.get()] = pair.second;
716 }
717 }
718
719 Stmt ReplaceInlinedBuffer(BufferStore producer) {
720 producer_rhs_ = producer->value;
721 return Substituter(this)(GetRef<BufferStore>(inlined_store_));
722 }
723
724 /*!
725 * \brief Extracts expressions that loads a specific buffer
726 * \param buffer The buffer to be loaded from
727 * \param from The BufferStore statement to be extracted from
728 * \return A list of `BufferLoad` expressions
729 */
730 static std::vector<const BufferLoadNode*> ExtractBufferLoad(const Buffer& buffer,
731 const BufferStoreNode* from) {
732 struct Extractor : public ExprVisitor {
733 void VisitExpr_(const BufferLoadNode* load) final {
734 if (load->buffer.get() == buffer) {
735 result.push_back(load);
736 }
737 ExprVisitor::VisitExpr_(load);
738 }
739 const BufferNode* buffer;
740 std::vector<const BufferLoadNode*> result;
741 } extractor;
742 extractor.buffer = buffer.get();
743 for (const PrimExpr& expr : from->indices) {
744 extractor(expr);
745 }
746 extractor(from->value);
747 return std::move(extractor.result);
748 }
749
750 /*!
751 * \brief Update `buffer_load_indices_` with the given indices. If `buffer_load_indices_` is
752 * already non-empty, check it is consistent with the given indices.
753 * \param indices The indices
754 * \param expected_ndim The expected ndim of the access
755 * \return A boolean flag indicating if the check is successful
756 */
757 bool UpdateAndCheckIndexExprs(const Array<PrimExpr>& indices) {
758 if (buffer_load_indices_.empty()) {
759 buffer_load_indices_ = indices;
760 } else if (!std::equal(buffer_load_indices_.begin(), buffer_load_indices_.end(),
761 indices.begin(), indices.end(), ExprDeepEqual())) {
762 // Failure: indices are not consistent in different BufferLoads
763 return false;
764 }
765 return true;
766 }
767
768 /*! \brief The RHS value of the producer's BufferStore statement */
769 PrimExpr producer_rhs_{nullptr};
770 /*! \brief The indices of the consumer's BufferLoad */
771 Array<PrimExpr> buffer_load_indices_;
772 /*! \brief The IterMap representing the indices of the consumer's BufferLoad */
773 Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
774 /*! \brief The producer block */
775 const BlockNode* producer_block_{nullptr};
776 /* \brief The consumer block */
777 const BlockNode* consumer_block_{nullptr};
778 /*! \brief The predicate to ensure the consumer block iters are in-bound. It will be inserted
779 * as the predicate of the producer block after inlining.
780 */
781 PrimExpr consumer_iter_in_bound_{nullptr};
782 /*! \brief The arithmetic analyzer */
783 arith::Analyzer analyzer_;
784 /*! \brief The target module, only used for error reporting. */
785 const IRModule& mod_;
786};
787
788void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
789 bool check_only = false) {
790 const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref);
791 Block producer_block = GetRef<Block>(_producer_block);
792 HasInitBlock::Check(self->mod, producer_block);
793 Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
794 // Step 1. Get the scope block
795 StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref,
796 /*require_stage_pipeline=*/true);
797 // Step 2. Check completeness
798 CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
799 CheckCompleteBlock(self, producer_block_sref, scope_root_sref);
800 // Step 3. Analyze the block body
801 ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref);
802 if (!inliner.BodyPatternAllowInline(producer_block)) {
803 throw BodyAnalysisError(false, self->mod, producer_block);
804 }
805 // Step 4. Create a plan that removes the leaf block to be inlined
806 LeafBlockRemovalPlan(self, producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt);
807 // Step 5. Create an AST where the leaf `producer_block_sref` points to is removed,
808 // and update other blocks who read from the removed block
809 Stmt tgt_stmt = inliner(GetRef<Stmt>(scope_root_sref->stmt));
810 if (inliner.has_opaque_access) {
811 throw OpaqueAccessError(self->mod, scope_root_sref);
812 }
813 // Step 6. Do the real mutation on the AST and the sref tree in the schedule state
814 if (check_only) {
815 return;
816 }
817 self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
818}
819
820void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
821 ComputeInlineImpl(self, producer_block_sref);
822}
823
824bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) {
825 try {
826 ComputeInlineImpl(self, producer_block_sref, true);
827 } catch (const tvm::runtime::Error& e) {
828 return false;
829 }
830 return true;
831}
832
833void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref,
834 bool check_only = false) {
835 const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
836 Block consumer_block = GetRef<Block>(_consumer_block);
837 BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref);
838 HasInitBlock::Check(self->mod, consumer_block);
839 // Step 1. Get the scope block
840 StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
841 /*require_stage_pipeline=*/true);
842 Buffer inlined_buffer =
843 NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref);
844 // Step 2. Check completeness
845 CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
846 // Step 3. Check if the consumer has a single complete producer
847 StmtSRef producer_block_sref =
848 NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
849 // Step 4. Analyze the block body
850 ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs<BlockNode>(),
851 consumer_block_realize, scope_root_sref, self->mod);
852 if (!inliner.BodyPatternAllowInline(consumer_block_realize)) {
853 throw BodyAnalysisError(true, self->mod, consumer_block);
854 }
855 // Step 5. Create a plan that removes the leaf block to be inlined
856 LeafBlockRemovalPlan(self, consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt);
857 // Step 6. Create an AST where the leaf `consumer_block_sref` points to is removed,
858 // and update other blocks who read from the removed block
859 Stmt tgt_stmt = inliner(GetRef<Stmt>(scope_root_sref->stmt));
860 if (inliner.has_opaque_access) {
861 throw OpaqueAccessError(self->mod, scope_root_sref);
862 }
863 // Step 7. Do the real mutation on the AST and the sref tree in the schedule state
864 if (check_only) {
865 return;
866 }
867 self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
868}
869
870bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {
871 try {
872 ReverseComputeInlineImpl(self, block_sref, true);
873 } catch (const tvm::runtime::Error& e) {
874 return false;
875 }
876 return true;
877}
878
879void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) {
880 ReverseComputeInlineImpl(self, consumer_block_sref);
881}
882
883/******** InstructionKind Registration ********/
884
885struct ComputeInlineTraits : public UnpackedInstTraits<ComputeInlineTraits> {
886 static constexpr const char* kName = "ComputeInline";
887 static constexpr bool kIsPure = false;
888
889 private:
890 static constexpr size_t kNumInputs = 1;
891 static constexpr size_t kNumAttrs = 0;
892 static constexpr size_t kNumDecisions = 0;
893
894 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
895 return sch->ComputeInline(block_rv);
896 }
897
898 static String UnpackedAsPython(Array<String> outputs, String block_rv) {
899 PythonAPICall py("compute_inline");
900 py.Input("block", block_rv);
901 return py.Str();
902 }
903
904 template <typename>
905 friend struct ::tvm::tir::UnpackedInstTraits;
906};
907
908struct ReverseComputeInlineTraits : public UnpackedInstTraits<ReverseComputeInlineTraits> {
909 static constexpr const char* kName = "ReverseComputeInline";
910 static constexpr bool kIsPure = false;
911
912 private:
913 static constexpr size_t kNumInputs = 1;
914 static constexpr size_t kNumAttrs = 0;
915 static constexpr size_t kNumDecisions = 0;
916
917 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
918 return sch->ReverseComputeInline(block_rv);
919 }
920
921 static String UnpackedAsPython(Array<String> outputs, String block_rv) {
922 PythonAPICall py("reverse_compute_inline");
923 py.Input("block", block_rv);
924 return py.Str();
925 }
926
927 template <typename>
928 friend struct ::tvm::tir::UnpackedInstTraits;
929};
930
931TVM_REGISTER_INST_KIND_TRAITS(ComputeInlineTraits);
932TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeInlineTraits);
933
934} // namespace tir
935} // namespace tvm
936