1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include <cctype>
6#include <iostream>
7#include <iterator>
8#include <sstream>
9
10#include "data_type_utils.h"
11
12namespace ONNX_NAMESPACE {
13namespace Utils {
14
15// Singleton wrapper around allowed data types.
16// This implements construct on first use which is needed to ensure
17// static objects are initialized before use. Ops registration does not work
18// properly without this.
19class TypesWrapper final {
20 public:
21 static TypesWrapper& GetTypesWrapper();
22
23 std::unordered_set<std::string>& GetAllowedDataTypes();
24
25 std::unordered_map<std::string, int32_t>& TypeStrToTensorDataType();
26
27 std::unordered_map<int32_t, std::string>& TensorDataTypeToTypeStr();
28
29 ~TypesWrapper() = default;
30 TypesWrapper(const TypesWrapper&) = delete;
31 void operator=(const TypesWrapper&) = delete;
32
33 private:
34 TypesWrapper();
35
36 std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
37 std::unordered_map<int, std::string> tensor_data_type_to_type_str_;
38 std::unordered_set<std::string> allowed_data_types_;
39};
40
41// Simple class which contains pointers to external string buffer and a size.
42// This can be used to track a "valid" range/slice of the string.
43// Caller should ensure StringRange is not used after external storage has
44// been freed.
45class StringRange final {
46 public:
47 StringRange();
48 StringRange(const char* data, size_t size);
49 StringRange(const std::string& str);
50 StringRange(const char* data);
51 const char* Data() const;
52 size_t Size() const;
53 bool Empty() const;
54 char operator[](size_t idx) const;
55 void Reset();
56 void Reset(const char* data, size_t size);
57 void Reset(const std::string& str);
58 bool StartsWith(const StringRange& str) const;
59 bool EndsWith(const StringRange& str) const;
60 bool LStrip();
61 bool LStrip(size_t size);
62 bool LStrip(StringRange str);
63 bool RStrip();
64 bool RStrip(size_t size);
65 bool RStrip(StringRange str);
66 bool LAndRStrip();
67 void ParensWhitespaceStrip();
68 size_t Find(const char ch) const;
69
70 // These methods provide a way to return the range of the string
71 // which was discarded by LStrip(). i.e. We capture the string
72 // range which was discarded.
73 StringRange GetCaptured();
74 void RestartCapture();
75
76 private:
77 // data_ + size tracks the "valid" range of the external string buffer.
78 const char* data_;
79 size_t size_;
80
81 // start_ and end_ track the captured range.
82 // end_ advances when LStrip() is called.
83 const char* start_;
84 const char* end_;
85};
86
87std::unordered_map<std::string, TypeProto>& DataTypeUtils::GetTypeStrToProtoMap() {
88 static std::unordered_map<std::string, TypeProto> map;
89 return map;
90}
91
92std::mutex& DataTypeUtils::GetTypeStrLock() {
93 static std::mutex lock;
94 return lock;
95}
96
97DataType DataTypeUtils::ToType(const TypeProto& type_proto) {
98 auto typeStr = ToString(type_proto);
99 std::lock_guard<std::mutex> lock(GetTypeStrLock());
100 if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end()) {
101 TypeProto type;
102 FromString(typeStr, type);
103 GetTypeStrToProtoMap()[typeStr] = type;
104 }
105 return &(GetTypeStrToProtoMap().find(typeStr)->first);
106}
107
108DataType DataTypeUtils::ToType(const std::string& type_str) {
109 TypeProto type;
110 FromString(type_str, type);
111 return ToType(type);
112}
113
114const TypeProto& DataTypeUtils::ToTypeProto(const DataType& data_type) {
115 std::lock_guard<std::mutex> lock(GetTypeStrLock());
116 auto it = GetTypeStrToProtoMap().find(*data_type);
117 if (GetTypeStrToProtoMap().end() == it) {
118 ONNX_THROW_EX(std::invalid_argument("Invalid data type " + *data_type));
119 }
120 return it->second;
121}
122
123std::string DataTypeUtils::ToString(const TypeProto& type_proto, const std::string& left, const std::string& right) {
124 switch (type_proto.value_case()) {
125 case TypeProto::ValueCase::kTensorType: {
126 // Note: We do not distinguish tensors with zero rank (a shape consisting
127 // of an empty sequence of dimensions) here.
128 return left + "tensor(" + ToDataTypeString(type_proto.tensor_type().elem_type()) + ")" + right;
129 }
130 case TypeProto::ValueCase::kSequenceType: {
131 return ToString(type_proto.sequence_type().elem_type(), left + "seq(", ")" + right);
132 }
133 case TypeProto::ValueCase::kOptionalType: {
134 return ToString(type_proto.optional_type().elem_type(), left + "optional(", ")" + right);
135 }
136 case TypeProto::ValueCase::kMapType: {
137 std::string map_str = "map(" + ToDataTypeString(type_proto.map_type().key_type()) + ",";
138 return ToString(type_proto.map_type().value_type(), left + map_str, ")" + right);
139 }
140#ifdef ONNX_ML
141 case TypeProto::ValueCase::kOpaqueType: {
142 static const std::string empty;
143 std::string result;
144 const auto& op_type = type_proto.opaque_type();
145 result.append(left).append("opaque(");
146 if (op_type.has_domain() && !op_type.domain().empty()) {
147 result.append(op_type.domain()).append(",");
148 }
149 if (op_type.has_name() && !op_type.name().empty()) {
150 result.append(op_type.name());
151 }
152 result.append(")").append(right);
153 return result;
154 }
155#endif
156 case TypeProto::ValueCase::kSparseTensorType: {
157 // Note: We do not distinguish tensors with zero rank (a shape consisting
158 // of an empty sequence of dimensions) here.
159 return left + "sparse_tensor(" + ToDataTypeString(type_proto.sparse_tensor_type().elem_type()) + ")" + right;
160 }
161 default:
162 ONNX_THROW_EX(std::invalid_argument("Unsuported type proto value case."));
163 }
164}
165
166std::string DataTypeUtils::ToDataTypeString(int32_t tensor_data_type) {
167 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
168 auto iter = t.TensorDataTypeToTypeStr().find(tensor_data_type);
169 if (t.TensorDataTypeToTypeStr().end() == iter) {
170 ONNX_THROW_EX(std::invalid_argument("Invalid tensor data type " + std::to_string(tensor_data_type) + "."));
171 }
172 return iter->second;
173}
174
175void DataTypeUtils::FromString(const std::string& type_str, TypeProto& type_proto) {
176 StringRange s(type_str);
177 type_proto.Clear();
178 if (s.LStrip("seq")) {
179 s.ParensWhitespaceStrip();
180 return FromString(std::string(s.Data(), s.Size()), *type_proto.mutable_sequence_type()->mutable_elem_type());
181 } else if (s.LStrip("optional")) {
182 s.ParensWhitespaceStrip();
183 return FromString(std::string(s.Data(), s.Size()), *type_proto.mutable_optional_type()->mutable_elem_type());
184 } else if (s.LStrip("map")) {
185 s.ParensWhitespaceStrip();
186 size_t key_size = s.Find(',');
187 StringRange k(s.Data(), key_size);
188 std::string key(k.Data(), k.Size());
189 s.LStrip(key_size);
190 s.LStrip(",");
191 StringRange v(s.Data(), s.Size());
192 int32_t key_type;
193 FromDataTypeString(key, key_type);
194 type_proto.mutable_map_type()->set_key_type(key_type);
195 return FromString(std::string(v.Data(), v.Size()), *type_proto.mutable_map_type()->mutable_value_type());
196 } else
197#ifdef ONNX_ML
198 if (s.LStrip("opaque")) {
199 auto* opaque_type = type_proto.mutable_opaque_type();
200 s.ParensWhitespaceStrip();
201 if (!s.Empty()) {
202 size_t cm = s.Find(',');
203 if (cm != std::string::npos) {
204 if (cm > 0) {
205 opaque_type->mutable_domain()->assign(s.Data(), cm);
206 }
207 s.LStrip(cm + 1); // skip comma
208 }
209 if (!s.Empty()) {
210 opaque_type->mutable_name()->assign(s.Data(), s.Size());
211 }
212 }
213 } else
214#endif
215 if (s.LStrip("sparse_tensor")) {
216 s.ParensWhitespaceStrip();
217 int32_t e;
218 FromDataTypeString(std::string(s.Data(), s.Size()), e);
219 type_proto.mutable_sparse_tensor_type()->set_elem_type(e);
220 } else if (s.LStrip("tensor")) {
221 s.ParensWhitespaceStrip();
222 int32_t e;
223 FromDataTypeString(std::string(s.Data(), s.Size()), e);
224 type_proto.mutable_tensor_type()->set_elem_type(e);
225 } else {
226 // Scalar
227 int32_t e;
228 FromDataTypeString(std::string(s.Data(), s.Size()), e);
229 TypeProto::Tensor* t = type_proto.mutable_tensor_type();
230 t->set_elem_type(e);
231 // Call mutable_shape() to initialize a shape with no dimension.
232 t->mutable_shape();
233 }
234} // namespace Utils
235
236bool DataTypeUtils::IsValidDataTypeString(const std::string& type_str) {
237 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
238 const auto& allowedSet = t.GetAllowedDataTypes();
239 return (allowedSet.find(type_str) != allowedSet.end());
240}
241
242void DataTypeUtils::FromDataTypeString(const std::string& type_str, int32_t& tensor_data_type) {
243 if (!IsValidDataTypeString(type_str)) {
244 ONNX_THROW_EX(
245 std::invalid_argument("DataTypeUtils::FromDataTypeString - Received invalid data type string " + type_str));
246 }
247
248 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
249 tensor_data_type = t.TypeStrToTensorDataType()[type_str];
250}
251
252StringRange::StringRange() : data_(""), size_(0), start_(data_), end_(data_) {}
253
254StringRange::StringRange(const char* p_data, size_t p_size) : data_(p_data), size_(p_size), start_(data_), end_(data_) {
255 assert(p_data != nullptr);
256 LAndRStrip();
257}
258
259StringRange::StringRange(const std::string& p_str)
260 : data_(p_str.data()), size_(p_str.size()), start_(data_), end_(data_) {
261 LAndRStrip();
262}
263
264StringRange::StringRange(const char* p_data) : data_(p_data), size_(strlen(p_data)), start_(data_), end_(data_) {
265 LAndRStrip();
266}
267
268const char* StringRange::Data() const {
269 return data_;
270}
271
272size_t StringRange::Size() const {
273 return size_;
274}
275
276bool StringRange::Empty() const {
277 return size_ == 0;
278}
279
280char StringRange::operator[](size_t idx) const {
281 return data_[idx];
282}
283
284void StringRange::Reset() {
285 data_ = "";
286 size_ = 0;
287 start_ = end_ = data_;
288}
289
290void StringRange::Reset(const char* data, size_t size) {
291 data_ = data;
292 size_ = size;
293 start_ = end_ = data_;
294}
295
296void StringRange::Reset(const std::string& str) {
297 data_ = str.data();
298 size_ = str.size();
299 start_ = end_ = data_;
300}
301
302bool StringRange::StartsWith(const StringRange& str) const {
303 return ((size_ >= str.size_) && (memcmp(data_, str.data_, str.size_) == 0));
304}
305
306bool StringRange::EndsWith(const StringRange& str) const {
307 return ((size_ >= str.size_) && (memcmp(data_ + (size_ - str.size_), str.data_, str.size_) == 0));
308}
309
310bool StringRange::LStrip() {
311 size_t count = 0;
312 const char* ptr = data_;
313 while (count < size_ && isspace(*ptr)) {
314 count++;
315 ptr++;
316 }
317
318 if (count > 0) {
319 return LStrip(count);
320 }
321 return false;
322}
323
324bool StringRange::LStrip(size_t size) {
325 if (size <= size_) {
326 data_ += size;
327 size_ -= size;
328 end_ += size;
329 return true;
330 }
331 return false;
332}
333
334bool StringRange::LStrip(StringRange str) {
335 if (StartsWith(str)) {
336 return LStrip(str.size_);
337 }
338 return false;
339}
340
341bool StringRange::RStrip() {
342 size_t count = 0;
343 const char* ptr = data_ + size_ - 1;
344 while (count < size_ && isspace(*ptr)) {
345 ++count;
346 --ptr;
347 }
348
349 if (count > 0) {
350 return RStrip(count);
351 }
352 return false;
353}
354
355bool StringRange::RStrip(size_t size) {
356 if (size_ >= size) {
357 size_ -= size;
358 return true;
359 }
360 return false;
361}
362
363bool StringRange::RStrip(StringRange str) {
364 if (EndsWith(str)) {
365 return RStrip(str.size_);
366 }
367 return false;
368}
369
370bool StringRange::LAndRStrip() {
371 bool l = LStrip();
372 bool r = RStrip();
373 return l || r;
374}
375
376void StringRange::ParensWhitespaceStrip() {
377 LStrip();
378 LStrip("(");
379 LAndRStrip();
380 RStrip(")");
381 RStrip();
382}
383
384size_t StringRange::Find(const char ch) const {
385 size_t idx = 0;
386 while (idx < size_) {
387 if (data_[idx] == ch) {
388 return idx;
389 }
390 idx++;
391 }
392 return std::string::npos;
393}
394
395void StringRange::RestartCapture() {
396 start_ = data_;
397 end_ = data_;
398}
399
400StringRange StringRange::GetCaptured() {
401 return StringRange(start_, end_ - start_);
402}
403
404TypesWrapper& TypesWrapper::GetTypesWrapper() {
405 static TypesWrapper types;
406 return types;
407}
408
409std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes() {
410 return allowed_data_types_;
411}
412
413std::unordered_map<std::string, int>& TypesWrapper::TypeStrToTensorDataType() {
414 return type_str_to_tensor_data_type_;
415}
416
417std::unordered_map<int, std::string>& TypesWrapper::TensorDataTypeToTypeStr() {
418 return tensor_data_type_to_type_str_;
419}
420
421TypesWrapper::TypesWrapper() {
422 // DataType strings. These should match the DataTypes defined in onnx.proto
423 type_str_to_tensor_data_type_["float"] = TensorProto_DataType_FLOAT;
424 type_str_to_tensor_data_type_["float16"] = TensorProto_DataType_FLOAT16;
425 type_str_to_tensor_data_type_["bfloat16"] = TensorProto_DataType_BFLOAT16;
426 type_str_to_tensor_data_type_["double"] = TensorProto_DataType_DOUBLE;
427 type_str_to_tensor_data_type_["int8"] = TensorProto_DataType_INT8;
428 type_str_to_tensor_data_type_["int16"] = TensorProto_DataType_INT16;
429 type_str_to_tensor_data_type_["int32"] = TensorProto_DataType_INT32;
430 type_str_to_tensor_data_type_["int64"] = TensorProto_DataType_INT64;
431 type_str_to_tensor_data_type_["uint8"] = TensorProto_DataType_UINT8;
432 type_str_to_tensor_data_type_["uint16"] = TensorProto_DataType_UINT16;
433 type_str_to_tensor_data_type_["uint32"] = TensorProto_DataType_UINT32;
434 type_str_to_tensor_data_type_["uint64"] = TensorProto_DataType_UINT64;
435 type_str_to_tensor_data_type_["complex64"] = TensorProto_DataType_COMPLEX64;
436 type_str_to_tensor_data_type_["complex128"] = TensorProto_DataType_COMPLEX128;
437 type_str_to_tensor_data_type_["string"] = TensorProto_DataType_STRING;
438 type_str_to_tensor_data_type_["bool"] = TensorProto_DataType_BOOL;
439
440 for (auto& str_type_pair : type_str_to_tensor_data_type_) {
441 tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first;
442 allowed_data_types_.insert(str_type_pair.first);
443 }
444}
445} // namespace Utils
446} // namespace ONNX_NAMESPACE
447