1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file flatten_buffer.cc |
22 | */ |
23 | |
24 | #include <tvm/tir/analysis.h> |
25 | #include <tvm/tir/stmt_functor.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | #include "ir_utils.h" |
29 | |
30 | namespace tvm { |
31 | namespace tir { |
32 | |
33 | /*! |
34 | * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension |
35 | * for the TIR not contains opaque block. |
36 | */ |
37 | class BufferFlattener : public StmtExprMutator { |
38 | public: |
39 | static PrimFunc Flatten(PrimFunc func) { |
40 | auto pass = BufferFlattener(); |
41 | auto writer = func.CopyOnWrite(); |
42 | writer->body = pass.VisitStmt(func->body); |
43 | // The buffers in func->buffer_map are deliberately left |
44 | // unflattened, as they are used for validation of user-provided |
45 | // arguments. The flattened buffers used in the updated |
46 | // function body alias the argument buffers. |
47 | return func; |
48 | } |
49 | |
50 | private: |
51 | BufferFlattener() {} |
52 | |
53 | Stmt VisitStmt_(const BlockNode* op) final { |
54 | ICHECK_EQ(op->match_buffers.size(), 0) |
55 | << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " |
56 | << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer." ; |
57 | |
58 | Block block = GetRef<Block>(op); |
59 | |
60 | Array<Buffer> alloc_buffers = op->alloc_buffers; |
61 | alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); }); |
62 | if (!alloc_buffers.same_as(op->alloc_buffers)) { |
63 | block.CopyOnWrite()->alloc_buffers = alloc_buffers; |
64 | } |
65 | |
66 | Array<BufferRegion> reads = op->reads; |
67 | reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); |
68 | if (!reads.same_as(op->reads)) { |
69 | block.CopyOnWrite()->reads = reads; |
70 | } |
71 | |
72 | Array<BufferRegion> writes = op->writes; |
73 | writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); |
74 | if (!writes.same_as(op->writes)) { |
75 | block.CopyOnWrite()->writes = writes; |
76 | } |
77 | |
78 | return StmtExprMutator::VisitStmt_(block.get()); |
79 | } |
80 | |
81 | Stmt VisitStmt_(const AllocateNode* op) final { |
82 | Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op)); |
83 | // TODO(Lunderberg): Move the handling of boolean into a |
84 | // dedicated pass. |
85 | if (alloc->dtype == DataType::Bool()) { |
86 | auto writer = alloc.CopyOnWrite(); |
87 | writer->dtype = DataType::Int(8); |
88 | } |
89 | |
90 | if (alloc->extents.size() == 1) { |
91 | // No flattening required for buffers that are already flat |
92 | |
93 | // TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it |
94 | // out in the current implementation as not all lowering passes |
95 | // support DeclBuffer. |
96 | if (auto* decl_buffer = alloc->body.as<DeclBufferNode>()) { |
97 | alloc.CopyOnWrite()->body = std::move(decl_buffer->body); |
98 | } |
99 | |
100 | return std::move(alloc); |
101 | } |
102 | |
103 | if (auto* decl_buffer = alloc->body.as<DeclBufferNode>(); |
104 | decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) { |
105 | // N-d buffer, use the DeclBuffer inside to determine how it |
106 | // should be flattened. |
107 | auto& buffer = decl_buffer->buffer; |
108 | bool matching_buffer = [&]() { |
109 | if (alloc->dtype != buffer->dtype) { |
110 | return false; |
111 | } |
112 | if (alloc->extents.size() != buffer->shape.size()) { |
113 | return false; |
114 | } |
115 | ExprDeepEqual expr_equal; |
116 | for (size_t i = 0; i < alloc->extents.size(); i++) { |
117 | if (!expr_equal(alloc->extents[i], buffer->shape[i])) { |
118 | return false; |
119 | } |
120 | } |
121 | return true; |
122 | }(); |
123 | |
124 | if (matching_buffer) { |
125 | Buffer flattened = GetFlattenedBuffer(buffer); |
126 | |
127 | auto n = alloc.CopyOnWrite(); |
128 | // TODO(rfc-70): Update the DeclBuffer node instead of |
129 | // stripping it out. Stripping it out in the current |
130 | // implementation as not all lowering passes support |
131 | // DeclBuffer. |
132 | // |
133 | // n->body = DeclBuffer(flattened, std::move(decl_buffer->body)); |
134 | n->body = std::move(decl_buffer->body); |
135 | n->extents = flattened->shape; |
136 | return std::move(alloc); |
137 | } else { |
138 | ICHECK(decl_buffer->buffer->axis_separators.empty()) |
139 | << "DeclBuffer node doesn't match Allocate extents, but also shouldn't be " |
140 | "flattened to 1-d physical memory" ; |
141 | } |
142 | } |
143 | |
144 | // Fallback, this is an allocation without a matching DeclBuffer |
145 | PrimExpr flat_extent = 1; |
146 | for (const auto& dim : alloc->extents) { |
147 | flat_extent *= dim; |
148 | } |
149 | |
150 | auto n = alloc.CopyOnWrite(); |
151 | n->extents = {flat_extent}; |
152 | return std::move(alloc); |
153 | } |
154 | |
155 | Buffer GetFlattenedBuffer(Buffer buf) { |
156 | auto it = buffer_remap_.find(buf); |
157 | if (it != buffer_remap_.end()) { |
158 | return it->second; |
159 | } |
160 | auto flattened = buf.GetFlattenedBuffer(); |
161 | |
162 | // TODO(Lunderberg): Move the handling of boolean into a |
163 | // dedicated pass. |
164 | if (flattened->dtype == DataType::Bool()) { |
165 | auto writer = flattened.CopyOnWrite(); |
166 | writer->dtype = DataType::Int(8); |
167 | } |
168 | |
169 | buffer_remap_[buf] = flattened; |
170 | return flattened; |
171 | } |
172 | |
173 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
174 | BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
175 | bool store_returns_bool = (op->value.dtype() == DataType::Bool()); |
176 | store = VisitBufferAccess(store); |
177 | |
178 | // Handle casts from the value's dtype to the dtype of the |
179 | // backing array. |
180 | // TODO(Lunderberg): Move the handling of boolean into a |
181 | // dedicated pass. |
182 | if (store_returns_bool) { |
183 | ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) |
184 | << "Expected int8 backing array for boolean tensor" ; |
185 | auto writer = store.CopyOnWrite(); |
186 | writer->value = tvm::cast(DataType::Int(8), store->value); |
187 | return std::move(store); |
188 | } |
189 | return std::move(store); |
190 | } |
191 | |
192 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
193 | bool load_returns_bool = (op->dtype == DataType::Bool()); |
194 | BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
195 | load = VisitBufferAccess(load); |
196 | // Handle casts from dtype of the backing array to value's dtype. |
197 | // TODO(Lunderberg): Move the handling of boolean into a |
198 | // dedicated pass. |
199 | if (load_returns_bool) { |
200 | ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) |
201 | << "Expected int8 backing array for boolean tensor" ; |
202 | load.CopyOnWrite()->dtype = DataType::Int(8); |
203 | return tvm::cast(DataType::Bool(), load); |
204 | } else { |
205 | return std::move(load); |
206 | } |
207 | } |
208 | |
209 | template <typename Node> |
210 | Node VisitBufferAccess(Node node) { |
211 | ICHECK(node->buffer.defined()); |
212 | auto flattened_indices = node->buffer->ElemOffset(node->indices); |
213 | Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); |
214 | |
215 | auto writer = node.CopyOnWrite(); |
216 | writer->buffer = flattened_buffer; |
217 | writer->indices = flattened_indices; |
218 | return node; |
219 | } |
220 | |
221 | BufferRegion MutateBufferRegion(BufferRegion region) { |
222 | Buffer orig_buf = region->buffer; |
223 | Buffer flattened_buf = GetFlattenedBuffer(orig_buf); |
224 | if (flattened_buf.same_as(orig_buf)) { |
225 | return region; |
226 | } |
227 | |
228 | Array<PrimExpr> min_values; |
229 | Array<PrimExpr> max_values; |
230 | for (const auto& range : region->region) { |
231 | min_values.push_back(range->min); |
232 | max_values.push_back(range->min + range->extent - 1); |
233 | } |
234 | |
235 | Array<PrimExpr> flattened_min = orig_buf->ElemOffset(min_values); |
236 | Array<PrimExpr> flattened_max = orig_buf->ElemOffset(max_values); |
237 | |
238 | Array<Range> flattened_ranges; |
239 | ICHECK_EQ(flattened_min.size(), flattened_max.size()); |
240 | for (size_t i = 0; i < flattened_min.size(); i++) { |
241 | flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); |
242 | } |
243 | |
244 | return BufferRegion(flattened_buf, flattened_ranges); |
245 | } |
246 | |
247 | /*! \brief Map of buffers being remapped. */ |
248 | std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_; |
249 | |
250 | /*! \brief The updated external buffer map. */ |
251 | Map<Var, Buffer> updated_extern_buffer_map_; |
252 | }; |
253 | |
254 | PrimFunc FlattenBuffer(PrimFunc f) { |
255 | // Only apply this pass to TIR that is not from TE schedules |
256 | if (!IsFromLegacyTESchedule(f)) { |
257 | return BufferFlattener::Flatten(f); |
258 | } else { |
259 | return f; |
260 | } |
261 | } |
262 | |
263 | namespace transform { |
264 | |
265 | Pass FlattenBuffer() { |
266 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
267 | return FlattenBuffer(std::move(f)); |
268 | }; |
269 | return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer" , {}); |
270 | } |
271 | |
272 | TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer" ).set_body_typed(FlattenBuffer); |
273 | } // namespace transform |
274 | |
275 | } // namespace tir |
276 | } // namespace tvm |
277 | |