1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file arg_binder.cc |
22 | * \brief Helper utility to match and bind arguments. |
23 | */ |
24 | #include "arg_binder.h" |
25 | |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include "ir_utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace tir { |
35 | |
36 | void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg_name, |
37 | std::vector<Stmt>* asserts) { |
38 | PrimExpr scond = ana->Simplify(cond); |
39 | if (is_zero(scond)) { |
40 | LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " |
41 | << " on argument " << arg_name; |
42 | } |
43 | if (!is_one(scond)) { |
44 | std::ostringstream os; |
45 | os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond; |
46 | asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()), Evaluate(0))); |
47 | } |
48 | } |
49 | |
50 | bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, |
51 | bool with_lets) { |
52 | ICHECK_EQ(arg.dtype(), value.dtype()); |
53 | if (const VarNode* v = arg.as<VarNode>()) { |
54 | auto it = def_map_->find(v); |
55 | if (it == def_map_->end()) { |
56 | Var v_arg = Downcast<Var>(arg); |
57 | defs_.emplace_back(v_arg); |
58 | if (with_lets) { |
59 | (*def_map_)[v] = arg; |
60 | init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); |
61 | } else { |
62 | (*def_map_)[v] = value; |
63 | } |
64 | return true; |
65 | } else { |
66 | BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_); |
67 | } |
68 | } else { |
69 | BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_); |
70 | } |
71 | return false; |
72 | } |
73 | |
74 | void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, |
75 | bool with_let) { |
76 | Bind_(arg, value, arg_name, with_let); |
77 | } |
78 | |
79 | void ArgBinder::BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& value, |
80 | const std::string& arg_name) { |
81 | ICHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch" ; |
82 | for (size_t i = 0; i < arg.size(); ++i) { |
83 | std::ostringstream os; |
84 | os << arg_name << "[" << i << "]" ; |
85 | this->Bind(arg[i], value[i], os.str()); |
86 | } |
87 | } |
88 | |
89 | void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, |
90 | bool fuzzy_match) { |
91 | ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch" ; |
92 | ICHECK_EQ(arg->dtype, value->dtype) |
93 | << "Argument " << arg_name << " Buffer bind data type mismatch" ; |
94 | if (value->data_alignment % arg->data_alignment != 0) { |
95 | LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " |
96 | << " required_alignment=" << arg->data_alignment |
97 | << ", provided_alignment=" << value->data_alignment; |
98 | } |
99 | |
100 | if (value->elem_offset.defined()) { |
101 | // bind pointer and offset. |
102 | if (is_zero(arg->elem_offset)) { |
103 | ICHECK(is_zero(value->elem_offset)) |
104 | << "Trying to bind a Buffer with offset into one without offset " |
105 | << " required elem_offset=" << arg->elem_offset |
106 | << ", provided elem_offset=" << value->elem_offset; |
107 | } |
108 | |
109 | this->Bind(arg->data, value->data, arg_name + ".data" ); |
110 | if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset" , false)) { |
111 | if (arg->offset_factor > 1) { |
112 | PrimExpr offset = value->elem_offset; |
113 | PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); |
114 | PrimExpr zero = make_zero(offset.dtype()); |
115 | BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset" , |
116 | &asserts_); |
117 | } |
118 | } |
119 | } |
120 | |
121 | if (arg->shape.size() < value->shape.size()) { |
122 | ICHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch" ; |
123 | size_t diff = value->shape.size() - arg->shape.size(); |
124 | for (size_t i = 0; i < diff; ++i) { |
125 | ICHECK(is_one(analyzer_.Simplify(value->shape[i]))) |
126 | << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape; |
127 | } |
128 | for (size_t i = 0; i < arg->shape.size(); ++i) { |
129 | std::ostringstream os; |
130 | os << arg_name << ".shape[" << i << "]" ; |
131 | this->Bind(arg->shape[i], value->shape[i + diff], os.str()); |
132 | } |
133 | if (value->strides.size() != 0) { |
134 | ICHECK_EQ(arg->strides.size(), arg->shape.size()); |
135 | ICHECK_EQ(value->strides.size(), value->shape.size()); |
136 | for (size_t i = 0; i < arg->strides.size(); ++i) { |
137 | std::ostringstream os; |
138 | os << arg_name << ".strides[" << i << "]" ; |
139 | this->Bind(arg->strides[i], value->strides[i + diff], os.str()); |
140 | } |
141 | } |
142 | } else { |
143 | this->BindArray(arg->shape, value->shape, arg_name + ".shape" ); |
144 | this->BindArray(arg->strides, value->strides, arg_name + ".strides" ); |
145 | } |
146 | } |
147 | |
148 | inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kind) { |
149 | return TVMStructGet(t, arr, 0, kind); |
150 | } |
151 | |
152 | void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, |
153 | const PrimExpr& device_id, const Var& handle, |
154 | const std::string& arg_name) { |
155 | const DataType tvm_shape_type = DataType::ShapeIndex(); |
156 | const DataType tvm_ndim_type = DataType::Int(32); |
157 | const Stmt nop = Evaluate(0); |
158 | // dimension checks |
159 | PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); |
160 | |
161 | // Helper functions for shape/stride name formatting |
162 | auto shape_handle_name = [&]() { return arg_name + ".shape" ; }; |
163 | auto stride_handle_name = [&]() { return arg_name + ".strides" ; }; |
164 | auto array_element_name = [&](const std::string& arr_name, size_t k) { |
165 | std::stringstream ss; |
166 | ss << arr_name << '[' << k << ']'; |
167 | return ss.str(); |
168 | }; |
169 | auto shape_element_name = [&](size_t k) { return array_element_name(shape_handle_name(), k); }; |
170 | auto stride_element_name = [&](size_t k) { return array_element_name(stride_handle_name(), k); }; |
171 | |
172 | PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size())); |
173 | std::ostringstream ndim_err_msg; |
174 | ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); |
175 | auto msg = tvm::tir::StringImm(ndim_err_msg.str()); |
176 | asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); |
177 | // type checks |
178 | std::ostringstream type_err_msg; |
179 | type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype; |
180 | PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == |
181 | IntImm(DataType::UInt(8), buffer->dtype.code()) && |
182 | TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == |
183 | IntImm(DataType::UInt(8), buffer->dtype.bits()) && |
184 | TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == |
185 | IntImm(DataType::UInt(16), buffer->dtype.lanes())); |
186 | if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) || |
187 | buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::UInt(16))) { |
188 | auto type_msg = tvm::tir::StringImm(type_err_msg.str()); |
189 | asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); |
190 | asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); |
191 | } |
192 | // data field |
193 | if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), |
194 | arg_name + ".data" , true)) { |
195 | Var vptr(buffer->data); |
196 | def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); |
197 | // mark alignment of external bufs |
198 | init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, |
199 | IntImm(DataType::Int(32), buffer->data_alignment), nop)); |
200 | } |
201 | |
202 | // shape field |
203 | Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type, |
204 | shape_handle_name()); |
205 | Var v_shape(shape_handle_name(), DataType::Handle()); |
206 | def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); |
207 | init_nest_.emplace_back( |
208 | LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); |
209 | for (size_t k = 0; k < buffer->shape.size(); ++k) { |
210 | if (buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) || |
211 | buffer->dtype == DataType::Int(1)) { |
212 | break; |
213 | } |
214 | Bind_(buffer->shape[k], |
215 | cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), |
216 | shape_element_name(k), true); |
217 | } |
218 | // strides field |
219 | Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, |
220 | tvm_shape_type, arg_name + ".strides" ); |
221 | def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); |
222 | init_nest_.emplace_back(LetStmt( |
223 | buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); |
224 | PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); |
225 | if (buffer->strides.size() == 0) { |
226 | // Assert the buffer is compact |
227 | DataType stype = buffer->DefaultIndexType(); |
228 | PrimExpr expect_stride = make_const(stype, 1); |
229 | Array<PrimExpr> conds; |
230 | for (size_t i = buffer->shape.size(); i != 0; --i) { |
231 | size_t k = i - 1; |
232 | PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); |
233 | conds.push_back(expect_stride == svalue); |
234 | expect_stride = expect_stride * buffer->shape[k]; |
235 | } |
236 | std::ostringstream stride_err_msg; |
237 | stride_err_msg << stride_handle_name() << ": expected to be compact array" ; |
238 | if (conds.size() != 0) { |
239 | auto stride_msg = tvm::tir::StringImm(stride_err_msg.str()); |
240 | Stmt check = AssertStmt( |
241 | foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); }, |
242 | const_true(1), conds), |
243 | stride_msg, Evaluate(0)); |
244 | check = IfThenElse(Not(v_strides_is_null), check, Stmt()); |
245 | asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); |
246 | } |
247 | } else if (buffer->buffer_type == kAutoBroadcast) { |
248 | DataType stype = buffer->DefaultIndexType(); |
249 | PrimExpr stride = make_const(stype, 1); |
250 | for (size_t i = buffer->shape.size(); i != 0; --i) { |
251 | size_t k = i - 1; |
252 | PrimExpr value = |
253 | cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); |
254 | value = tvm::if_then_else(v_strides_is_null, stride, value); |
255 | value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); |
256 | Bind_(buffer->strides[k], value, stride_element_name(k), true); |
257 | stride = analyzer_.Simplify(stride * buffer->shape[k]); |
258 | } |
259 | } else { |
260 | PrimExpr stride_from_shape = 1; |
261 | |
262 | for (int k = buffer->strides.size() - 1; k >= 0; k--) { |
263 | PrimExpr explicit_stride = |
264 | cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); |
265 | |
266 | Bind_(buffer->strides[k], |
267 | tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), |
268 | stride_element_name(k), true); |
269 | |
270 | stride_from_shape *= |
271 | cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); |
272 | } |
273 | } |
274 | // Byte_offset field. |
275 | int data_bytes = GetVectorBytes(buffer->dtype); |
276 | |
277 | if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) { |
278 | Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), |
279 | TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), |
280 | arg_name + ".byte_offset" , true); |
281 | } else { |
282 | if (Bind_(buffer->elem_offset, |
283 | cast(buffer->elem_offset.dtype(), |
284 | (TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset) / |
285 | make_const(DataType::UInt(64), data_bytes))), |
286 | arg_name + ".elem_offset" , true)) { |
287 | if (buffer->offset_factor > 1) { |
288 | PrimExpr offset = buffer->elem_offset; |
289 | PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); |
290 | PrimExpr zero = make_zero(offset.dtype()); |
291 | BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset" , |
292 | &asserts_); |
293 | } |
294 | } |
295 | } |
296 | // device info. |
297 | Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), |
298 | arg_name + ".device_type" , true); |
299 | Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), |
300 | arg_name + ".device_id" , true); |
301 | } |
302 | |
303 | } // namespace tir |
304 | } // namespace tvm |
305 | |