1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "tensor_proto_util.h" |
6 | #include <vector> |
7 | #include "onnx/common/platform_helpers.h" |
8 | #include "onnx/defs/data_type_utils.h" |
9 | #include "onnx/defs/shape_inference.h" |
10 | |
11 | namespace ONNX_NAMESPACE { |
12 | |
13 | #define DEFINE_TO_TENSOR_ONE(type, enumType, field) \ |
14 | template <> \ |
15 | TensorProto ToTensor<type>(const type& value) { \ |
16 | TensorProto t; \ |
17 | t.set_data_type(enumType); \ |
18 | t.add_##field##_data(value); \ |
19 | return t; \ |
20 | } |
21 | |
22 | #define DEFINE_TO_TENSOR_LIST(type, enumType, field) \ |
23 | template <> \ |
24 | TensorProto ToTensor<type>(const std::vector<type>& values) { \ |
25 | TensorProto t; \ |
26 | t.clear_##field##_data(); \ |
27 | t.set_data_type(enumType); \ |
28 | for (const type& val : values) { \ |
29 | t.add_##field##_data(val); \ |
30 | } \ |
31 | return t; \ |
32 | } |
33 | |
34 | #define DEFINE_PARSE_DATA(type, typed_data_fetch, tensorproto_datatype) \ |
35 | template <> \ |
36 | const std::vector<type> ParseData(const TensorProto* tensor_proto) { \ |
37 | if (!tensor_proto->has_data_type() || tensor_proto->data_type() == TensorProto_DataType_UNDEFINED) { \ |
38 | fail_shape_inference("The type of tensor: ", tensor_proto->name(), " is undefined so it cannot be parsed."); \ |
39 | } else if (tensor_proto->data_type() != tensorproto_datatype) { \ |
40 | fail_shape_inference( \ |
41 | "ParseData type mismatch for tensor: ", \ |
42 | tensor_proto->name(), \ |
43 | ". Expected:", \ |
44 | Utils::DataTypeUtils::ToDataTypeString(tensorproto_datatype), \ |
45 | " Actual:", \ |
46 | Utils::DataTypeUtils::ToDataTypeString(tensor_proto->data_type())); \ |
47 | } \ |
48 | std::vector<type> res; \ |
49 | if (tensor_proto->has_data_location() && tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) { \ |
50 | fail_shape_inference( \ |
51 | "Cannot parse data from external tensors. Please ", \ |
52 | "load external data into raw data for tensor: ", \ |
53 | tensor_proto->name()); \ |
54 | } else if (!tensor_proto->has_raw_data()) { \ |
55 | const auto& data = tensor_proto->typed_data_fetch(); \ |
56 | int expected_size = 1; \ |
57 | for (int i = 0; i < tensor_proto->dims_size(); ++i) { \ |
58 | expected_size *= tensor_proto->dims(i); \ |
59 | } \ |
60 | if (tensor_proto->dims_size() != 0 && data.size() != expected_size) { \ |
61 | fail_shape_inference( \ |
62 | "Data size mismatch. Tensor: ", \ |
63 | tensor_proto->name(), \ |
64 | " expected size ", \ |
65 | expected_size, \ |
66 | " does not match the actual size", \ |
67 | data.size()); \ |
68 | } \ |
69 | res.insert(res.end(), data.begin(), data.end()); \ |
70 | return res; \ |
71 | } \ |
72 | if (tensor_proto->data_type() == TensorProto_DataType_STRING) { \ |
73 | fail_shape_inference( \ |
74 | tensor_proto->name(), \ |
75 | " data type is string. string", \ |
76 | " content is required to be stored in repeated bytes string_data field.", \ |
77 | " raw_data type cannot be string."); \ |
78 | } \ |
79 | /* The given tensor does have raw_data itself so parse it by given type */ \ |
80 | /* make copy as we may have to reverse bytes */ \ |
81 | std::string raw_data = tensor_proto->raw_data(); \ |
82 | /* okay to remove const qualifier as we have already made a copy */ \ |
83 | char* bytes = const_cast<char*>(raw_data.c_str()); \ |
84 | /* onnx is little endian serialized always-tweak byte order if needed */ \ |
85 | if (!is_processor_little_endian()) { \ |
86 | const size_t element_size = sizeof(type); \ |
87 | const size_t num_elements = raw_data.size() / element_size; \ |
88 | for (size_t i = 0; i < num_elements; ++i) { \ |
89 | char* start_byte = bytes + i * element_size; \ |
90 | char* end_byte = start_byte + element_size - 1; \ |
91 | /* keep swapping */ \ |
92 | for (size_t count = 0; count < element_size / 2; ++count) { \ |
93 | char temp = *start_byte; \ |
94 | *start_byte = *end_byte; \ |
95 | *end_byte = temp; \ |
96 | ++start_byte; \ |
97 | --end_byte; \ |
98 | } \ |
99 | } \ |
100 | } \ |
101 | /* raw_data.c_str()/bytes is a byte array and may not be properly */ \ |
102 | /* aligned for the underlying type */ \ |
103 | /* We need to copy the raw_data.c_str()/bytes as byte instead of */ \ |
104 | /* copying as the underlying type, otherwise we may hit memory */ \ |
105 | /* misalignment issues on certain platforms, such as arm32-v7a */ \ |
106 | const size_t raw_data_size = raw_data.size(); \ |
107 | res.resize(raw_data_size / sizeof(type)); \ |
108 | memcpy(reinterpret_cast<char*>(res.data()), bytes, raw_data_size); \ |
109 | return res; \ |
110 | } |
111 | |
112 | DEFINE_TO_TENSOR_ONE(float, TensorProto_DataType_FLOAT, float) |
113 | DEFINE_TO_TENSOR_ONE(bool, TensorProto_DataType_BOOL, int32) |
114 | DEFINE_TO_TENSOR_ONE(int32_t, TensorProto_DataType_INT32, int32) |
115 | DEFINE_TO_TENSOR_ONE(int64_t, TensorProto_DataType_INT64, int64) |
116 | DEFINE_TO_TENSOR_ONE(uint64_t, TensorProto_DataType_UINT64, uint64) |
117 | DEFINE_TO_TENSOR_ONE(double, TensorProto_DataType_DOUBLE, double) |
118 | DEFINE_TO_TENSOR_ONE(std::string, TensorProto_DataType_STRING, string) |
119 | |
120 | DEFINE_TO_TENSOR_LIST(float, TensorProto_DataType_FLOAT, float) |
121 | DEFINE_TO_TENSOR_LIST(bool, TensorProto_DataType_BOOL, int32) |
122 | DEFINE_TO_TENSOR_LIST(int32_t, TensorProto_DataType_INT32, int32) |
123 | DEFINE_TO_TENSOR_LIST(int64_t, TensorProto_DataType_INT64, int64) |
124 | DEFINE_TO_TENSOR_LIST(uint64_t, TensorProto_DataType_UINT64, uint64) |
125 | DEFINE_TO_TENSOR_LIST(double, TensorProto_DataType_DOUBLE, double) |
126 | DEFINE_TO_TENSOR_LIST(std::string, TensorProto_DataType_STRING, string) |
127 | |
128 | DEFINE_PARSE_DATA(int32_t, int32_data, TensorProto_DataType_INT32) |
129 | DEFINE_PARSE_DATA(int64_t, int64_data, TensorProto_DataType_INT64) |
130 | DEFINE_PARSE_DATA(float, float_data, TensorProto_DataType_FLOAT) |
131 | DEFINE_PARSE_DATA(double, double_data, TensorProto_DataType_DOUBLE) |
132 | |
133 | } // namespace ONNX_NAMESPACE |
134 | |