1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "onnx/defs/shape_inference.h" |
6 | |
7 | namespace ONNX_NAMESPACE { |
8 | |
9 | inline void appendDimToTensorShapeProto(TensorShapeProto& tsp, const TensorShapeProto* input_data, int index) { |
10 | if (index >= input_data->dim_size() || index < -input_data->dim_size()) { |
11 | fail_shape_inference("indices must be in [-rank, rank-1]." ); |
12 | } else { |
13 | *tsp.add_dim() = input_data->dim((index < 0) ? input_data->dim_size() + index : index); |
14 | } |
15 | } |
16 | |
17 | // Returns true if the given axis attribute is 0 |
18 | inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) { |
19 | auto axisAttr = ctx.getAttribute("axis" ); |
20 | // if axis is not defined |
21 | if (!axisAttr) { |
22 | if (defaultZero) { |
23 | return true; |
24 | } else { |
25 | fail_shape_inference("Required attribute axis is missing" ); |
26 | return false; |
27 | } |
28 | } |
29 | int axis = static_cast<int>(axisAttr->i()); |
30 | auto input_data_0 = ctx.getInputData(0); |
31 | if (input_data_0 == nullptr) { |
32 | return false; |
33 | } |
34 | int rank = input_data_0->dim_size(); |
35 | if (axis < -rank || axis >= rank) { |
36 | fail_shape_inference("axis must be in [-rank, rank-1]." ); |
37 | return false; |
38 | } |
39 | if (axis < 0) { |
40 | axis += rank; |
41 | } |
42 | // Only supports axis = 0 since the data comes from Shape |
43 | return axis == 0; |
44 | } |
45 | |
46 | inline void PropagateShapeDataFromInputToOutput(DataPropagationContext& ctx, int idx) { |
47 | // propogate input data |
48 | const auto input_data = ctx.getInputData(idx); |
49 | if (input_data != nullptr) { |
50 | TensorShapeProto tsp; |
51 | tsp.CopyFrom(*input_data); |
52 | ctx.addOutputData(0, std::move(tsp)); |
53 | } |
54 | } |
55 | |
56 | inline void GatherOp13DataPropagator(DataPropagationContext& ctx) { |
57 | if (!axisIsZero(ctx, true)) { |
58 | return; |
59 | } |
60 | const auto input_data = ctx.getInputData(0); |
61 | if (input_data == nullptr) { |
62 | return; |
63 | } |
64 | const auto input_indices = ctx.getInputData(1); |
65 | if (input_data == nullptr || input_indices == nullptr) { |
66 | return; |
67 | } |
68 | TensorShapeProto tsp; |
69 | for (int i = 0; i < input_indices->dim_size(); ++i) { |
70 | if (input_indices->dim(i).has_dim_value()) { |
71 | appendDimToTensorShapeProto(tsp, input_data, input_indices->dim(i).dim_value()); |
72 | } else { |
73 | return; |
74 | } |
75 | } |
76 | if (tsp.dim_size() > 0) { |
77 | ctx.addOutputData(0, std::move(tsp)); |
78 | } |
79 | } |
80 | |
81 | } // namespace ONNX_NAMESPACE |
82 | |