1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// ATTENTION: The code in this file is highly EXPERIMENTAL.
6// Adventurous users should note that the APIs will probably change.
7
8#pragma once
9
10#include <cmath>
11#include <functional>
12#include <numeric>
13#include "onnx/common/assertions.h"
14#include "onnx/onnx_pb.h"
15#include "onnx/string_utils.h"
16
17namespace ONNX_NAMESPACE {
18
19struct Tensor final {
20 private:
21 bool is_segment_;
22 int64_t segment_begin_;
23 int64_t segment_end_;
24 bool has_name_;
25 std::string name_;
26 int32_t elem_type_;
27 std::vector<int64_t> sizes_;
28
29 std::vector<float> float_data_;
30 std::vector<double> double_data_;
31 std::vector<int32_t> int32_data_;
32 std::vector<int64_t> int64_data_;
33 std::vector<uint64_t> uint64_data_;
34 std::vector<std::string> string_data_;
35
36 bool is_raw_data_;
37 std::string raw_data_;
38
39 template <typename F, typename T>
40 void bin_func(const F& f, T* ptr, const T* a_ptr);
41
42 template <typename F, typename T>
43 void un_func(const F& f, T* ptr);
44
45 template <typename T>
46 void scale_dim(T* ptr, const T* s_ptr);
47
48 public:
49 Tensor()
50 : is_segment_(false),
51 segment_begin_(0),
52 segment_end_(0),
53 has_name_(false),
54 elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
55 is_raw_data_(false) {}
56
57 Tensor(const Tensor& other)
58 : is_segment_(other.is_segment_),
59 segment_begin_(other.segment_begin_),
60 segment_end_(other.segment_end_),
61 has_name_(other.has_name_),
62 elem_type_(other.elem_type_),
63 sizes_(other.sizes_),
64 float_data_(other.float_data_),
65 double_data_(other.double_data_),
66 int32_data_(other.int32_data_),
67 int64_data_(other.int64_data_),
68 uint64_data_(other.uint64_data_),
69 is_raw_data_(other.is_raw_data_) {
70 // Deep copy. Avoid copy on write when using gcc<5.0
71 string_data_.resize(other.string_data_.size());
72 for (unsigned int i = 0; i < other.string_data_.size(); ++i) {
73 string_data_[i] = std::string(other.string_data_[i].data(), other.string_data_[i].size());
74 }
75 name_ = std::string(other.name_.data(), other.name_.size());
76 raw_data_ = std::string(other.raw_data_.data(), other.raw_data_.size());
77 }
78 Tensor(Tensor&&) = default;
79 ~Tensor() = default;
80
81 friend void swap(Tensor& first, Tensor& second) {
82 using std::swap;
83 swap(first.is_segment_, second.is_segment_);
84 swap(first.segment_begin_, second.segment_begin_);
85 swap(first.segment_end_, second.segment_end_);
86 swap(first.has_name_, second.has_name_);
87 swap(first.name_, second.name_);
88 swap(first.elem_type_, second.elem_type_);
89 swap(first.sizes_, second.sizes_);
90 swap(first.float_data_, second.float_data_);
91 swap(first.double_data_, second.double_data_);
92 swap(first.int32_data_, second.int32_data_);
93 swap(first.int64_data_, second.int64_data_);
94 swap(first.uint64_data_, second.uint64_data_);
95 swap(first.is_raw_data_, second.is_raw_data_);
96 swap(first.string_data_, second.string_data_);
97 swap(first.raw_data_, second.raw_data_);
98 }
99
100 Tensor& operator=(Tensor other) noexcept {
101 swap(*this, other);
102 return *this;
103 }
104
105 const std::vector<int64_t>& sizes() const {
106 return sizes_;
107 }
108 std::vector<int64_t>& sizes() {
109 return sizes_;
110 }
111
112 int64_t size_from_dim(int dim) const {
113 if (dim < 0) {
114 dim += (int)sizes_.size();
115 }
116 ONNX_ASSERT(dim >= 0 && (size_t)dim < sizes_.size());
117 return std::accumulate(sizes_.begin() + dim, sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
118 }
119
120 int32_t elem_type() const {
121 return elem_type_;
122 }
123
124 int32_t& elem_type() {
125 return elem_type_;
126 }
127
128 std::vector<std::string>& strings() {
129 return string_data_;
130 }
131
132 const std::vector<std::string>& strings() const {
133 return string_data_;
134 }
135
136 std::vector<float>& floats() {
137 return float_data_;
138 }
139
140 const std::vector<float>& floats() const {
141 return float_data_;
142 }
143
144 std::vector<double>& doubles() {
145 return double_data_;
146 }
147
148 const std::vector<double>& doubles() const {
149 return double_data_;
150 }
151
152 std::vector<int32_t>& int32s() {
153 return int32_data_;
154 }
155
156 const std::vector<int32_t>& int32s() const {
157 return int32_data_;
158 }
159
160 std::vector<int64_t>& int64s() {
161 return int64_data_;
162 }
163
164 const std::vector<int64_t>& int64s() const {
165 return int64_data_;
166 }
167
168 std::vector<uint64_t>& uint64s() {
169 return uint64_data_;
170 }
171
172 const std::vector<uint64_t>& uint64s() const {
173 return uint64_data_;
174 }
175
176 const std::string& raw() const {
177 return raw_data_;
178 }
179
180 void set_raw_data(std::string raw_data) {
181 is_raw_data_ = true;
182 raw_data_ = std::move(raw_data);
183 }
184
185 template <typename T>
186 T* data();
187
188 template <typename T>
189 const T* data() const;
190
191 bool is_segment() const {
192 return is_segment_;
193 }
194
195 int64_t segment_begin() const {
196 return segment_begin_;
197 }
198
199 int64_t segment_end() const {
200 return segment_end_;
201 }
202
203 void set_segment_begin_and_end(int64_t begin, int64_t end) {
204 is_segment_ = true;
205 segment_begin_ = begin;
206 segment_end_ = end;
207 }
208
209 bool hasName() const {
210 return has_name_;
211 }
212
213 const std::string& name() const {
214 return name_;
215 }
216
217 void setName(std::string name) {
218 has_name_ = true;
219 name_ = std::move(name);
220 }
221
222 bool is_raw_data() const {
223 return is_raw_data_;
224 }
225
226 // this += a
227 // Supported for
228 // FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
229 // UINT32, UINT64, DOUBLE,
230 // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
231 void add(const Tensor& a);
232
233 // this -= a
234 // Supported for
235 // FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
236 // UINT32, UINT64, DOUBLE
237 // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
238 void subtract(const Tensor& a);
239
240 // this *= a
241 // Supported for
242 // FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
243 // UINT32, UINT64, DOUBLE
244 // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
245 void multiply(const Tensor& a);
246
247 // this /= a
248 // Supported for
249 // FLOAT, INT8, INT16, INT32, UINT8, UINT16, INT64,
250 // UINT32, UINT64, DOUBLE
251 // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
252 void divide(const Tensor& a);
253
254 // Element-wise square root of This
255 // Supported for
256 // FLOAT, DOUBLE,
257 // TODO: Support for FLOAT16
258 void sqrt();
259
260 // Element wise scaling of tensor s
261 // s is one dimensional, has size M, where M is size of first dimension of tensor
262 // s must have has data type corresponding to this
263 // Supported for
264 // FLOAT16, FLOAT, DOUBLE
265 void scale_by_first_dim(const Tensor& s);
266};
267
268#define define_data(type, field) \
269 template <> \
270 inline type* Tensor::data<type>() { \
271 if (is_raw_data_) { \
272 return (type*)const_cast<char*>(&raw_data_.data()[0]); \
273 } else { \
274 return field.data(); \
275 } \
276 } \
277 \
278 template <> \
279 inline const type* Tensor::data<type>() const { \
280 if (is_raw_data_) { \
281 return (const type*)(raw_data_.data()); \
282 } else { \
283 return field.data(); \
284 } \
285 }
286
287define_data(float, float_data_);
288define_data(double, double_data_);
289define_data(int32_t, int32_data_);
290define_data(int64_t, int64_data_);
291define_data(uint64_t, uint64_data_);
292define_data(std::string, string_data_);
293#undef define_data
294
295template <typename F, typename T>
296inline void Tensor::bin_func(const F& f, T* ptr, const T* a_ptr) {
297 const int64_t num_elements = size_from_dim(0);
298 for (int64_t i = 0; i < num_elements; ++i) {
299 ptr[i] = f(ptr[i], a_ptr[i]);
300 }
301}
302
303template <typename F, typename T>
304inline void Tensor::un_func(const F& f, T* ptr) {
305 const int64_t num_elements = size_from_dim(0);
306 for (int64_t i = 0; i < num_elements; ++i) {
307 ptr[i] = f(ptr[i]);
308 }
309}
310
311template <typename T>
312inline void Tensor::scale_dim(T* ptr, const T* s_ptr) {
313 int64_t elems_per_first_dim = size_from_dim(1);
314 int64_t first_dim_size = sizes_[0];
315 int64_t counter = 0;
316 for (int64_t i = 0; i < first_dim_size; ++i) {
317 for (int64_t j = 0; j < elems_per_first_dim; ++j) {
318 ptr[counter++] *= s_ptr[i];
319 }
320 }
321}
322
323#define APPLY_BINARY_FUNCTION(op_name, f) \
324 inline void Tensor::op_name(const Tensor& other) { \
325 TENSOR_ASSERTM( \
326 other.elem_type() == elem_type_, \
327 "Tensor types do not match: %s != %s", \
328 to_string(elem_type_).c_str(), \
329 " vs. ", \
330 to_string(other.elem_type()).c_str()); \
331 TENSOR_ASSERTM(other.sizes() == sizes_, "Tensor sizes do not match."); \
332 switch (elem_type_) { \
333 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { \
334 bin_func(f<float>(), data<float>(), other.data<float>()); \
335 break; \
336 } \
337 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
338 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
339 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
340 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
341 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
342 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { \
343 bin_func(f<int32_t>(), data<int32_t>(), other.data<int32_t>()); \
344 break; \
345 } \
346 case ONNX_NAMESPACE::TensorProto_DataType_INT64: { \
347 bin_func(f<int64_t>(), data<int64_t>(), other.data<int64_t>()); \
348 break; \
349 } \
350 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
351 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { \
352 bin_func(f<uint64_t>(), data<uint64_t>(), other.data<uint64_t>()); \
353 break; \
354 } \
355 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { \
356 bin_func(f<double>(), data<double>(), other.data<double>()); \
357 break; \
358 } \
359 default: \
360 TENSOR_ASSERTM( \
361 false, \
362 "Operation %s not supported for data type %s", \
363 #op_name, \
364 " not supported for data type ", \
365 to_string(elem_type_).c_str()); \
366 } \
367 }
368
369APPLY_BINARY_FUNCTION(add, std::plus)
370APPLY_BINARY_FUNCTION(subtract, std::minus)
371APPLY_BINARY_FUNCTION(multiply, std::multiplies)
372APPLY_BINARY_FUNCTION(divide, std::divides)
373
374#undef APPLY_BINARY_FUNCTION
375
376inline void Tensor::sqrt() {
377 switch (elem_type_) {
378 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
379 un_func<float (*)(float), float>(std::sqrt, data<float>());
380 break;
381 }
382 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
383 un_func<double (*)(double), double>(std::sqrt, data<double>());
384 break;
385 }
386 default:
387 TENSOR_ASSERTM(false, "Operation sqrt not supported for data type %s", to_string(elem_type_).c_str());
388 }
389}
390
391inline void Tensor::scale_by_first_dim(const Tensor& other) {
392 ONNX_ASSERT(sizes_.size() > 1 && other.sizes().size() == 1 && other.sizes()[0] == sizes_[0]);
393 ONNX_ASSERT(other.elem_type() == elem_type_);
394
395 switch (elem_type_) {
396 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
397 scale_dim(data<float>(), other.data<float>());
398 break;
399 }
400 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
401 scale_dim(data<int32_t>(), other.data<int32_t>());
402 break;
403 }
404 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
405 scale_dim(data<double>(), other.data<double>());
406 break;
407 }
408 default:
409 TENSOR_ASSERTM(
410 false, "Operation scale_by_first_dim not supported for data type %s", to_string(elem_type_).c_str());
411 }
412}
413
414} // namespace ONNX_NAMESPACE
415