1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. |
6 | // Adventurous users should note that the APIs will probably change. |
7 | |
8 | #pragma once |
9 | |
10 | #include <cmath> |
11 | #include <functional> |
12 | #include <numeric> |
13 | #include "onnx/common/assertions.h" |
14 | #include "onnx/onnx_pb.h" |
15 | #include "onnx/string_utils.h" |
16 | |
17 | namespace ONNX_NAMESPACE { |
18 | |
19 | struct Tensor final { |
20 | private: |
21 | bool is_segment_; |
22 | int64_t segment_begin_; |
23 | int64_t segment_end_; |
24 | bool has_name_; |
25 | std::string name_; |
26 | int32_t elem_type_; |
27 | std::vector<int64_t> sizes_; |
28 | |
29 | std::vector<float> float_data_; |
30 | std::vector<double> double_data_; |
31 | std::vector<int32_t> int32_data_; |
32 | std::vector<int64_t> int64_data_; |
33 | std::vector<uint64_t> uint64_data_; |
34 | std::vector<std::string> string_data_; |
35 | |
36 | bool is_raw_data_; |
37 | std::string raw_data_; |
38 | |
39 | template <typename F, typename T> |
40 | void bin_func(const F& f, T* ptr, const T* a_ptr); |
41 | |
42 | template <typename F, typename T> |
43 | void un_func(const F& f, T* ptr); |
44 | |
45 | template <typename T> |
46 | void scale_dim(T* ptr, const T* s_ptr); |
47 | |
48 | public: |
49 | Tensor() |
50 | : is_segment_(false), |
51 | segment_begin_(0), |
52 | segment_end_(0), |
53 | has_name_(false), |
54 | elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED), |
55 | is_raw_data_(false) {} |
56 | |
57 | Tensor(const Tensor& other) |
58 | : is_segment_(other.is_segment_), |
59 | segment_begin_(other.segment_begin_), |
60 | segment_end_(other.segment_end_), |
61 | has_name_(other.has_name_), |
62 | elem_type_(other.elem_type_), |
63 | sizes_(other.sizes_), |
64 | float_data_(other.float_data_), |
65 | double_data_(other.double_data_), |
66 | int32_data_(other.int32_data_), |
67 | int64_data_(other.int64_data_), |
68 | uint64_data_(other.uint64_data_), |
69 | is_raw_data_(other.is_raw_data_) { |
70 | // Deep copy. Avoid copy on write when using gcc<5.0 |
71 | string_data_.resize(other.string_data_.size()); |
72 | for (unsigned int i = 0; i < other.string_data_.size(); ++i) { |
73 | string_data_[i] = std::string(other.string_data_[i].data(), other.string_data_[i].size()); |
74 | } |
75 | name_ = std::string(other.name_.data(), other.name_.size()); |
76 | raw_data_ = std::string(other.raw_data_.data(), other.raw_data_.size()); |
77 | } |
78 | Tensor(Tensor&&) = default; |
79 | ~Tensor() = default; |
80 | |
81 | friend void swap(Tensor& first, Tensor& second) { |
82 | using std::swap; |
83 | swap(first.is_segment_, second.is_segment_); |
84 | swap(first.segment_begin_, second.segment_begin_); |
85 | swap(first.segment_end_, second.segment_end_); |
86 | swap(first.has_name_, second.has_name_); |
87 | swap(first.name_, second.name_); |
88 | swap(first.elem_type_, second.elem_type_); |
89 | swap(first.sizes_, second.sizes_); |
90 | swap(first.float_data_, second.float_data_); |
91 | swap(first.double_data_, second.double_data_); |
92 | swap(first.int32_data_, second.int32_data_); |
93 | swap(first.int64_data_, second.int64_data_); |
94 | swap(first.uint64_data_, second.uint64_data_); |
95 | swap(first.is_raw_data_, second.is_raw_data_); |
96 | swap(first.string_data_, second.string_data_); |
97 | swap(first.raw_data_, second.raw_data_); |
98 | } |
99 | |
100 | Tensor& operator=(Tensor other) noexcept { |
101 | swap(*this, other); |
102 | return *this; |
103 | } |
104 | |
105 | const std::vector<int64_t>& sizes() const { |
106 | return sizes_; |
107 | } |
108 | std::vector<int64_t>& sizes() { |
109 | return sizes_; |
110 | } |
111 | |
112 | int64_t size_from_dim(int dim) const { |
113 | if (dim < 0) { |
114 | dim += (int)sizes_.size(); |
115 | } |
116 | ONNX_ASSERT(dim >= 0 && (size_t)dim < sizes_.size()); |
117 | return std::accumulate(sizes_.begin() + dim, sizes_.end(), (int64_t)1, std::multiplies<int64_t>{}); |
118 | } |
119 | |
120 | int32_t elem_type() const { |
121 | return elem_type_; |
122 | } |
123 | |
124 | int32_t& elem_type() { |
125 | return elem_type_; |
126 | } |
127 | |
128 | std::vector<std::string>& strings() { |
129 | return string_data_; |
130 | } |
131 | |
132 | const std::vector<std::string>& strings() const { |
133 | return string_data_; |
134 | } |
135 | |
136 | std::vector<float>& floats() { |
137 | return float_data_; |
138 | } |
139 | |
140 | const std::vector<float>& floats() const { |
141 | return float_data_; |
142 | } |
143 | |
144 | std::vector<double>& doubles() { |
145 | return double_data_; |
146 | } |
147 | |
148 | const std::vector<double>& doubles() const { |
149 | return double_data_; |
150 | } |
151 | |
152 | std::vector<int32_t>& int32s() { |
153 | return int32_data_; |
154 | } |
155 | |
156 | const std::vector<int32_t>& int32s() const { |
157 | return int32_data_; |
158 | } |
159 | |
160 | std::vector<int64_t>& int64s() { |
161 | return int64_data_; |
162 | } |
163 | |
164 | const std::vector<int64_t>& int64s() const { |
165 | return int64_data_; |
166 | } |
167 | |
168 | std::vector<uint64_t>& uint64s() { |
169 | return uint64_data_; |
170 | } |
171 | |
172 | const std::vector<uint64_t>& uint64s() const { |
173 | return uint64_data_; |
174 | } |
175 | |
176 | const std::string& raw() const { |
177 | return raw_data_; |
178 | } |
179 | |
180 | void set_raw_data(std::string raw_data) { |
181 | is_raw_data_ = true; |
182 | raw_data_ = std::move(raw_data); |
183 | } |
184 | |
185 | template <typename T> |
186 | T* data(); |
187 | |
188 | template <typename T> |
189 | const T* data() const; |
190 | |
191 | bool is_segment() const { |
192 | return is_segment_; |
193 | } |
194 | |
195 | int64_t segment_begin() const { |
196 | return segment_begin_; |
197 | } |
198 | |
199 | int64_t segment_end() const { |
200 | return segment_end_; |
201 | } |
202 | |
203 | void set_segment_begin_and_end(int64_t begin, int64_t end) { |
204 | is_segment_ = true; |
205 | segment_begin_ = begin; |
206 | segment_end_ = end; |
207 | } |
208 | |
209 | bool hasName() const { |
210 | return has_name_; |
211 | } |
212 | |
213 | const std::string& name() const { |
214 | return name_; |
215 | } |
216 | |
217 | void setName(std::string name) { |
218 | has_name_ = true; |
219 | name_ = std::move(name); |
220 | } |
221 | |
222 | bool is_raw_data() const { |
223 | return is_raw_data_; |
224 | } |
225 | |
226 | // this += a |
227 | // Supported for |
228 | // FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64, |
229 | // UINT32, UINT64, DOUBLE, |
230 | // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128 |
231 | void add(const Tensor& a); |
232 | |
233 | // this -= a |
234 | // Supported for |
235 | // FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64, |
236 | // UINT32, UINT64, DOUBLE |
237 | // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128 |
238 | void subtract(const Tensor& a); |
239 | |
240 | // this *= a |
241 | // Supported for |
242 | // FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64, |
243 | // UINT32, UINT64, DOUBLE |
244 | // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128 |
245 | void multiply(const Tensor& a); |
246 | |
247 | // this /= a |
248 | // Supported for |
249 | // FLOAT, INT8, INT16, INT32, UINT8, UINT16, INT64, |
250 | // UINT32, UINT64, DOUBLE |
251 | // TODO: Support for FLOAT16, COMPLEX64, COMPLEX128 |
252 | void divide(const Tensor& a); |
253 | |
254 | // Element-wise square root of This |
255 | // Supported for |
256 | // FLOAT, DOUBLE, |
257 | // TODO: Support for FLOAT16 |
258 | void sqrt(); |
259 | |
260 | // Element wise scaling of tensor s |
261 | // s is one dimensional, has size M, where M is size of first dimension of tensor |
262 | // s must have has data type corresponding to this |
263 | // Supported for |
264 | // FLOAT16, FLOAT, DOUBLE |
265 | void scale_by_first_dim(const Tensor& s); |
266 | }; |
267 | |
268 | #define define_data(type, field) \ |
269 | template <> \ |
270 | inline type* Tensor::data<type>() { \ |
271 | if (is_raw_data_) { \ |
272 | return (type*)const_cast<char*>(&raw_data_.data()[0]); \ |
273 | } else { \ |
274 | return field.data(); \ |
275 | } \ |
276 | } \ |
277 | \ |
278 | template <> \ |
279 | inline const type* Tensor::data<type>() const { \ |
280 | if (is_raw_data_) { \ |
281 | return (const type*)(raw_data_.data()); \ |
282 | } else { \ |
283 | return field.data(); \ |
284 | } \ |
285 | } |
286 | |
287 | define_data(float, float_data_); |
288 | define_data(double, double_data_); |
289 | define_data(int32_t, int32_data_); |
290 | define_data(int64_t, int64_data_); |
291 | define_data(uint64_t, uint64_data_); |
292 | define_data(std::string, string_data_); |
293 | #undef define_data |
294 | |
295 | template <typename F, typename T> |
296 | inline void Tensor::bin_func(const F& f, T* ptr, const T* a_ptr) { |
297 | const int64_t num_elements = size_from_dim(0); |
298 | for (int64_t i = 0; i < num_elements; ++i) { |
299 | ptr[i] = f(ptr[i], a_ptr[i]); |
300 | } |
301 | } |
302 | |
303 | template <typename F, typename T> |
304 | inline void Tensor::un_func(const F& f, T* ptr) { |
305 | const int64_t num_elements = size_from_dim(0); |
306 | for (int64_t i = 0; i < num_elements; ++i) { |
307 | ptr[i] = f(ptr[i]); |
308 | } |
309 | } |
310 | |
311 | template <typename T> |
312 | inline void Tensor::scale_dim(T* ptr, const T* s_ptr) { |
313 | int64_t elems_per_first_dim = size_from_dim(1); |
314 | int64_t first_dim_size = sizes_[0]; |
315 | int64_t counter = 0; |
316 | for (int64_t i = 0; i < first_dim_size; ++i) { |
317 | for (int64_t j = 0; j < elems_per_first_dim; ++j) { |
318 | ptr[counter++] *= s_ptr[i]; |
319 | } |
320 | } |
321 | } |
322 | |
323 | #define APPLY_BINARY_FUNCTION(op_name, f) \ |
324 | inline void Tensor::op_name(const Tensor& other) { \ |
325 | TENSOR_ASSERTM( \ |
326 | other.elem_type() == elem_type_, \ |
327 | "Tensor types do not match: %s != %s", \ |
328 | to_string(elem_type_).c_str(), \ |
329 | " vs. ", \ |
330 | to_string(other.elem_type()).c_str()); \ |
331 | TENSOR_ASSERTM(other.sizes() == sizes_, "Tensor sizes do not match."); \ |
332 | switch (elem_type_) { \ |
333 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { \ |
334 | bin_func(f<float>(), data<float>(), other.data<float>()); \ |
335 | break; \ |
336 | } \ |
337 | case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ |
338 | case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ |
339 | case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ |
340 | case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ |
341 | case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ |
342 | case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { \ |
343 | bin_func(f<int32_t>(), data<int32_t>(), other.data<int32_t>()); \ |
344 | break; \ |
345 | } \ |
346 | case ONNX_NAMESPACE::TensorProto_DataType_INT64: { \ |
347 | bin_func(f<int64_t>(), data<int64_t>(), other.data<int64_t>()); \ |
348 | break; \ |
349 | } \ |
350 | case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ |
351 | case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { \ |
352 | bin_func(f<uint64_t>(), data<uint64_t>(), other.data<uint64_t>()); \ |
353 | break; \ |
354 | } \ |
355 | case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { \ |
356 | bin_func(f<double>(), data<double>(), other.data<double>()); \ |
357 | break; \ |
358 | } \ |
359 | default: \ |
360 | TENSOR_ASSERTM( \ |
361 | false, \ |
362 | "Operation %s not supported for data type %s", \ |
363 | #op_name, \ |
364 | " not supported for data type ", \ |
365 | to_string(elem_type_).c_str()); \ |
366 | } \ |
367 | } |
368 | |
369 | APPLY_BINARY_FUNCTION(add, std::plus) |
370 | APPLY_BINARY_FUNCTION(subtract, std::minus) |
371 | APPLY_BINARY_FUNCTION(multiply, std::multiplies) |
372 | APPLY_BINARY_FUNCTION(divide, std::divides) |
373 | |
374 | #undef APPLY_BINARY_FUNCTION |
375 | |
376 | inline void Tensor::sqrt() { |
377 | switch (elem_type_) { |
378 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { |
379 | un_func<float (*)(float), float>(std::sqrt, data<float>()); |
380 | break; |
381 | } |
382 | case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { |
383 | un_func<double (*)(double), double>(std::sqrt, data<double>()); |
384 | break; |
385 | } |
386 | default: |
387 | TENSOR_ASSERTM(false, "Operation sqrt not supported for data type %s" , to_string(elem_type_).c_str()); |
388 | } |
389 | } |
390 | |
391 | inline void Tensor::scale_by_first_dim(const Tensor& other) { |
392 | ONNX_ASSERT(sizes_.size() > 1 && other.sizes().size() == 1 && other.sizes()[0] == sizes_[0]); |
393 | ONNX_ASSERT(other.elem_type() == elem_type_); |
394 | |
395 | switch (elem_type_) { |
396 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { |
397 | scale_dim(data<float>(), other.data<float>()); |
398 | break; |
399 | } |
400 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { |
401 | scale_dim(data<int32_t>(), other.data<int32_t>()); |
402 | break; |
403 | } |
404 | case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { |
405 | scale_dim(data<double>(), other.data<double>()); |
406 | break; |
407 | } |
408 | default: |
409 | TENSOR_ASSERTM( |
410 | false, "Operation scale_by_first_dim not supported for data type %s" , to_string(elem_type_).c_str()); |
411 | } |
412 | } |
413 | |
414 | } // namespace ONNX_NAMESPACE |
415 | |