1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include <cctype> |
6 | #include <iostream> |
7 | #include <iterator> |
8 | #include <sstream> |
9 | |
10 | #include "data_type_utils.h" |
11 | |
12 | namespace ONNX_NAMESPACE { |
13 | namespace Utils { |
14 | |
15 | // Singleton wrapper around allowed data types. |
16 | // This implements construct on first use which is needed to ensure |
17 | // static objects are initialized before use. Ops registration does not work |
18 | // properly without this. |
19 | class TypesWrapper final { |
20 | public: |
21 | static TypesWrapper& GetTypesWrapper(); |
22 | |
23 | std::unordered_set<std::string>& GetAllowedDataTypes(); |
24 | |
25 | std::unordered_map<std::string, int32_t>& TypeStrToTensorDataType(); |
26 | |
27 | std::unordered_map<int32_t, std::string>& TensorDataTypeToTypeStr(); |
28 | |
29 | ~TypesWrapper() = default; |
30 | TypesWrapper(const TypesWrapper&) = delete; |
31 | void operator=(const TypesWrapper&) = delete; |
32 | |
33 | private: |
34 | TypesWrapper(); |
35 | |
36 | std::unordered_map<std::string, int> type_str_to_tensor_data_type_; |
37 | std::unordered_map<int, std::string> tensor_data_type_to_type_str_; |
38 | std::unordered_set<std::string> allowed_data_types_; |
39 | }; |
40 | |
41 | // Simple class which contains pointers to external string buffer and a size. |
42 | // This can be used to track a "valid" range/slice of the string. |
43 | // Caller should ensure StringRange is not used after external storage has |
44 | // been freed. |
45 | class StringRange final { |
46 | public: |
47 | StringRange(); |
48 | StringRange(const char* data, size_t size); |
49 | StringRange(const std::string& str); |
50 | StringRange(const char* data); |
51 | const char* Data() const; |
52 | size_t Size() const; |
53 | bool Empty() const; |
54 | char operator[](size_t idx) const; |
55 | void Reset(); |
56 | void Reset(const char* data, size_t size); |
57 | void Reset(const std::string& str); |
58 | bool StartsWith(const StringRange& str) const; |
59 | bool EndsWith(const StringRange& str) const; |
60 | bool LStrip(); |
61 | bool LStrip(size_t size); |
62 | bool LStrip(StringRange str); |
63 | bool RStrip(); |
64 | bool RStrip(size_t size); |
65 | bool RStrip(StringRange str); |
66 | bool LAndRStrip(); |
67 | void ParensWhitespaceStrip(); |
68 | size_t Find(const char ch) const; |
69 | |
70 | // These methods provide a way to return the range of the string |
71 | // which was discarded by LStrip(). i.e. We capture the string |
72 | // range which was discarded. |
73 | StringRange GetCaptured(); |
74 | void RestartCapture(); |
75 | |
76 | private: |
77 | // data_ + size tracks the "valid" range of the external string buffer. |
78 | const char* data_; |
79 | size_t size_; |
80 | |
81 | // start_ and end_ track the captured range. |
82 | // end_ advances when LStrip() is called. |
83 | const char* start_; |
84 | const char* end_; |
85 | }; |
86 | |
87 | std::unordered_map<std::string, TypeProto>& DataTypeUtils::GetTypeStrToProtoMap() { |
88 | static std::unordered_map<std::string, TypeProto> map; |
89 | return map; |
90 | } |
91 | |
92 | std::mutex& DataTypeUtils::GetTypeStrLock() { |
93 | static std::mutex lock; |
94 | return lock; |
95 | } |
96 | |
97 | DataType DataTypeUtils::ToType(const TypeProto& type_proto) { |
98 | auto typeStr = ToString(type_proto); |
99 | std::lock_guard<std::mutex> lock(GetTypeStrLock()); |
100 | if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end()) { |
101 | TypeProto type; |
102 | FromString(typeStr, type); |
103 | GetTypeStrToProtoMap()[typeStr] = type; |
104 | } |
105 | return &(GetTypeStrToProtoMap().find(typeStr)->first); |
106 | } |
107 | |
108 | DataType DataTypeUtils::ToType(const std::string& type_str) { |
109 | TypeProto type; |
110 | FromString(type_str, type); |
111 | return ToType(type); |
112 | } |
113 | |
114 | const TypeProto& DataTypeUtils::ToTypeProto(const DataType& data_type) { |
115 | std::lock_guard<std::mutex> lock(GetTypeStrLock()); |
116 | auto it = GetTypeStrToProtoMap().find(*data_type); |
117 | if (GetTypeStrToProtoMap().end() == it) { |
118 | ONNX_THROW_EX(std::invalid_argument("Invalid data type " + *data_type)); |
119 | } |
120 | return it->second; |
121 | } |
122 | |
123 | std::string DataTypeUtils::ToString(const TypeProto& type_proto, const std::string& left, const std::string& right) { |
124 | switch (type_proto.value_case()) { |
125 | case TypeProto::ValueCase::kTensorType: { |
126 | // Note: We do not distinguish tensors with zero rank (a shape consisting |
127 | // of an empty sequence of dimensions) here. |
128 | return left + "tensor(" + ToDataTypeString(type_proto.tensor_type().elem_type()) + ")" + right; |
129 | } |
130 | case TypeProto::ValueCase::kSequenceType: { |
131 | return ToString(type_proto.sequence_type().elem_type(), left + "seq(" , ")" + right); |
132 | } |
133 | case TypeProto::ValueCase::kOptionalType: { |
134 | return ToString(type_proto.optional_type().elem_type(), left + "optional(" , ")" + right); |
135 | } |
136 | case TypeProto::ValueCase::kMapType: { |
137 | std::string map_str = "map(" + ToDataTypeString(type_proto.map_type().key_type()) + "," ; |
138 | return ToString(type_proto.map_type().value_type(), left + map_str, ")" + right); |
139 | } |
140 | #ifdef ONNX_ML |
141 | case TypeProto::ValueCase::kOpaqueType: { |
142 | static const std::string empty; |
143 | std::string result; |
144 | const auto& op_type = type_proto.opaque_type(); |
145 | result.append(left).append("opaque(" ); |
146 | if (op_type.has_domain() && !op_type.domain().empty()) { |
147 | result.append(op_type.domain()).append("," ); |
148 | } |
149 | if (op_type.has_name() && !op_type.name().empty()) { |
150 | result.append(op_type.name()); |
151 | } |
152 | result.append(")" ).append(right); |
153 | return result; |
154 | } |
155 | #endif |
156 | case TypeProto::ValueCase::kSparseTensorType: { |
157 | // Note: We do not distinguish tensors with zero rank (a shape consisting |
158 | // of an empty sequence of dimensions) here. |
159 | return left + "sparse_tensor(" + ToDataTypeString(type_proto.sparse_tensor_type().elem_type()) + ")" + right; |
160 | } |
161 | default: |
162 | ONNX_THROW_EX(std::invalid_argument("Unsuported type proto value case." )); |
163 | } |
164 | } |
165 | |
166 | std::string DataTypeUtils::ToDataTypeString(int32_t tensor_data_type) { |
167 | TypesWrapper& t = TypesWrapper::GetTypesWrapper(); |
168 | auto iter = t.TensorDataTypeToTypeStr().find(tensor_data_type); |
169 | if (t.TensorDataTypeToTypeStr().end() == iter) { |
170 | ONNX_THROW_EX(std::invalid_argument("Invalid tensor data type " + std::to_string(tensor_data_type) + "." )); |
171 | } |
172 | return iter->second; |
173 | } |
174 | |
175 | void DataTypeUtils::FromString(const std::string& type_str, TypeProto& type_proto) { |
176 | StringRange s(type_str); |
177 | type_proto.Clear(); |
178 | if (s.LStrip("seq" )) { |
179 | s.ParensWhitespaceStrip(); |
180 | return FromString(std::string(s.Data(), s.Size()), *type_proto.mutable_sequence_type()->mutable_elem_type()); |
181 | } else if (s.LStrip("optional" )) { |
182 | s.ParensWhitespaceStrip(); |
183 | return FromString(std::string(s.Data(), s.Size()), *type_proto.mutable_optional_type()->mutable_elem_type()); |
184 | } else if (s.LStrip("map" )) { |
185 | s.ParensWhitespaceStrip(); |
186 | size_t key_size = s.Find(','); |
187 | StringRange k(s.Data(), key_size); |
188 | std::string key(k.Data(), k.Size()); |
189 | s.LStrip(key_size); |
190 | s.LStrip("," ); |
191 | StringRange v(s.Data(), s.Size()); |
192 | int32_t key_type; |
193 | FromDataTypeString(key, key_type); |
194 | type_proto.mutable_map_type()->set_key_type(key_type); |
195 | return FromString(std::string(v.Data(), v.Size()), *type_proto.mutable_map_type()->mutable_value_type()); |
196 | } else |
197 | #ifdef ONNX_ML |
198 | if (s.LStrip("opaque" )) { |
199 | auto* opaque_type = type_proto.mutable_opaque_type(); |
200 | s.ParensWhitespaceStrip(); |
201 | if (!s.Empty()) { |
202 | size_t cm = s.Find(','); |
203 | if (cm != std::string::npos) { |
204 | if (cm > 0) { |
205 | opaque_type->mutable_domain()->assign(s.Data(), cm); |
206 | } |
207 | s.LStrip(cm + 1); // skip comma |
208 | } |
209 | if (!s.Empty()) { |
210 | opaque_type->mutable_name()->assign(s.Data(), s.Size()); |
211 | } |
212 | } |
213 | } else |
214 | #endif |
215 | if (s.LStrip("sparse_tensor" )) { |
216 | s.ParensWhitespaceStrip(); |
217 | int32_t e; |
218 | FromDataTypeString(std::string(s.Data(), s.Size()), e); |
219 | type_proto.mutable_sparse_tensor_type()->set_elem_type(e); |
220 | } else if (s.LStrip("tensor" )) { |
221 | s.ParensWhitespaceStrip(); |
222 | int32_t e; |
223 | FromDataTypeString(std::string(s.Data(), s.Size()), e); |
224 | type_proto.mutable_tensor_type()->set_elem_type(e); |
225 | } else { |
226 | // Scalar |
227 | int32_t e; |
228 | FromDataTypeString(std::string(s.Data(), s.Size()), e); |
229 | TypeProto::Tensor* t = type_proto.mutable_tensor_type(); |
230 | t->set_elem_type(e); |
231 | // Call mutable_shape() to initialize a shape with no dimension. |
232 | t->mutable_shape(); |
233 | } |
234 | } // namespace Utils |
235 | |
236 | bool DataTypeUtils::IsValidDataTypeString(const std::string& type_str) { |
237 | TypesWrapper& t = TypesWrapper::GetTypesWrapper(); |
238 | const auto& allowedSet = t.GetAllowedDataTypes(); |
239 | return (allowedSet.find(type_str) != allowedSet.end()); |
240 | } |
241 | |
242 | void DataTypeUtils::FromDataTypeString(const std::string& type_str, int32_t& tensor_data_type) { |
243 | if (!IsValidDataTypeString(type_str)) { |
244 | ONNX_THROW_EX( |
245 | std::invalid_argument("DataTypeUtils::FromDataTypeString - Received invalid data type string " + type_str)); |
246 | } |
247 | |
248 | TypesWrapper& t = TypesWrapper::GetTypesWrapper(); |
249 | tensor_data_type = t.TypeStrToTensorDataType()[type_str]; |
250 | } |
251 | |
252 | StringRange::StringRange() : data_("" ), size_(0), start_(data_), end_(data_) {} |
253 | |
254 | StringRange::StringRange(const char* p_data, size_t p_size) : data_(p_data), size_(p_size), start_(data_), end_(data_) { |
255 | assert(p_data != nullptr); |
256 | LAndRStrip(); |
257 | } |
258 | |
259 | StringRange::StringRange(const std::string& p_str) |
260 | : data_(p_str.data()), size_(p_str.size()), start_(data_), end_(data_) { |
261 | LAndRStrip(); |
262 | } |
263 | |
264 | StringRange::StringRange(const char* p_data) : data_(p_data), size_(strlen(p_data)), start_(data_), end_(data_) { |
265 | LAndRStrip(); |
266 | } |
267 | |
268 | const char* StringRange::Data() const { |
269 | return data_; |
270 | } |
271 | |
272 | size_t StringRange::Size() const { |
273 | return size_; |
274 | } |
275 | |
276 | bool StringRange::Empty() const { |
277 | return size_ == 0; |
278 | } |
279 | |
280 | char StringRange::operator[](size_t idx) const { |
281 | return data_[idx]; |
282 | } |
283 | |
284 | void StringRange::Reset() { |
285 | data_ = "" ; |
286 | size_ = 0; |
287 | start_ = end_ = data_; |
288 | } |
289 | |
290 | void StringRange::Reset(const char* data, size_t size) { |
291 | data_ = data; |
292 | size_ = size; |
293 | start_ = end_ = data_; |
294 | } |
295 | |
296 | void StringRange::Reset(const std::string& str) { |
297 | data_ = str.data(); |
298 | size_ = str.size(); |
299 | start_ = end_ = data_; |
300 | } |
301 | |
302 | bool StringRange::StartsWith(const StringRange& str) const { |
303 | return ((size_ >= str.size_) && (memcmp(data_, str.data_, str.size_) == 0)); |
304 | } |
305 | |
306 | bool StringRange::EndsWith(const StringRange& str) const { |
307 | return ((size_ >= str.size_) && (memcmp(data_ + (size_ - str.size_), str.data_, str.size_) == 0)); |
308 | } |
309 | |
310 | bool StringRange::LStrip() { |
311 | size_t count = 0; |
312 | const char* ptr = data_; |
313 | while (count < size_ && isspace(*ptr)) { |
314 | count++; |
315 | ptr++; |
316 | } |
317 | |
318 | if (count > 0) { |
319 | return LStrip(count); |
320 | } |
321 | return false; |
322 | } |
323 | |
324 | bool StringRange::LStrip(size_t size) { |
325 | if (size <= size_) { |
326 | data_ += size; |
327 | size_ -= size; |
328 | end_ += size; |
329 | return true; |
330 | } |
331 | return false; |
332 | } |
333 | |
334 | bool StringRange::LStrip(StringRange str) { |
335 | if (StartsWith(str)) { |
336 | return LStrip(str.size_); |
337 | } |
338 | return false; |
339 | } |
340 | |
341 | bool StringRange::RStrip() { |
342 | size_t count = 0; |
343 | const char* ptr = data_ + size_ - 1; |
344 | while (count < size_ && isspace(*ptr)) { |
345 | ++count; |
346 | --ptr; |
347 | } |
348 | |
349 | if (count > 0) { |
350 | return RStrip(count); |
351 | } |
352 | return false; |
353 | } |
354 | |
355 | bool StringRange::RStrip(size_t size) { |
356 | if (size_ >= size) { |
357 | size_ -= size; |
358 | return true; |
359 | } |
360 | return false; |
361 | } |
362 | |
363 | bool StringRange::RStrip(StringRange str) { |
364 | if (EndsWith(str)) { |
365 | return RStrip(str.size_); |
366 | } |
367 | return false; |
368 | } |
369 | |
370 | bool StringRange::LAndRStrip() { |
371 | bool l = LStrip(); |
372 | bool r = RStrip(); |
373 | return l || r; |
374 | } |
375 | |
376 | void StringRange::ParensWhitespaceStrip() { |
377 | LStrip(); |
378 | LStrip("(" ); |
379 | LAndRStrip(); |
380 | RStrip(")" ); |
381 | RStrip(); |
382 | } |
383 | |
384 | size_t StringRange::Find(const char ch) const { |
385 | size_t idx = 0; |
386 | while (idx < size_) { |
387 | if (data_[idx] == ch) { |
388 | return idx; |
389 | } |
390 | idx++; |
391 | } |
392 | return std::string::npos; |
393 | } |
394 | |
395 | void StringRange::RestartCapture() { |
396 | start_ = data_; |
397 | end_ = data_; |
398 | } |
399 | |
400 | StringRange StringRange::GetCaptured() { |
401 | return StringRange(start_, end_ - start_); |
402 | } |
403 | |
404 | TypesWrapper& TypesWrapper::GetTypesWrapper() { |
405 | static TypesWrapper types; |
406 | return types; |
407 | } |
408 | |
409 | std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes() { |
410 | return allowed_data_types_; |
411 | } |
412 | |
413 | std::unordered_map<std::string, int>& TypesWrapper::TypeStrToTensorDataType() { |
414 | return type_str_to_tensor_data_type_; |
415 | } |
416 | |
417 | std::unordered_map<int, std::string>& TypesWrapper::TensorDataTypeToTypeStr() { |
418 | return tensor_data_type_to_type_str_; |
419 | } |
420 | |
421 | TypesWrapper::TypesWrapper() { |
422 | // DataType strings. These should match the DataTypes defined in onnx.proto |
423 | type_str_to_tensor_data_type_["float" ] = TensorProto_DataType_FLOAT; |
424 | type_str_to_tensor_data_type_["float16" ] = TensorProto_DataType_FLOAT16; |
425 | type_str_to_tensor_data_type_["bfloat16" ] = TensorProto_DataType_BFLOAT16; |
426 | type_str_to_tensor_data_type_["double" ] = TensorProto_DataType_DOUBLE; |
427 | type_str_to_tensor_data_type_["int8" ] = TensorProto_DataType_INT8; |
428 | type_str_to_tensor_data_type_["int16" ] = TensorProto_DataType_INT16; |
429 | type_str_to_tensor_data_type_["int32" ] = TensorProto_DataType_INT32; |
430 | type_str_to_tensor_data_type_["int64" ] = TensorProto_DataType_INT64; |
431 | type_str_to_tensor_data_type_["uint8" ] = TensorProto_DataType_UINT8; |
432 | type_str_to_tensor_data_type_["uint16" ] = TensorProto_DataType_UINT16; |
433 | type_str_to_tensor_data_type_["uint32" ] = TensorProto_DataType_UINT32; |
434 | type_str_to_tensor_data_type_["uint64" ] = TensorProto_DataType_UINT64; |
435 | type_str_to_tensor_data_type_["complex64" ] = TensorProto_DataType_COMPLEX64; |
436 | type_str_to_tensor_data_type_["complex128" ] = TensorProto_DataType_COMPLEX128; |
437 | type_str_to_tensor_data_type_["string" ] = TensorProto_DataType_STRING; |
438 | type_str_to_tensor_data_type_["bool" ] = TensorProto_DataType_BOOL; |
439 | |
440 | for (auto& str_type_pair : type_str_to_tensor_data_type_) { |
441 | tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first; |
442 | allowed_data_types_.insert(str_type_pair.first); |
443 | } |
444 | } |
445 | } // namespace Utils |
446 | } // namespace ONNX_NAMESPACE |
447 | |