1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "onnx/defs/shape_inference.h"
6
7namespace ONNX_NAMESPACE {
8
9inline void appendDimToTensorShapeProto(TensorShapeProto& tsp, const TensorShapeProto* input_data, int index) {
10 if (index >= input_data->dim_size() || index < -input_data->dim_size()) {
11 fail_shape_inference("indices must be in [-rank, rank-1].");
12 } else {
13 *tsp.add_dim() = input_data->dim((index < 0) ? input_data->dim_size() + index : index);
14 }
15}
16
17// Returns true if the given axis attribute is 0
18inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
19 auto axisAttr = ctx.getAttribute("axis");
20 // if axis is not defined
21 if (!axisAttr) {
22 if (defaultZero) {
23 return true;
24 } else {
25 fail_shape_inference("Required attribute axis is missing");
26 return false;
27 }
28 }
29 int axis = static_cast<int>(axisAttr->i());
30 auto input_data_0 = ctx.getInputData(0);
31 if (input_data_0 == nullptr) {
32 return false;
33 }
34 int rank = input_data_0->dim_size();
35 if (axis < -rank || axis >= rank) {
36 fail_shape_inference("axis must be in [-rank, rank-1].");
37 return false;
38 }
39 if (axis < 0) {
40 axis += rank;
41 }
42 // Only supports axis = 0 since the data comes from Shape
43 return axis == 0;
44}
45
46inline void PropagateShapeDataFromInputToOutput(DataPropagationContext& ctx, int idx) {
47 // propogate input data
48 const auto input_data = ctx.getInputData(idx);
49 if (input_data != nullptr) {
50 TensorShapeProto tsp;
51 tsp.CopyFrom(*input_data);
52 ctx.addOutputData(0, std::move(tsp));
53 }
54}
55
56inline void GatherOp13DataPropagator(DataPropagationContext& ctx) {
57 if (!axisIsZero(ctx, true)) {
58 return;
59 }
60 const auto input_data = ctx.getInputData(0);
61 if (input_data == nullptr) {
62 return;
63 }
64 const auto input_indices = ctx.getInputData(1);
65 if (input_data == nullptr || input_indices == nullptr) {
66 return;
67 }
68 TensorShapeProto tsp;
69 for (int i = 0; i < input_indices->dim_size(); ++i) {
70 if (input_indices->dim(i).has_dim_value()) {
71 appendDimToTensorShapeProto(tsp, input_data, input_indices->dim(i).dim_value());
72 } else {
73 return;
74 }
75 }
76 if (tsp.dim_size() > 0) {
77 ctx.addOutputData(0, std::move(tsp));
78 }
79}
80
81} // namespace ONNX_NAMESPACE
82