1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ |
18 | |
19 | #include <map> |
20 | #include <set> |
21 | #include <string> |
22 | |
23 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
24 | // Disable clang-format to prevent 'FixedPoint' header from being included |
25 | // before 'Tensor' header on which it depends. |
26 | // clang-format off |
27 | #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" |
28 | // clang-format on |
29 | #include "tensorflow/core/framework/bfloat16.h" |
30 | #include "tensorflow/core/framework/full_type.pb.h" |
31 | #include "tensorflow/core/framework/numeric_types.h" |
32 | #include "tensorflow/core/framework/resource_handle.h" |
33 | #include "tensorflow/core/framework/types.pb.h" |
34 | #include "tensorflow/core/lib/core/stringpiece.h" |
35 | #include "tensorflow/core/lib/gtl/array_slice.h" |
36 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
37 | #include "tensorflow/core/platform/logging.h" |
38 | #include "tensorflow/core/platform/types.h" |
39 | |
40 | namespace tensorflow { |
41 | |
42 | class Variant; |
43 | |
44 | // MemoryType is used to describe whether input or output Tensors of |
45 | // an OpKernel should reside in "Host memory" (e.g., CPU memory) or |
46 | // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU |
47 | // devices). |
48 | enum MemoryType { |
49 | DEVICE_MEMORY = 0, |
50 | HOST_MEMORY = 1, |
51 | }; |
52 | |
53 | // A DeviceType is just a string, but we wrap it up in a class to give |
54 | // some type checking as we're passing these around |
55 | class DeviceType { |
56 | public: |
57 | DeviceType(const char* type) // NOLINT(runtime/explicit) |
58 | : type_(type) {} |
59 | |
60 | explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} |
61 | |
62 | const char* type() const { return type_.c_str(); } |
63 | const std::string& type_string() const { return type_; } |
64 | |
65 | bool operator<(const DeviceType& other) const; |
66 | bool operator==(const DeviceType& other) const; |
67 | bool operator!=(const DeviceType& other) const { return !(*this == other); } |
68 | |
69 | private: |
70 | std::string type_; |
71 | }; |
72 | std::ostream& operator<<(std::ostream& os, const DeviceType& d); |
73 | |
74 | // Convenient constants that can be passed to a DeviceType constructor. |
75 | // See comments for CreateOpKernel in op_kernel.h for uses of DEVICE_DEFAULT |
76 | // and other device types. |
77 | TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT" |
78 | TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" |
79 | TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" |
80 | TF_EXPORT extern const char* const DEVICE_TPU; // "TPU" |
81 | TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM" |
82 | |
83 | template <typename Device> |
84 | struct DeviceName {}; |
85 | |
86 | template <> |
87 | struct DeviceName<Eigen::ThreadPoolDevice> { |
88 | static const std::string value; |
89 | }; |
90 | |
91 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
92 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
93 | template <> |
94 | struct DeviceName<Eigen::GpuDevice> { |
95 | static const std::string value; |
96 | }; |
97 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
98 | |
99 | |
100 | typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector; |
101 | typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice; |
102 | |
103 | typedef gtl::InlinedVector<DataType, 4> DataTypeVector; |
104 | typedef gtl::ArraySlice<DataType> DataTypeSlice; |
105 | |
106 | typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector; |
107 | typedef gtl::InlinedVector<std::pair<DeviceType, int32>, 4> |
108 | PrioritizedDeviceTypeVector; |
109 | |
110 | // Convert the enums to strings for errors: |
111 | std::string DataTypeString(DataType dtype); |
112 | std::string DeviceTypeString(const DeviceType& device_type); |
113 | std::string DataTypeSliceString(const DataTypeSlice dtypes); |
114 | inline std::string DataTypeVectorString(const DataTypeVector& dtypes) { |
115 | return DataTypeSliceString(dtypes); |
116 | } |
117 | |
118 | // DataTypeSet represents a set of DataType values as a simple and efficient |
119 | // bit mask. Note that DataTypeSet cannot represent all DataType values; it |
120 | // cannot represent any of the DT_*_REF values. |
121 | class DataTypeSet { |
122 | private: |
123 | const uint32 mask_; |
124 | |
125 | static constexpr uint32 kNumBits = 32; |
126 | |
127 | public: |
128 | constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {} |
129 | explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {} |
130 | |
131 | constexpr bool Contains(DataType dt) const { |
132 | return (static_cast<uint32>(dt) < kNumBits) && |
133 | ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u; |
134 | } |
135 | |
136 | class Iterator { |
137 | const DataTypeSet& set_; |
138 | uint32 pos_; |
139 | |
140 | public: |
141 | Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) { |
142 | DCHECK_LE(pos, kNumBits); |
143 | } |
144 | DataType operator*() const { return static_cast<DataType>(pos_); } |
145 | Iterator& operator++() { |
146 | ++pos_; |
147 | DCHECK_LE(pos_, kNumBits); |
148 | if (pos_ < kNumBits) { |
149 | uint32 remaining_mask = set_.mask_ >> pos_; |
150 | if (remaining_mask != 0u) { |
151 | pos_ += ctz_uint32(remaining_mask); |
152 | } |
153 | } |
154 | DCHECK_LE(pos_, kNumBits); |
155 | return *this; |
156 | } |
157 | bool operator==(const Iterator& other) const { return pos_ == other.pos_; } |
158 | bool operator!=(const Iterator& other) const { return !(*this == other); } |
159 | size_t operator-(const Iterator& other) const { |
160 | return this->pos_ - other.pos_; |
161 | } |
162 | }; |
163 | |
164 | static uint32 ctz_uint32(uint32 x) { |
165 | DCHECK_NE(x, 0u); |
166 | #ifdef __GNUC__ |
167 | return __builtin_ctz(x); |
168 | #else |
169 | uint32 n = 0u; |
170 | while ((x & 1u) == 0u) { |
171 | x >>= 1; |
172 | ++n; |
173 | } |
174 | return n; |
175 | #endif |
176 | } |
177 | |
178 | static uint32 clz_uint32(uint32 x) { |
179 | DCHECK_NE(x, 0u); |
180 | #ifdef __GNUC__ |
181 | return __builtin_clz(x); |
182 | #else |
183 | uint32 n = 0u; |
184 | while ((x >> (kNumBits - 1u)) == 0u) { |
185 | x <<= 1; |
186 | ++n; |
187 | } |
188 | return n; |
189 | #endif |
190 | } |
191 | |
192 | Iterator begin() const { |
193 | // The begin position is the index of the first bit set to 1 in the entire |
194 | // bit mask. If there are no bits set to 1, then the index is 0. |
195 | if (mask_ != 0) { |
196 | return Iterator(*this, ctz_uint32(mask_)); |
197 | } |
198 | // The set is empty. |
199 | return Iterator(*this, 0); |
200 | } |
201 | |
202 | Iterator end() const { |
203 | // The end position is the index of the highest bit that is set, plus 1. |
204 | // If there are no bits set to 1, then the index is 0. |
205 | if (mask_ != 0) { |
206 | return Iterator(*this, kNumBits - clz_uint32(mask_)); |
207 | } |
208 | // The set is empty. |
209 | return Iterator(*this, 0); |
210 | } |
211 | |
212 | size_t size() const { |
213 | #if defined(__GNUC__) |
214 | return __builtin_popcount(mask_); |
215 | #else |
216 | size_t n = 0; |
217 | uint32 x = mask_; |
218 | while (x > 0) { |
219 | n += x & 1u; |
220 | x >>= 1; |
221 | } |
222 | return n; |
223 | #endif |
224 | } |
225 | |
226 | constexpr DataTypeSet operator|(const DataTypeSet& other) const { |
227 | return DataTypeSet(mask_ | other.mask_); |
228 | } |
229 | }; |
230 | |
231 | // If "sp" names a valid type, store it in "*dt" and return true. Otherwise, |
232 | // return false. |
233 | bool DataTypeFromString(StringPiece sp, DataType* dt); |
234 | |
235 | constexpr inline DataTypeSet ToSet(DataType dt) { |
236 | return DataTypeSet(1u << static_cast<uint32>(dt)); |
237 | } |
238 | |
239 | // DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. |
240 | enum { kDataTypeRefOffset = 100 }; |
241 | inline bool IsRefType(DataType dtype) { |
242 | return dtype > static_cast<DataType>(kDataTypeRefOffset); |
243 | } |
244 | inline DataType MakeRefType(DataType dtype) { |
245 | DCHECK(!IsRefType(dtype)); |
246 | return static_cast<DataType>(dtype + kDataTypeRefOffset); |
247 | } |
248 | inline DataType RemoveRefType(DataType dtype) { |
249 | DCHECK(IsRefType(dtype)); |
250 | return static_cast<DataType>(dtype - kDataTypeRefOffset); |
251 | } |
252 | inline DataType BaseType(DataType dtype) { |
253 | return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; |
254 | } |
255 | |
256 | // Returns true if the actual type is the same as or ref of the expected type. |
257 | inline bool TypesCompatible(DataType expected, DataType actual) { |
258 | return expected == actual || expected == BaseType(actual); |
259 | } |
260 | |
261 | // Does not include _ref types. |
262 | constexpr DataTypeSet kAllTypes = |
263 | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) | |
264 | ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) | |
265 | ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | |
266 | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | |
267 | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) | |
268 | ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | |
269 | ToSet(DT_BFLOAT16); |
270 | inline const DataTypeSet& AllTypes() { return kAllTypes; } |
271 | |
272 | #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) |
273 | |
274 | // Types that support '<' and '>'. |
275 | constexpr DataTypeSet kRealNumberTypes = |
276 | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | |
277 | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) | |
278 | ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); |
279 | inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; } |
280 | |
281 | // Return the list of all numeric types. |
282 | // Includes complex and quantized types. |
283 | // NOTE: On Android, we only include the float and int32 types for now. |
284 | const DataTypeSet kNumberTypes = |
285 | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) | |
286 | ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | |
287 | ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) | |
288 | ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) | |
289 | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); |
290 | inline const DataTypeSet& NumberTypes() { return kNumberTypes; } |
291 | |
292 | constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | |
293 | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | |
294 | ToSet(DT_QINT32); |
295 | inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; } |
296 | |
297 | // Types that support '<' and '>', including quantized types. |
298 | const DataTypeSet kRealAndQuantizedTypes = |
299 | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | |
300 | ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | |
301 | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | |
302 | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16); |
303 | inline const DataTypeSet& RealAndQuantizedTypes() { |
304 | return kRealAndQuantizedTypes; |
305 | } |
306 | |
307 | #elif defined(__ANDROID_TYPES_FULL__) |
308 | |
309 | constexpr DataTypeSet kRealNumberTypes = |
310 | ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF); |
311 | inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } |
312 | |
313 | constexpr DataTypeSet kNumberTypes = |
314 | ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | |
315 | ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF); |
316 | inline DataTypeSet NumberTypes() { return kNumberTypes; } |
317 | |
318 | constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | |
319 | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | |
320 | ToSet(DT_QINT32); |
321 | inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } |
322 | |
323 | constexpr DataTypeSet kRealAndQuantizedTypes = |
324 | ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | |
325 | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | |
326 | ToSet(DT_HALF); |
327 | inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } |
328 | |
329 | #else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) |
330 | |
331 | constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32); |
332 | inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } |
333 | |
334 | constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) | |
335 | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | |
336 | ToSet(DT_QINT32); |
337 | inline DataTypeSet NumberTypes() { return kNumberTypes; } |
338 | |
339 | constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | |
340 | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | |
341 | ToSet(DT_QINT32); |
342 | inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } |
343 | |
344 | constexpr DataTypeSet kRealAndQuantizedTypes = |
345 | ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | |
346 | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32); |
347 | inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } |
348 | |
349 | #endif // defined(IS_MOBILE_PLATFORM) |
350 | |
351 | // Validates type T for whether it is a supported DataType. |
352 | template <class T> |
353 | struct IsValidDataType; |
354 | |
355 | // DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType |
356 | // constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT. |
357 | template <class T> |
358 | struct DataTypeToEnum { |
359 | static_assert(IsValidDataType<T>::value, "Specified Data Type not supported" ); |
360 | }; // Specializations below |
361 | |
362 | // EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g. |
363 | // EnumToDataType<DT_FLOAT>::Type is float. |
364 | template <DataType VALUE> |
365 | struct EnumToDataType {}; // Specializations below |
366 | |
367 | // Template specialization for both DataTypeToEnum and EnumToDataType. |
368 | #define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ |
369 | template <> \ |
370 | struct DataTypeToEnum<TYPE> { \ |
371 | static DataType v() { return ENUM; } \ |
372 | static DataType ref() { return MakeRefType(ENUM); } \ |
373 | static constexpr DataType value = ENUM; \ |
374 | }; \ |
375 | template <> \ |
376 | struct IsValidDataType<TYPE> { \ |
377 | static constexpr bool value = true; \ |
378 | }; \ |
379 | template <> \ |
380 | struct EnumToDataType<ENUM> { \ |
381 | typedef TYPE Type; \ |
382 | } |
383 | |
384 | MATCH_TYPE_AND_ENUM(float, DT_FLOAT); |
385 | MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); |
386 | MATCH_TYPE_AND_ENUM(int32, DT_INT32); |
387 | MATCH_TYPE_AND_ENUM(uint32, DT_UINT32); |
388 | MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); |
389 | MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); |
390 | MATCH_TYPE_AND_ENUM(int16, DT_INT16); |
391 | MATCH_TYPE_AND_ENUM(int8, DT_INT8); |
392 | MATCH_TYPE_AND_ENUM(tstring, DT_STRING); |
393 | MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); |
394 | MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128); |
395 | MATCH_TYPE_AND_ENUM(bool, DT_BOOL); |
396 | MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); |
397 | MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); |
398 | MATCH_TYPE_AND_ENUM(qint16, DT_QINT16); |
399 | MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16); |
400 | MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); |
401 | MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); |
402 | MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); |
403 | MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); |
404 | MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); |
405 | |
406 | template <> |
407 | struct DataTypeToEnum<long> { |
408 | static DataType v() { return value; } |
409 | static DataType ref() { return MakeRefType(value); } |
410 | static constexpr DataType value = sizeof(long) == 4 ? DT_INT32 : DT_INT64; |
411 | }; |
412 | template <> |
413 | struct IsValidDataType<long> { |
414 | static constexpr bool value = true; |
415 | }; |
416 | template <> |
417 | struct EnumToDataType<DT_INT64> { |
418 | typedef int64_t Type; |
419 | }; |
420 | |
421 | template <> |
422 | struct DataTypeToEnum<unsigned long> { |
423 | static DataType v() { return value; } |
424 | static DataType ref() { return MakeRefType(value); } |
425 | static constexpr DataType value = |
426 | sizeof(unsigned long) == 4 ? DT_UINT32 : DT_UINT64; |
427 | }; |
428 | template <> |
429 | struct IsValidDataType<unsigned long> { |
430 | static constexpr bool value = true; |
431 | }; |
432 | template <> |
433 | struct EnumToDataType<DT_UINT64> { |
434 | typedef tensorflow::uint64 Type; |
435 | }; |
436 | |
437 | template <> |
438 | struct DataTypeToEnum<long long> { |
439 | static DataType v() { return DT_INT64; } |
440 | static DataType ref() { return MakeRefType(DT_INT64); } |
441 | static constexpr DataType value = DT_INT64; |
442 | }; |
443 | template <> |
444 | struct IsValidDataType<long long> { |
445 | static constexpr bool value = true; |
446 | }; |
447 | |
448 | template <> |
449 | struct DataTypeToEnum<unsigned long long> { |
450 | static DataType v() { return DT_UINT64; } |
451 | static DataType ref() { return MakeRefType(DT_UINT64); } |
452 | static constexpr DataType value = DT_UINT64; |
453 | }; |
454 | template <> |
455 | struct IsValidDataType<unsigned long long> { |
456 | static constexpr bool value = true; |
457 | }; |
458 | |
459 | #undef MATCH_TYPE_AND_ENUM |
460 | |
461 | // All types not specialized are marked invalid. |
462 | template <class T> |
463 | struct IsValidDataType { |
464 | static constexpr bool value = false; |
465 | }; |
466 | |
467 | // Extra validity checking; not part of public API. |
468 | static_assert(IsValidDataType<int64_t>::value, "Incorrect impl for int64" ); |
469 | static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32" ); |
470 | |
471 | // TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying |
472 | // is_simple<T> in tensor.cc (and possible choose a more general name?) |
473 | constexpr DataTypeSet kDataTypesCanUseMemcpy = |
474 | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) | |
475 | ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | |
476 | ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | |
477 | ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | |
478 | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | |
479 | ToSet(DT_BFLOAT16) | ToSet(DT_HALF); |
480 | inline bool DataTypeCanUseMemcpy(DataType dt) { |
481 | return kDataTypesCanUseMemcpy.Contains(dt); |
482 | } |
483 | |
484 | // Returns true iff 'dt' is a real, non-quantized floating point type. |
485 | constexpr DataTypeSet kDataTypeIsFloating = |
486 | ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE); |
487 | inline bool DataTypeIsFloating(DataType dt) { |
488 | return kDataTypeIsFloating.Contains(dt); |
489 | } |
490 | |
491 | // Returns true iff 'dt' is a complex type. |
492 | constexpr DataTypeSet kDataTypeIsComplex = |
493 | ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128); |
494 | inline bool DataTypeIsComplex(DataType dt) { |
495 | return kDataTypeIsComplex.Contains(dt); |
496 | } |
497 | |
498 | inline bool DataTypeIsQuantized(DataType dt) { |
499 | return kQuantizedTypes.Contains(dt); |
500 | } |
501 | |
502 | // Is the dtype nonquantized integral? |
503 | constexpr DataTypeSet kDataTypeIsInteger = |
504 | ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) | |
505 | ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64); |
506 | inline bool DataTypeIsInteger(DataType dt) { |
507 | return kDataTypeIsInteger.Contains(dt); |
508 | } |
509 | |
510 | // Is the dtype a signed integral type? |
511 | constexpr DataTypeSet kDataTypeIsSigned = |
512 | ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64); |
513 | inline bool DataTypeIsSigned(DataType dt) { |
514 | return kDataTypeIsSigned.Contains(dt); |
515 | } |
516 | |
517 | // Is the dtype an unsigned integral type? |
518 | constexpr DataTypeSet kDataTypeIsUnsigned = |
519 | ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64); |
520 | inline bool DataTypeIsUnsigned(DataType dt) { |
521 | return kDataTypeIsUnsigned.Contains(dt); |
522 | } |
523 | |
524 | // Returns a 0 on failure |
525 | int DataTypeSize(DataType dt); |
526 | |
527 | // Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32, |
528 | // DEVICE_MEMORY otherwise. |
529 | MemoryType MTypeFromDType(const DataType dtype); |
530 | |
531 | // Returns HOST_MEMORY if `dtype` is always on host, DEVICE_MEMORY otherwise. |
532 | // The reason we have MTypeFromDType() and MTypeFromDTypeIntsOnDevice(): for |
533 | // GPUs, we would like to keep int operations on host for performance concerns. |
534 | // But for TPUs (and other devices), int operations are placed on device. |
535 | MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype); |
536 | |
537 | // Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. |
538 | // For DT_RESOURCE, the handle always sits on host (even if the underlying |
539 | // object has device-allocated resources). |
540 | bool DataTypeAlwaysOnHost(DataType dt); |
541 | |
542 | // FullType implementation. |
543 | |
544 | // Reference container for a type definition. These values are usually interned. |
545 | // These containers admit a notion of ordering for efficient access. The |
546 | // ordering has no semantic otherwise. |
547 | struct TypeRef { |
548 | std::shared_ptr<FullTypeDef> full_type; |
549 | |
550 | bool operator==(const TypeRef& other) const { |
551 | // TODO(mdan): This should be more efficient. |
552 | return full_type->SerializeAsString() == |
553 | other.full_type->SerializeAsString(); |
554 | } |
555 | bool operator<(const TypeRef& other) const { |
556 | return full_type->SerializeAsString() < |
557 | other.full_type->SerializeAsString(); |
558 | } |
559 | }; |
560 | |
561 | struct TypeHasher { |
562 | std::size_t operator()(const TypeRef& k) const { |
563 | return std::hash<std::string>()(k.full_type->SerializeAsString()); |
564 | } |
565 | }; |
566 | |
567 | // Maps a legacy DType proto enum to an equivalent FullType ID. |
568 | void map_dtype_to_tensor(const DataType& dtype, FullTypeDef& t); |
569 | |
570 | } // namespace tensorflow |
571 | |
572 | #endif // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ |
573 | |