1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "tensor_util.h" |
6 | #include <vector> |
7 | #include "onnx/common/platform_helpers.h" |
8 | |
9 | namespace ONNX_NAMESPACE { |
10 | |
11 | #define DEFINE_PARSE_DATA(type, typed_data_fetch) \ |
12 | template <> \ |
13 | const std::vector<type> ParseData(const Tensor* tensor) { \ |
14 | std::vector<type> res; \ |
15 | if (!tensor->is_raw_data()) { \ |
16 | const auto& data = tensor->typed_data_fetch(); \ |
17 | res.insert(res.end(), data.begin(), data.end()); \ |
18 | return res; \ |
19 | } \ |
20 | /* make copy as we may have to reverse bytes */ \ |
21 | std::string raw_data = tensor->raw(); \ |
22 | /* okay to remove const qualifier as we have already made a copy */ \ |
23 | char* bytes = const_cast<char*>(raw_data.c_str()); \ |
24 | /*onnx is little endian serialized always-tweak byte order if needed*/ \ |
25 | if (!is_processor_little_endian()) { \ |
26 | const size_t element_size = sizeof(type); \ |
27 | const size_t num_elements = raw_data.size() / element_size; \ |
28 | for (size_t i = 0; i < num_elements; ++i) { \ |
29 | char* start_byte = bytes + i * element_size; \ |
30 | char* end_byte = start_byte + element_size - 1; \ |
31 | /* keep swapping */ \ |
32 | for (size_t count = 0; count < element_size / 2; ++count) { \ |
33 | char temp = *start_byte; \ |
34 | *start_byte = *end_byte; \ |
35 | *end_byte = temp; \ |
36 | ++start_byte; \ |
37 | --end_byte; \ |
38 | } \ |
39 | } \ |
40 | } \ |
41 | /* raw_data.c_str()/bytes is a byte array and may not be properly */ \ |
42 | /* aligned for the underlying type */ \ |
43 | /* We need to copy the raw_data.c_str()/bytes as byte instead of */ \ |
44 | /* copying as the underlying type, otherwise we may hit memory */ \ |
45 | /* misalignment issues on certain platforms, such as arm32-v7a */ \ |
46 | const size_t raw_data_size = raw_data.size(); \ |
47 | res.resize(raw_data_size / sizeof(type)); \ |
48 | memcpy(reinterpret_cast<char*>(res.data()), bytes, raw_data_size); \ |
49 | return res; \ |
50 | } |
51 | |
52 | DEFINE_PARSE_DATA(int32_t, int32s) |
53 | DEFINE_PARSE_DATA(int64_t, int64s) |
54 | DEFINE_PARSE_DATA(float, floats) |
55 | DEFINE_PARSE_DATA(double, doubles) |
56 | DEFINE_PARSE_DATA(uint64_t, uint64s) |
57 | |
58 | } // namespace ONNX_NAMESPACE |
59 | |