1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #ifndef ONNX_DATA_TYPE_UTILS_H |
6 | #define ONNX_DATA_TYPE_UTILS_H |
7 | |
8 | #include <mutex> |
9 | #include <string> |
10 | #include <unordered_map> |
11 | #include <unordered_set> |
12 | #include "onnx/common/common.h" |
13 | #include "onnx/onnx_pb.h" |
14 | |
15 | namespace ONNX_NAMESPACE { |
16 | // String pointer as unique TypeProto identifier. |
17 | using DataType = const std::string*; |
18 | |
19 | namespace Utils { |
20 | |
21 | // Data type utility, which maintains a global type string to TypeProto map. |
22 | // DataType (string pointer) is used as unique data type identifier for |
23 | // efficiency. |
24 | // |
25 | // Grammar for data type string: |
26 | // <type> ::= <data_type> | |
27 | // tensor(<data_type>) | |
28 | // seq(<type>) | |
29 | // map(<data_type>, <type>) |
30 | // <data_type> :: = float | int32 | string | bool | uint8 |
31 | // | int8 | uint16 | int16 | int64 | float16 | double |
32 | // |
33 | // NOTE: <type> ::= <data_type> means the data is scalar (zero dimension). |
34 | // |
35 | // Example: float, tensor(float), etc. |
36 | // |
37 | class DataTypeUtils final { |
38 | public: |
39 | // If the DataType input is invalid, this function will throw std::invalid_argument exception. |
40 | // If ONNX_NO_EXCEPTIONS is set it will abort. |
41 | static DataType ToType(const std::string& type_str); |
42 | |
43 | // If the DataType input is invalid, this function will throw std::invalid_argument exception. |
44 | // If ONNX_NO_EXCEPTIONS is set it will abort. |
45 | static DataType ToType(const TypeProto& type_proto); |
46 | |
47 | // If the DataType input is invalid, this function will throw std::invalid_argument exception. |
48 | // If ONNX_NO_EXCEPTIONS is set it will abort. |
49 | static const TypeProto& ToTypeProto(const DataType& data_type); |
50 | static std::string ToDataTypeString(int32_t tensor_data_type); |
51 | |
52 | private: |
53 | static void FromString(const std::string& type_str, TypeProto& type_proto); |
54 | |
55 | static void FromDataTypeString(const std::string& type_str, int32_t& tensor_data_type); |
56 | |
57 | static std::string ToString(const TypeProto& type_proto, const std::string& left = "" , const std::string& right = "" ); |
58 | |
59 | // If int32_t input is invalid, this function will throw an exception. |
60 | // If ONNX_NO_EXCEPTIONS is set it will abort. |
61 | |
62 | static bool IsValidDataTypeString(const std::string& type_str); |
63 | |
64 | static std::unordered_map<std::string, TypeProto>& GetTypeStrToProtoMap(); |
65 | |
66 | // Returns lock used for concurrent updates to TypeStrToProtoMap. |
67 | static std::mutex& GetTypeStrLock(); |
68 | }; |
69 | } // namespace Utils |
70 | } // namespace ONNX_NAMESPACE |
71 | |
72 | #endif // ! ONNX_DATA_TYPE_UTILS_H |
73 | |