1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <optional>
20#include <unordered_set>
21
22#include "../utils.h"
23
24namespace tvm {
25namespace tir {
26
27/*!
28 * \brief Collect the block and index where the buffer is read.
29 * \note The buffer is expected to be read by only one BufferLoad
30 */
31class BufferReadPosCollector : public StmtExprVisitor {
32 public:
33 explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {}
34
35 const std::pair<Block, int>& GetBufferLocation() const { return buffer_loc_; }
36
37 const Optional<IndexMap> GetBufferIndexMap() const { return buffer_index_map_; }
38
39 private:
40 void VisitStmt_(const ForNode* op) final {
41 loop_stack_.push_back(GetRef<For>(op));
42 StmtVisitor::VisitStmt_(op);
43 loop_stack_.pop_back();
44 }
45
46 void VisitStmt_(const BlockRealizeNode* op) final {
47 BlockRealize outer_block_realize = GetRef<BlockRealize>(op);
48 std::swap(outer_block_realize, cur_realize_);
49 StmtVisitor::VisitStmt_(op);
50 std::swap(cur_realize_, outer_block_realize);
51 }
52
53 void VisitExpr_(const BufferLoadNode* op) final {
54 CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block";
55
56 const Buffer& buffer = op->buffer;
57 if (buffer_ == buffer.get()) {
58 Map<Var, PrimExpr> subst_map;
59 for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) {
60 const Var& var = cur_realize_->block->iter_vars[i]->var;
61 const PrimExpr& value = cur_realize_->iter_values[i];
62 subst_map.Set(var, value);
63 }
64 Array<PrimExpr> subst_indices;
65 for (const PrimExpr& e : op->indices) {
66 subst_indices.push_back(Substitute(e, subst_map));
67 }
68 buffer_index_map_ = SuggestIndexMap(/*buffer=*/buffer, //
69 /*indices=*/subst_indices, //
70 /*loops=*/loop_stack_, //
71 /*predicate=*/cur_realize_->predicate, //
72 /*analyzer=*/&analyzer_);
73 int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer);
74 ICHECK(buffer_index != -1);
75 buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index);
76 }
77 }
78
79 static int GetReadBufferIndex(const Block& block, const Buffer& buffer) {
80 for (size_t i = 0; i < block->reads.size(); i++) {
81 if (block->reads[i]->buffer.same_as(buffer)) {
82 return i;
83 }
84 }
85 return -1;
86 }
87
88 private:
89 /*! \brief The buffer of interest. */
90 const BufferNode* buffer_;
91 /*! \brief The block that consumes the buffer and the corresponding read index. */
92 std::pair<Block, int> buffer_loc_;
93 /*! \brief The proposed IndexMap. */
94 Optional<IndexMap> buffer_index_map_;
95
96 /*! \brief Loop stack for calculating IndexMap. */
97 Array<For> loop_stack_;
98 /*! \brief Arithmetic analyzer. */
99 arith::Analyzer analyzer_;
100 /*! \brief Current BlockRealize scope, used in recursive visit */
101 BlockRealize cur_realize_;
102};
103
104class LayoutFreeBufferCollector : public StmtVisitor {
105 public:
106 void VisitStmt_(const BlockNode* block) final {
107 StmtVisitor::VisitStmt_(block);
108 if (Optional<ObjectRef> ann = block->annotations.Get("layout_free_placeholders")) {
109 for (Buffer buffer : Downcast<Array<Buffer>>(ann)) {
110 buffers.insert(buffer);
111 }
112 }
113 }
114
115 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffers;
116};
117
118Array<Buffer> CollectLayoutFreeBuffers(const PrimFuncNode* func) {
119 // Only rewrite PrimFuncs with attr "layout_free_buffers"
120 Array<Integer> layout_free_buffer_index =
121 func->GetAttr(attr::layout_free_buffers, Array<Integer>()).value();
122
123 Array<Buffer> layout_free_buffers;
124 for (const Integer& index : layout_free_buffer_index) {
125 ICHECK(static_cast<size_t>(index->value) < func->params.size());
126 const Var& param = func->params[index->value];
127 layout_free_buffers.push_back(func->buffer_map.at(param));
128 }
129
130 LayoutFreeBufferCollector collector;
131 collector(func->body);
132
133 for (auto buf : collector.buffers) {
134 layout_free_buffers.push_back(buf);
135 }
136 return layout_free_buffers;
137}
138
139std::optional<std::tuple<Block, int, IndexMap>> GetSuggestedIndexMap(
140 Buffer buffer, const PrimFuncNode* prim_func) {
141 BufferReadPosCollector collector(buffer);
142 collector(prim_func->body);
143
144 const auto& index_map = collector.GetBufferIndexMap();
145
146 if (!index_map.defined() || !index_map) {
147 return std::nullopt;
148 }
149
150 const auto& [anchor_block, buffer_index] = collector.GetBufferLocation();
151
152 return std::make_tuple(anchor_block, buffer_index, index_map.value());
153}
154
155/*! \brief Get a chain of cache-read blocks, starting from the one consuming buf. */
156std::vector<std::string> GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) {
157 class BufferReadChainCollector : public StmtVisitor {
158 public:
159 explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {}
160
161 void VisitStmt_(const BlockNode* op) final {
162 // Check if this block is doing cache_read or a similar operation that consumes cur_buffer_.
163 if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 &&
164 op->reads[0]->buffer.get() == cur_buffer_) {
165 cache_read_chain.push_back(op->name_hint);
166 cur_buffer_ = op->writes[0]->buffer.get();
167 }
168 StmtVisitor::VisitStmt_(op);
169 }
170
171 std::vector<std::string> cache_read_chain;
172
173 private:
174 const BufferNode* cur_buffer_;
175 };
176
177 BufferReadChainCollector collector(buf);
178 collector(prim_func->body);
179 return collector.cache_read_chain;
180}
181
182bool RewriteLayout(const Schedule& sch) {
183 std::vector<std::pair<StmtSRef, String>> results;
184 auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) {
185 BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global");
186 sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
187 };
188
189 for (const auto& [g_var, base_func] : sch->mod()->functions) {
190 const String& func_name = g_var->name_hint;
191 const auto* prim_func = base_func.as<PrimFuncNode>();
192 // Only consider PrimFunc
193 if (prim_func == nullptr) {
194 continue;
195 }
196
197 for (auto buffer : CollectLayoutFreeBuffers(prim_func)) {
198 const auto cache_read_chain = GetCacheReadChain(buffer, prim_func);
199 if (cache_read_chain.empty()) {
200 // The common case, where the layout-free buffer is directly consumed by an anchor op such
201 // as conv2d or dense.
202 auto tup_opt = GetSuggestedIndexMap(buffer, prim_func);
203 if (tup_opt == std::nullopt) continue;
204
205 auto [anchor_block, buffer_index, index_map] = *tup_opt;
206 auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name);
207 add_layout_rewrite_block(anchor_block_rv, buffer_index);
208 sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map,
209 NullOpt);
210 } else {
211 // When the layout-free buffer is consumed by cache_read, we need to find the index map
212 // for a cache-read buffer that is directly consumed by an anchor op. The last buffer
213 // in cache_read_chain corresponds to that buffer.
214 Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name));
215 ICHECK_EQ(cache_read_block->writes.size(), 1);
216 auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func);
217 if (tup_opt == std::nullopt) continue;
218
219 auto [anchor_block, buffer_index, index_map] = *tup_opt;
220 // Transform the layout of the last cache-read buffer.
221 sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index,
222 BufferIndexType::kRead, index_map, NullOpt);
223
224 // Propagate the layout transformation over cache_read_chain, starting from
225 // the next-to-last cache-read buffer.
226 for (int i = static_cast<int>(cache_read_chain.size()) - 1; i >= 0; --i) {
227 BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name);
228 if (i == 0) {
229 // Before the first cache_read that consumes the layout-free buffer, insert
230 // a layout-rewrite block. Another cache-read buffer is added, and its layout is
231 // transformed by TransformLayout below.
232 add_layout_rewrite_block(cache_read_block_rv, 0);
233 }
234 sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt);
235 }
236 }
237 }
238 }
239 return true;
240}
241
242} // namespace tir
243
244namespace meta_schedule {
245/*! \brief Layout Rewrite. */
246class RewriteLayoutNode : public PostprocNode {
247 public:
248 // Inherited from PostprocNode
249 void InitializeWithTuneContext(const TuneContext& context) final {}
250
251 // Inherited from PostprocNode
252 bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); }
253
254 Postproc Clone() const {
255 ObjectPtr<RewriteLayoutNode> n = make_object<RewriteLayoutNode>(*this);
256 return Postproc(n);
257 }
258
259 static constexpr const char* _type_key = "meta_schedule.RewriteLayout";
260 TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode);
261};
262
263Postproc Postproc::RewriteLayout() {
264 auto n = make_object<RewriteLayoutNode>();
265 return Postproc(n);
266}
267
268TVM_REGISTER_NODE_TYPE(RewriteLayoutNode);
269TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout").set_body_typed(Postproc::RewriteLayout);
270
271} // namespace meta_schedule
272} // namespace tvm
273