1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#ifndef ONNX_DATA_TYPE_UTILS_H
6#define ONNX_DATA_TYPE_UTILS_H
7
8#include <mutex>
9#include <string>
10#include <unordered_map>
11#include <unordered_set>
12#include "onnx/common/common.h"
13#include "onnx/onnx_pb.h"
14
15namespace ONNX_NAMESPACE {
16// String pointer as unique TypeProto identifier.
17using DataType = const std::string*;
18
19namespace Utils {
20
21// Data type utility, which maintains a global type string to TypeProto map.
22// DataType (string pointer) is used as unique data type identifier for
23// efficiency.
24//
25// Grammar for data type string:
26// <type> ::= <data_type> |
27// tensor(<data_type>) |
28// seq(<type>) |
29// map(<data_type>, <type>)
30// <data_type> :: = float | int32 | string | bool | uint8
31// | int8 | uint16 | int16 | int64 | float16 | double
32//
33// NOTE: <type> ::= <data_type> means the data is scalar (zero dimension).
34//
35// Example: float, tensor(float), etc.
36//
37class DataTypeUtils final {
38 public:
39 // If the DataType input is invalid, this function will throw std::invalid_argument exception.
40 // If ONNX_NO_EXCEPTIONS is set it will abort.
41 static DataType ToType(const std::string& type_str);
42
43 // If the DataType input is invalid, this function will throw std::invalid_argument exception.
44 // If ONNX_NO_EXCEPTIONS is set it will abort.
45 static DataType ToType(const TypeProto& type_proto);
46
47 // If the DataType input is invalid, this function will throw std::invalid_argument exception.
48 // If ONNX_NO_EXCEPTIONS is set it will abort.
49 static const TypeProto& ToTypeProto(const DataType& data_type);
50 static std::string ToDataTypeString(int32_t tensor_data_type);
51
52 private:
53 static void FromString(const std::string& type_str, TypeProto& type_proto);
54
55 static void FromDataTypeString(const std::string& type_str, int32_t& tensor_data_type);
56
57 static std::string ToString(const TypeProto& type_proto, const std::string& left = "", const std::string& right = "");
58
59 // If int32_t input is invalid, this function will throw an exception.
60 // If ONNX_NO_EXCEPTIONS is set it will abort.
61
62 static bool IsValidDataTypeString(const std::string& type_str);
63
64 static std::unordered_map<std::string, TypeProto>& GetTypeStrToProtoMap();
65
66 // Returns lock used for concurrent updates to TypeStrToProtoMap.
67 static std::mutex& GetTypeStrLock();
68};
69} // namespace Utils
70} // namespace ONNX_NAMESPACE
71
72#endif // ! ONNX_DATA_TYPE_UTILS_H
73