1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stdint.h> |
17 | |
18 | #include <algorithm> |
19 | #include <functional> |
20 | |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
23 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
24 | #include "tensorflow/lite/kernels/kernel_util.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace unsorted_segment { |
30 | |
31 | enum SegmentType { |
32 | kSegmentMax, |
33 | kSegmentMin, |
34 | kSegmentProd, |
35 | kSegmentSum, |
36 | }; |
37 | |
38 | static const int kInputDataTensor = 0; |
39 | static const int kInputSegmentIdsTensor = 1; |
40 | static const int kInputNumSegmentsTensor = 2; |
41 | static const int kOutputTensor = 0; |
42 | |
43 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
44 | const TfLiteTensor* data, |
45 | const TfLiteTensor* segment_ids, |
46 | const TfLiteTensor* num_segments, |
47 | TfLiteTensor* output) { |
48 | // The shape of segment_ids is permitted to be any non-empty prefix of |
49 | // the input data's shape. The shape of output's first dimension is always |
50 | // equal to num_segments. The remaining dimensions of output's shape are then |
51 | // taken to be the suffix of input shape after rank(segment_ids)th position. |
52 | // Public facing tensorflow erroneously describe unsorted_segment ops as only |
53 | // supporting segment_ids of rank 1, however tensorflow implementation |
54 | // supports higher dimensional segment_ids as described. |
55 | const int segment_ids_rank = NumDimensions(segment_ids); |
56 | const int data_rank = NumDimensions(data); |
57 | TF_LITE_ENSURE(context, segment_ids_rank <= data_rank); |
58 | for (int i = 0; i < segment_ids_rank; ++i) { |
59 | // segment_ids shape must be prefix of data shape. |
60 | TF_LITE_ENSURE_EQ(context, segment_ids->dims->data[i], data->dims->data[i]); |
61 | } |
62 | TF_LITE_ENSURE(context, (num_segments->dims->size == 1 && |
63 | num_segments->dims->data[0] == 1) || |
64 | num_segments->dims->size == 0); |
65 | // num_segments can be thought of as number of buckets (segments) in output, |
66 | // where each segment is the reduction of all elements mapped to that |
67 | // segment_ids. The shape of said elements is the respective |
68 | // suffix of the data shape. |
69 | int32_t num_segments_ = GetTensorData<int32_t>(num_segments)[0]; |
70 | const int num_segment_ids = NumElements(segment_ids); |
71 | int max_index = -1; |
72 | for (int i = 0; i < num_segment_ids; i++) { |
73 | max_index = std::max(GetTensorData<int32_t>(segment_ids)[i], max_index); |
74 | } |
75 | // num_segments_ must be at greater than max_index else would map elements |
76 | // to non existent output segments. |
77 | TF_LITE_ENSURE(context, max_index < num_segments_); |
78 | const int output_rank = data_rank - segment_ids_rank + 1; |
79 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); |
80 | output_shape->data[0] = num_segments_; |
81 | // output_shape[1:] should be data_shape[Rank(segment_ids):] |
82 | for (int i = segment_ids_rank; i < data_rank; ++i) { |
83 | output_shape->data[i - segment_ids_rank + 1] = data->dims->data[i]; |
84 | } |
85 | return context->ResizeTensor(context, output, output_shape); |
86 | } |
87 | |
88 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
89 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); |
90 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
91 | const TfLiteTensor* data; |
92 | TF_LITE_ENSURE_OK(context, |
93 | GetInputSafe(context, node, kInputDataTensor, &data)); |
94 | const TfLiteTensor* segment_ids; |
95 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor, |
96 | &segment_ids)); |
97 | const TfLiteTensor* num_segments; |
98 | TF_LITE_ENSURE_OK( |
99 | context, |
100 | GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments)); |
101 | TfLiteTensor* output; |
102 | TF_LITE_ENSURE_OK(context, |
103 | GetOutputSafe(context, node, kOutputTensor, &output)); |
104 | TF_LITE_ENSURE(context, |
105 | data->type == kTfLiteInt32 || data->type == kTfLiteFloat32); |
106 | TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32); |
107 | TF_LITE_ENSURE_EQ(context, num_segments->type, kTfLiteInt32); |
108 | |
109 | if (IsDynamicTensor(data) || !IsConstantTensor(segment_ids) || |
110 | !IsConstantTensor(num_segments)) { |
111 | SetTensorToDynamic(output); |
112 | return kTfLiteOk; |
113 | } |
114 | return ResizeOutputTensor(context, data, segment_ids, num_segments, output); |
115 | } |
116 | |
117 | template <typename T> |
118 | struct SegmenMax { |
119 | inline T operator()(const T& a, const T& b) const { return std::max(a, b); } |
120 | static constexpr T kInitialValue = std::numeric_limits<T>::lowest(); |
121 | }; |
122 | |
123 | template <typename T> |
124 | struct SegmenMin { |
125 | inline T operator()(const T& a, const T& b) const { return std::min(a, b); } |
126 | static constexpr T kInitialValue = std::numeric_limits<T>::max(); |
127 | }; |
128 | |
129 | template <typename T> |
130 | struct SegmenProd { |
131 | inline T operator()(const T& a, const T& b) const { return a * b; } |
132 | static constexpr T kInitialValue = T(1); |
133 | }; |
134 | |
135 | template <typename T> |
136 | struct SegmenSum { |
137 | inline T operator()(const T& a, const T& b) const { return a + b; } |
138 | static constexpr T kInitialValue = T(0); |
139 | }; |
140 | |
141 | template <typename T> |
142 | TfLiteStatus EvalType(TfLiteContext* context, const RuntimeShape& input_shape, |
143 | const T* input_data, |
144 | const RuntimeShape& segment_ids_shape, |
145 | const int32_t* segment_ids_data, |
146 | const RuntimeShape& output_shape, T* output_data, |
147 | SegmentType segment_type) { |
148 | switch (segment_type) { |
149 | case kSegmentProd: |
150 | reference_ops::UnsortedSegmentRef<T, SegmenProd>( |
151 | input_shape, input_data, segment_ids_shape, segment_ids_data, |
152 | output_shape, output_data); |
153 | break; |
154 | case kSegmentMax: |
155 | reference_ops::UnsortedSegmentRef<T, SegmenMax>( |
156 | input_shape, input_data, segment_ids_shape, segment_ids_data, |
157 | output_shape, output_data); |
158 | break; |
159 | case kSegmentSum: |
160 | reference_ops::UnsortedSegmentRef<T, SegmenSum>( |
161 | input_shape, input_data, segment_ids_shape, segment_ids_data, |
162 | output_shape, output_data); |
163 | break; |
164 | case kSegmentMin: |
165 | reference_ops::UnsortedSegmentRef<T, SegmenMin>( |
166 | input_shape, input_data, segment_ids_shape, segment_ids_data, |
167 | output_shape, output_data); |
168 | break; |
169 | default: |
170 | TF_LITE_KERNEL_LOG(context, "Not recognized segment type: %d" , |
171 | segment_type); |
172 | return kTfLiteError; |
173 | } |
174 | return kTfLiteOk; |
175 | } |
176 | |
177 | TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node, |
178 | SegmentType segment_type) { |
179 | const TfLiteTensor* data; |
180 | TF_LITE_ENSURE_OK(context, |
181 | GetInputSafe(context, node, kInputDataTensor, &data)); |
182 | const TfLiteTensor* segment_ids; |
183 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor, |
184 | &segment_ids)); |
185 | const TfLiteTensor* num_segments; |
186 | TF_LITE_ENSURE_OK( |
187 | context, |
188 | GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments)); |
189 | TfLiteTensor* output; |
190 | TF_LITE_ENSURE_OK(context, |
191 | GetOutputSafe(context, node, kOutputTensor, &output)); |
192 | |
193 | if (IsDynamicTensor(output)) { |
194 | TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, data, segment_ids, |
195 | num_segments, output)); |
196 | } |
197 | TF_LITE_ENSURE_EQ(context, GetTensorShape(data).Dims(0), |
198 | GetTensorShape(segment_ids).Dims(0)); |
199 | |
200 | #define TF_LITE_UNSORTED_SEGMENT(dtype) \ |
201 | EvalType<dtype>(context, GetTensorShape(data), GetTensorData<dtype>(data), \ |
202 | GetTensorShape(segment_ids), \ |
203 | GetTensorData<int32_t>(segment_ids), GetTensorShape(output), \ |
204 | GetTensorData<dtype>(output), segment_type); |
205 | switch (data->type) { |
206 | case kTfLiteInt32: |
207 | TF_LITE_UNSORTED_SEGMENT(int32_t); |
208 | break; |
209 | case kTfLiteFloat32: |
210 | TF_LITE_UNSORTED_SEGMENT(float); |
211 | break; |
212 | default: |
213 | TF_LITE_KERNEL_LOG( |
214 | context, "Currently UnsortedSegment doesn't support data type: %s" , |
215 | TfLiteTypeGetName(data->type)); |
216 | return kTfLiteError; |
217 | } |
218 | #undef TF_LITE_UNSORTED_SEGMENT |
219 | return kTfLiteOk; |
220 | } |
221 | |
222 | TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) { |
223 | return EvalGeneric(context, node, kSegmentProd); |
224 | } |
225 | TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) { |
226 | return EvalGeneric(context, node, kSegmentMax); |
227 | } |
228 | TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { |
229 | return EvalGeneric(context, node, kSegmentSum); |
230 | } |
231 | TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) { |
232 | return EvalGeneric(context, node, kSegmentMin); |
233 | } |
234 | |
235 | } // namespace unsorted_segment |
236 | |
237 | TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD() { |
238 | static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, |
239 | unsorted_segment::EvalProd}; |
240 | return &r; |
241 | } |
242 | |
243 | TfLiteRegistration* Register_UNSORTED_SEGMENT_MAX() { |
244 | static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, |
245 | unsorted_segment::EvalMax}; |
246 | return &r; |
247 | } |
248 | |
249 | TfLiteRegistration* Register_UNSORTED_SEGMENT_SUM() { |
250 | static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, |
251 | unsorted_segment::EvalSum}; |
252 | return &r; |
253 | } |
254 | |
255 | TfLiteRegistration* Register_UNSORTED_SEGMENT_MIN() { |
256 | static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare, |
257 | unsorted_segment::EvalMin}; |
258 | return &r; |
259 | } |
260 | |
261 | } // namespace builtin |
262 | } // namespace ops |
263 | } // namespace tflite |
264 | |