1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /* |
20 | * \file tvm/runtime/data_type.h |
21 | * \brief Primitive runtime data type. |
22 | */ |
23 | // Acknowledgement: DataType structure design originates from Halide. |
24 | #ifndef TVM_RUNTIME_DATA_TYPE_H_ |
25 | #define TVM_RUNTIME_DATA_TYPE_H_ |
26 | |
27 | #include <tvm/runtime/c_runtime_api.h> |
28 | #include <tvm/runtime/logging.h> |
29 | |
30 | #include <string> |
31 | #include <type_traits> |
32 | |
33 | namespace tvm { |
34 | namespace runtime { |
35 | /*! |
36 | * \brief Runtime primitive data type. |
37 | * |
38 | * This class is a thin wrapper of DLDataType. |
39 | * We also make use of DataType in compiler to store quick hint |
40 | */ |
41 | class DataType { |
42 | public: |
43 | /*! |
44 | * \brief Type code for the DataType. |
45 | * |
46 | * DLPack consistency: |
47 | * 1) kInt is consistent with kDLInt |
48 | * 2) kUInt is consistent with kDLUInt |
49 | * 3) kFloat is consistent with kDLFloat |
50 | */ |
51 | enum TypeCode { |
52 | kInt = kDLInt, |
53 | kUInt = kDLUInt, |
54 | kFloat = kDLFloat, |
55 | kHandle = TVMArgTypeCode::kTVMOpaqueHandle, |
56 | kBFloat = kDLBfloat, |
57 | kCustomBegin = 129 |
58 | }; |
59 | /*! \brief default constructor */ |
60 | DataType() { data_ = DataType::Void(); } |
61 | /*! |
62 | * \brief Constructor |
63 | * \param dtype The DLDataType |
64 | */ |
65 | explicit DataType(DLDataType dtype) : data_(dtype) {} |
66 | /*! |
67 | * \brief Constructor |
68 | * \param code The type code. |
69 | * \param bits The number of bits in the type. |
70 | * \param lanes The number of lanes. |
71 | */ |
72 | DataType(int code, int bits, int lanes) { |
73 | data_.code = static_cast<uint8_t>(code); |
74 | data_.bits = static_cast<uint8_t>(bits); |
75 | data_.lanes = static_cast<uint16_t>(lanes); |
76 | if (code == kBFloat) { |
77 | ICHECK_EQ(bits, 16); |
78 | } |
79 | } |
80 | /*! \return The type code. */ |
81 | int code() const { return static_cast<int>(data_.code); } |
82 | /*! \return number of bits in the data. */ |
83 | int bits() const { return static_cast<int>(data_.bits); } |
84 | /*! \return number of bytes to store each scalar. */ |
85 | int bytes() const { return (bits() + 7) / 8; } |
86 | /*! \return number of lanes in the data. */ |
87 | int lanes() const { return static_cast<int>(data_.lanes); } |
88 | /*! \return whether type is a scalar type. */ |
89 | bool is_scalar() const { return lanes() == 1; } |
90 | /*! \return whether type is a scalar type. */ |
91 | bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } |
92 | /*! \return whether type is a float type. */ |
93 | bool is_float() const { return code() == DataType::kFloat; } |
94 | /*! \return whether type is a float16 type. */ |
95 | bool is_float16() const { return is_float() && bits() == 16; } |
96 | /*! \return whether type is a bfloat16 type. */ |
97 | bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } |
98 | /*! \return whether type is an int type. */ |
99 | bool is_int() const { return code() == DataType::kInt; } |
100 | /*! \return whether type is an uint type. */ |
101 | bool is_uint() const { return code() == DataType::kUInt; } |
102 | /*! \return whether type is a handle type. */ |
103 | bool is_handle() const { return code() == DataType::kHandle && !is_void(); } |
104 | /*! \return whether type is a vector type. */ |
105 | bool is_vector() const { return lanes() > 1; } |
106 | /*! \return whether type is a bool vector type. */ |
107 | bool is_vector_bool() const { return is_vector() && bits() == 1; } |
108 | /*! \return whether type is a Void type. */ |
109 | bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } |
110 | /*! |
111 | * \brief Create a new data type by change lanes to a specified value. |
112 | * \param lanes The target number of lanes. |
113 | * \return the result type. |
114 | */ |
115 | DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } |
116 | /*! |
117 | * \brief Create a new data type by change bits to a specified value. |
118 | * \param bits The target number of bits. |
119 | * \return the result type. |
120 | */ |
121 | DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); } |
122 | /*! |
123 | * \brief Get the scalar version of the type. |
124 | * \return the result type. |
125 | */ |
126 | DataType element_of() const { return with_lanes(1); } |
127 | /*! |
128 | * \brief Assignment operator. |
129 | */ |
130 | DataType& operator=(const DataType& rhs) { |
131 | if (this == &rhs) { |
132 | return *this; |
133 | } |
134 | data_ = rhs.data_; |
135 | return *this; |
136 | } |
137 | /*! |
138 | * \brief Equal comparator. |
139 | * \param other The data type to compare against. |
140 | * \return The comparison result. |
141 | */ |
142 | bool operator==(const DataType& other) const { |
143 | return data_.code == other.data_.code && data_.bits == other.data_.bits && |
144 | data_.lanes == other.data_.lanes; |
145 | } |
146 | /*! |
147 | * \brief NotEqual comparator. |
148 | * \param other The data type to compare against. |
149 | * \return The comparison result. |
150 | */ |
151 | bool operator!=(const DataType& other) const { return !operator==(other); } |
152 | /*! |
153 | * \brief Converter to DLDataType |
154 | * \return the result. |
155 | */ |
156 | operator DLDataType() const { return data_; } |
157 | |
158 | /*! |
159 | * \brief Construct an int type. |
160 | * \param bits The number of bits in the type. |
161 | * \param lanes The number of lanes. |
162 | * \return The constructed data type. |
163 | */ |
164 | static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } |
165 | /*! |
166 | * \brief Construct an uint type. |
167 | * \param bits The number of bits in the type. |
168 | * \param lanes The number of lanes |
169 | * \return The constructed data type. |
170 | */ |
171 | static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } |
172 | /*! |
173 | * \brief Construct an float type. |
174 | * \param bits The number of bits in the type. |
175 | * \param lanes The number of lanes |
176 | * \return The constructed data type. |
177 | */ |
178 | static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } |
179 | /*! |
180 | * \brief Construct an bfloat type. |
181 | * \param bits The number of bits in the type. |
182 | * \param lanes The number of lanes |
183 | * \return The constructed data type. |
184 | */ |
185 | static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } |
186 | /*! |
187 | * \brief Construct a bool type. |
188 | * \param lanes The number of lanes |
189 | * \return The constructed data type. |
190 | */ |
191 | static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); } |
192 | /*! |
193 | * \brief Construct a handle type. |
194 | * \param bits The number of bits in the type. |
195 | * \param lanes The number of lanes |
196 | * \return The constructed data type. |
197 | */ |
198 | static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } |
199 | /*! |
200 | * \brief Construct a Void type. |
201 | * \return The constructed data type. |
202 | */ |
203 | static DataType Void() { return DataType(kHandle, 0, 0); } |
204 | /*! |
205 | * \brief Get the corresponding type of TVMShapeIndex. |
206 | * \return The type of TVM shape index. |
207 | */ |
208 | static DataType ShapeIndex() { |
209 | if (std::is_signed<tvm_index_t>::value) { |
210 | return DataType::Int(sizeof(tvm_index_t) * 8); |
211 | } else { |
212 | return DataType::UInt(sizeof(tvm_index_t) * 8); |
213 | } |
214 | } |
215 | |
216 | private: |
217 | DLDataType data_; |
218 | }; |
219 | |
220 | /*! |
221 | * \brief Get the number of bytes needed in a vector. |
222 | * \param dtype The data type. |
223 | * \return Number of bytes needed. |
224 | */ |
225 | inline int GetVectorBytes(DataType dtype) { |
226 | int data_bits = dtype.bits() * dtype.lanes(); |
227 | // allow bool to exist |
228 | if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || |
229 | dtype == DataType::Int(1)) { |
230 | return 1; |
231 | } |
232 | ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes" ; |
233 | return data_bits / 8; |
234 | } |
235 | |
236 | /*! |
237 | * \brief Check whether type matches the given spec. |
238 | * \param t The type |
239 | * \param code The type code. |
240 | * \param bits The number of bits to be matched. |
241 | * \param lanes The number of lanes in the type. |
242 | */ |
243 | inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) { |
244 | return t.code == code && t.bits == bits && t.lanes == lanes; |
245 | } |
246 | /*! |
247 | * \brief Check whether two types are equal . |
248 | * \param lhs The left operand. |
249 | * \param rhs The right operand. |
250 | */ |
251 | inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { |
252 | return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; |
253 | } |
254 | |
255 | /*! |
256 | * \brief Runtime utility for getting custom type name from code |
257 | * \param type_code Custom type code |
258 | * \return Custom type name |
259 | */ |
260 | TVM_DLL std::string GetCustomTypeName(uint8_t type_code); |
261 | |
262 | /*! |
263 | * \brief Runtime utility for checking whether custom type is registered |
264 | * \param type_code Custom type code |
265 | * \return Bool representing whether type is registered |
266 | */ |
267 | TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); |
268 | |
269 | /*! |
270 | * \brief Runtime utility for parsing string of the form "custom[<typename>]" |
271 | * \param s String to parse |
272 | * \param scan pointer to parsing pointer, which is scanning across s |
273 | * \return type code of custom type parsed |
274 | */ |
275 | TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); |
276 | |
277 | /*! |
278 | * \brief Convert type code to its name |
279 | * \param type_code The type code . |
280 | * \return The name of type code. |
281 | */ |
282 | inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code); |
283 | |
284 | /*! |
285 | * \brief convert a string to TVM type. |
286 | * \param s The string to be converted. |
287 | * \return The corresponding tvm type. |
288 | */ |
289 | inline DLDataType String2DLDataType(std::string s); |
290 | |
291 | /*! |
292 | * \brief convert a TVM type to string. |
293 | * \param t The type to be converted. |
294 | * \return The corresponding tvm type in string. |
295 | */ |
296 | inline std::string DLDataType2String(DLDataType t); |
297 | |
298 | // implementation details |
299 | inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { |
300 | switch (static_cast<int>(type_code)) { |
301 | case kDLInt: |
302 | return "int" ; |
303 | case kDLUInt: |
304 | return "uint" ; |
305 | case kDLFloat: |
306 | return "float" ; |
307 | case DataType::kHandle: |
308 | return "handle" ; |
309 | case kDLBfloat: |
310 | return "bfloat" ; |
311 | default: |
312 | LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code); |
313 | } |
314 | } |
315 | |
316 | inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) |
317 | if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { |
318 | os << "bool" ; |
319 | return os; |
320 | } |
321 | if (DataType(t).is_void()) { |
322 | return os << "void" ; |
323 | } |
324 | if (t.code < DataType::kCustomBegin) { |
325 | os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code)); |
326 | } else { |
327 | os << "custom[" << GetCustomTypeName(t.code) << "]" ; |
328 | } |
329 | if (t.code == kTVMOpaqueHandle) return os; |
330 | os << static_cast<int>(t.bits); |
331 | if (t.lanes != 1) { |
332 | os << 'x' << static_cast<int>(t.lanes); |
333 | } |
334 | return os; |
335 | } |
336 | |
337 | inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) |
338 | return os << dtype.operator DLDataType(); |
339 | } |
340 | |
341 | inline std::string DLDataType2String(DLDataType t) { |
342 | if (t.bits == 0) return "" ; |
343 | std::ostringstream os; |
344 | os << t; |
345 | return os.str(); |
346 | } |
347 | |
348 | inline DLDataType String2DLDataType(std::string s) { |
349 | DLDataType t; |
350 | // handle void type |
351 | if (s.length() == 0 || s == "void" ) { |
352 | t = DataType::Void(); |
353 | return t; |
354 | } |
355 | t.bits = 32; |
356 | t.lanes = 1; |
357 | const char* scan; |
358 | if (s.substr(0, 3) == "int" ) { |
359 | t.code = kDLInt; |
360 | scan = s.c_str() + 3; |
361 | } else if (s.substr(0, 4) == "uint" ) { |
362 | t.code = kDLUInt; |
363 | scan = s.c_str() + 4; |
364 | } else if (s.substr(0, 5) == "float" ) { |
365 | t.code = kDLFloat; |
366 | scan = s.c_str() + 5; |
367 | } else if (s.substr(0, 6) == "handle" ) { |
368 | t.code = kTVMOpaqueHandle; |
369 | t.bits = 64; // handle uses 64 bit by default. |
370 | scan = s.c_str() + 6; |
371 | } else if (s == "bool" ) { |
372 | t.code = kDLUInt; |
373 | t.bits = 1; |
374 | t.lanes = 1; |
375 | return t; |
376 | } else if (s.substr(0, 6) == "bfloat" ) { |
377 | t.code = DataType::kBFloat; |
378 | scan = s.c_str() + 6; |
379 | } else if (s.substr(0, 6) == "custom" ) { |
380 | t.code = ParseCustomDatatype(s, &scan); |
381 | } else { |
382 | scan = s.c_str(); |
383 | LOG(FATAL) << "unknown type " << s; |
384 | } |
385 | char* xdelim; // emulate sscanf("%ux%u", bits, lanes) |
386 | uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10)); |
387 | if (bits != 0) t.bits = bits; |
388 | char* endpt = xdelim; |
389 | if (*xdelim == 'x') { |
390 | t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10)); |
391 | } |
392 | ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; |
393 | return t; |
394 | } |
395 | |
396 | } // namespace runtime |
397 | |
398 | using DataType = runtime::DataType; |
399 | |
400 | } // namespace tvm |
401 | |
402 | namespace std { |
403 | template <> |
404 | struct hash<tvm::DataType> { |
405 | inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; } |
406 | std::size_t operator()(tvm::DataType const& dtype) const { |
407 | int a = dtype.code(); |
408 | int b = dtype.bits(); |
409 | int c = dtype.lanes(); |
410 | int d = cantor_pairing_function(a, b); |
411 | return cantor_pairing_function(c, d); |
412 | } |
413 | }; |
414 | } // namespace std |
415 | |
416 | #endif // TVM_RUNTIME_DATA_TYPE_H_ |
417 | |