1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Planning where buffers to be allocated and update the AST. |
22 | * \file plan_update_buffer_allocation_location.cc |
23 | */ |
24 | |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | #include <tvm/tir/transform.h> |
28 | #include <tvm/tir/var.h> |
29 | |
30 | #include "ir_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | class CollectUnmanagedAllocations : public StmtExprVisitor { |
36 | public: |
37 | void VisitStmt_(const AllocateNode* op) final { |
38 | unmanaged_allocations.insert(op->buffer_var.get()); |
39 | StmtExprVisitor::VisitStmt_(op); |
40 | } |
41 | |
42 | void VisitStmt_(const AllocateConstNode* op) final { |
43 | unmanaged_allocations.insert(op->buffer_var.get()); |
44 | StmtExprVisitor::VisitStmt_(op); |
45 | } |
46 | |
47 | /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by |
48 | * BufferAllocationLocator. */ |
49 | std::unordered_set<const VarNode*> unmanaged_allocations; |
50 | }; |
51 | |
52 | /*! \brief Collect the allocate buffer order. */ |
53 | class BufferAllocateOrderCollector : public StmtExprVisitor { |
54 | public: |
55 | static Array<Buffer> Collect(const PrimFunc& func) { |
56 | BufferAllocateOrderCollector collector; |
57 | for (const auto& kv : func->buffer_map) { |
58 | collector.buffer_alloc_recorder_.push_back(kv.second); |
59 | } |
60 | collector(func->body); |
61 | return std::move(collector.buffer_alloc_recorder_); |
62 | } |
63 | |
64 | private: |
65 | bool find(const Buffer& buf) { |
66 | return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) != |
67 | buffer_alloc_recorder_.end(); |
68 | } |
69 | |
70 | void VisitStmt_(const BlockNode* op) final { |
71 | for (const Buffer& buffer : op->alloc_buffers) { |
72 | buffer_alloc_recorder_.push_back(buffer); |
73 | } |
74 | // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes. |
75 | // These buffers only appear in read and match_buffer regions. |
76 | for (const auto& region : op->match_buffers) { |
77 | if (!find(region->source->buffer)) { |
78 | buffer_alloc_recorder_.push_back(region->source->buffer); |
79 | } |
80 | } |
81 | |
82 | StmtExprVisitor::VisitStmt_(op); |
83 | } |
84 | |
85 | void VisitExpr_(const BufferLoadNode* op) final { |
86 | if (!find(op->buffer)) { |
87 | buffer_alloc_recorder_.push_back(op->buffer); |
88 | } |
89 | StmtExprVisitor::VisitExpr_(op); |
90 | } |
91 | |
92 | void VisitStmt_(const BufferStoreNode* op) final { |
93 | if (!find(op->buffer)) { |
94 | buffer_alloc_recorder_.push_back(op->buffer); |
95 | } |
96 | StmtExprVisitor::VisitStmt_(op); |
97 | } |
98 | |
99 | /*! \brief The buffer allocated order recorder. */ |
100 | Array<Buffer> buffer_alloc_recorder_; |
101 | }; |
102 | |
103 | class BufferAllocationLocator : public StmtExprMutator { |
104 | public: |
105 | explicit BufferAllocationLocator(const PrimFunc& func) { |
106 | Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func); |
107 | // The buffer_alloc_recorder Array is used to keep the buffer allocation order |
108 | // since the buffer_lca Map is unordered. |
109 | Array<Buffer> buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func); |
110 | std::unordered_set<const VarNode*> arg_buffer_vars; |
111 | CollectUnmanagedAllocations collector; |
112 | collector(func->body); |
113 | unmanaged_allocations_ = collector.unmanaged_allocations; |
114 | |
115 | for (const Var& param : func->params) { |
116 | if (param->type_annotation.defined() && param->type_annotation.as<PointerTypeNode>()) { |
117 | unmanaged_allocations_.insert(param.get()); |
118 | } |
119 | } |
120 | |
121 | for (const auto& kv : func->buffer_map) { |
122 | const Buffer& buffer = kv.second; |
123 | arg_buffer_vars.emplace(buffer->data.get()); |
124 | buffer_data_to_buffer_.Set(buffer->data, buffer); |
125 | } |
126 | // create buffers to be allocated at each stmts |
127 | for (const auto& buffer : buffer_alloc_recorder) { |
128 | auto it = buffer_lca.find(buffer); |
129 | if (it != buffer_lca.end()) { |
130 | const StmtNode* stmt = (*it).second.get(); |
131 | if (arg_buffer_vars.count(buffer->data.get())) { |
132 | continue; |
133 | } |
134 | if (!unmanaged_allocations_.count(buffer->data.get())) { |
135 | alloc_buffers_[stmt].push_back(buffer); |
136 | } |
137 | buffer_data_to_buffer_.Set(buffer->data, buffer); |
138 | } |
139 | } |
140 | } |
141 | |
142 | private: |
143 | Stmt VisitStmt_(const ForNode* op) final { |
144 | auto it = alloc_buffers_.find(op); |
145 | if (it == alloc_buffers_.end()) { |
146 | return StmtMutator::VisitStmt_(op); |
147 | } |
148 | for (const Buffer& buf : it->second) { |
149 | buffer_data_to_buffer_.Set(buf->data, buf); |
150 | } |
151 | auto node = Downcast<For>(StmtMutator::VisitStmt_(op)); |
152 | |
153 | Array<Buffer> new_block_alloc_bufs; |
154 | for (const Buffer& buf : it->second) { |
155 | if (!unmanaged_allocations_.count(buf->data.get())) { |
156 | buffer_data_to_buffer_.erase(buf->data); |
157 | new_block_alloc_bufs.push_back(buf); |
158 | } |
159 | } |
160 | |
161 | if (new_block_alloc_bufs.size()) { |
162 | node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs); |
163 | } |
164 | |
165 | return std::move(node); |
166 | } |
167 | |
168 | Stmt VisitStmt_(const BlockNode* op) final { |
169 | ICHECK(!op->init.defined()); |
170 | Array<Buffer> alloc_buffers; |
171 | auto it = alloc_buffers_.find(op); |
172 | if (it != alloc_buffers_.end()) { |
173 | alloc_buffers = it->second; |
174 | for (const Buffer& buf : it->second) { |
175 | buffer_data_to_buffer_.Set(buf->data, buf); |
176 | } |
177 | } |
178 | for (const MatchBufferRegion match_buffer : op->match_buffers) { |
179 | const Var& target_var = match_buffer->buffer->data; |
180 | const Var& source_var = match_buffer->source->buffer->data; |
181 | ICHECK(buffer_data_to_buffer_.count(source_var)); |
182 | buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); |
183 | } |
184 | Stmt stmt = StmtMutator::VisitStmt_(op); |
185 | op = stmt.as<BlockNode>(); |
186 | ICHECK(op != nullptr); |
187 | |
188 | // No longer consider buffers created by match_buffer inside the block when updating access |
189 | // region. |
190 | for (const MatchBufferRegion match_buffer : op->match_buffers) { |
191 | const Var& target_var = match_buffer->buffer->data; |
192 | buffer_data_to_buffer_.erase(target_var); |
193 | } |
194 | // No longer consider buffers allocated inside the block when updating access region. |
195 | if (it != alloc_buffers_.end()) { |
196 | for (const Buffer& buf : it->second) { |
197 | buffer_data_to_buffer_.erase(buf->data); |
198 | } |
199 | } |
200 | |
201 | ObjectPtr<BlockNode> n = CopyOnWrite(op); |
202 | n->alloc_buffers = std::move(alloc_buffers); |
203 | // Erase buffer allocated inside the block from access region. |
204 | n->reads = RemoveRedundantBufferRegion(n->reads); |
205 | n->writes = RemoveRedundantBufferRegion(n->writes); |
206 | return Stmt(n); |
207 | } |
208 | |
209 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
210 | ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR." ; |
211 | throw; |
212 | } |
213 | |
214 | Stmt InjectOpaqueBlock(Stmt body, const Array<Buffer>& alloc_buffers) { |
215 | ICHECK(!alloc_buffers.empty()); |
216 | Block opaque_block(/*iter_vars=*/{}, |
217 | /*reads=*/{}, |
218 | /*writes=*/{}, |
219 | /*name_hint=*/"" , |
220 | /*body=*/std::move(body), |
221 | /*init=*/NullOpt, |
222 | /*alloc_buffers=*/alloc_buffers); |
223 | ObjectPtr<BlockNode> n = CopyOnWrite(opaque_block.get()); |
224 | Array<Array<BufferRegion>> access = |
225 | GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); |
226 | n->reads = access[0]; |
227 | n->writes = access[1]; |
228 | BlockRealize realize({}, Bool(true), Block(n)); |
229 | return std::move(realize); |
230 | } |
231 | |
232 | Array<BufferRegion> RemoveRedundantBufferRegion(const Array<BufferRegion>& region) const { |
233 | Array<BufferRegion> result; |
234 | for (const BufferRegion& buffer_region : region) { |
235 | if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { |
236 | result.push_back(buffer_region); |
237 | } |
238 | } |
239 | return result; |
240 | } |
241 | |
242 | /*! \brief The map from stmt to the buffers to be allocated under it. */ |
243 | std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_; |
244 | /*! \brief The buffer already allocated during recursive visiting. */ |
245 | Map<Var, Buffer> buffer_data_to_buffer_; |
246 | /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */ |
247 | std::unordered_set<const VarNode*> unmanaged_allocations_; |
248 | }; |
249 | |
250 | PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { |
251 | // Only apply this pass to TIR that is not from TE schedules |
252 | if (!IsFromLegacyTESchedule(func)) { |
253 | auto fptr = func.CopyOnWrite(); |
254 | BufferAllocationLocator locator(func); |
255 | fptr->body = locator(fptr->body); |
256 | return func; |
257 | } else { |
258 | return func; |
259 | } |
260 | } |
261 | |
262 | namespace transform { |
263 | |
264 | Pass PlanAndUpdateBufferAllocationLocation() { |
265 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
266 | return PlanAndUpdateBufferAllocationLocation(std::move(f)); |
267 | }; |
268 | return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation" , {}); |
269 | } |
270 | |
271 | TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation" ) |
272 | .set_body_typed(PlanAndUpdateBufferAllocationLocation); |
273 | |
274 | } // namespace transform |
275 | |
276 | } // namespace tir |
277 | } // namespace tvm |
278 | |