1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace segment_sum {
28
29static const int kInputDataTensor = 0;
30static const int kInputSegmentIdsTensor = 1;
31static const int kOutputTensor = 0;
32
33TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
34 const TfLiteTensor* data,
35 const TfLiteTensor* segment_ids,
36 TfLiteTensor* output) {
37 // Segment ids should be of same cardinality as first input dimension and they
38 // should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid)
39 const int segment_id_size = segment_ids->dims->data[0];
40 TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]);
41 int previous_segment_id = -1;
42 for (int i = 0; i < segment_id_size; i++) {
43 const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i];
44 if (i == 0) {
45 TF_LITE_ENSURE_EQ(context, current_segment_id, 0);
46 } else {
47 int delta = current_segment_id - previous_segment_id;
48 TF_LITE_ENSURE(context, delta == 0 || delta == 1);
49 }
50 previous_segment_id = current_segment_id;
51 }
52
53 const int max_index = previous_segment_id;
54
55 const int data_rank = NumDimensions(data);
56 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
57 output_shape->data[0] = max_index + 1;
58 for (int i = 1; i < data_rank; ++i) {
59 output_shape->data[i] = data->dims->data[i];
60 }
61 return context->ResizeTensor(context, output, output_shape);
62}
63
64TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
65 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
66 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
67 const TfLiteTensor* data;
68 TF_LITE_ENSURE_OK(context,
69 GetInputSafe(context, node, kInputDataTensor, &data));
70 const TfLiteTensor* segment_ids;
71 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
72 &segment_ids));
73 TfLiteTensor* output;
74 TF_LITE_ENSURE_OK(context,
75 GetOutputSafe(context, node, kOutputTensor, &output));
76 TF_LITE_ENSURE(context,
77 data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
78 TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
79
80 if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) {
81 SetTensorToDynamic(output);
82 return kTfLiteOk;
83 }
84
85 return ResizeOutputTensor(context, data, segment_ids, output);
86}
87
88TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
89 const TfLiteTensor* data;
90 TF_LITE_ENSURE_OK(context,
91 GetInputSafe(context, node, kInputDataTensor, &data));
92 const TfLiteTensor* segment_ids;
93 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
94 &segment_ids));
95 TfLiteTensor* output;
96 TF_LITE_ENSURE_OK(context,
97 GetOutputSafe(context, node, kOutputTensor, &output));
98
99 if (IsDynamicTensor(output)) {
100 TF_LITE_ENSURE_OK(context,
101 ResizeOutputTensor(context, data, segment_ids, output));
102 }
103
104#define TF_LITE_SEGMENT_SUM(dtype) \
105 reference_ops::SegmentSum<dtype>( \
106 GetTensorShape(data), GetTensorData<dtype>(data), \
107 GetTensorShape(segment_ids), GetTensorData<int32_t>(segment_ids), \
108 GetTensorShape(output), GetTensorData<dtype>(output));
109 switch (data->type) {
110 case kTfLiteInt32:
111 TF_LITE_SEGMENT_SUM(int32_t);
112 break;
113 case kTfLiteFloat32:
114 TF_LITE_SEGMENT_SUM(float);
115 break;
116 default:
117 TF_LITE_KERNEL_LOG(context,
118 "Currently SegmentSum doesn't support type: %s",
119 TfLiteTypeGetName(data->type));
120 return kTfLiteError;
121 }
122#undef TF_LITE_SEGMENT_SUM
123 return kTfLiteOk;
124}
125
126} // namespace segment_sum
127
128TfLiteRegistration* Register_SEGMENT_SUM() {
129 static TfLiteRegistration r = {nullptr, nullptr, segment_sum::Prepare,
130 segment_sum::Eval};
131 return &r;
132}
133
134} // namespace builtin
135} // namespace ops
136} // namespace tflite
137