1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/types.h"
17#include "tensorflow/core/framework/register_types.h"
18#include "tensorflow/core/lib/strings/str_util.h"
19#include "tensorflow/core/lib/strings/strcat.h"
20#include "tensorflow/core/platform/logging.h"
21
22namespace tensorflow {
23
24struct DataTypeHasher {
25 std::size_t operator()(const DataType& k) const {
26 return std::hash<int>()(static_cast<int>(k));
27 }
28};
29
30// Mapping from some of the DType fields, for backward compatibility. All other
31// dtypes are mapped to TFT_ANY, but can be added here if a counterpart is
32// defined.
33auto* DT_TO_FT = new std::unordered_map<DataType, FullTypeId, DataTypeHasher>({
34 {DT_FLOAT, TFT_FLOAT},
35 {DT_DOUBLE, TFT_DOUBLE},
36 {DT_INT32, TFT_INT32},
37 {DT_UINT8, TFT_UINT8},
38 {DT_INT16, TFT_INT16},
39 {DT_INT8, TFT_INT8},
40 {DT_STRING, TFT_STRING},
41 {DT_COMPLEX64, TFT_COMPLEX64},
42 {DT_INT64, TFT_INT64},
43 {DT_BOOL, TFT_BOOL},
44 {DT_UINT16, TFT_UINT16},
45 {DT_COMPLEX128, TFT_COMPLEX128},
46 {DT_HALF, TFT_HALF},
47 {DT_UINT32, TFT_UINT32},
48 {DT_UINT64, TFT_UINT64},
49 {DT_VARIANT, TFT_LEGACY_VARIANT},
50});
51
52void map_dtype_to_tensor(const DataType& dtype, FullTypeDef& t) {
53 t.Clear();
54
55 const auto& mapped = DT_TO_FT->find(dtype);
56 // Only map known types, everything else remains unset. This is so that we
57 // only set the most specific type when it is fully known. For example, if the
58 // dtype is DT_VARIANT, then we don't know much and opt to assume that
59 // the type is unset, rather than TFT_ANY.
60 if (mapped != DT_TO_FT->end()) {
61 t.set_type_id(mapped->second);
62 }
63}
64
65bool DeviceType::operator<(const DeviceType& other) const {
66 return type_ < other.type_;
67}
68
69bool DeviceType::operator==(const DeviceType& other) const {
70 return type_ == other.type_;
71}
72
73std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
74 os << d.type();
75 return os;
76}
77
78const char* const DEVICE_DEFAULT = "DEFAULT";
79const char* const DEVICE_CPU = "CPU";
80const char* const DEVICE_GPU = "GPU";
81const char* const DEVICE_TPU = "TPU";
82const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
83
84const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
85#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
86 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
87const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
88#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
89
90namespace {
91string DataTypeStringInternal(DataType dtype) {
92 switch (dtype) {
93 case DT_INVALID:
94 return "INVALID";
95 case DT_FLOAT:
96 return "float";
97 case DT_DOUBLE:
98 return "double";
99 case DT_INT32:
100 return "int32";
101 case DT_UINT32:
102 return "uint32";
103 case DT_UINT8:
104 return "uint8";
105 case DT_UINT16:
106 return "uint16";
107 case DT_INT16:
108 return "int16";
109 case DT_INT8:
110 return "int8";
111 case DT_STRING:
112 return "string";
113 case DT_COMPLEX64:
114 return "complex64";
115 case DT_COMPLEX128:
116 return "complex128";
117 case DT_INT64:
118 return "int64";
119 case DT_UINT64:
120 return "uint64";
121 case DT_BOOL:
122 return "bool";
123 case DT_QINT8:
124 return "qint8";
125 case DT_QUINT8:
126 return "quint8";
127 case DT_QUINT16:
128 return "quint16";
129 case DT_QINT16:
130 return "qint16";
131 case DT_QINT32:
132 return "qint32";
133 case DT_BFLOAT16:
134 return "bfloat16";
135 case DT_HALF:
136 return "half";
137 case DT_RESOURCE:
138 return "resource";
139 case DT_VARIANT:
140 return "variant";
141 default:
142 LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
143 return strings::StrCat("unknown dtype enum (", dtype, ")");
144 }
145}
146} // end namespace
147
148string DataTypeString(DataType dtype) {
149 if (IsRefType(dtype)) {
150 DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
151 return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
152 }
153 return DataTypeStringInternal(dtype);
154}
155
156bool DataTypeFromString(StringPiece sp, DataType* dt) {
157 if (str_util::EndsWith(sp, "_ref")) {
158 sp.remove_suffix(4);
159 DataType non_ref;
160 if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
161 *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
162 return true;
163 } else {
164 return false;
165 }
166 }
167
168 if (sp == "float" || sp == "float32") {
169 *dt = DT_FLOAT;
170 return true;
171 } else if (sp == "double" || sp == "float64") {
172 *dt = DT_DOUBLE;
173 return true;
174 } else if (sp == "int32") {
175 *dt = DT_INT32;
176 return true;
177 } else if (sp == "uint32") {
178 *dt = DT_UINT32;
179 return true;
180 } else if (sp == "uint8") {
181 *dt = DT_UINT8;
182 return true;
183 } else if (sp == "uint16") {
184 *dt = DT_UINT16;
185 return true;
186 } else if (sp == "int16") {
187 *dt = DT_INT16;
188 return true;
189 } else if (sp == "int8") {
190 *dt = DT_INT8;
191 return true;
192 } else if (sp == "string") {
193 *dt = DT_STRING;
194 return true;
195 } else if (sp == "complex64") {
196 *dt = DT_COMPLEX64;
197 return true;
198 } else if (sp == "complex128") {
199 *dt = DT_COMPLEX128;
200 return true;
201 } else if (sp == "int64") {
202 *dt = DT_INT64;
203 return true;
204 } else if (sp == "uint64") {
205 *dt = DT_UINT64;
206 return true;
207 } else if (sp == "bool") {
208 *dt = DT_BOOL;
209 return true;
210 } else if (sp == "qint8") {
211 *dt = DT_QINT8;
212 return true;
213 } else if (sp == "quint8") {
214 *dt = DT_QUINT8;
215 return true;
216 } else if (sp == "qint16") {
217 *dt = DT_QINT16;
218 return true;
219 } else if (sp == "quint16") {
220 *dt = DT_QUINT16;
221 return true;
222 } else if (sp == "qint32") {
223 *dt = DT_QINT32;
224 return true;
225 } else if (sp == "bfloat16") {
226 *dt = DT_BFLOAT16;
227 return true;
228 } else if (sp == "half" || sp == "float16") {
229 *dt = DT_HALF;
230 return true;
231 } else if (sp == "resource") {
232 *dt = DT_RESOURCE;
233 return true;
234 } else if (sp == "variant") {
235 *dt = DT_VARIANT;
236 return true;
237 }
238 return false;
239}
240
241string DeviceTypeString(const DeviceType& device_type) {
242 return device_type.type();
243}
244
245string DataTypeSliceString(const DataTypeSlice types) {
246 string out;
247 for (auto it = types.begin(); it != types.end(); ++it) {
248 strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
249 DataTypeString(*it));
250 }
251 return out;
252}
253
254bool DataTypeAlwaysOnHost(DataType dt) {
255 // Includes DT_STRING and DT_RESOURCE.
256 switch (dt) {
257 case DT_STRING:
258 case DT_STRING_REF:
259 case DT_RESOURCE:
260 return true;
261 default:
262 return false;
263 }
264}
265
266int DataTypeSize(DataType dt) {
267#define CASE(T) \
268 case DataTypeToEnum<T>::value: \
269 return sizeof(T);
270 switch (dt) {
271 TF_CALL_POD_TYPES(CASE);
272 TF_CALL_QUANTIZED_TYPES(CASE);
273 // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
274 // they are not supported widely, but are explicitly listed here for
275 // bitcast.
276 TF_CALL_qint16(CASE);
277 TF_CALL_quint16(CASE);
278
279 default:
280 return 0;
281 }
282#undef CASE
283}
284
285// Define DataTypeToEnum<T>::value.
286#define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
287 constexpr DataType DataTypeToEnum<TYPE>::value;
288
289DEFINE_DATATYPETOENUM_VALUE(float);
290DEFINE_DATATYPETOENUM_VALUE(double);
291DEFINE_DATATYPETOENUM_VALUE(int32);
292DEFINE_DATATYPETOENUM_VALUE(uint32);
293DEFINE_DATATYPETOENUM_VALUE(uint16);
294DEFINE_DATATYPETOENUM_VALUE(uint8);
295DEFINE_DATATYPETOENUM_VALUE(int16);
296DEFINE_DATATYPETOENUM_VALUE(int8);
297DEFINE_DATATYPETOENUM_VALUE(tstring);
298DEFINE_DATATYPETOENUM_VALUE(complex64);
299DEFINE_DATATYPETOENUM_VALUE(complex128);
300DEFINE_DATATYPETOENUM_VALUE(int64_t);
301DEFINE_DATATYPETOENUM_VALUE(uint64);
302DEFINE_DATATYPETOENUM_VALUE(bool);
303DEFINE_DATATYPETOENUM_VALUE(qint8);
304DEFINE_DATATYPETOENUM_VALUE(quint8);
305DEFINE_DATATYPETOENUM_VALUE(qint16);
306DEFINE_DATATYPETOENUM_VALUE(quint16);
307DEFINE_DATATYPETOENUM_VALUE(qint32);
308DEFINE_DATATYPETOENUM_VALUE(bfloat16);
309DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
310DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
311DEFINE_DATATYPETOENUM_VALUE(Variant);
312#undef DEFINE_DATATYPETOENUM_VALUE
313
314} // namespace tensorflow
315