1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file flatten_buffer.cc
22 */
23
24#include <tvm/tir/analysis.h>
25#include <tvm/tir/stmt_functor.h>
26#include <tvm/tir/transform.h>
27
28#include "ir_utils.h"
29
30namespace tvm {
31namespace tir {
32
33/*!
34 * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension
35 * for the TIR not contains opaque block.
36 */
37class BufferFlattener : public StmtExprMutator {
38 public:
39 static PrimFunc Flatten(PrimFunc func) {
40 auto pass = BufferFlattener();
41 auto writer = func.CopyOnWrite();
42 writer->body = pass.VisitStmt(func->body);
43 // The buffers in func->buffer_map are deliberately left
44 // unflattened, as they are used for validation of user-provided
45 // arguments. The flattened buffers used in the updated
46 // function body alias the argument buffers.
47 return func;
48 }
49
50 private:
51 BufferFlattener() {}
52
53 Stmt VisitStmt_(const BlockNode* op) final {
54 ICHECK_EQ(op->match_buffers.size(), 0)
55 << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. "
56 << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer.";
57
58 Block block = GetRef<Block>(op);
59
60 Array<Buffer> alloc_buffers = op->alloc_buffers;
61 alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); });
62 if (!alloc_buffers.same_as(op->alloc_buffers)) {
63 block.CopyOnWrite()->alloc_buffers = alloc_buffers;
64 }
65
66 Array<BufferRegion> reads = op->reads;
67 reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
68 if (!reads.same_as(op->reads)) {
69 block.CopyOnWrite()->reads = reads;
70 }
71
72 Array<BufferRegion> writes = op->writes;
73 writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
74 if (!writes.same_as(op->writes)) {
75 block.CopyOnWrite()->writes = writes;
76 }
77
78 return StmtExprMutator::VisitStmt_(block.get());
79 }
80
81 Stmt VisitStmt_(const AllocateNode* op) final {
82 Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
83 // TODO(Lunderberg): Move the handling of boolean into a
84 // dedicated pass.
85 if (alloc->dtype == DataType::Bool()) {
86 auto writer = alloc.CopyOnWrite();
87 writer->dtype = DataType::Int(8);
88 }
89
90 if (alloc->extents.size() == 1) {
91 // No flattening required for buffers that are already flat
92
93 // TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it
94 // out in the current implementation as not all lowering passes
95 // support DeclBuffer.
96 if (auto* decl_buffer = alloc->body.as<DeclBufferNode>()) {
97 alloc.CopyOnWrite()->body = std::move(decl_buffer->body);
98 }
99
100 return std::move(alloc);
101 }
102
103 if (auto* decl_buffer = alloc->body.as<DeclBufferNode>();
104 decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) {
105 // N-d buffer, use the DeclBuffer inside to determine how it
106 // should be flattened.
107 auto& buffer = decl_buffer->buffer;
108 bool matching_buffer = [&]() {
109 if (alloc->dtype != buffer->dtype) {
110 return false;
111 }
112 if (alloc->extents.size() != buffer->shape.size()) {
113 return false;
114 }
115 ExprDeepEqual expr_equal;
116 for (size_t i = 0; i < alloc->extents.size(); i++) {
117 if (!expr_equal(alloc->extents[i], buffer->shape[i])) {
118 return false;
119 }
120 }
121 return true;
122 }();
123
124 if (matching_buffer) {
125 Buffer flattened = GetFlattenedBuffer(buffer);
126
127 auto n = alloc.CopyOnWrite();
128 // TODO(rfc-70): Update the DeclBuffer node instead of
129 // stripping it out. Stripping it out in the current
130 // implementation as not all lowering passes support
131 // DeclBuffer.
132 //
133 // n->body = DeclBuffer(flattened, std::move(decl_buffer->body));
134 n->body = std::move(decl_buffer->body);
135 n->extents = flattened->shape;
136 return std::move(alloc);
137 } else {
138 ICHECK(decl_buffer->buffer->axis_separators.empty())
139 << "DeclBuffer node doesn't match Allocate extents, but also shouldn't be "
140 "flattened to 1-d physical memory";
141 }
142 }
143
144 // Fallback, this is an allocation without a matching DeclBuffer
145 PrimExpr flat_extent = 1;
146 for (const auto& dim : alloc->extents) {
147 flat_extent *= dim;
148 }
149
150 auto n = alloc.CopyOnWrite();
151 n->extents = {flat_extent};
152 return std::move(alloc);
153 }
154
155 Buffer GetFlattenedBuffer(Buffer buf) {
156 auto it = buffer_remap_.find(buf);
157 if (it != buffer_remap_.end()) {
158 return it->second;
159 }
160 auto flattened = buf.GetFlattenedBuffer();
161
162 // TODO(Lunderberg): Move the handling of boolean into a
163 // dedicated pass.
164 if (flattened->dtype == DataType::Bool()) {
165 auto writer = flattened.CopyOnWrite();
166 writer->dtype = DataType::Int(8);
167 }
168
169 buffer_remap_[buf] = flattened;
170 return flattened;
171 }
172
173 Stmt VisitStmt_(const BufferStoreNode* op) final {
174 BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
175 bool store_returns_bool = (op->value.dtype() == DataType::Bool());
176 store = VisitBufferAccess(store);
177
178 // Handle casts from the value's dtype to the dtype of the
179 // backing array.
180 // TODO(Lunderberg): Move the handling of boolean into a
181 // dedicated pass.
182 if (store_returns_bool) {
183 ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
184 << "Expected int8 backing array for boolean tensor";
185 auto writer = store.CopyOnWrite();
186 writer->value = tvm::cast(DataType::Int(8), store->value);
187 return std::move(store);
188 }
189 return std::move(store);
190 }
191
192 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
193 bool load_returns_bool = (op->dtype == DataType::Bool());
194 BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
195 load = VisitBufferAccess(load);
196 // Handle casts from dtype of the backing array to value's dtype.
197 // TODO(Lunderberg): Move the handling of boolean into a
198 // dedicated pass.
199 if (load_returns_bool) {
200 ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
201 << "Expected int8 backing array for boolean tensor";
202 load.CopyOnWrite()->dtype = DataType::Int(8);
203 return tvm::cast(DataType::Bool(), load);
204 } else {
205 return std::move(load);
206 }
207 }
208
209 template <typename Node>
210 Node VisitBufferAccess(Node node) {
211 ICHECK(node->buffer.defined());
212 auto flattened_indices = node->buffer->ElemOffset(node->indices);
213 Buffer flattened_buffer = GetFlattenedBuffer(node->buffer);
214
215 auto writer = node.CopyOnWrite();
216 writer->buffer = flattened_buffer;
217 writer->indices = flattened_indices;
218 return node;
219 }
220
221 BufferRegion MutateBufferRegion(BufferRegion region) {
222 Buffer orig_buf = region->buffer;
223 Buffer flattened_buf = GetFlattenedBuffer(orig_buf);
224 if (flattened_buf.same_as(orig_buf)) {
225 return region;
226 }
227
228 Array<PrimExpr> min_values;
229 Array<PrimExpr> max_values;
230 for (const auto& range : region->region) {
231 min_values.push_back(range->min);
232 max_values.push_back(range->min + range->extent - 1);
233 }
234
235 Array<PrimExpr> flattened_min = orig_buf->ElemOffset(min_values);
236 Array<PrimExpr> flattened_max = orig_buf->ElemOffset(max_values);
237
238 Array<Range> flattened_ranges;
239 ICHECK_EQ(flattened_min.size(), flattened_max.size());
240 for (size_t i = 0; i < flattened_min.size(); i++) {
241 flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1));
242 }
243
244 return BufferRegion(flattened_buf, flattened_ranges);
245 }
246
247 /*! \brief Map of buffers being remapped. */
248 std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
249
250 /*! \brief The updated external buffer map. */
251 Map<Var, Buffer> updated_extern_buffer_map_;
252};
253
254PrimFunc FlattenBuffer(PrimFunc f) {
255 // Only apply this pass to TIR that is not from TE schedules
256 if (!IsFromLegacyTESchedule(f)) {
257 return BufferFlattener::Flatten(f);
258 } else {
259 return f;
260 }
261}
262
263namespace transform {
264
265Pass FlattenBuffer() {
266 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
267 return FlattenBuffer(std::move(f));
268 };
269 return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {});
270}
271
272TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer);
273} // namespace transform
274
275} // namespace tir
276} // namespace tvm
277