1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file arg_binder.cc
22 * \brief Helper utility to match and bind arguments.
23 */
24#include "arg_binder.h"
25
26#include <tvm/runtime/device_api.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30
31#include "ir_utils.h"
32
33namespace tvm {
34namespace tir {
35
36void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg_name,
37 std::vector<Stmt>* asserts) {
38 PrimExpr scond = ana->Simplify(cond);
39 if (is_zero(scond)) {
40 LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", "
41 << " on argument " << arg_name;
42 }
43 if (!is_one(scond)) {
44 std::ostringstream os;
45 os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond;
46 asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()), Evaluate(0)));
47 }
48}
49
50bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
51 bool with_lets) {
52 ICHECK_EQ(arg.dtype(), value.dtype());
53 if (const VarNode* v = arg.as<VarNode>()) {
54 auto it = def_map_->find(v);
55 if (it == def_map_->end()) {
56 Var v_arg = Downcast<Var>(arg);
57 defs_.emplace_back(v_arg);
58 if (with_lets) {
59 (*def_map_)[v] = arg;
60 init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0)));
61 } else {
62 (*def_map_)[v] = value;
63 }
64 return true;
65 } else {
66 BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_);
67 }
68 } else {
69 BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_);
70 }
71 return false;
72}
73
74void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
75 bool with_let) {
76 Bind_(arg, value, arg_name, with_let);
77}
78
79void ArgBinder::BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& value,
80 const std::string& arg_name) {
81 ICHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch";
82 for (size_t i = 0; i < arg.size(); ++i) {
83 std::ostringstream os;
84 os << arg_name << "[" << i << "]";
85 this->Bind(arg[i], value[i], os.str());
86 }
87}
88
89void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name,
90 bool fuzzy_match) {
91 ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch";
92 ICHECK_EQ(arg->dtype, value->dtype)
93 << "Argument " << arg_name << " Buffer bind data type mismatch";
94 if (value->data_alignment % arg->data_alignment != 0) {
95 LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
96 << " required_alignment=" << arg->data_alignment
97 << ", provided_alignment=" << value->data_alignment;
98 }
99
100 if (value->elem_offset.defined()) {
101 // bind pointer and offset.
102 if (is_zero(arg->elem_offset)) {
103 ICHECK(is_zero(value->elem_offset))
104 << "Trying to bind a Buffer with offset into one without offset "
105 << " required elem_offset=" << arg->elem_offset
106 << ", provided elem_offset=" << value->elem_offset;
107 }
108
109 this->Bind(arg->data, value->data, arg_name + ".data");
110 if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
111 if (arg->offset_factor > 1) {
112 PrimExpr offset = value->elem_offset;
113 PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
114 PrimExpr zero = make_zero(offset.dtype());
115 BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
116 &asserts_);
117 }
118 }
119 }
120
121 if (arg->shape.size() < value->shape.size()) {
122 ICHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
123 size_t diff = value->shape.size() - arg->shape.size();
124 for (size_t i = 0; i < diff; ++i) {
125 ICHECK(is_one(analyzer_.Simplify(value->shape[i])))
126 << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape;
127 }
128 for (size_t i = 0; i < arg->shape.size(); ++i) {
129 std::ostringstream os;
130 os << arg_name << ".shape[" << i << "]";
131 this->Bind(arg->shape[i], value->shape[i + diff], os.str());
132 }
133 if (value->strides.size() != 0) {
134 ICHECK_EQ(arg->strides.size(), arg->shape.size());
135 ICHECK_EQ(value->strides.size(), value->shape.size());
136 for (size_t i = 0; i < arg->strides.size(); ++i) {
137 std::ostringstream os;
138 os << arg_name << ".strides[" << i << "]";
139 this->Bind(arg->strides[i], value->strides[i + diff], os.str());
140 }
141 }
142 } else {
143 this->BindArray(arg->shape, value->shape, arg_name + ".shape");
144 this->BindArray(arg->strides, value->strides, arg_name + ".strides");
145 }
146}
147
148inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kind) {
149 return TVMStructGet(t, arr, 0, kind);
150}
151
152void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
153 const PrimExpr& device_id, const Var& handle,
154 const std::string& arg_name) {
155 const DataType tvm_shape_type = DataType::ShapeIndex();
156 const DataType tvm_ndim_type = DataType::Int(32);
157 const Stmt nop = Evaluate(0);
158 // dimension checks
159 PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
160
161 // Helper functions for shape/stride name formatting
162 auto shape_handle_name = [&]() { return arg_name + ".shape"; };
163 auto stride_handle_name = [&]() { return arg_name + ".strides"; };
164 auto array_element_name = [&](const std::string& arr_name, size_t k) {
165 std::stringstream ss;
166 ss << arr_name << '[' << k << ']';
167 return ss.str();
168 };
169 auto shape_element_name = [&](size_t k) { return array_element_name(shape_handle_name(), k); };
170 auto stride_element_name = [&](size_t k) { return array_element_name(stride_handle_name(), k); };
171
172 PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
173 std::ostringstream ndim_err_msg;
174 ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size();
175 auto msg = tvm::tir::StringImm(ndim_err_msg.str());
176 asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
177 // type checks
178 std::ostringstream type_err_msg;
179 type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
180 PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) ==
181 IntImm(DataType::UInt(8), buffer->dtype.code()) &&
182 TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) ==
183 IntImm(DataType::UInt(8), buffer->dtype.bits()) &&
184 TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
185 IntImm(DataType::UInt(16), buffer->dtype.lanes()));
186 if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) ||
187 buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::UInt(16))) {
188 auto type_msg = tvm::tir::StringImm(type_err_msg.str());
189 asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
190 asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
191 }
192 // data field
193 if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
194 arg_name + ".data", true)) {
195 Var vptr(buffer->data);
196 def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
197 // mark alignment of external bufs
198 init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment,
199 IntImm(DataType::Int(32), buffer->data_alignment), nop));
200 }
201
202 // shape field
203 Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type,
204 shape_handle_name());
205 Var v_shape(shape_handle_name(), DataType::Handle());
206 def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
207 init_nest_.emplace_back(
208 LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop));
209 for (size_t k = 0; k < buffer->shape.size(); ++k) {
210 if (buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) ||
211 buffer->dtype == DataType::Int(1)) {
212 break;
213 }
214 Bind_(buffer->shape[k],
215 cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})),
216 shape_element_name(k), true);
217 }
218 // strides field
219 Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
220 tvm_shape_type, arg_name + ".strides");
221 def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type));
222 init_nest_.emplace_back(LetStmt(
223 buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
224 PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
225 if (buffer->strides.size() == 0) {
226 // Assert the buffer is compact
227 DataType stype = buffer->DefaultIndexType();
228 PrimExpr expect_stride = make_const(stype, 1);
229 Array<PrimExpr> conds;
230 for (size_t i = buffer->shape.size(); i != 0; --i) {
231 size_t k = i - 1;
232 PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
233 conds.push_back(expect_stride == svalue);
234 expect_stride = expect_stride * buffer->shape[k];
235 }
236 std::ostringstream stride_err_msg;
237 stride_err_msg << stride_handle_name() << ": expected to be compact array";
238 if (conds.size() != 0) {
239 auto stride_msg = tvm::tir::StringImm(stride_err_msg.str());
240 Stmt check = AssertStmt(
241 foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
242 const_true(1), conds),
243 stride_msg, Evaluate(0));
244 check = IfThenElse(Not(v_strides_is_null), check, Stmt());
245 asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
246 }
247 } else if (buffer->buffer_type == kAutoBroadcast) {
248 DataType stype = buffer->DefaultIndexType();
249 PrimExpr stride = make_const(stype, 1);
250 for (size_t i = buffer->shape.size(); i != 0; --i) {
251 size_t k = i - 1;
252 PrimExpr value =
253 cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
254 value = tvm::if_then_else(v_strides_is_null, stride, value);
255 value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
256 Bind_(buffer->strides[k], value, stride_element_name(k), true);
257 stride = analyzer_.Simplify(stride * buffer->shape[k]);
258 }
259 } else {
260 PrimExpr stride_from_shape = 1;
261
262 for (int k = buffer->strides.size() - 1; k >= 0; k--) {
263 PrimExpr explicit_stride =
264 cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
265
266 Bind_(buffer->strides[k],
267 tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride),
268 stride_element_name(k), true);
269
270 stride_from_shape *=
271 cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
272 }
273 }
274 // Byte_offset field.
275 int data_bytes = GetVectorBytes(buffer->dtype);
276
277 if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
278 Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
279 TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
280 arg_name + ".byte_offset", true);
281 } else {
282 if (Bind_(buffer->elem_offset,
283 cast(buffer->elem_offset.dtype(),
284 (TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset) /
285 make_const(DataType::UInt(64), data_bytes))),
286 arg_name + ".elem_offset", true)) {
287 if (buffer->offset_factor > 1) {
288 PrimExpr offset = buffer->elem_offset;
289 PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
290 PrimExpr zero = make_zero(offset.dtype());
291 BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
292 &asserts_);
293 }
294 }
295 }
296 // device info.
297 Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType),
298 arg_name + ".device_type", true);
299 Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
300 arg_name + ".device_id", true);
301}
302
303} // namespace tir
304} // namespace tvm
305