1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stdint.h> |
17 | |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
20 | #include "tensorflow/lite/kernels/internal/tensor.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
22 | #include "tensorflow/lite/kernels/kernel_util.h" |
23 | |
24 | namespace tflite { |
25 | namespace ops { |
26 | namespace builtin { |
27 | namespace segment_sum { |
28 | |
29 | static const int kInputDataTensor = 0; |
30 | static const int kInputSegmentIdsTensor = 1; |
31 | static const int kOutputTensor = 0; |
32 | |
33 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
34 | const TfLiteTensor* data, |
35 | const TfLiteTensor* segment_ids, |
36 | TfLiteTensor* output) { |
37 | // Segment ids should be of same cardinality as first input dimension and they |
38 | // should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid) |
39 | const int segment_id_size = segment_ids->dims->data[0]; |
40 | TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]); |
41 | int previous_segment_id = -1; |
42 | for (int i = 0; i < segment_id_size; i++) { |
43 | const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i]; |
44 | if (i == 0) { |
45 | TF_LITE_ENSURE_EQ(context, current_segment_id, 0); |
46 | } else { |
47 | int delta = current_segment_id - previous_segment_id; |
48 | TF_LITE_ENSURE(context, delta == 0 || delta == 1); |
49 | } |
50 | previous_segment_id = current_segment_id; |
51 | } |
52 | |
53 | const int max_index = previous_segment_id; |
54 | |
55 | const int data_rank = NumDimensions(data); |
56 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data)); |
57 | output_shape->data[0] = max_index + 1; |
58 | for (int i = 1; i < data_rank; ++i) { |
59 | output_shape->data[i] = data->dims->data[i]; |
60 | } |
61 | return context->ResizeTensor(context, output, output_shape); |
62 | } |
63 | |
64 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
65 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
66 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
67 | const TfLiteTensor* data; |
68 | TF_LITE_ENSURE_OK(context, |
69 | GetInputSafe(context, node, kInputDataTensor, &data)); |
70 | const TfLiteTensor* segment_ids; |
71 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor, |
72 | &segment_ids)); |
73 | TfLiteTensor* output; |
74 | TF_LITE_ENSURE_OK(context, |
75 | GetOutputSafe(context, node, kOutputTensor, &output)); |
76 | TF_LITE_ENSURE(context, |
77 | data->type == kTfLiteInt32 || data->type == kTfLiteFloat32); |
78 | TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32); |
79 | |
80 | if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) { |
81 | SetTensorToDynamic(output); |
82 | return kTfLiteOk; |
83 | } |
84 | |
85 | return ResizeOutputTensor(context, data, segment_ids, output); |
86 | } |
87 | |
88 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
89 | const TfLiteTensor* data; |
90 | TF_LITE_ENSURE_OK(context, |
91 | GetInputSafe(context, node, kInputDataTensor, &data)); |
92 | const TfLiteTensor* segment_ids; |
93 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor, |
94 | &segment_ids)); |
95 | TfLiteTensor* output; |
96 | TF_LITE_ENSURE_OK(context, |
97 | GetOutputSafe(context, node, kOutputTensor, &output)); |
98 | |
99 | if (IsDynamicTensor(output)) { |
100 | TF_LITE_ENSURE_OK(context, |
101 | ResizeOutputTensor(context, data, segment_ids, output)); |
102 | } |
103 | |
104 | #define TF_LITE_SEGMENT_SUM(dtype) \ |
105 | reference_ops::SegmentSum<dtype>( \ |
106 | GetTensorShape(data), GetTensorData<dtype>(data), \ |
107 | GetTensorShape(segment_ids), GetTensorData<int32_t>(segment_ids), \ |
108 | GetTensorShape(output), GetTensorData<dtype>(output)); |
109 | switch (data->type) { |
110 | case kTfLiteInt32: |
111 | TF_LITE_SEGMENT_SUM(int32_t); |
112 | break; |
113 | case kTfLiteFloat32: |
114 | TF_LITE_SEGMENT_SUM(float); |
115 | break; |
116 | default: |
117 | TF_LITE_KERNEL_LOG(context, |
118 | "Currently SegmentSum doesn't support type: %s" , |
119 | TfLiteTypeGetName(data->type)); |
120 | return kTfLiteError; |
121 | } |
122 | #undef TF_LITE_SEGMENT_SUM |
123 | return kTfLiteOk; |
124 | } |
125 | |
126 | } // namespace segment_sum |
127 | |
128 | TfLiteRegistration* Register_SEGMENT_SUM() { |
129 | static TfLiteRegistration r = {nullptr, nullptr, segment_sum::Prepare, |
130 | segment_sum::Eval}; |
131 | return &r; |
132 | } |
133 | |
134 | } // namespace builtin |
135 | } // namespace ops |
136 | } // namespace tflite |
137 | |