1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "tensor_proto_util.h"
6#include <vector>
7#include "onnx/common/platform_helpers.h"
8#include "onnx/defs/data_type_utils.h"
9#include "onnx/defs/shape_inference.h"
10
11namespace ONNX_NAMESPACE {
12
13#define DEFINE_TO_TENSOR_ONE(type, enumType, field) \
14 template <> \
15 TensorProto ToTensor<type>(const type& value) { \
16 TensorProto t; \
17 t.set_data_type(enumType); \
18 t.add_##field##_data(value); \
19 return t; \
20 }
21
22#define DEFINE_TO_TENSOR_LIST(type, enumType, field) \
23 template <> \
24 TensorProto ToTensor<type>(const std::vector<type>& values) { \
25 TensorProto t; \
26 t.clear_##field##_data(); \
27 t.set_data_type(enumType); \
28 for (const type& val : values) { \
29 t.add_##field##_data(val); \
30 } \
31 return t; \
32 }
33
34#define DEFINE_PARSE_DATA(type, typed_data_fetch, tensorproto_datatype) \
35 template <> \
36 const std::vector<type> ParseData(const TensorProto* tensor_proto) { \
37 if (!tensor_proto->has_data_type() || tensor_proto->data_type() == TensorProto_DataType_UNDEFINED) { \
38 fail_shape_inference("The type of tensor: ", tensor_proto->name(), " is undefined so it cannot be parsed."); \
39 } else if (tensor_proto->data_type() != tensorproto_datatype) { \
40 fail_shape_inference( \
41 "ParseData type mismatch for tensor: ", \
42 tensor_proto->name(), \
43 ". Expected:", \
44 Utils::DataTypeUtils::ToDataTypeString(tensorproto_datatype), \
45 " Actual:", \
46 Utils::DataTypeUtils::ToDataTypeString(tensor_proto->data_type())); \
47 } \
48 std::vector<type> res; \
49 if (tensor_proto->has_data_location() && tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) { \
50 fail_shape_inference( \
51 "Cannot parse data from external tensors. Please ", \
52 "load external data into raw data for tensor: ", \
53 tensor_proto->name()); \
54 } else if (!tensor_proto->has_raw_data()) { \
55 const auto& data = tensor_proto->typed_data_fetch(); \
56 int expected_size = 1; \
57 for (int i = 0; i < tensor_proto->dims_size(); ++i) { \
58 expected_size *= tensor_proto->dims(i); \
59 } \
60 if (tensor_proto->dims_size() != 0 && data.size() != expected_size) { \
61 fail_shape_inference( \
62 "Data size mismatch. Tensor: ", \
63 tensor_proto->name(), \
64 " expected size ", \
65 expected_size, \
66 " does not match the actual size", \
67 data.size()); \
68 } \
69 res.insert(res.end(), data.begin(), data.end()); \
70 return res; \
71 } \
72 if (tensor_proto->data_type() == TensorProto_DataType_STRING) { \
73 fail_shape_inference( \
74 tensor_proto->name(), \
75 " data type is string. string", \
76 " content is required to be stored in repeated bytes string_data field.", \
77 " raw_data type cannot be string."); \
78 } \
79 /* The given tensor does have raw_data itself so parse it by given type */ \
80 /* make copy as we may have to reverse bytes */ \
81 std::string raw_data = tensor_proto->raw_data(); \
82 /* okay to remove const qualifier as we have already made a copy */ \
83 char* bytes = const_cast<char*>(raw_data.c_str()); \
84 /* onnx is little endian serialized always-tweak byte order if needed */ \
85 if (!is_processor_little_endian()) { \
86 const size_t element_size = sizeof(type); \
87 const size_t num_elements = raw_data.size() / element_size; \
88 for (size_t i = 0; i < num_elements; ++i) { \
89 char* start_byte = bytes + i * element_size; \
90 char* end_byte = start_byte + element_size - 1; \
91 /* keep swapping */ \
92 for (size_t count = 0; count < element_size / 2; ++count) { \
93 char temp = *start_byte; \
94 *start_byte = *end_byte; \
95 *end_byte = temp; \
96 ++start_byte; \
97 --end_byte; \
98 } \
99 } \
100 } \
101 /* raw_data.c_str()/bytes is a byte array and may not be properly */ \
102 /* aligned for the underlying type */ \
103 /* We need to copy the raw_data.c_str()/bytes as byte instead of */ \
104 /* copying as the underlying type, otherwise we may hit memory */ \
105 /* misalignment issues on certain platforms, such as arm32-v7a */ \
106 const size_t raw_data_size = raw_data.size(); \
107 res.resize(raw_data_size / sizeof(type)); \
108 memcpy(reinterpret_cast<char*>(res.data()), bytes, raw_data_size); \
109 return res; \
110 }
111
112DEFINE_TO_TENSOR_ONE(float, TensorProto_DataType_FLOAT, float)
113DEFINE_TO_TENSOR_ONE(bool, TensorProto_DataType_BOOL, int32)
114DEFINE_TO_TENSOR_ONE(int32_t, TensorProto_DataType_INT32, int32)
115DEFINE_TO_TENSOR_ONE(int64_t, TensorProto_DataType_INT64, int64)
116DEFINE_TO_TENSOR_ONE(uint64_t, TensorProto_DataType_UINT64, uint64)
117DEFINE_TO_TENSOR_ONE(double, TensorProto_DataType_DOUBLE, double)
118DEFINE_TO_TENSOR_ONE(std::string, TensorProto_DataType_STRING, string)
119
120DEFINE_TO_TENSOR_LIST(float, TensorProto_DataType_FLOAT, float)
121DEFINE_TO_TENSOR_LIST(bool, TensorProto_DataType_BOOL, int32)
122DEFINE_TO_TENSOR_LIST(int32_t, TensorProto_DataType_INT32, int32)
123DEFINE_TO_TENSOR_LIST(int64_t, TensorProto_DataType_INT64, int64)
124DEFINE_TO_TENSOR_LIST(uint64_t, TensorProto_DataType_UINT64, uint64)
125DEFINE_TO_TENSOR_LIST(double, TensorProto_DataType_DOUBLE, double)
126DEFINE_TO_TENSOR_LIST(std::string, TensorProto_DataType_STRING, string)
127
128DEFINE_PARSE_DATA(int32_t, int32_data, TensorProto_DataType_INT32)
129DEFINE_PARSE_DATA(int64_t, int64_data, TensorProto_DataType_INT64)
130DEFINE_PARSE_DATA(float, float_data, TensorProto_DataType_FLOAT)
131DEFINE_PARSE_DATA(double, double_data, TensorProto_DataType_DOUBLE)
132
133} // namespace ONNX_NAMESPACE
134