1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of |
25 | 'A[i, j, k, ...] = f(i, j, k, ...)', |
26 | where the indices on the left are distinct atomic variables, |
27 | and there should be no variables other than the index variables)" ; |
28 | |
29 | static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of |
30 | `B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`, |
31 | where A is the only buffer the block consumes, whose indices are distinct atomic variables, |
32 | and there should be no variables other than the index variables), and f is a bijective affine |
33 | mapping and there should not be predicates in the inlined block. The iter domains of the inlined |
34 | block should be covered by the producer block.)" ; |
35 | |
36 | class HasInitBlock : public ScheduleError { |
37 | public: |
38 | explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {} |
39 | |
40 | String FastErrorString() const final { return "ScheduleError: The block has init statement" ; } |
41 | |
42 | String DetailRenderTemplate() const final { |
43 | return "ScheduleError: The block has init statement: {0}" ; |
44 | } |
45 | |
46 | IRModule mod() const final { return mod_; } |
47 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
48 | |
49 | static void Check(const IRModule& mod, const Block& block) { |
50 | if (block->init.defined()) { |
51 | throw HasInitBlock(mod, block); |
52 | } |
53 | } |
54 | |
55 | private: |
56 | IRModule mod_; |
57 | Block block_; |
58 | }; |
59 | |
60 | class NotSingleReadWriteBuffer : public ScheduleError { |
61 | public: |
62 | explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) |
63 | : mod_(mod), is_read_(is_read), block_(std::move(block)) {} |
64 | |
65 | String FastErrorString() const final { |
66 | return is_read_ ? "ScheduleError: The block is allowed to read only a single buffer region" |
67 | : "ScheduleError: The block is allowed to write only a single buffer region" ; |
68 | } |
69 | |
70 | String DetailRenderTemplate() const final { |
71 | if (is_read_) { |
72 | int k = block_->reads.size(); |
73 | return "The block is only allowed to read a single buffer region, but it reads " + |
74 | std::to_string(k) + " region(s): {0}" ; |
75 | } else { |
76 | int k = block_->writes.size(); |
77 | return "The block is only allowed to write a single buffer region, but it writes " + |
78 | std::to_string(k) + " region(s): {0}" ; |
79 | } |
80 | } |
81 | |
82 | IRModule mod() const final { return mod_; } |
83 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
84 | |
85 | IRModule mod_; |
86 | bool is_read_; |
87 | Block block_; |
88 | |
89 | static Buffer GetSingleRead(const ScheduleState& self, const Block& block, |
90 | const StmtSRef& scope_root_sref) { |
91 | const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>& |
92 | buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers; |
93 | const BufferNode* read_buffer = nullptr; |
94 | for (const BufferRegion& read_region : block->reads) { |
95 | const BufferNode* buffer = read_region->buffer.get(); |
96 | if (buffer == read_buffer) { |
97 | continue; |
98 | } |
99 | if (buffer_writers.count(GetRef<Buffer>(buffer)) > 0) { |
100 | if (read_buffer != nullptr) { |
101 | throw NotSingleReadWriteBuffer(self->mod, true, block); |
102 | } |
103 | read_buffer = buffer; |
104 | } |
105 | } |
106 | if (read_buffer == nullptr) { |
107 | throw NotSingleReadWriteBuffer(self->mod, true, block); |
108 | } |
109 | return GetRef<Buffer>(read_buffer); |
110 | } |
111 | |
112 | static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { |
113 | if (block->writes.size() != 1) { |
114 | throw NotSingleReadWriteBuffer(self->mod, false, block); |
115 | } |
116 | return block->writes[0]->buffer; |
117 | } |
118 | }; |
119 | |
120 | class BodyAnalysisError : public ScheduleError { |
121 | public: |
122 | explicit BodyAnalysisError(bool is_reverse, IRModule mod, Block block) |
123 | : is_reverse_(is_reverse), mod_(mod), block_(std::move(block)) {} |
124 | |
125 | String FastErrorString() const final { |
126 | return "ScheduleError: The block cannot be inlined because its body pattern does not meet the " |
127 | "condition for inlining" ; |
128 | } |
129 | |
130 | String DetailRenderTemplate() const final { |
131 | return is_reverse_ ? kErrBodyReverseInline : kErrBodyInline; |
132 | } |
133 | |
134 | IRModule mod() const final { return mod_; } |
135 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
136 | |
137 | bool is_reverse_; |
138 | IRModule mod_; |
139 | Block block_; |
140 | }; |
141 | |
142 | class NonSingleProducerError : public ScheduleError { |
143 | public: |
144 | explicit NonSingleProducerError(IRModule mod, Block block) |
145 | : mod_(mod), block_(std::move(block)) {} |
146 | |
147 | String FastErrorString() const final { |
148 | return "ScheduleError: The consumer block to be inlined is required to have only a single " |
149 | "producer block, and the producer block should be a complete block who has only a " |
150 | "single consumer" ; |
151 | } |
152 | |
153 | String DetailRenderTemplate() const final { |
154 | return "The consumer block {0} to be inlined is required to have only a single " |
155 | "producer block, and the producer block should be a complete block who has only a " |
156 | "single consumer" ; |
157 | } |
158 | |
159 | IRModule mod() const final { return mod_; } |
160 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
161 | |
162 | IRModule mod_; |
163 | Block block_; |
164 | |
165 | /*! |
166 | * \brief Check if the block has a single producer. |
167 | * \param self The schedule state |
168 | * \param block_sref The sref of the block to be checked |
169 | * \param scope_root_sref The sref of the scope root |
170 | * \return The sref of the producer block if the block has a single producer |
171 | * \throw ScheduleError if the block does not have a single producer |
172 | */ |
173 | static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref, |
174 | const StmtSRef& scope_root_sref) { |
175 | BlockScope scope = self->GetBlockScope(scope_root_sref); |
176 | Array<Dependency> producers = scope->GetDepsByDst(consumer_block_sref); |
177 | StmtSRef producer_block_sref{nullptr}; |
178 | if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) { |
179 | producer_block_sref = producers[0]->src; |
180 | if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) { |
181 | Array<Dependency> consumers = scope->GetDepsBySrc(producer_block_sref); |
182 | if (consumers.size() == 1) { |
183 | return producer_block_sref; |
184 | } |
185 | } |
186 | } |
187 | const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref); |
188 | throw NonSingleProducerError(self->mod, GetRef<Block>(block)); |
189 | } |
190 | }; |
191 | |
192 | class OpaqueAccessError : public ScheduleError { |
193 | public: |
194 | explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) |
195 | : mod_(mod), scope_root_(nullptr) { |
196 | const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); |
197 | this->scope_root_ = GetRef<Block>(scope_root); |
198 | } |
199 | |
200 | String FastErrorString() const final { |
201 | return "ScheduleError: The buffer to be inlined has opaque access (e.g. `B.data`), or its " |
202 | "subregion is matched into other blocks" ; |
203 | } |
204 | |
205 | String DetailRenderTemplate() const final { |
206 | return "The buffer to be inlined has opaque access (e.g. `B.data`), or its " |
207 | "subregion is matched into other blocks: {0}" ; |
208 | } |
209 | |
210 | IRModule mod() const final { return mod_; } |
211 | Array<ObjectRef> LocationsOfInterest() const final { return {scope_root_}; } |
212 | |
213 | IRModule mod_; |
214 | Block scope_root_; |
215 | }; |
216 | |
217 | class ProducerHasNonTrivialPredicateError : public ScheduleError { |
218 | public: |
219 | explicit ProducerHasNonTrivialPredicateError(IRModule mod, BlockRealize producer, |
220 | PrimExpr new_predicate) |
221 | : mod_(mod), producer_(producer), new_predicate_(new_predicate) {} |
222 | |
223 | String FastErrorString() const final { |
224 | return "ScheduleError: The producer block has a non-trivial predicate." ; |
225 | } |
226 | |
227 | String DetailRenderTemplate() const final { |
228 | std::ostringstream os; |
229 | os << "ScheduleError: The producer block {0} has a non-trivial predicate " |
230 | << producer_->predicate << " that cannot be implied by the synthesized predicate " |
231 | << new_predicate_ << " of the new inlined block." ; |
232 | return os.str(); |
233 | } |
234 | |
235 | IRModule mod() const final { return mod_; } |
236 | Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; } |
237 | |
238 | IRModule mod_; |
239 | BlockRealize producer_; |
240 | PrimExpr new_predicate_; |
241 | }; |
242 | |
243 | /*! |
244 | * \brief The base class of the inliner, which handles: |
245 | * 1) Substitute a subtree with the specific block being inlined |
246 | * 2) Update the block signature to reflect the changes of read/write/allocated buffers |
247 | * 3) Maintain a list of index variables and their substitution of the buffer being inlined |
248 | */ |
249 | class BaseInliner : public StmtExprMutator { |
250 | protected: |
251 | explicit BaseInliner(const Buffer& inlined_buffer, const Block& inlined_block, |
252 | const StmtSRef& scope_root_sref) |
253 | : inlined_buffer_(inlined_buffer), |
254 | inlined_store_(inlined_block->body.as<BufferStoreNode>()), |
255 | scope_root_sref_(scope_root_sref) { |
256 | AddBuffersInBlockSignature(inlined_block.get()); |
257 | } |
258 | |
259 | PrimExpr VisitExpr_(const VarNode* var) final { |
260 | CheckOpaqueAccess(var); |
261 | return StmtExprMutator::VisitExpr_(var); |
262 | } |
263 | |
264 | PrimExpr VisitExpr_(const LoadNode* op) final { |
265 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
266 | } |
267 | |
268 | Stmt VisitStmt_(const StoreNode* op) final { |
269 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
270 | } |
271 | |
272 | Stmt VisitStmt_(const ForNode* loop) final { |
273 | if (src_stmt.get() == loop) { |
274 | loop = tgt_stmt.as<ForNode>(); |
275 | ICHECK(loop != nullptr); |
276 | } |
277 | return StmtExprMutator::VisitStmt_(loop); |
278 | } |
279 | |
280 | Stmt VisitStmt_(const BlockNode* block) final { |
281 | CheckMatchBufferRegion(block); |
282 | AddBuffersInBlockSignature(block); |
283 | Block src_block = GetRef<Block>(block); |
284 | if (src_block.same_as(src_stmt)) { |
285 | block = tgt_stmt.as<BlockNode>(); |
286 | ICHECK(block != nullptr); |
287 | } |
288 | Block tgt_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); |
289 | bool is_scope_root = src_block.get() == scope_root_sref_->stmt; |
290 | tgt_block = UpdateBuffersInBlockSignature(std::move(tgt_block), is_scope_root); |
291 | block_reuse.Set(src_block, tgt_block); |
292 | return std::move(tgt_block); |
293 | } |
294 | |
295 | /*! |
296 | * \brief Count the number of undefined variables that are not used |
297 | * as buffer objects. |
298 | * |
299 | * This is used to determine whether inlining or reverse inlining is |
300 | * possible. The only undefined variables present should be the |
301 | * load/store indices, or buffer access based on those indices. |
302 | * |
303 | * \param stmt The statement in which to count undefined variables |
304 | */ |
305 | static int GetNumUndefinedNonpointerVars(const Stmt& stmt) { |
306 | auto undefined_vars = UndefinedVars(stmt, {}); |
307 | // Buffer pointers and the inlined indices are allowed, but no |
308 | // other variables may appear in the inlined block. |
309 | int num_nonpointer_vars = 0; |
310 | for (const auto& var : undefined_vars) { |
311 | bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() && |
312 | var->type_annotation.as<PointerTypeNode>(); |
313 | if (!is_pointer) { |
314 | num_nonpointer_vars++; |
315 | } |
316 | } |
317 | return num_nonpointer_vars; |
318 | } |
319 | |
320 | private: |
321 | /*! |
322 | * \brief Add the buffers in the block signature to the `buffer_var_map_`, |
323 | * which is used for auto-completion of a block's read/write region |
324 | * \param block The block whose signature to be added |
325 | */ |
326 | void AddBuffersInBlockSignature(const BlockNode* block) { |
327 | for (const BufferRegion& buffer_region : block->reads) { |
328 | const Buffer& buffer = buffer_region->buffer; |
329 | buffer_var_map_.Set(buffer->data, buffer); |
330 | } |
331 | for (const BufferRegion& buffer_region : block->writes) { |
332 | const Buffer& buffer = buffer_region->buffer; |
333 | buffer_var_map_.Set(buffer->data, buffer); |
334 | } |
335 | for (const Buffer& buffer : block->alloc_buffers) { |
336 | buffer_var_map_.Set(buffer->data, buffer); |
337 | } |
338 | } |
339 | |
340 | /*! |
341 | * \brief Update the following block signature: |
342 | * 1) T.alloc_buffer, if the block is scope root |
343 | * 2) T.reads, if the block is not scope root |
344 | * 3) T.writes, if the block is not scope root |
345 | * \param block The block to be updated |
346 | * \param is_scope_root A flag indicating if a block is the scope root of the block to be inlined |
347 | * \return The updated block |
348 | */ |
349 | Block UpdateBuffersInBlockSignature(Block block, bool is_scope_root) { |
350 | // Step 1. Update `BlockNode::alloc_buffers` |
351 | Array<Buffer> alloc_buffers; |
352 | if (is_scope_root) { |
353 | alloc_buffers.reserve(block->alloc_buffers.size()); |
354 | for (const Buffer& alloc_buffer : block->alloc_buffers) { |
355 | if (!alloc_buffer.same_as(inlined_buffer_)) { |
356 | alloc_buffers.push_back(alloc_buffer); |
357 | } |
358 | } |
359 | } else { |
360 | alloc_buffers = std::move(block->alloc_buffers); |
361 | } |
362 | // Step 2. Update `BlockNode::reads` and `BlockNode::writes` |
363 | Array<BufferRegion> reads = std::move(block->reads); |
364 | Array<BufferRegion> writes = std::move(block->writes); |
365 | auto f_access_inline_buffer = [this](const BufferRegion& access) { |
366 | return access->buffer.same_as(this->inlined_buffer_); |
367 | }; |
368 | if (!is_scope_root && (std::any_of(reads.begin(), reads.end(), f_access_inline_buffer) || |
369 | std::any_of(writes.begin(), writes.end(), f_access_inline_buffer))) { |
370 | Array<Array<BufferRegion>> inspected = GetBlockReadWriteRegion(block, buffer_var_map_); |
371 | reads = std::move(inspected[0]); |
372 | writes = std::move(inspected[1]); |
373 | } |
374 | // Step 3. Assemble the result |
375 | BlockNode* n = block.CopyOnWrite(); |
376 | n->reads = std::move(reads); |
377 | n->writes = std::move(writes); |
378 | n->alloc_buffers = std::move(alloc_buffers); |
379 | return block; |
380 | } |
381 | |
382 | /*! |
383 | * \brief Opaque access to the buffer to be inlined is disallowed. |
384 | * This method checks if a buffer var belongs to the buffer |
385 | * \param buffer_var The buffer var to be checked |
386 | */ |
387 | void CheckOpaqueAccess(const VarNode* buffer_var) { |
388 | if (inlined_buffer_->data.get() == buffer_var) { |
389 | this->has_opaque_access = true; |
390 | } |
391 | } |
392 | |
393 | /*! |
394 | * \brief The buffer to be inlined is not allowed to be region matched. |
395 | * This method checks if a block has the disallowed behavior of buffer region match. |
396 | * \param block The block to be checked |
397 | */ |
398 | void CheckMatchBufferRegion(const BlockNode* block) { |
399 | for (const MatchBufferRegion& match_buffer_region : block->match_buffers) { |
400 | const Buffer& matched = match_buffer_region->source->buffer; |
401 | if (matched.same_as(inlined_buffer_)) { |
402 | this->has_opaque_access = true; |
403 | } |
404 | } |
405 | } |
406 | |
407 | protected: |
408 | /*! \brief The buffer to be inlined */ |
409 | Buffer inlined_buffer_{nullptr}; |
410 | /*! \brief The body of the block to be inlined */ |
411 | const BufferStoreNode* inlined_store_{nullptr}; |
412 | /*! \brief The scope root */ |
413 | StmtSRef scope_root_sref_{nullptr}; |
414 | /*! \brief Maps a buffer's data field to itself */ |
415 | Map<Var, Buffer> buffer_var_map_; |
416 | /*! \brief The indices used for indexing the buffer to be inlined */ |
417 | std::vector<const VarNode*> idx_vars_; |
418 | /*! \brief The mapping to substitute index variables to PrimExprs */ |
419 | std::unordered_map<const VarNode*, PrimExpr> idx_sub_; |
420 | |
421 | public: |
422 | /*! |
423 | * \brief The Stmt to be replaced when removing the leaf block |
424 | * \note The pair (src_stmt, tgt_stmt) are produced by LeafBlockRemovalPlan to indicate a |
425 | * transformation on top of the input AST. We take this approach to avoid changing the AST twice |
426 | */ |
427 | Stmt src_stmt{nullptr}; |
428 | /*! \brief The Stmt to be replaced to when removing the leaf block */ |
429 | Stmt tgt_stmt{nullptr}; |
430 | /*! \brief The reuse mapping of block srefs */ |
431 | Map<Block, Block> block_reuse; |
432 | /*! \brief Indicates if there is any opaque access of the inlined buffer */ |
433 | bool has_opaque_access{false}; |
434 | }; |
435 | |
436 | /*! |
437 | * \brief Helper to inline the producer block into its consumer(s) |
438 | * The derived class implements the following functionalities: |
439 | * 1) Substitute `BufferLoad` on the buffer to be inlined |
440 | * to its value calculation in the producer block |
441 | * 2) Analyze the producer block to determine the remapping of index variables |
442 | */ |
443 | class ComputeInliner : public BaseInliner { |
444 | public: |
445 | explicit ComputeInliner(const Buffer& inlined_buffer, const Block& producer_block, |
446 | const StmtSRef& scope_root_sref) |
447 | : BaseInliner(inlined_buffer, producer_block, scope_root_sref) {} |
448 | |
449 | bool BodyPatternAllowInline(const Block& producer_block) { |
450 | if (inlined_store_ == nullptr) { |
451 | return false; |
452 | } |
453 | |
454 | int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_)); |
455 | if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) { |
456 | return false; |
457 | } |
458 | return true; |
459 | } |
460 | |
461 | private: |
462 | using BaseInliner::VisitExpr_; |
463 | using BaseInliner::VisitStmt_; |
464 | |
465 | PrimExpr VisitExpr_(const BufferLoadNode* _load) final { |
466 | BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_load)); |
467 | if (!load->buffer.same_as(inlined_buffer_)) { |
468 | return std::move(load); |
469 | } |
470 | return ReplaceInlinedBuffer(std::move(load)); |
471 | } |
472 | |
473 | PrimExpr ReplaceInlinedBuffer(BufferLoad load) { |
474 | SetIndexSubstitution(load->indices); |
475 | return Substitute(inlined_store_->value, idx_sub_); |
476 | } |
477 | |
478 | /*! |
479 | * \brief Check if the indices are atomic distinct variables and the access is n-dimensional. |
480 | * If so, set `self->idx_vars_` properly. |
481 | * \param indices The indices to be extracted |
482 | * \param expected_ndim The expected ndim of the access |
483 | * \return A boolean flag indicating if the check is successful |
484 | */ |
485 | bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) { |
486 | int n = indices.size(); |
487 | if (n != expected_ndim) { |
488 | // Failure: dimension mismatch |
489 | return false; |
490 | } |
491 | std::vector<const VarNode*> result; |
492 | result.reserve(n); |
493 | for (const PrimExpr& i : indices) { |
494 | if (const auto* var = i.as<VarNode>()) { |
495 | result.push_back(var); |
496 | } else { |
497 | // Failure: indexing expression is not a variable |
498 | return false; |
499 | } |
500 | } |
501 | using DistinctSet = std::unordered_set<const VarNode*>; |
502 | int n_distinct = DistinctSet(result.begin(), result.end()).size(); |
503 | if (n != n_distinct) { |
504 | // Failure: indexing variables are not distinct |
505 | return false; |
506 | } |
507 | if (idx_vars_.empty()) { |
508 | idx_vars_ = std::move(result); |
509 | } else if (!support::ArrayWithSameContent(idx_vars_, result)) { |
510 | // Failure: indexing variables are not consitent in different BufferLoads |
511 | return false; |
512 | } |
513 | return true; |
514 | } |
515 | |
516 | /*! |
517 | * \brief Set the mapping of index substitution `self->idx_sub_` |
518 | * \param indices The expressions that the corresponding index variables are replaced to |
519 | */ |
520 | void SetIndexSubstitution(const Array<PrimExpr>& indices) { |
521 | ICHECK_EQ(indices.size(), idx_vars_.size()); |
522 | int n = idx_vars_.size(); |
523 | idx_sub_.reserve(n); |
524 | for (int i = 0; i < n; ++i) { |
525 | idx_sub_[idx_vars_[i]] = indices[i]; |
526 | } |
527 | } |
528 | }; |
529 | |
530 | /*! |
531 | * \brief Helper to inline the consumer block into its producer |
532 | * The derived class implements the following functionalities: |
533 | * 1) Analyze the consumer block to determine the remapping of index variables |
534 | * 2) Substitute `BufferStore` of the buffer to be inlined, |
535 | * replacing it with direct writing to the buffer that consumer writes |
536 | */ |
537 | class ReverseComputeInliner : public BaseInliner { |
538 | class Substituter : public StmtExprMutator { |
539 | public: |
540 | explicit Substituter(ReverseComputeInliner* self) : self_(self) {} |
541 | |
542 | private: |
543 | PrimExpr VisitExpr_(const VarNode* var) final { |
544 | auto it = self_->idx_sub_.find(var); |
545 | ICHECK(it != self_->idx_sub_.end()); |
546 | return (*it).second; |
547 | } |
548 | |
549 | PrimExpr VisitExpr_(const BufferLoadNode* _load) final { |
550 | BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_load)); |
551 | return load->buffer.same_as(self_->inlined_buffer_) ? self_->producer_rhs_ : load; |
552 | } |
553 | |
554 | ReverseComputeInliner* self_; |
555 | }; |
556 | |
557 | public: |
558 | explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block, |
559 | const BlockRealize& consumer_block_realize, |
560 | const StmtSRef& scope_root_sref, const IRModule& mod) |
561 | : BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref), |
562 | producer_block_(producer_block), |
563 | consumer_block_(consumer_block_realize->block.get()), |
564 | mod_(mod) { |
565 | // Initialize the predicates to ensure consumer block iters are in-bound |
566 | consumer_iter_in_bound_ = Bool(true); |
567 | for (const IterVar& iter : consumer_block_realize->block->iter_vars) { |
568 | consumer_iter_in_bound_ = |
569 | consumer_iter_in_bound_ && |
570 | (iter->var >= iter->dom->min && iter->var < iter->dom->min + iter->dom->extent); |
571 | } |
572 | } |
573 | |
574 | bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) { |
575 | const Block& consumer_block = consumer_block_realize->block; |
576 | |
577 | if (!is_one(consumer_block_realize->predicate)) { |
578 | // Failure: Predicate is the consumer block is not supported |
579 | return false; |
580 | } |
581 | if (inlined_store_ == nullptr) { |
582 | // Failure: block body is not BufferStore |
583 | return false; |
584 | } |
585 | std::vector<const BufferLoadNode*> loads = ExtractBufferLoad(inlined_buffer_, inlined_store_); |
586 | if (loads.size() == 0) { |
587 | // Failure: no BufferLoad from the `inlined_buffer_` |
588 | return false; |
589 | } |
590 | |
591 | // Collect block iter domains and update the substition map |
592 | Map<Var, Range> consumer_iter_doms; |
593 | for (const auto& iter_var : consumer_block->iter_vars) { |
594 | consumer_iter_doms.Set(iter_var->var, iter_var->dom); |
595 | // Set default mapping for unit iters |
596 | if (is_const_int(iter_var->dom->extent, 1) && is_const_int(iter_var->dom->min)) { |
597 | idx_sub_[iter_var->var.get()] = iter_var->dom->min; |
598 | } |
599 | } |
600 | |
601 | for (const BufferLoadNode* load : loads) { |
602 | if (!UpdateAndCheckIndexExprs(load->indices)) { |
603 | return false; |
604 | } |
605 | } |
606 | |
607 | auto res = arith::DetectIterMap( |
608 | /*indices=*/buffer_load_indices_, |
609 | /*input_iters=*/consumer_iter_doms, |
610 | /*predicate=*/true, |
611 | /*check_level=*/arith::IterMapLevel::Bijective, |
612 | /*analyzer=*/&analyzer_, |
613 | /*simplify_trivial_iterators=*/false); |
614 | buffer_load_iter_map_ = res->indices; |
615 | if (buffer_load_iter_map_.empty()) { |
616 | // Failure: indices of BufferLoad are not bijective affine |
617 | return false; |
618 | } |
619 | |
620 | const BufferStoreNode* producer_store = producer_block_->body.as<BufferStoreNode>(); |
621 | if (producer_store == nullptr) { |
622 | // Failure: producer block body is not BufferStore |
623 | return false; |
624 | } |
625 | CreateInverseMapping(producer_store->indices); |
626 | if (!CheckConsumerCovered()) { |
627 | // Failure: consumer block iter domains are not covered by the producer block |
628 | return false; |
629 | } |
630 | |
631 | return true; |
632 | } |
633 | |
634 | private: |
635 | using BaseInliner::VisitExpr_; |
636 | using BaseInliner::VisitStmt_; |
637 | |
638 | /*! \brief Generate the predicate after inlining based on the consumer predicate */ |
639 | PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) { |
640 | // Bind the producer block iter domains for simplification |
641 | Map<Var, PrimExpr> subst_map; |
642 | for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) { |
643 | const IterVar& iter = producer_block_realize->block->iter_vars[i]; |
644 | analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); |
645 | subst_map.Set(iter->var, producer_block_realize->iter_values[i]); |
646 | } |
647 | // Substitute the consumer block iters with the corresponding iters in the producer blocks |
648 | PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_); |
649 | // Simplify the predicate using the producer block iter domains |
650 | predicate = analyzer_.Simplify(predicate); |
651 | // Substitute the producer block iters with the its bindings since the predicate in BlockRealize |
652 | // should not contain the block iters |
653 | predicate = Substitute(predicate, subst_map); |
654 | return predicate; |
655 | } |
656 | |
657 | Stmt VisitStmt_(const BlockRealizeNode* op) final { |
658 | BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op)); |
659 | if (op->block.get() == producer_block_) { |
660 | auto new_predicate = BuildInlinedConsumerPredicate(new_block_realize.get()); |
661 | |
662 | With<arith::ConstraintContext> ctx(&analyzer_, new_predicate); |
663 | if (!analyzer_.CanProve(op->predicate)) { |
664 | // We do not allow cases where the new predicate for the inlined block cannot |
665 | // imply the original predicate in the producer block. |
666 | throw ProducerHasNonTrivialPredicateError(mod_, GetRef<BlockRealize>(op), new_predicate); |
667 | } |
668 | new_block_realize.CopyOnWrite()->predicate = new_predicate; |
669 | } |
670 | return std::move(new_block_realize); |
671 | } |
672 | |
673 | Stmt VisitStmt_(const BufferStoreNode* _store) final { |
674 | BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_store)); |
675 | if (!store->buffer.same_as(inlined_buffer_)) { |
676 | return std::move(store); |
677 | } |
678 | return ReplaceInlinedBuffer(std::move(store)); |
679 | } |
680 | |
681 | /*! |
682 | * \brief Check the consumer block iter domains are covered by the producer block iter domains |
683 | * \return Whether the consumer block iter domains are covered |
684 | */ |
685 | bool CheckConsumerCovered() { |
686 | Map<IterVar, arith::IntSet> producer_iter_doms; |
687 | for (const IterVar& iter_var : producer_block_->iter_vars) { |
688 | producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom)); |
689 | } |
690 | // For each block iter in the consumer block, find the corresponding expression in the producer |
691 | for (const IterVar& iter : consumer_block_->iter_vars) { |
692 | if (auto it = idx_sub_.find(iter->var.get()); it != idx_sub_.end()) { |
693 | const PrimExpr& producer_iter = it->second; |
694 | arith::IntSet producer_iter_range = arith::EvalSet(producer_iter, producer_iter_doms); |
695 | if (analyzer_.CanProve(producer_iter_range.min() > iter->dom->min) || |
696 | analyzer_.CanProve(producer_iter_range.max() < |
697 | iter->dom->min + iter->dom->extent - 1)) { |
698 | return false; |
699 | } |
700 | } else { |
701 | return false; |
702 | } |
703 | } |
704 | return true; |
705 | } |
706 | |
707 | /*! |
708 | * \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with |
709 | * the result. It will be later used to transform the BufferStore indices of the producer. |
710 | * \param producer_indices The BufferStore indices of the producer. |
711 | */ |
712 | void CreateInverseMapping(const Array<PrimExpr> producer_indices) { |
713 | auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, producer_indices); |
714 | for (const auto& pair : inverse_iter_map) { |
715 | idx_sub_[pair.first.get()] = pair.second; |
716 | } |
717 | } |
718 | |
719 | Stmt ReplaceInlinedBuffer(BufferStore producer) { |
720 | producer_rhs_ = producer->value; |
721 | return Substituter(this)(GetRef<BufferStore>(inlined_store_)); |
722 | } |
723 | |
724 | /*! |
725 | * \brief Extracts expressions that loads a specific buffer |
726 | * \param buffer The buffer to be loaded from |
727 | * \param from The BufferStore statement to be extracted from |
728 | * \return A list of `BufferLoad` expressions |
729 | */ |
730 | static std::vector<const BufferLoadNode*> (const Buffer& buffer, |
731 | const BufferStoreNode* from) { |
732 | struct : public ExprVisitor { |
733 | void VisitExpr_(const BufferLoadNode* load) final { |
734 | if (load->buffer.get() == buffer) { |
735 | result.push_back(load); |
736 | } |
737 | ExprVisitor::VisitExpr_(load); |
738 | } |
739 | const BufferNode* buffer; |
740 | std::vector<const BufferLoadNode*> result; |
741 | } ; |
742 | extractor.buffer = buffer.get(); |
743 | for (const PrimExpr& expr : from->indices) { |
744 | extractor(expr); |
745 | } |
746 | extractor(from->value); |
747 | return std::move(extractor.result); |
748 | } |
749 | |
750 | /*! |
751 | * \brief Update `buffer_load_indices_` with the given indices. If `buffer_load_indices_` is |
752 | * already non-empty, check it is consistent with the given indices. |
753 | * \param indices The indices |
754 | * \param expected_ndim The expected ndim of the access |
755 | * \return A boolean flag indicating if the check is successful |
756 | */ |
757 | bool UpdateAndCheckIndexExprs(const Array<PrimExpr>& indices) { |
758 | if (buffer_load_indices_.empty()) { |
759 | buffer_load_indices_ = indices; |
760 | } else if (!std::equal(buffer_load_indices_.begin(), buffer_load_indices_.end(), |
761 | indices.begin(), indices.end(), ExprDeepEqual())) { |
762 | // Failure: indices are not consistent in different BufferLoads |
763 | return false; |
764 | } |
765 | return true; |
766 | } |
767 | |
768 | /*! \brief The RHS value of the producer's BufferStore statement */ |
769 | PrimExpr producer_rhs_{nullptr}; |
770 | /*! \brief The indices of the consumer's BufferLoad */ |
771 | Array<PrimExpr> buffer_load_indices_; |
772 | /*! \brief The IterMap representing the indices of the consumer's BufferLoad */ |
773 | Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr}; |
774 | /*! \brief The producer block */ |
775 | const BlockNode* producer_block_{nullptr}; |
776 | /* \brief The consumer block */ |
777 | const BlockNode* consumer_block_{nullptr}; |
778 | /*! \brief The predicate to ensure the consumer block iters are in-bound. It will be inserted |
779 | * as the predicate of the producer block after inlining. |
780 | */ |
781 | PrimExpr consumer_iter_in_bound_{nullptr}; |
782 | /*! \brief The arithmetic analyzer */ |
783 | arith::Analyzer analyzer_; |
784 | /*! \brief The target module, only used for error reporting. */ |
785 | const IRModule& mod_; |
786 | }; |
787 | |
788 | void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, |
789 | bool check_only = false) { |
790 | const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref); |
791 | Block producer_block = GetRef<Block>(_producer_block); |
792 | HasInitBlock::Check(self->mod, producer_block); |
793 | Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); |
794 | // Step 1. Get the scope block |
795 | StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, |
796 | /*require_stage_pipeline=*/true); |
797 | // Step 2. Check completeness |
798 | CheckNotOutputBlock(self, producer_block_sref, scope_root_sref); |
799 | CheckCompleteBlock(self, producer_block_sref, scope_root_sref); |
800 | // Step 3. Analyze the block body |
801 | ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref); |
802 | if (!inliner.BodyPatternAllowInline(producer_block)) { |
803 | throw BodyAnalysisError(false, self->mod, producer_block); |
804 | } |
805 | // Step 4. Create a plan that removes the leaf block to be inlined |
806 | LeafBlockRemovalPlan(self, producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); |
807 | // Step 5. Create an AST where the leaf `producer_block_sref` points to is removed, |
808 | // and update other blocks who read from the removed block |
809 | Stmt tgt_stmt = inliner(GetRef<Stmt>(scope_root_sref->stmt)); |
810 | if (inliner.has_opaque_access) { |
811 | throw OpaqueAccessError(self->mod, scope_root_sref); |
812 | } |
813 | // Step 6. Do the real mutation on the AST and the sref tree in the schedule state |
814 | if (check_only) { |
815 | return; |
816 | } |
817 | self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); |
818 | } |
819 | |
820 | void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { |
821 | ComputeInlineImpl(self, producer_block_sref); |
822 | } |
823 | |
824 | bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) { |
825 | try { |
826 | ComputeInlineImpl(self, producer_block_sref, true); |
827 | } catch (const tvm::runtime::Error& e) { |
828 | return false; |
829 | } |
830 | return true; |
831 | } |
832 | |
833 | void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, |
834 | bool check_only = false) { |
835 | const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); |
836 | Block consumer_block = GetRef<Block>(_consumer_block); |
837 | BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref); |
838 | HasInitBlock::Check(self->mod, consumer_block); |
839 | // Step 1. Get the scope block |
840 | StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // |
841 | /*require_stage_pipeline=*/true); |
842 | Buffer inlined_buffer = |
843 | NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref); |
844 | // Step 2. Check completeness |
845 | CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); |
846 | // Step 3. Check if the consumer has a single complete producer |
847 | StmtSRef producer_block_sref = |
848 | NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); |
849 | // Step 4. Analyze the block body |
850 | ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs<BlockNode>(), |
851 | consumer_block_realize, scope_root_sref, self->mod); |
852 | if (!inliner.BodyPatternAllowInline(consumer_block_realize)) { |
853 | throw BodyAnalysisError(true, self->mod, consumer_block); |
854 | } |
855 | // Step 5. Create a plan that removes the leaf block to be inlined |
856 | LeafBlockRemovalPlan(self, consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); |
857 | // Step 6. Create an AST where the leaf `consumer_block_sref` points to is removed, |
858 | // and update other blocks who read from the removed block |
859 | Stmt tgt_stmt = inliner(GetRef<Stmt>(scope_root_sref->stmt)); |
860 | if (inliner.has_opaque_access) { |
861 | throw OpaqueAccessError(self->mod, scope_root_sref); |
862 | } |
863 | // Step 7. Do the real mutation on the AST and the sref tree in the schedule state |
864 | if (check_only) { |
865 | return; |
866 | } |
867 | self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); |
868 | } |
869 | |
870 | bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { |
871 | try { |
872 | ReverseComputeInlineImpl(self, block_sref, true); |
873 | } catch (const tvm::runtime::Error& e) { |
874 | return false; |
875 | } |
876 | return true; |
877 | } |
878 | |
879 | void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { |
880 | ReverseComputeInlineImpl(self, consumer_block_sref); |
881 | } |
882 | |
883 | /******** InstructionKind Registration ********/ |
884 | |
885 | struct ComputeInlineTraits : public UnpackedInstTraits<ComputeInlineTraits> { |
886 | static constexpr const char* kName = "ComputeInline" ; |
887 | static constexpr bool kIsPure = false; |
888 | |
889 | private: |
890 | static constexpr size_t kNumInputs = 1; |
891 | static constexpr size_t kNumAttrs = 0; |
892 | static constexpr size_t kNumDecisions = 0; |
893 | |
894 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { |
895 | return sch->ComputeInline(block_rv); |
896 | } |
897 | |
898 | static String UnpackedAsPython(Array<String> outputs, String block_rv) { |
899 | PythonAPICall py("compute_inline" ); |
900 | py.Input("block" , block_rv); |
901 | return py.Str(); |
902 | } |
903 | |
904 | template <typename> |
905 | friend struct ::tvm::tir::UnpackedInstTraits; |
906 | }; |
907 | |
908 | struct ReverseComputeInlineTraits : public UnpackedInstTraits<ReverseComputeInlineTraits> { |
909 | static constexpr const char* kName = "ReverseComputeInline" ; |
910 | static constexpr bool kIsPure = false; |
911 | |
912 | private: |
913 | static constexpr size_t kNumInputs = 1; |
914 | static constexpr size_t kNumAttrs = 0; |
915 | static constexpr size_t kNumDecisions = 0; |
916 | |
917 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { |
918 | return sch->ReverseComputeInline(block_rv); |
919 | } |
920 | |
921 | static String UnpackedAsPython(Array<String> outputs, String block_rv) { |
922 | PythonAPICall py("reverse_compute_inline" ); |
923 | py.Input("block" , block_rv); |
924 | return py.Str(); |
925 | } |
926 | |
927 | template <typename> |
928 | friend struct ::tvm::tir::UnpackedInstTraits; |
929 | }; |
930 | |
931 | TVM_REGISTER_INST_KIND_TRAITS(ComputeInlineTraits); |
932 | TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeInlineTraits); |
933 | |
934 | } // namespace tir |
935 | } // namespace tvm |
936 | |