1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <optional> |
20 | #include <unordered_set> |
21 | |
22 | #include "../utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | /*! |
28 | * \brief Collect the block and index where the buffer is read. |
29 | * \note The buffer is expected to be read by only one BufferLoad |
30 | */ |
31 | class BufferReadPosCollector : public StmtExprVisitor { |
32 | public: |
33 | explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {} |
34 | |
35 | const std::pair<Block, int>& GetBufferLocation() const { return buffer_loc_; } |
36 | |
37 | const Optional<IndexMap> GetBufferIndexMap() const { return buffer_index_map_; } |
38 | |
39 | private: |
40 | void VisitStmt_(const ForNode* op) final { |
41 | loop_stack_.push_back(GetRef<For>(op)); |
42 | StmtVisitor::VisitStmt_(op); |
43 | loop_stack_.pop_back(); |
44 | } |
45 | |
46 | void VisitStmt_(const BlockRealizeNode* op) final { |
47 | BlockRealize outer_block_realize = GetRef<BlockRealize>(op); |
48 | std::swap(outer_block_realize, cur_realize_); |
49 | StmtVisitor::VisitStmt_(op); |
50 | std::swap(cur_realize_, outer_block_realize); |
51 | } |
52 | |
53 | void VisitExpr_(const BufferLoadNode* op) final { |
54 | CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block" ; |
55 | |
56 | const Buffer& buffer = op->buffer; |
57 | if (buffer_ == buffer.get()) { |
58 | Map<Var, PrimExpr> subst_map; |
59 | for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) { |
60 | const Var& var = cur_realize_->block->iter_vars[i]->var; |
61 | const PrimExpr& value = cur_realize_->iter_values[i]; |
62 | subst_map.Set(var, value); |
63 | } |
64 | Array<PrimExpr> subst_indices; |
65 | for (const PrimExpr& e : op->indices) { |
66 | subst_indices.push_back(Substitute(e, subst_map)); |
67 | } |
68 | buffer_index_map_ = SuggestIndexMap(/*buffer=*/buffer, // |
69 | /*indices=*/subst_indices, // |
70 | /*loops=*/loop_stack_, // |
71 | /*predicate=*/cur_realize_->predicate, // |
72 | /*analyzer=*/&analyzer_); |
73 | int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer); |
74 | ICHECK(buffer_index != -1); |
75 | buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index); |
76 | } |
77 | } |
78 | |
79 | static int GetReadBufferIndex(const Block& block, const Buffer& buffer) { |
80 | for (size_t i = 0; i < block->reads.size(); i++) { |
81 | if (block->reads[i]->buffer.same_as(buffer)) { |
82 | return i; |
83 | } |
84 | } |
85 | return -1; |
86 | } |
87 | |
88 | private: |
89 | /*! \brief The buffer of interest. */ |
90 | const BufferNode* buffer_; |
91 | /*! \brief The block that consumes the buffer and the corresponding read index. */ |
92 | std::pair<Block, int> buffer_loc_; |
93 | /*! \brief The proposed IndexMap. */ |
94 | Optional<IndexMap> buffer_index_map_; |
95 | |
96 | /*! \brief Loop stack for calculating IndexMap. */ |
97 | Array<For> loop_stack_; |
98 | /*! \brief Arithmetic analyzer. */ |
99 | arith::Analyzer analyzer_; |
100 | /*! \brief Current BlockRealize scope, used in recursive visit */ |
101 | BlockRealize cur_realize_; |
102 | }; |
103 | |
104 | class LayoutFreeBufferCollector : public StmtVisitor { |
105 | public: |
106 | void VisitStmt_(const BlockNode* block) final { |
107 | StmtVisitor::VisitStmt_(block); |
108 | if (Optional<ObjectRef> ann = block->annotations.Get("layout_free_placeholders" )) { |
109 | for (Buffer buffer : Downcast<Array<Buffer>>(ann)) { |
110 | buffers.insert(buffer); |
111 | } |
112 | } |
113 | } |
114 | |
115 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffers; |
116 | }; |
117 | |
118 | Array<Buffer> CollectLayoutFreeBuffers(const PrimFuncNode* func) { |
119 | // Only rewrite PrimFuncs with attr "layout_free_buffers" |
120 | Array<Integer> layout_free_buffer_index = |
121 | func->GetAttr(attr::layout_free_buffers, Array<Integer>()).value(); |
122 | |
123 | Array<Buffer> layout_free_buffers; |
124 | for (const Integer& index : layout_free_buffer_index) { |
125 | ICHECK(static_cast<size_t>(index->value) < func->params.size()); |
126 | const Var& param = func->params[index->value]; |
127 | layout_free_buffers.push_back(func->buffer_map.at(param)); |
128 | } |
129 | |
130 | LayoutFreeBufferCollector collector; |
131 | collector(func->body); |
132 | |
133 | for (auto buf : collector.buffers) { |
134 | layout_free_buffers.push_back(buf); |
135 | } |
136 | return layout_free_buffers; |
137 | } |
138 | |
139 | std::optional<std::tuple<Block, int, IndexMap>> GetSuggestedIndexMap( |
140 | Buffer buffer, const PrimFuncNode* prim_func) { |
141 | BufferReadPosCollector collector(buffer); |
142 | collector(prim_func->body); |
143 | |
144 | const auto& index_map = collector.GetBufferIndexMap(); |
145 | |
146 | if (!index_map.defined() || !index_map) { |
147 | return std::nullopt; |
148 | } |
149 | |
150 | const auto& [anchor_block, buffer_index] = collector.GetBufferLocation(); |
151 | |
152 | return std::make_tuple(anchor_block, buffer_index, index_map.value()); |
153 | } |
154 | |
155 | /*! \brief Get a chain of cache-read blocks, starting from the one consuming buf. */ |
156 | std::vector<std::string> GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) { |
157 | class BufferReadChainCollector : public StmtVisitor { |
158 | public: |
159 | explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {} |
160 | |
161 | void VisitStmt_(const BlockNode* op) final { |
162 | // Check if this block is doing cache_read or a similar operation that consumes cur_buffer_. |
163 | if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 && |
164 | op->reads[0]->buffer.get() == cur_buffer_) { |
165 | cache_read_chain.push_back(op->name_hint); |
166 | cur_buffer_ = op->writes[0]->buffer.get(); |
167 | } |
168 | StmtVisitor::VisitStmt_(op); |
169 | } |
170 | |
171 | std::vector<std::string> cache_read_chain; |
172 | |
173 | private: |
174 | const BufferNode* cur_buffer_; |
175 | }; |
176 | |
177 | BufferReadChainCollector collector(buf); |
178 | collector(prim_func->body); |
179 | return collector.cache_read_chain; |
180 | } |
181 | |
182 | bool RewriteLayout(const Schedule& sch) { |
183 | std::vector<std::pair<StmtSRef, String>> results; |
184 | auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) { |
185 | BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global" ); |
186 | sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); |
187 | }; |
188 | |
189 | for (const auto& [g_var, base_func] : sch->mod()->functions) { |
190 | const String& func_name = g_var->name_hint; |
191 | const auto* prim_func = base_func.as<PrimFuncNode>(); |
192 | // Only consider PrimFunc |
193 | if (prim_func == nullptr) { |
194 | continue; |
195 | } |
196 | |
197 | for (auto buffer : CollectLayoutFreeBuffers(prim_func)) { |
198 | const auto cache_read_chain = GetCacheReadChain(buffer, prim_func); |
199 | if (cache_read_chain.empty()) { |
200 | // The common case, where the layout-free buffer is directly consumed by an anchor op such |
201 | // as conv2d or dense. |
202 | auto tup_opt = GetSuggestedIndexMap(buffer, prim_func); |
203 | if (tup_opt == std::nullopt) continue; |
204 | |
205 | auto [anchor_block, buffer_index, index_map] = *tup_opt; |
206 | auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name); |
207 | add_layout_rewrite_block(anchor_block_rv, buffer_index); |
208 | sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map, |
209 | NullOpt); |
210 | } else { |
211 | // When the layout-free buffer is consumed by cache_read, we need to find the index map |
212 | // for a cache-read buffer that is directly consumed by an anchor op. The last buffer |
213 | // in cache_read_chain corresponds to that buffer. |
214 | Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name)); |
215 | ICHECK_EQ(cache_read_block->writes.size(), 1); |
216 | auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func); |
217 | if (tup_opt == std::nullopt) continue; |
218 | |
219 | auto [anchor_block, buffer_index, index_map] = *tup_opt; |
220 | // Transform the layout of the last cache-read buffer. |
221 | sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index, |
222 | BufferIndexType::kRead, index_map, NullOpt); |
223 | |
224 | // Propagate the layout transformation over cache_read_chain, starting from |
225 | // the next-to-last cache-read buffer. |
226 | for (int i = static_cast<int>(cache_read_chain.size()) - 1; i >= 0; --i) { |
227 | BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name); |
228 | if (i == 0) { |
229 | // Before the first cache_read that consumes the layout-free buffer, insert |
230 | // a layout-rewrite block. Another cache-read buffer is added, and its layout is |
231 | // transformed by TransformLayout below. |
232 | add_layout_rewrite_block(cache_read_block_rv, 0); |
233 | } |
234 | sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt); |
235 | } |
236 | } |
237 | } |
238 | } |
239 | return true; |
240 | } |
241 | |
242 | } // namespace tir |
243 | |
244 | namespace meta_schedule { |
245 | /*! \brief Layout Rewrite. */ |
246 | class RewriteLayoutNode : public PostprocNode { |
247 | public: |
248 | // Inherited from PostprocNode |
249 | void InitializeWithTuneContext(const TuneContext& context) final {} |
250 | |
251 | // Inherited from PostprocNode |
252 | bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); } |
253 | |
254 | Postproc Clone() const { |
255 | ObjectPtr<RewriteLayoutNode> n = make_object<RewriteLayoutNode>(*this); |
256 | return Postproc(n); |
257 | } |
258 | |
259 | static constexpr const char* _type_key = "meta_schedule.RewriteLayout" ; |
260 | TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode); |
261 | }; |
262 | |
263 | Postproc Postproc::RewriteLayout() { |
264 | auto n = make_object<RewriteLayoutNode>(); |
265 | return Postproc(n); |
266 | } |
267 | |
268 | TVM_REGISTER_NODE_TYPE(RewriteLayoutNode); |
269 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout" ).set_body_typed(Postproc::RewriteLayout); |
270 | |
271 | } // namespace meta_schedule |
272 | } // namespace tvm |
273 | |