1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // Helper Methods for Adapters |
6 | |
7 | #include "onnx/version_converter/helper.h" |
8 | |
9 | namespace ONNX_NAMESPACE { |
10 | namespace version_conversion { |
11 | int check_numpy_unibroadcastable_and_require_broadcast( |
12 | const std::vector<Dimension>& input1_sizes, |
13 | const std::vector<Dimension>& input2_sizes) { |
14 | // Check that input1 is larger |
15 | if (input1_sizes.size() < input2_sizes.size()) |
16 | return -1; |
17 | // Check that axis is input1_sizes.size()-input2_sizes.size() |
18 | bool broadcast = false; |
19 | int axis = (int)(input1_sizes.size() - input2_sizes.size()); |
20 | for (int i = 0; i < (int)input2_sizes.size(); i++) { |
21 | if (input2_sizes[i].dim != input1_sizes[axis + i].dim && input2_sizes[i].dim != 1) |
22 | return -1; |
23 | if (input2_sizes[i].dim != input1_sizes[axis + i].dim) |
24 | broadcast = true; |
25 | } |
26 | // Return true if broadcasting is required |
27 | if (input1_sizes.size() > input2_sizes.size() || broadcast) |
28 | return 1; |
29 | else |
30 | return 0; |
31 | } |
32 | |
33 | void assert_numpy_multibroadcastable( |
34 | const std::vector<Dimension>& input1_sizes, |
35 | const std::vector<Dimension>& input2_sizes) { |
36 | // Generalize above for multibroadcastable case |
37 | const std::vector<Dimension>* A_ptr; |
38 | const std::vector<Dimension>* B_ptr; |
39 | int A; |
40 | int B; |
41 | if (input1_sizes.size() < input2_sizes.size()) { |
42 | A_ptr = &input2_sizes; |
43 | B_ptr = &input1_sizes; |
44 | A = 2; |
45 | B = 1; |
46 | } else { |
47 | A_ptr = &input1_sizes; |
48 | B_ptr = &input2_sizes; |
49 | A = 1; |
50 | B = 2; |
51 | } |
52 | const std::vector<Dimension>& A_sizes = *A_ptr; |
53 | const std::vector<Dimension>& B_sizes = *B_ptr; |
54 | int axis = (int)(A_sizes.size() - B_sizes.size()); |
55 | for (int i = 0; i < (int)B_sizes.size(); i++) { |
56 | ONNX_ASSERTM( |
57 | B_sizes[i].dim == A_sizes[axis + i].dim || B_sizes[i].dim == 1 || A_sizes[axis + i].dim == 1, |
58 | "Dimension %d of input %d does not match " |
59 | "dimension %d of input %d, and neither's value is 1" , |
60 | i, |
61 | B, |
62 | axis + i, |
63 | A); |
64 | } |
65 | } |
66 | |
67 | void assertNotParams(const std::vector<Dimension>& sizes) { |
68 | for (const Dimension& dim : sizes) { |
69 | ONNX_ASSERTM(dim.is_int, "%s Dimension is a param instead of an int." , dim.param.c_str()); |
70 | } |
71 | } |
72 | |
73 | void assertInputsAvailable(const ArrayRef<Value*>& inputs, const char* name, uint64_t num_inputs) { |
74 | ONNX_ASSERTM( |
75 | inputs.size() == num_inputs, |
76 | "%s in opset version 6 can only broadcast" |
77 | " between %d inputs" , |
78 | name, |
79 | num_inputs); |
80 | for (int i = 0; i < (int)num_inputs; i++) { |
81 | ONNX_ASSERTM(inputs[i]->has_sizes(), "Shape of input %d is not available." , num_inputs); |
82 | assertNotParams(inputs[i]->sizes()); |
83 | } |
84 | } |
85 | } // namespace version_conversion |
86 | } // namespace ONNX_NAMESPACE |
87 | |