1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file lower_match_buffer.cc
22 * \brief The pass for lowering match_buffer.
23 */
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/tir/function.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31#include "../ir/functor_common.h"
32#include "ir_utils.h"
33
34namespace tvm {
35namespace tir {
36class MatchBufferLower : public StmtExprMutator {
37 public:
38 explicit MatchBufferLower(const PrimFunc& func) {
39 for (const Var& param : func->params) {
40 // Mark input var as const variable.
41 if (!param.dtype().is_handle()) var_map_.Set(param, param);
42 }
43 }
44
45 private:
46 Stmt VisitStmt_(const BlockNode* op) final {
47 for (const MatchBufferRegion& match_buffer : op->match_buffers) {
48 CheckAndUpdateVarMap(match_buffer);
49 }
50
51 Stmt stmt = StmtExprMutator ::VisitStmt_(op);
52 op = stmt.as<BlockNode>();
53 ICHECK(op != nullptr);
54 Array<BufferRegion> reads =
55 op->reads.Map(std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1));
56 Array<BufferRegion> writes = op->writes.Map(
57 std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1));
58
59 if (reads.same_as(op->reads) && writes.same_as(op->writes) && op->match_buffers.empty()) {
60 return stmt;
61 } else {
62 auto n = CopyOnWrite(op);
63 n->match_buffers = {};
64 n->reads = std::move(reads);
65 n->writes = std::move(writes);
66 return Stmt(n);
67 }
68 }
69
70 Stmt VisitStmt_(const ForNode* op) final {
71 analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
72 return StmtExprMutator::VisitStmt_(op);
73 }
74
75 PrimExpr VisitExpr_(const VarNode* op) final {
76 Var v = GetRef<Var>(op);
77 auto it = var_map_.find(v);
78 if (it != var_map_.end()) {
79 return (*it).second;
80 } else {
81 return std::move(v);
82 }
83 }
84
85 Stmt VisitStmt_(const BufferStoreNode* op) final {
86 Stmt stmt = StmtExprMutator::VisitStmt_(op);
87 op = stmt.as<BufferStoreNode>();
88 ICHECK(op != nullptr);
89
90 auto it = match_buffers_.find(op->buffer);
91 if (it == match_buffers_.end()) {
92 return stmt;
93 } else {
94 const Buffer& buffer = (*it).first;
95 const BufferRegion& source = (*it).second;
96
97 auto n = CopyOnWrite(op);
98 n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices);
99 n->buffer = source->buffer;
100 return Stmt(n);
101 }
102 }
103
104 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
105 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
106 op = expr.as<BufferLoadNode>();
107 ICHECK(op != nullptr);
108
109 auto it = match_buffers_.find(op->buffer);
110 if (it == match_buffers_.end()) {
111 return expr;
112 } else {
113 const Buffer& buffer = (*it).first;
114 const BufferRegion& source = (*it).second;
115 Array<PrimExpr> indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices);
116 return BufferLoad(source->buffer, indices);
117 }
118 }
119
120 PrimExpr VisitExpr_(const LoadNode* op) final {
121 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
122 CHECK(var_map_.find(op->buffer_var) == var_map_.end())
123 << "Load from buffer created by match_buffer is not allowed, but got: " << expr;
124 return expr;
125 }
126
127 Stmt VisitStmt_(const StoreNode* op) final {
128 Stmt stmt = StmtExprMutator::VisitStmt_(op);
129 CHECK(var_map_.find(op->buffer_var) == var_map_.end())
130 << "Store from buffer created by match_buffer is not allowed, but got: " << stmt;
131 return stmt;
132 }
133
134 BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) {
135 const Buffer& buffer = buffer_region->buffer;
136 auto it = match_buffers_.find(buffer);
137 if (it == match_buffers_.end()) {
138 return buffer_region;
139 } else {
140 const BufferRegion& source = (*it).second;
141 Region region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region);
142 return BufferRegion(source->buffer, std::move(region));
143 }
144 }
145
146 private:
147 void CheckAndUpdateVarMap(const MatchBufferRegion& match_buffer) {
148 // Step.1. Check
149 const Buffer& buffer = match_buffer->buffer;
150 const BufferRegion& source = VisitBufferRegion(match_buffer->source);
151 const Buffer& source_buffer = source->buffer;
152
153 // Step.1.1. Check scope & dtype
154 ICHECK_EQ(buffer.scope(), source_buffer.scope())
155 << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << "vs."
156 << source_buffer.scope();
157 ICHECK_EQ(buffer->dtype, source_buffer->dtype)
158 << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << "vs."
159 << source_buffer->dtype;
160
161 // Step.1.2. Check data alignment
162 if (source_buffer->data_alignment % buffer->data_alignment != 0) {
163 LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
164 << " required_alignment=" << buffer->data_alignment
165 << ", provided_alignment=" << source_buffer->data_alignment;
166 }
167 if (is_zero(buffer->elem_offset)) {
168 ICHECK(is_zero(source_buffer->elem_offset))
169 << "Trying to bind a Buffer with offset into one without offset "
170 << " required elem_offset=" << buffer->elem_offset
171 << ", provided elem_offset=" << source_buffer->elem_offset;
172 }
173
174 // Step.2. Update
175 match_buffers_.Set(buffer, source);
176 // Step.2.1. Update buffer data
177 Bind(buffer->data, source_buffer->data, buffer->name + ".data");
178
179 // Step.2.2. Update element offset
180 // We use the ElemOffset method to avoid duplicating the index calculation.
181 {
182 Array<PrimExpr> indices;
183 indices.reserve(source->region.size());
184 for (const Range& range : source->region) {
185 indices.push_back(range->min);
186 }
187
188 Array<PrimExpr> buffer_start_indices = source_buffer->ElemOffset(indices);
189 if (buffer_start_indices.size() == 1) {
190 Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset");
191 CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0))
192 << "The source elem_offset " << buffer_start_indices[0]
193 << " does not satisfy the offset_factor " << buffer->offset_factor << ".";
194 } else {
195 // Non-zero elem_offset is ill-defined for non-flat memory.
196 // If needed in the future, will require `Array<PrimExpr>
197 // elem_offsets`, with one offset for each flattened index.
198 Bind(buffer->elem_offset, make_const(buffer->elem_offset.dtype(), 0));
199 }
200 }
201
202 // Step 2.3. Check and update strides
203 // Check if target buffer strides are defined
204 ICHECK(source->region.size() >= buffer->shape.size());
205 int offset = source->region.size() - buffer->shape.size();
206 if (!buffer->strides.empty()) {
207 ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
208 if (source_buffer->strides.empty()) {
209 PrimExpr stride = make_const(buffer->strides.back().dtype(), 1);
210 for (size_t i = buffer->shape.size(); i > 0; --i) {
211 const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
212 Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
213 stride *= shape;
214 }
215 } else {
216 ICHECK_EQ(buffer->shape.size() + offset, source_buffer->strides.size());
217 for (size_t i = buffer->shape.size(); i > 0; --i) {
218 const PrimExpr& stride = source_buffer->strides[i - 1 + offset];
219 Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
220 }
221 }
222 }
223
224 // Step 2.4. Check and update shape
225 for (size_t i = 0; i < buffer->shape.size(); ++i) {
226 const Range& range = source->region[i + offset];
227 Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i));
228 }
229 }
230
231 void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") {
232 CHECK_EQ(arg.dtype(), value.dtype())
233 << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype;
234 // Handle recursive case
235 value = Substitute(std::move(value), var_map_);
236 if (arg->IsInstance<VarNode>()) {
237 Var v = Downcast<Var>(arg);
238 auto it = var_map_.find(v);
239 if (it == var_map_.end()) {
240 var_map_.Set(v, value);
241 analyzer_.Bind(v, value);
242 } else {
243 AssertBinding((*it).second, value, arg_name);
244 }
245 } else {
246 AssertBinding(arg, value, arg_name);
247 }
248 }
249
250 void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs,
251 const std::string& arg_name = "argument") {
252 CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name
253 << " unmet: " << lhs << "==" << rhs << ".";
254 }
255
256 private:
257 /*! \brief Buffer region mapping. */
258 Map<Buffer, BufferRegion> match_buffers_;
259 /*! \brief Var mapping for buffer signature (data, strides, element_offset, etc.) */
260 Map<Var, PrimExpr> var_map_;
261 /*! \brief The analyzer */
262 arith::Analyzer analyzer_;
263};
264
265PrimFunc LowerMatchBuffer(PrimFunc func) {
266 auto fptr = func.CopyOnWrite();
267 fptr->body = MatchBufferLower(func)(std::move(fptr->body));
268 return func;
269}
270
271namespace transform {
272
273Pass LowerMatchBuffer() {
274 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
275 return LowerMatchBuffer(std::move(f));
276 };
277 return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {});
278}
279
280TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer);
281
282} // namespace transform
283
284} // namespace tir
285} // namespace tvm
286