1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*
20 * \file tvm/runtime/data_type.h
21 * \brief Primitive runtime data type.
22 */
23// Acknowledgement: DataType structure design originates from Halide.
24#ifndef TVM_RUNTIME_DATA_TYPE_H_
25#define TVM_RUNTIME_DATA_TYPE_H_
26
27#include <tvm/runtime/c_runtime_api.h>
28#include <tvm/runtime/logging.h>
29
30#include <string>
31#include <type_traits>
32
33namespace tvm {
34namespace runtime {
35/*!
36 * \brief Runtime primitive data type.
37 *
38 * This class is a thin wrapper of DLDataType.
39 * We also make use of DataType in compiler to store quick hint
40 */
41class DataType {
42 public:
43 /*!
44 * \brief Type code for the DataType.
45 *
46 * DLPack consistency:
47 * 1) kInt is consistent with kDLInt
48 * 2) kUInt is consistent with kDLUInt
49 * 3) kFloat is consistent with kDLFloat
50 */
51 enum TypeCode {
52 kInt = kDLInt,
53 kUInt = kDLUInt,
54 kFloat = kDLFloat,
55 kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
56 kBFloat = kDLBfloat,
57 kCustomBegin = 129
58 };
59 /*! \brief default constructor */
60 DataType() { data_ = DataType::Void(); }
61 /*!
62 * \brief Constructor
63 * \param dtype The DLDataType
64 */
65 explicit DataType(DLDataType dtype) : data_(dtype) {}
66 /*!
67 * \brief Constructor
68 * \param code The type code.
69 * \param bits The number of bits in the type.
70 * \param lanes The number of lanes.
71 */
72 DataType(int code, int bits, int lanes) {
73 data_.code = static_cast<uint8_t>(code);
74 data_.bits = static_cast<uint8_t>(bits);
75 data_.lanes = static_cast<uint16_t>(lanes);
76 if (code == kBFloat) {
77 ICHECK_EQ(bits, 16);
78 }
79 }
80 /*! \return The type code. */
81 int code() const { return static_cast<int>(data_.code); }
82 /*! \return number of bits in the data. */
83 int bits() const { return static_cast<int>(data_.bits); }
84 /*! \return number of bytes to store each scalar. */
85 int bytes() const { return (bits() + 7) / 8; }
86 /*! \return number of lanes in the data. */
87 int lanes() const { return static_cast<int>(data_.lanes); }
88 /*! \return whether type is a scalar type. */
89 bool is_scalar() const { return lanes() == 1; }
90 /*! \return whether type is a scalar type. */
91 bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
92 /*! \return whether type is a float type. */
93 bool is_float() const { return code() == DataType::kFloat; }
94 /*! \return whether type is a float16 type. */
95 bool is_float16() const { return is_float() && bits() == 16; }
96 /*! \return whether type is a bfloat16 type. */
97 bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
98 /*! \return whether type is an int type. */
99 bool is_int() const { return code() == DataType::kInt; }
100 /*! \return whether type is an uint type. */
101 bool is_uint() const { return code() == DataType::kUInt; }
102 /*! \return whether type is a handle type. */
103 bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
104 /*! \return whether type is a vector type. */
105 bool is_vector() const { return lanes() > 1; }
106 /*! \return whether type is a bool vector type. */
107 bool is_vector_bool() const { return is_vector() && bits() == 1; }
108 /*! \return whether type is a Void type. */
109 bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
110 /*!
111 * \brief Create a new data type by change lanes to a specified value.
112 * \param lanes The target number of lanes.
113 * \return the result type.
114 */
115 DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); }
116 /*!
117 * \brief Create a new data type by change bits to a specified value.
118 * \param bits The target number of bits.
119 * \return the result type.
120 */
121 DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); }
122 /*!
123 * \brief Get the scalar version of the type.
124 * \return the result type.
125 */
126 DataType element_of() const { return with_lanes(1); }
127 /*!
128 * \brief Assignment operator.
129 */
130 DataType& operator=(const DataType& rhs) {
131 if (this == &rhs) {
132 return *this;
133 }
134 data_ = rhs.data_;
135 return *this;
136 }
137 /*!
138 * \brief Equal comparator.
139 * \param other The data type to compare against.
140 * \return The comparison result.
141 */
142 bool operator==(const DataType& other) const {
143 return data_.code == other.data_.code && data_.bits == other.data_.bits &&
144 data_.lanes == other.data_.lanes;
145 }
146 /*!
147 * \brief NotEqual comparator.
148 * \param other The data type to compare against.
149 * \return The comparison result.
150 */
151 bool operator!=(const DataType& other) const { return !operator==(other); }
152 /*!
153 * \brief Converter to DLDataType
154 * \return the result.
155 */
156 operator DLDataType() const { return data_; }
157
158 /*!
159 * \brief Construct an int type.
160 * \param bits The number of bits in the type.
161 * \param lanes The number of lanes.
162 * \return The constructed data type.
163 */
164 static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); }
165 /*!
166 * \brief Construct an uint type.
167 * \param bits The number of bits in the type.
168 * \param lanes The number of lanes
169 * \return The constructed data type.
170 */
171 static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
172 /*!
173 * \brief Construct an float type.
174 * \param bits The number of bits in the type.
175 * \param lanes The number of lanes
176 * \return The constructed data type.
177 */
178 static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
179 /*!
180 * \brief Construct an bfloat type.
181 * \param bits The number of bits in the type.
182 * \param lanes The number of lanes
183 * \return The constructed data type.
184 */
185 static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
186 /*!
187 * \brief Construct a bool type.
188 * \param lanes The number of lanes
189 * \return The constructed data type.
190 */
191 static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); }
192 /*!
193 * \brief Construct a handle type.
194 * \param bits The number of bits in the type.
195 * \param lanes The number of lanes
196 * \return The constructed data type.
197 */
198 static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); }
199 /*!
200 * \brief Construct a Void type.
201 * \return The constructed data type.
202 */
203 static DataType Void() { return DataType(kHandle, 0, 0); }
204 /*!
205 * \brief Get the corresponding type of TVMShapeIndex.
206 * \return The type of TVM shape index.
207 */
208 static DataType ShapeIndex() {
209 if (std::is_signed<tvm_index_t>::value) {
210 return DataType::Int(sizeof(tvm_index_t) * 8);
211 } else {
212 return DataType::UInt(sizeof(tvm_index_t) * 8);
213 }
214 }
215
216 private:
217 DLDataType data_;
218};
219
220/*!
221 * \brief Get the number of bytes needed in a vector.
222 * \param dtype The data type.
223 * \return Number of bytes needed.
224 */
225inline int GetVectorBytes(DataType dtype) {
226 int data_bits = dtype.bits() * dtype.lanes();
227 // allow bool to exist
228 if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
229 dtype == DataType::Int(1)) {
230 return 1;
231 }
232 ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
233 return data_bits / 8;
234}
235
236/*!
237 * \brief Check whether type matches the given spec.
238 * \param t The type
239 * \param code The type code.
240 * \param bits The number of bits to be matched.
241 * \param lanes The number of lanes in the type.
242 */
243inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
244 return t.code == code && t.bits == bits && t.lanes == lanes;
245}
246/*!
247 * \brief Check whether two types are equal .
248 * \param lhs The left operand.
249 * \param rhs The right operand.
250 */
251inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
252 return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
253}
254
255/*!
256 * \brief Runtime utility for getting custom type name from code
257 * \param type_code Custom type code
258 * \return Custom type name
259 */
260TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
261
262/*!
263 * \brief Runtime utility for checking whether custom type is registered
264 * \param type_code Custom type code
265 * \return Bool representing whether type is registered
266 */
267TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
268
269/*!
270 * \brief Runtime utility for parsing string of the form "custom[<typename>]"
271 * \param s String to parse
272 * \param scan pointer to parsing pointer, which is scanning across s
273 * \return type code of custom type parsed
274 */
275TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
276
277/*!
278 * \brief Convert type code to its name
279 * \param type_code The type code .
280 * \return The name of type code.
281 */
282inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code);
283
284/*!
285 * \brief convert a string to TVM type.
286 * \param s The string to be converted.
287 * \return The corresponding tvm type.
288 */
289inline DLDataType String2DLDataType(std::string s);
290
291/*!
292 * \brief convert a TVM type to string.
293 * \param t The type to be converted.
294 * \return The corresponding tvm type in string.
295 */
296inline std::string DLDataType2String(DLDataType t);
297
298// implementation details
299inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
300 switch (static_cast<int>(type_code)) {
301 case kDLInt:
302 return "int";
303 case kDLUInt:
304 return "uint";
305 case kDLFloat:
306 return "float";
307 case DataType::kHandle:
308 return "handle";
309 case kDLBfloat:
310 return "bfloat";
311 default:
312 LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
313 }
314}
315
316inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
317 if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
318 os << "bool";
319 return os;
320 }
321 if (DataType(t).is_void()) {
322 return os << "void";
323 }
324 if (t.code < DataType::kCustomBegin) {
325 os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
326 } else {
327 os << "custom[" << GetCustomTypeName(t.code) << "]";
328 }
329 if (t.code == kTVMOpaqueHandle) return os;
330 os << static_cast<int>(t.bits);
331 if (t.lanes != 1) {
332 os << 'x' << static_cast<int>(t.lanes);
333 }
334 return os;
335}
336
337inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
338 return os << dtype.operator DLDataType();
339}
340
341inline std::string DLDataType2String(DLDataType t) {
342 if (t.bits == 0) return "";
343 std::ostringstream os;
344 os << t;
345 return os.str();
346}
347
348inline DLDataType String2DLDataType(std::string s) {
349 DLDataType t;
350 // handle void type
351 if (s.length() == 0 || s == "void") {
352 t = DataType::Void();
353 return t;
354 }
355 t.bits = 32;
356 t.lanes = 1;
357 const char* scan;
358 if (s.substr(0, 3) == "int") {
359 t.code = kDLInt;
360 scan = s.c_str() + 3;
361 } else if (s.substr(0, 4) == "uint") {
362 t.code = kDLUInt;
363 scan = s.c_str() + 4;
364 } else if (s.substr(0, 5) == "float") {
365 t.code = kDLFloat;
366 scan = s.c_str() + 5;
367 } else if (s.substr(0, 6) == "handle") {
368 t.code = kTVMOpaqueHandle;
369 t.bits = 64; // handle uses 64 bit by default.
370 scan = s.c_str() + 6;
371 } else if (s == "bool") {
372 t.code = kDLUInt;
373 t.bits = 1;
374 t.lanes = 1;
375 return t;
376 } else if (s.substr(0, 6) == "bfloat") {
377 t.code = DataType::kBFloat;
378 scan = s.c_str() + 6;
379 } else if (s.substr(0, 6) == "custom") {
380 t.code = ParseCustomDatatype(s, &scan);
381 } else {
382 scan = s.c_str();
383 LOG(FATAL) << "unknown type " << s;
384 }
385 char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
386 uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
387 if (bits != 0) t.bits = bits;
388 char* endpt = xdelim;
389 if (*xdelim == 'x') {
390 t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
391 }
392 ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
393 return t;
394}
395
396} // namespace runtime
397
398using DataType = runtime::DataType;
399
400} // namespace tvm
401
402namespace std {
403template <>
404struct hash<tvm::DataType> {
405 inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; }
406 std::size_t operator()(tvm::DataType const& dtype) const {
407 int a = dtype.code();
408 int b = dtype.bits();
409 int c = dtype.lanes();
410 int d = cantor_pairing_function(a, b);
411 return cantor_pairing_function(c, d);
412 }
413};
414} // namespace std
415
416#endif // TVM_RUNTIME_DATA_TYPE_H_
417