1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Planning where buffers to be allocated and update the AST.
22 * \file plan_update_buffer_allocation_location.cc
23 */
24
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/stmt_functor.h>
27#include <tvm/tir/transform.h>
28#include <tvm/tir/var.h>
29
30#include "ir_utils.h"
31
32namespace tvm {
33namespace tir {
34
35class CollectUnmanagedAllocations : public StmtExprVisitor {
36 public:
37 void VisitStmt_(const AllocateNode* op) final {
38 unmanaged_allocations.insert(op->buffer_var.get());
39 StmtExprVisitor::VisitStmt_(op);
40 }
41
42 void VisitStmt_(const AllocateConstNode* op) final {
43 unmanaged_allocations.insert(op->buffer_var.get());
44 StmtExprVisitor::VisitStmt_(op);
45 }
46
47 /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by
48 * BufferAllocationLocator. */
49 std::unordered_set<const VarNode*> unmanaged_allocations;
50};
51
52/*! \brief Collect the allocate buffer order. */
53class BufferAllocateOrderCollector : public StmtExprVisitor {
54 public:
55 static Array<Buffer> Collect(const PrimFunc& func) {
56 BufferAllocateOrderCollector collector;
57 for (const auto& kv : func->buffer_map) {
58 collector.buffer_alloc_recorder_.push_back(kv.second);
59 }
60 collector(func->body);
61 return std::move(collector.buffer_alloc_recorder_);
62 }
63
64 private:
65 bool find(const Buffer& buf) {
66 return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) !=
67 buffer_alloc_recorder_.end();
68 }
69
70 void VisitStmt_(const BlockNode* op) final {
71 for (const Buffer& buffer : op->alloc_buffers) {
72 buffer_alloc_recorder_.push_back(buffer);
73 }
74 // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes.
75 // These buffers only appear in read and match_buffer regions.
76 for (const auto& region : op->match_buffers) {
77 if (!find(region->source->buffer)) {
78 buffer_alloc_recorder_.push_back(region->source->buffer);
79 }
80 }
81
82 StmtExprVisitor::VisitStmt_(op);
83 }
84
85 void VisitExpr_(const BufferLoadNode* op) final {
86 if (!find(op->buffer)) {
87 buffer_alloc_recorder_.push_back(op->buffer);
88 }
89 StmtExprVisitor::VisitExpr_(op);
90 }
91
92 void VisitStmt_(const BufferStoreNode* op) final {
93 if (!find(op->buffer)) {
94 buffer_alloc_recorder_.push_back(op->buffer);
95 }
96 StmtExprVisitor::VisitStmt_(op);
97 }
98
99 /*! \brief The buffer allocated order recorder. */
100 Array<Buffer> buffer_alloc_recorder_;
101};
102
103class BufferAllocationLocator : public StmtExprMutator {
104 public:
105 explicit BufferAllocationLocator(const PrimFunc& func) {
106 Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
107 // The buffer_alloc_recorder Array is used to keep the buffer allocation order
108 // since the buffer_lca Map is unordered.
109 Array<Buffer> buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func);
110 std::unordered_set<const VarNode*> arg_buffer_vars;
111 CollectUnmanagedAllocations collector;
112 collector(func->body);
113 unmanaged_allocations_ = collector.unmanaged_allocations;
114
115 for (const Var& param : func->params) {
116 if (param->type_annotation.defined() && param->type_annotation.as<PointerTypeNode>()) {
117 unmanaged_allocations_.insert(param.get());
118 }
119 }
120
121 for (const auto& kv : func->buffer_map) {
122 const Buffer& buffer = kv.second;
123 arg_buffer_vars.emplace(buffer->data.get());
124 buffer_data_to_buffer_.Set(buffer->data, buffer);
125 }
126 // create buffers to be allocated at each stmts
127 for (const auto& buffer : buffer_alloc_recorder) {
128 auto it = buffer_lca.find(buffer);
129 if (it != buffer_lca.end()) {
130 const StmtNode* stmt = (*it).second.get();
131 if (arg_buffer_vars.count(buffer->data.get())) {
132 continue;
133 }
134 if (!unmanaged_allocations_.count(buffer->data.get())) {
135 alloc_buffers_[stmt].push_back(buffer);
136 }
137 buffer_data_to_buffer_.Set(buffer->data, buffer);
138 }
139 }
140 }
141
142 private:
143 Stmt VisitStmt_(const ForNode* op) final {
144 auto it = alloc_buffers_.find(op);
145 if (it == alloc_buffers_.end()) {
146 return StmtMutator::VisitStmt_(op);
147 }
148 for (const Buffer& buf : it->second) {
149 buffer_data_to_buffer_.Set(buf->data, buf);
150 }
151 auto node = Downcast<For>(StmtMutator::VisitStmt_(op));
152
153 Array<Buffer> new_block_alloc_bufs;
154 for (const Buffer& buf : it->second) {
155 if (!unmanaged_allocations_.count(buf->data.get())) {
156 buffer_data_to_buffer_.erase(buf->data);
157 new_block_alloc_bufs.push_back(buf);
158 }
159 }
160
161 if (new_block_alloc_bufs.size()) {
162 node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs);
163 }
164
165 return std::move(node);
166 }
167
168 Stmt VisitStmt_(const BlockNode* op) final {
169 ICHECK(!op->init.defined());
170 Array<Buffer> alloc_buffers;
171 auto it = alloc_buffers_.find(op);
172 if (it != alloc_buffers_.end()) {
173 alloc_buffers = it->second;
174 for (const Buffer& buf : it->second) {
175 buffer_data_to_buffer_.Set(buf->data, buf);
176 }
177 }
178 for (const MatchBufferRegion match_buffer : op->match_buffers) {
179 const Var& target_var = match_buffer->buffer->data;
180 const Var& source_var = match_buffer->source->buffer->data;
181 ICHECK(buffer_data_to_buffer_.count(source_var));
182 buffer_data_to_buffer_.Set(target_var, match_buffer->buffer);
183 }
184 Stmt stmt = StmtMutator::VisitStmt_(op);
185 op = stmt.as<BlockNode>();
186 ICHECK(op != nullptr);
187
188 // No longer consider buffers created by match_buffer inside the block when updating access
189 // region.
190 for (const MatchBufferRegion match_buffer : op->match_buffers) {
191 const Var& target_var = match_buffer->buffer->data;
192 buffer_data_to_buffer_.erase(target_var);
193 }
194 // No longer consider buffers allocated inside the block when updating access region.
195 if (it != alloc_buffers_.end()) {
196 for (const Buffer& buf : it->second) {
197 buffer_data_to_buffer_.erase(buf->data);
198 }
199 }
200
201 ObjectPtr<BlockNode> n = CopyOnWrite(op);
202 n->alloc_buffers = std::move(alloc_buffers);
203 // Erase buffer allocated inside the block from access region.
204 n->reads = RemoveRedundantBufferRegion(n->reads);
205 n->writes = RemoveRedundantBufferRegion(n->writes);
206 return Stmt(n);
207 }
208
209 Stmt VisitStmt_(const BufferRealizeNode* op) final {
210 ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
211 throw;
212 }
213
214 Stmt InjectOpaqueBlock(Stmt body, const Array<Buffer>& alloc_buffers) {
215 ICHECK(!alloc_buffers.empty());
216 Block opaque_block(/*iter_vars=*/{},
217 /*reads=*/{},
218 /*writes=*/{},
219 /*name_hint=*/"",
220 /*body=*/std::move(body),
221 /*init=*/NullOpt,
222 /*alloc_buffers=*/alloc_buffers);
223 ObjectPtr<BlockNode> n = CopyOnWrite(opaque_block.get());
224 Array<Array<BufferRegion>> access =
225 GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_);
226 n->reads = access[0];
227 n->writes = access[1];
228 BlockRealize realize({}, Bool(true), Block(n));
229 return std::move(realize);
230 }
231
232 Array<BufferRegion> RemoveRedundantBufferRegion(const Array<BufferRegion>& region) const {
233 Array<BufferRegion> result;
234 for (const BufferRegion& buffer_region : region) {
235 if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) {
236 result.push_back(buffer_region);
237 }
238 }
239 return result;
240 }
241
242 /*! \brief The map from stmt to the buffers to be allocated under it. */
243 std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
244 /*! \brief The buffer already allocated during recursive visiting. */
245 Map<Var, Buffer> buffer_data_to_buffer_;
246 /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */
247 std::unordered_set<const VarNode*> unmanaged_allocations_;
248};
249
250PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
251 // Only apply this pass to TIR that is not from TE schedules
252 if (!IsFromLegacyTESchedule(func)) {
253 auto fptr = func.CopyOnWrite();
254 BufferAllocationLocator locator(func);
255 fptr->body = locator(fptr->body);
256 return func;
257 } else {
258 return func;
259 }
260}
261
262namespace transform {
263
264Pass PlanAndUpdateBufferAllocationLocation() {
265 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
266 return PlanAndUpdateBufferAllocationLocation(std::move(f));
267 };
268 return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {});
269}
270
271TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation")
272 .set_body_typed(PlanAndUpdateBufferAllocationLocation);
273
274} // namespace transform
275
276} // namespace tir
277} // namespace tvm
278