1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
17#define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
18
19#include <map>
20#include <set>
21#include <string>
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24// Disable clang-format to prevent 'FixedPoint' header from being included
25// before 'Tensor' header on which it depends.
26// clang-format off
27#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
28// clang-format on
29#include "tensorflow/core/framework/bfloat16.h"
30#include "tensorflow/core/framework/full_type.pb.h"
31#include "tensorflow/core/framework/numeric_types.h"
32#include "tensorflow/core/framework/resource_handle.h"
33#include "tensorflow/core/framework/types.pb.h"
34#include "tensorflow/core/lib/core/stringpiece.h"
35#include "tensorflow/core/lib/gtl/array_slice.h"
36#include "tensorflow/core/lib/gtl/inlined_vector.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/platform/types.h"
39
40namespace tensorflow {
41
42class Variant;
43
44// MemoryType is used to describe whether input or output Tensors of
45// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
46// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
47// devices).
48enum MemoryType {
49 DEVICE_MEMORY = 0,
50 HOST_MEMORY = 1,
51};
52
53// A DeviceType is just a string, but we wrap it up in a class to give
54// some type checking as we're passing these around
55class DeviceType {
56 public:
57 DeviceType(const char* type) // NOLINT(runtime/explicit)
58 : type_(type) {}
59
60 explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {}
61
62 const char* type() const { return type_.c_str(); }
63 const std::string& type_string() const { return type_; }
64
65 bool operator<(const DeviceType& other) const;
66 bool operator==(const DeviceType& other) const;
67 bool operator!=(const DeviceType& other) const { return !(*this == other); }
68
69 private:
70 std::string type_;
71};
72std::ostream& operator<<(std::ostream& os, const DeviceType& d);
73
74// Convenient constants that can be passed to a DeviceType constructor.
75// See comments for CreateOpKernel in op_kernel.h for uses of DEVICE_DEFAULT
76// and other device types.
77TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT"
78TF_EXPORT extern const char* const DEVICE_CPU; // "CPU"
79TF_EXPORT extern const char* const DEVICE_GPU; // "GPU"
80TF_EXPORT extern const char* const DEVICE_TPU; // "TPU"
81TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM"
82
83template <typename Device>
84struct DeviceName {};
85
86template <>
87struct DeviceName<Eigen::ThreadPoolDevice> {
88 static const std::string value;
89};
90
91#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
92 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
93template <>
94struct DeviceName<Eigen::GpuDevice> {
95 static const std::string value;
96};
97#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
98
99
100typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
101typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;
102
103typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
104typedef gtl::ArraySlice<DataType> DataTypeSlice;
105
106typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
107typedef gtl::InlinedVector<std::pair<DeviceType, int32>, 4>
108 PrioritizedDeviceTypeVector;
109
110// Convert the enums to strings for errors:
111std::string DataTypeString(DataType dtype);
112std::string DeviceTypeString(const DeviceType& device_type);
113std::string DataTypeSliceString(const DataTypeSlice dtypes);
114inline std::string DataTypeVectorString(const DataTypeVector& dtypes) {
115 return DataTypeSliceString(dtypes);
116}
117
118// DataTypeSet represents a set of DataType values as a simple and efficient
119// bit mask. Note that DataTypeSet cannot represent all DataType values; it
120// cannot represent any of the DT_*_REF values.
121class DataTypeSet {
122 private:
123 const uint32 mask_;
124
125 static constexpr uint32 kNumBits = 32;
126
127 public:
128 constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {}
129 explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {}
130
131 constexpr bool Contains(DataType dt) const {
132 return (static_cast<uint32>(dt) < kNumBits) &&
133 ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u;
134 }
135
136 class Iterator {
137 const DataTypeSet& set_;
138 uint32 pos_;
139
140 public:
141 Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) {
142 DCHECK_LE(pos, kNumBits);
143 }
144 DataType operator*() const { return static_cast<DataType>(pos_); }
145 Iterator& operator++() {
146 ++pos_;
147 DCHECK_LE(pos_, kNumBits);
148 if (pos_ < kNumBits) {
149 uint32 remaining_mask = set_.mask_ >> pos_;
150 if (remaining_mask != 0u) {
151 pos_ += ctz_uint32(remaining_mask);
152 }
153 }
154 DCHECK_LE(pos_, kNumBits);
155 return *this;
156 }
157 bool operator==(const Iterator& other) const { return pos_ == other.pos_; }
158 bool operator!=(const Iterator& other) const { return !(*this == other); }
159 size_t operator-(const Iterator& other) const {
160 return this->pos_ - other.pos_;
161 }
162 };
163
164 static uint32 ctz_uint32(uint32 x) {
165 DCHECK_NE(x, 0u);
166#ifdef __GNUC__
167 return __builtin_ctz(x);
168#else
169 uint32 n = 0u;
170 while ((x & 1u) == 0u) {
171 x >>= 1;
172 ++n;
173 }
174 return n;
175#endif
176 }
177
178 static uint32 clz_uint32(uint32 x) {
179 DCHECK_NE(x, 0u);
180#ifdef __GNUC__
181 return __builtin_clz(x);
182#else
183 uint32 n = 0u;
184 while ((x >> (kNumBits - 1u)) == 0u) {
185 x <<= 1;
186 ++n;
187 }
188 return n;
189#endif
190 }
191
192 Iterator begin() const {
193 // The begin position is the index of the first bit set to 1 in the entire
194 // bit mask. If there are no bits set to 1, then the index is 0.
195 if (mask_ != 0) {
196 return Iterator(*this, ctz_uint32(mask_));
197 }
198 // The set is empty.
199 return Iterator(*this, 0);
200 }
201
202 Iterator end() const {
203 // The end position is the index of the highest bit that is set, plus 1.
204 // If there are no bits set to 1, then the index is 0.
205 if (mask_ != 0) {
206 return Iterator(*this, kNumBits - clz_uint32(mask_));
207 }
208 // The set is empty.
209 return Iterator(*this, 0);
210 }
211
212 size_t size() const {
213#if defined(__GNUC__)
214 return __builtin_popcount(mask_);
215#else
216 size_t n = 0;
217 uint32 x = mask_;
218 while (x > 0) {
219 n += x & 1u;
220 x >>= 1;
221 }
222 return n;
223#endif
224 }
225
226 constexpr DataTypeSet operator|(const DataTypeSet& other) const {
227 return DataTypeSet(mask_ | other.mask_);
228 }
229};
230
231// If "sp" names a valid type, store it in "*dt" and return true. Otherwise,
232// return false.
233bool DataTypeFromString(StringPiece sp, DataType* dt);
234
235constexpr inline DataTypeSet ToSet(DataType dt) {
236 return DataTypeSet(1u << static_cast<uint32>(dt));
237}
238
239// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc.
240enum { kDataTypeRefOffset = 100 };
241inline bool IsRefType(DataType dtype) {
242 return dtype > static_cast<DataType>(kDataTypeRefOffset);
243}
244inline DataType MakeRefType(DataType dtype) {
245 DCHECK(!IsRefType(dtype));
246 return static_cast<DataType>(dtype + kDataTypeRefOffset);
247}
248inline DataType RemoveRefType(DataType dtype) {
249 DCHECK(IsRefType(dtype));
250 return static_cast<DataType>(dtype - kDataTypeRefOffset);
251}
252inline DataType BaseType(DataType dtype) {
253 return IsRefType(dtype) ? RemoveRefType(dtype) : dtype;
254}
255
256// Returns true if the actual type is the same as or ref of the expected type.
257inline bool TypesCompatible(DataType expected, DataType actual) {
258 return expected == actual || expected == BaseType(actual);
259}
260
261// Does not include _ref types.
262constexpr DataTypeSet kAllTypes =
263 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) |
264 ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) |
265 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
266 ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) |
267 ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) |
268 ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) |
269 ToSet(DT_BFLOAT16);
270inline const DataTypeSet& AllTypes() { return kAllTypes; }
271
272#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
273
274// Types that support '<' and '>'.
275constexpr DataTypeSet kRealNumberTypes =
276 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
277 ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) |
278 ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
279inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
280
281// Return the list of all numeric types.
282// Includes complex and quantized types.
283// NOTE: On Android, we only include the float and int32 types for now.
284const DataTypeSet kNumberTypes =
285 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) |
286 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
287 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) |
288 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) |
289 ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
290inline const DataTypeSet& NumberTypes() { return kNumberTypes; }
291
292constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
293 ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
294 ToSet(DT_QINT32);
295inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; }
296
297// Types that support '<' and '>', including quantized types.
298const DataTypeSet kRealAndQuantizedTypes =
299 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
300 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
301 ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
302 ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16);
303inline const DataTypeSet& RealAndQuantizedTypes() {
304 return kRealAndQuantizedTypes;
305}
306
307#elif defined(__ANDROID_TYPES_FULL__)
308
309constexpr DataTypeSet kRealNumberTypes =
310 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF);
311inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
312
313constexpr DataTypeSet kNumberTypes =
314 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
315 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF);
316inline DataTypeSet NumberTypes() { return kNumberTypes; }
317
318constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
319 ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
320 ToSet(DT_QINT32);
321inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
322
323constexpr DataTypeSet kRealAndQuantizedTypes =
324 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
325 ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
326 ToSet(DT_HALF);
327inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
328
329#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
330
331constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32);
332inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
333
334constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) |
335 ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
336 ToSet(DT_QINT32);
337inline DataTypeSet NumberTypes() { return kNumberTypes; }
338
339constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
340 ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
341 ToSet(DT_QINT32);
342inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
343
344constexpr DataTypeSet kRealAndQuantizedTypes =
345 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
346 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32);
347inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
348
349#endif // defined(IS_MOBILE_PLATFORM)
350
351// Validates type T for whether it is a supported DataType.
352template <class T>
353struct IsValidDataType;
354
355// DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType
356// constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT.
357template <class T>
358struct DataTypeToEnum {
359 static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
360}; // Specializations below
361
362// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
363// EnumToDataType<DT_FLOAT>::Type is float.
364template <DataType VALUE>
365struct EnumToDataType {}; // Specializations below
366
367// Template specialization for both DataTypeToEnum and EnumToDataType.
368#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \
369 template <> \
370 struct DataTypeToEnum<TYPE> { \
371 static DataType v() { return ENUM; } \
372 static DataType ref() { return MakeRefType(ENUM); } \
373 static constexpr DataType value = ENUM; \
374 }; \
375 template <> \
376 struct IsValidDataType<TYPE> { \
377 static constexpr bool value = true; \
378 }; \
379 template <> \
380 struct EnumToDataType<ENUM> { \
381 typedef TYPE Type; \
382 }
383
384MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
385MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
386MATCH_TYPE_AND_ENUM(int32, DT_INT32);
387MATCH_TYPE_AND_ENUM(uint32, DT_UINT32);
388MATCH_TYPE_AND_ENUM(uint16, DT_UINT16);
389MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
390MATCH_TYPE_AND_ENUM(int16, DT_INT16);
391MATCH_TYPE_AND_ENUM(int8, DT_INT8);
392MATCH_TYPE_AND_ENUM(tstring, DT_STRING);
393MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64);
394MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128);
395MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
396MATCH_TYPE_AND_ENUM(qint8, DT_QINT8);
397MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8);
398MATCH_TYPE_AND_ENUM(qint16, DT_QINT16);
399MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16);
400MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
401MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
402MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF);
403MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE);
404MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT);
405
406template <>
407struct DataTypeToEnum<long> {
408 static DataType v() { return value; }
409 static DataType ref() { return MakeRefType(value); }
410 static constexpr DataType value = sizeof(long) == 4 ? DT_INT32 : DT_INT64;
411};
412template <>
413struct IsValidDataType<long> {
414 static constexpr bool value = true;
415};
416template <>
417struct EnumToDataType<DT_INT64> {
418 typedef int64_t Type;
419};
420
421template <>
422struct DataTypeToEnum<unsigned long> {
423 static DataType v() { return value; }
424 static DataType ref() { return MakeRefType(value); }
425 static constexpr DataType value =
426 sizeof(unsigned long) == 4 ? DT_UINT32 : DT_UINT64;
427};
428template <>
429struct IsValidDataType<unsigned long> {
430 static constexpr bool value = true;
431};
432template <>
433struct EnumToDataType<DT_UINT64> {
434 typedef tensorflow::uint64 Type;
435};
436
437template <>
438struct DataTypeToEnum<long long> {
439 static DataType v() { return DT_INT64; }
440 static DataType ref() { return MakeRefType(DT_INT64); }
441 static constexpr DataType value = DT_INT64;
442};
443template <>
444struct IsValidDataType<long long> {
445 static constexpr bool value = true;
446};
447
448template <>
449struct DataTypeToEnum<unsigned long long> {
450 static DataType v() { return DT_UINT64; }
451 static DataType ref() { return MakeRefType(DT_UINT64); }
452 static constexpr DataType value = DT_UINT64;
453};
454template <>
455struct IsValidDataType<unsigned long long> {
456 static constexpr bool value = true;
457};
458
459#undef MATCH_TYPE_AND_ENUM
460
461// All types not specialized are marked invalid.
462template <class T>
463struct IsValidDataType {
464 static constexpr bool value = false;
465};
466
467// Extra validity checking; not part of public API.
468static_assert(IsValidDataType<int64_t>::value, "Incorrect impl for int64");
469static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32");
470
471// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying
472// is_simple<T> in tensor.cc (and possible choose a more general name?)
473constexpr DataTypeSet kDataTypesCanUseMemcpy =
474 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) |
475 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
476 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
477 ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
478 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
479 ToSet(DT_BFLOAT16) | ToSet(DT_HALF);
480inline bool DataTypeCanUseMemcpy(DataType dt) {
481 return kDataTypesCanUseMemcpy.Contains(dt);
482}
483
484// Returns true iff 'dt' is a real, non-quantized floating point type.
485constexpr DataTypeSet kDataTypeIsFloating =
486 ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE);
487inline bool DataTypeIsFloating(DataType dt) {
488 return kDataTypeIsFloating.Contains(dt);
489}
490
491// Returns true iff 'dt' is a complex type.
492constexpr DataTypeSet kDataTypeIsComplex =
493 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128);
494inline bool DataTypeIsComplex(DataType dt) {
495 return kDataTypeIsComplex.Contains(dt);
496}
497
498inline bool DataTypeIsQuantized(DataType dt) {
499 return kQuantizedTypes.Contains(dt);
500}
501
502// Is the dtype nonquantized integral?
503constexpr DataTypeSet kDataTypeIsInteger =
504 ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) |
505 ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64);
506inline bool DataTypeIsInteger(DataType dt) {
507 return kDataTypeIsInteger.Contains(dt);
508}
509
510// Is the dtype a signed integral type?
511constexpr DataTypeSet kDataTypeIsSigned =
512 ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64);
513inline bool DataTypeIsSigned(DataType dt) {
514 return kDataTypeIsSigned.Contains(dt);
515}
516
517// Is the dtype an unsigned integral type?
518constexpr DataTypeSet kDataTypeIsUnsigned =
519 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64);
520inline bool DataTypeIsUnsigned(DataType dt) {
521 return kDataTypeIsUnsigned.Contains(dt);
522}
523
524// Returns a 0 on failure
525int DataTypeSize(DataType dt);
526
527// Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32,
528// DEVICE_MEMORY otherwise.
529MemoryType MTypeFromDType(const DataType dtype);
530
531// Returns HOST_MEMORY if `dtype` is always on host, DEVICE_MEMORY otherwise.
532// The reason we have MTypeFromDType() and MTypeFromDTypeIntsOnDevice(): for
533// GPUs, we would like to keep int operations on host for performance concerns.
534// But for TPUs (and other devices), int operations are placed on device.
535MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype);
536
537// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
538// For DT_RESOURCE, the handle always sits on host (even if the underlying
539// object has device-allocated resources).
540bool DataTypeAlwaysOnHost(DataType dt);
541
542// FullType implementation.
543
544// Reference container for a type definition. These values are usually interned.
545// These containers admit a notion of ordering for efficient access. The
546// ordering has no semantic otherwise.
547struct TypeRef {
548 std::shared_ptr<FullTypeDef> full_type;
549
550 bool operator==(const TypeRef& other) const {
551 // TODO(mdan): This should be more efficient.
552 return full_type->SerializeAsString() ==
553 other.full_type->SerializeAsString();
554 }
555 bool operator<(const TypeRef& other) const {
556 return full_type->SerializeAsString() <
557 other.full_type->SerializeAsString();
558 }
559};
560
561struct TypeHasher {
562 std::size_t operator()(const TypeRef& k) const {
563 return std::hash<std::string>()(k.full_type->SerializeAsString());
564 }
565};
566
567// Maps a legacy DType proto enum to an equivalent FullType ID.
568void map_dtype_to_tensor(const DataType& dtype, FullTypeDef& t);
569
570} // namespace tensorflow
571
572#endif // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
573