1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include <algorithm>
19#include <functional>
20
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace unsorted_segment {
30
31enum SegmentType {
32 kSegmentMax,
33 kSegmentMin,
34 kSegmentProd,
35 kSegmentSum,
36};
37
38static const int kInputDataTensor = 0;
39static const int kInputSegmentIdsTensor = 1;
40static const int kInputNumSegmentsTensor = 2;
41static const int kOutputTensor = 0;
42
43TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
44 const TfLiteTensor* data,
45 const TfLiteTensor* segment_ids,
46 const TfLiteTensor* num_segments,
47 TfLiteTensor* output) {
48 // The shape of segment_ids is permitted to be any non-empty prefix of
49 // the input data's shape. The shape of output's first dimension is always
50 // equal to num_segments. The remaining dimensions of output's shape are then
51 // taken to be the suffix of input shape after rank(segment_ids)th position.
52 // Public facing tensorflow erroneously describe unsorted_segment ops as only
53 // supporting segment_ids of rank 1, however tensorflow implementation
54 // supports higher dimensional segment_ids as described.
55 const int segment_ids_rank = NumDimensions(segment_ids);
56 const int data_rank = NumDimensions(data);
57 TF_LITE_ENSURE(context, segment_ids_rank <= data_rank);
58 for (int i = 0; i < segment_ids_rank; ++i) {
59 // segment_ids shape must be prefix of data shape.
60 TF_LITE_ENSURE_EQ(context, segment_ids->dims->data[i], data->dims->data[i]);
61 }
62 TF_LITE_ENSURE(context, (num_segments->dims->size == 1 &&
63 num_segments->dims->data[0] == 1) ||
64 num_segments->dims->size == 0);
65 // num_segments can be thought of as number of buckets (segments) in output,
66 // where each segment is the reduction of all elements mapped to that
67 // segment_ids. The shape of said elements is the respective
68 // suffix of the data shape.
69 int32_t num_segments_ = GetTensorData<int32_t>(num_segments)[0];
70 const int num_segment_ids = NumElements(segment_ids);
71 int max_index = -1;
72 for (int i = 0; i < num_segment_ids; i++) {
73 max_index = std::max(GetTensorData<int32_t>(segment_ids)[i], max_index);
74 }
75 // num_segments_ must be at greater than max_index else would map elements
76 // to non existent output segments.
77 TF_LITE_ENSURE(context, max_index < num_segments_);
78 const int output_rank = data_rank - segment_ids_rank + 1;
79 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
80 output_shape->data[0] = num_segments_;
81 // output_shape[1:] should be data_shape[Rank(segment_ids):]
82 for (int i = segment_ids_rank; i < data_rank; ++i) {
83 output_shape->data[i - segment_ids_rank + 1] = data->dims->data[i];
84 }
85 return context->ResizeTensor(context, output, output_shape);
86}
87
88TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
89 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
90 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
91 const TfLiteTensor* data;
92 TF_LITE_ENSURE_OK(context,
93 GetInputSafe(context, node, kInputDataTensor, &data));
94 const TfLiteTensor* segment_ids;
95 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
96 &segment_ids));
97 const TfLiteTensor* num_segments;
98 TF_LITE_ENSURE_OK(
99 context,
100 GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
101 TfLiteTensor* output;
102 TF_LITE_ENSURE_OK(context,
103 GetOutputSafe(context, node, kOutputTensor, &output));
104 TF_LITE_ENSURE(context,
105 data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
106 TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
107 TF_LITE_ENSURE_EQ(context, num_segments->type, kTfLiteInt32);
108
109 if (IsDynamicTensor(data) || !IsConstantTensor(segment_ids) ||
110 !IsConstantTensor(num_segments)) {
111 SetTensorToDynamic(output);
112 return kTfLiteOk;
113 }
114 return ResizeOutputTensor(context, data, segment_ids, num_segments, output);
115}
116
117template <typename T>
118struct SegmenMax {
119 inline T operator()(const T& a, const T& b) const { return std::max(a, b); }
120 static constexpr T kInitialValue = std::numeric_limits<T>::lowest();
121};
122
123template <typename T>
124struct SegmenMin {
125 inline T operator()(const T& a, const T& b) const { return std::min(a, b); }
126 static constexpr T kInitialValue = std::numeric_limits<T>::max();
127};
128
129template <typename T>
130struct SegmenProd {
131 inline T operator()(const T& a, const T& b) const { return a * b; }
132 static constexpr T kInitialValue = T(1);
133};
134
135template <typename T>
136struct SegmenSum {
137 inline T operator()(const T& a, const T& b) const { return a + b; }
138 static constexpr T kInitialValue = T(0);
139};
140
141template <typename T>
142TfLiteStatus EvalType(TfLiteContext* context, const RuntimeShape& input_shape,
143 const T* input_data,
144 const RuntimeShape& segment_ids_shape,
145 const int32_t* segment_ids_data,
146 const RuntimeShape& output_shape, T* output_data,
147 SegmentType segment_type) {
148 switch (segment_type) {
149 case kSegmentProd:
150 reference_ops::UnsortedSegmentRef<T, SegmenProd>(
151 input_shape, input_data, segment_ids_shape, segment_ids_data,
152 output_shape, output_data);
153 break;
154 case kSegmentMax:
155 reference_ops::UnsortedSegmentRef<T, SegmenMax>(
156 input_shape, input_data, segment_ids_shape, segment_ids_data,
157 output_shape, output_data);
158 break;
159 case kSegmentSum:
160 reference_ops::UnsortedSegmentRef<T, SegmenSum>(
161 input_shape, input_data, segment_ids_shape, segment_ids_data,
162 output_shape, output_data);
163 break;
164 case kSegmentMin:
165 reference_ops::UnsortedSegmentRef<T, SegmenMin>(
166 input_shape, input_data, segment_ids_shape, segment_ids_data,
167 output_shape, output_data);
168 break;
169 default:
170 TF_LITE_KERNEL_LOG(context, "Not recognized segment type: %d",
171 segment_type);
172 return kTfLiteError;
173 }
174 return kTfLiteOk;
175}
176
177TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node,
178 SegmentType segment_type) {
179 const TfLiteTensor* data;
180 TF_LITE_ENSURE_OK(context,
181 GetInputSafe(context, node, kInputDataTensor, &data));
182 const TfLiteTensor* segment_ids;
183 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
184 &segment_ids));
185 const TfLiteTensor* num_segments;
186 TF_LITE_ENSURE_OK(
187 context,
188 GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
189 TfLiteTensor* output;
190 TF_LITE_ENSURE_OK(context,
191 GetOutputSafe(context, node, kOutputTensor, &output));
192
193 if (IsDynamicTensor(output)) {
194 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, data, segment_ids,
195 num_segments, output));
196 }
197 TF_LITE_ENSURE_EQ(context, GetTensorShape(data).Dims(0),
198 GetTensorShape(segment_ids).Dims(0));
199
200#define TF_LITE_UNSORTED_SEGMENT(dtype) \
201 EvalType<dtype>(context, GetTensorShape(data), GetTensorData<dtype>(data), \
202 GetTensorShape(segment_ids), \
203 GetTensorData<int32_t>(segment_ids), GetTensorShape(output), \
204 GetTensorData<dtype>(output), segment_type);
205 switch (data->type) {
206 case kTfLiteInt32:
207 TF_LITE_UNSORTED_SEGMENT(int32_t);
208 break;
209 case kTfLiteFloat32:
210 TF_LITE_UNSORTED_SEGMENT(float);
211 break;
212 default:
213 TF_LITE_KERNEL_LOG(
214 context, "Currently UnsortedSegment doesn't support data type: %s",
215 TfLiteTypeGetName(data->type));
216 return kTfLiteError;
217 }
218#undef TF_LITE_UNSORTED_SEGMENT
219 return kTfLiteOk;
220}
221
222TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
223 return EvalGeneric(context, node, kSegmentProd);
224}
225TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
226 return EvalGeneric(context, node, kSegmentMax);
227}
228TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
229 return EvalGeneric(context, node, kSegmentSum);
230}
231TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
232 return EvalGeneric(context, node, kSegmentMin);
233}
234
235} // namespace unsorted_segment
236
237TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD() {
238 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
239 unsorted_segment::EvalProd};
240 return &r;
241}
242
243TfLiteRegistration* Register_UNSORTED_SEGMENT_MAX() {
244 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
245 unsorted_segment::EvalMax};
246 return &r;
247}
248
249TfLiteRegistration* Register_UNSORTED_SEGMENT_SUM() {
250 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
251 unsorted_segment::EvalSum};
252 return &r;
253}
254
255TfLiteRegistration* Register_UNSORTED_SEGMENT_MIN() {
256 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
257 unsorted_segment::EvalMin};
258 return &r;
259}
260
261} // namespace builtin
262} // namespace ops
263} // namespace tflite
264