1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/types.h" |
17 | #include "tensorflow/core/framework/register_types.h" |
18 | #include "tensorflow/core/lib/strings/str_util.h" |
19 | #include "tensorflow/core/lib/strings/strcat.h" |
20 | #include "tensorflow/core/platform/logging.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | struct DataTypeHasher { |
25 | std::size_t operator()(const DataType& k) const { |
26 | return std::hash<int>()(static_cast<int>(k)); |
27 | } |
28 | }; |
29 | |
30 | // Mapping from some of the DType fields, for backward compatibility. All other |
31 | // dtypes are mapped to TFT_ANY, but can be added here if a counterpart is |
32 | // defined. |
33 | auto* DT_TO_FT = new std::unordered_map<DataType, FullTypeId, DataTypeHasher>({ |
34 | {DT_FLOAT, TFT_FLOAT}, |
35 | {DT_DOUBLE, TFT_DOUBLE}, |
36 | {DT_INT32, TFT_INT32}, |
37 | {DT_UINT8, TFT_UINT8}, |
38 | {DT_INT16, TFT_INT16}, |
39 | {DT_INT8, TFT_INT8}, |
40 | {DT_STRING, TFT_STRING}, |
41 | {DT_COMPLEX64, TFT_COMPLEX64}, |
42 | {DT_INT64, TFT_INT64}, |
43 | {DT_BOOL, TFT_BOOL}, |
44 | {DT_UINT16, TFT_UINT16}, |
45 | {DT_COMPLEX128, TFT_COMPLEX128}, |
46 | {DT_HALF, TFT_HALF}, |
47 | {DT_UINT32, TFT_UINT32}, |
48 | {DT_UINT64, TFT_UINT64}, |
49 | {DT_VARIANT, TFT_LEGACY_VARIANT}, |
50 | }); |
51 | |
52 | void map_dtype_to_tensor(const DataType& dtype, FullTypeDef& t) { |
53 | t.Clear(); |
54 | |
55 | const auto& mapped = DT_TO_FT->find(dtype); |
56 | // Only map known types, everything else remains unset. This is so that we |
57 | // only set the most specific type when it is fully known. For example, if the |
58 | // dtype is DT_VARIANT, then we don't know much and opt to assume that |
59 | // the type is unset, rather than TFT_ANY. |
60 | if (mapped != DT_TO_FT->end()) { |
61 | t.set_type_id(mapped->second); |
62 | } |
63 | } |
64 | |
65 | bool DeviceType::operator<(const DeviceType& other) const { |
66 | return type_ < other.type_; |
67 | } |
68 | |
69 | bool DeviceType::operator==(const DeviceType& other) const { |
70 | return type_ == other.type_; |
71 | } |
72 | |
73 | std::ostream& operator<<(std::ostream& os, const DeviceType& d) { |
74 | os << d.type(); |
75 | return os; |
76 | } |
77 | |
78 | const char* const DEVICE_DEFAULT = "DEFAULT" ; |
79 | const char* const DEVICE_CPU = "CPU" ; |
80 | const char* const DEVICE_GPU = "GPU" ; |
81 | const char* const DEVICE_TPU = "TPU" ; |
82 | const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM" ; |
83 | |
84 | const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU; |
85 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
86 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
87 | const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU; |
88 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
89 | |
90 | namespace { |
91 | string DataTypeStringInternal(DataType dtype) { |
92 | switch (dtype) { |
93 | case DT_INVALID: |
94 | return "INVALID" ; |
95 | case DT_FLOAT: |
96 | return "float" ; |
97 | case DT_DOUBLE: |
98 | return "double" ; |
99 | case DT_INT32: |
100 | return "int32" ; |
101 | case DT_UINT32: |
102 | return "uint32" ; |
103 | case DT_UINT8: |
104 | return "uint8" ; |
105 | case DT_UINT16: |
106 | return "uint16" ; |
107 | case DT_INT16: |
108 | return "int16" ; |
109 | case DT_INT8: |
110 | return "int8" ; |
111 | case DT_STRING: |
112 | return "string" ; |
113 | case DT_COMPLEX64: |
114 | return "complex64" ; |
115 | case DT_COMPLEX128: |
116 | return "complex128" ; |
117 | case DT_INT64: |
118 | return "int64" ; |
119 | case DT_UINT64: |
120 | return "uint64" ; |
121 | case DT_BOOL: |
122 | return "bool" ; |
123 | case DT_QINT8: |
124 | return "qint8" ; |
125 | case DT_QUINT8: |
126 | return "quint8" ; |
127 | case DT_QUINT16: |
128 | return "quint16" ; |
129 | case DT_QINT16: |
130 | return "qint16" ; |
131 | case DT_QINT32: |
132 | return "qint32" ; |
133 | case DT_BFLOAT16: |
134 | return "bfloat16" ; |
135 | case DT_HALF: |
136 | return "half" ; |
137 | case DT_RESOURCE: |
138 | return "resource" ; |
139 | case DT_VARIANT: |
140 | return "variant" ; |
141 | default: |
142 | LOG(ERROR) << "Unrecognized DataType enum value " << dtype; |
143 | return strings::StrCat("unknown dtype enum (" , dtype, ")" ); |
144 | } |
145 | } |
146 | } // end namespace |
147 | |
148 | string DataTypeString(DataType dtype) { |
149 | if (IsRefType(dtype)) { |
150 | DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset); |
151 | return strings::StrCat(DataTypeStringInternal(non_ref), "_ref" ); |
152 | } |
153 | return DataTypeStringInternal(dtype); |
154 | } |
155 | |
156 | bool DataTypeFromString(StringPiece sp, DataType* dt) { |
157 | if (str_util::EndsWith(sp, "_ref" )) { |
158 | sp.remove_suffix(4); |
159 | DataType non_ref; |
160 | if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) { |
161 | *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset); |
162 | return true; |
163 | } else { |
164 | return false; |
165 | } |
166 | } |
167 | |
168 | if (sp == "float" || sp == "float32" ) { |
169 | *dt = DT_FLOAT; |
170 | return true; |
171 | } else if (sp == "double" || sp == "float64" ) { |
172 | *dt = DT_DOUBLE; |
173 | return true; |
174 | } else if (sp == "int32" ) { |
175 | *dt = DT_INT32; |
176 | return true; |
177 | } else if (sp == "uint32" ) { |
178 | *dt = DT_UINT32; |
179 | return true; |
180 | } else if (sp == "uint8" ) { |
181 | *dt = DT_UINT8; |
182 | return true; |
183 | } else if (sp == "uint16" ) { |
184 | *dt = DT_UINT16; |
185 | return true; |
186 | } else if (sp == "int16" ) { |
187 | *dt = DT_INT16; |
188 | return true; |
189 | } else if (sp == "int8" ) { |
190 | *dt = DT_INT8; |
191 | return true; |
192 | } else if (sp == "string" ) { |
193 | *dt = DT_STRING; |
194 | return true; |
195 | } else if (sp == "complex64" ) { |
196 | *dt = DT_COMPLEX64; |
197 | return true; |
198 | } else if (sp == "complex128" ) { |
199 | *dt = DT_COMPLEX128; |
200 | return true; |
201 | } else if (sp == "int64" ) { |
202 | *dt = DT_INT64; |
203 | return true; |
204 | } else if (sp == "uint64" ) { |
205 | *dt = DT_UINT64; |
206 | return true; |
207 | } else if (sp == "bool" ) { |
208 | *dt = DT_BOOL; |
209 | return true; |
210 | } else if (sp == "qint8" ) { |
211 | *dt = DT_QINT8; |
212 | return true; |
213 | } else if (sp == "quint8" ) { |
214 | *dt = DT_QUINT8; |
215 | return true; |
216 | } else if (sp == "qint16" ) { |
217 | *dt = DT_QINT16; |
218 | return true; |
219 | } else if (sp == "quint16" ) { |
220 | *dt = DT_QUINT16; |
221 | return true; |
222 | } else if (sp == "qint32" ) { |
223 | *dt = DT_QINT32; |
224 | return true; |
225 | } else if (sp == "bfloat16" ) { |
226 | *dt = DT_BFLOAT16; |
227 | return true; |
228 | } else if (sp == "half" || sp == "float16" ) { |
229 | *dt = DT_HALF; |
230 | return true; |
231 | } else if (sp == "resource" ) { |
232 | *dt = DT_RESOURCE; |
233 | return true; |
234 | } else if (sp == "variant" ) { |
235 | *dt = DT_VARIANT; |
236 | return true; |
237 | } |
238 | return false; |
239 | } |
240 | |
241 | string DeviceTypeString(const DeviceType& device_type) { |
242 | return device_type.type(); |
243 | } |
244 | |
245 | string DataTypeSliceString(const DataTypeSlice types) { |
246 | string out; |
247 | for (auto it = types.begin(); it != types.end(); ++it) { |
248 | strings::StrAppend(&out, ((it == types.begin()) ? "" : ", " ), |
249 | DataTypeString(*it)); |
250 | } |
251 | return out; |
252 | } |
253 | |
254 | bool DataTypeAlwaysOnHost(DataType dt) { |
255 | // Includes DT_STRING and DT_RESOURCE. |
256 | switch (dt) { |
257 | case DT_STRING: |
258 | case DT_STRING_REF: |
259 | case DT_RESOURCE: |
260 | return true; |
261 | default: |
262 | return false; |
263 | } |
264 | } |
265 | |
266 | int DataTypeSize(DataType dt) { |
267 | #define CASE(T) \ |
268 | case DataTypeToEnum<T>::value: \ |
269 | return sizeof(T); |
270 | switch (dt) { |
271 | TF_CALL_POD_TYPES(CASE); |
272 | TF_CALL_QUANTIZED_TYPES(CASE); |
273 | // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since |
274 | // they are not supported widely, but are explicitly listed here for |
275 | // bitcast. |
276 | TF_CALL_qint16(CASE); |
277 | TF_CALL_quint16(CASE); |
278 | |
279 | default: |
280 | return 0; |
281 | } |
282 | #undef CASE |
283 | } |
284 | |
285 | // Define DataTypeToEnum<T>::value. |
286 | #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \ |
287 | constexpr DataType DataTypeToEnum<TYPE>::value; |
288 | |
289 | DEFINE_DATATYPETOENUM_VALUE(float); |
290 | DEFINE_DATATYPETOENUM_VALUE(double); |
291 | DEFINE_DATATYPETOENUM_VALUE(int32); |
292 | DEFINE_DATATYPETOENUM_VALUE(uint32); |
293 | DEFINE_DATATYPETOENUM_VALUE(uint16); |
294 | DEFINE_DATATYPETOENUM_VALUE(uint8); |
295 | DEFINE_DATATYPETOENUM_VALUE(int16); |
296 | DEFINE_DATATYPETOENUM_VALUE(int8); |
297 | DEFINE_DATATYPETOENUM_VALUE(tstring); |
298 | DEFINE_DATATYPETOENUM_VALUE(complex64); |
299 | DEFINE_DATATYPETOENUM_VALUE(complex128); |
300 | DEFINE_DATATYPETOENUM_VALUE(int64_t); |
301 | DEFINE_DATATYPETOENUM_VALUE(uint64); |
302 | DEFINE_DATATYPETOENUM_VALUE(bool); |
303 | DEFINE_DATATYPETOENUM_VALUE(qint8); |
304 | DEFINE_DATATYPETOENUM_VALUE(quint8); |
305 | DEFINE_DATATYPETOENUM_VALUE(qint16); |
306 | DEFINE_DATATYPETOENUM_VALUE(quint16); |
307 | DEFINE_DATATYPETOENUM_VALUE(qint32); |
308 | DEFINE_DATATYPETOENUM_VALUE(bfloat16); |
309 | DEFINE_DATATYPETOENUM_VALUE(Eigen::half); |
310 | DEFINE_DATATYPETOENUM_VALUE(ResourceHandle); |
311 | DEFINE_DATATYPETOENUM_VALUE(Variant); |
312 | #undef DEFINE_DATATYPETOENUM_VALUE |
313 | |
314 | } // namespace tensorflow |
315 | |