1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include <algorithm>
19
20#include "tensorflow/lite/c/builtin_op_data.h"
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/tensor.h"
23#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace bucketize {
30namespace {
31
32constexpr int kInputTensor = 0;
33constexpr int kOutputTensor = 0;
34
35struct OpData {
36 // boundaries array is owned by the buffer housing TfLiteBucketizeParams.
37 const float* boundaries;
38 int num_boundaries;
39};
40
41void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42 auto* op_data = new OpData();
43 const auto* params = reinterpret_cast<const TfLiteBucketizeParams*>(buffer);
44
45 op_data->boundaries = params->boundaries;
46 op_data->num_boundaries = params->num_boundaries;
47 return op_data;
48}
49
50void Free(TfLiteContext* context, void* buffer) {
51 delete reinterpret_cast<OpData*>(buffer);
52}
53
54TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
56 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
57 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
58 if (!std::is_sorted(opdata->boundaries,
59 opdata->boundaries + opdata->num_boundaries)) {
60 TF_LITE_KERNEL_LOG(context, "Expected sorted boundaries");
61 return kTfLiteError;
62 }
63
64 const TfLiteTensor* input;
65 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
66
67 if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
68 input->type != kTfLiteInt64 && input->type != kTfLiteFloat64) {
69 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by bucketize.",
70 TfLiteTypeGetName(input->type));
71 return kTfLiteError;
72 }
73
74 TfLiteTensor* output;
75 TF_LITE_ENSURE_OK(context,
76 GetOutputSafe(context, node, kOutputTensor, &output));
77 output->type = kTfLiteInt32;
78
79 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
80 return context->ResizeTensor(context, output, output_shape);
81}
82
83template <typename T>
84inline void Bucketize(const RuntimeShape& input_shape, const T* input_data,
85 const float* boundaries, int num_boundaries,
86 const RuntimeShape& output_shape, int32_t* output_data) {
87 const int flat_size = MatchingFlatSize(input_shape, output_shape);
88
89 for (int i = 0; i < flat_size; i++) {
90 auto first_bigger_it = std::upper_bound(
91 boundaries, boundaries + num_boundaries, input_data[i]);
92 output_data[i] = first_bigger_it - boundaries;
93 }
94}
95
96template <typename T>
97TfLiteStatus BucketizeImpl(TfLiteContext* context, TfLiteNode* node) {
98 const TfLiteTensor* input;
99 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
100 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
101 TfLiteTensor* output;
102 TF_LITE_ENSURE_OK(context,
103 GetOutputSafe(context, node, kOutputTensor, &output));
104 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt32);
105
106 Bucketize<T>(GetTensorShape(input), GetTensorData<T>(input),
107 opdata->boundaries, opdata->num_boundaries,
108 GetTensorShape(output), GetTensorData<int32_t>(output));
109
110 return kTfLiteOk;
111}
112
113TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
114 const TfLiteTensor* input;
115 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
116
117 switch (input->type) {
118 case kTfLiteFloat32: {
119 return BucketizeImpl<float>(context, node);
120 }
121 case kTfLiteFloat64: {
122 return BucketizeImpl<double>(context, node);
123 }
124 case kTfLiteInt32: {
125 return BucketizeImpl<int32_t>(context, node);
126 }
127 case kTfLiteInt64: {
128 return BucketizeImpl<int64_t>(context, node);
129 }
130 default: {
131 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by bucketize.",
132 TfLiteTypeGetName(input->type));
133 return kTfLiteError;
134 }
135 }
136}
137
138} // namespace
139} // namespace bucketize
140
141TfLiteRegistration* Register_BUCKETIZE() {
142 static TfLiteRegistration r = {bucketize::Init, bucketize::Free,
143 bucketize::Prepare, bucketize::Eval};
144 return &r;
145}
146
147} // namespace builtin
148} // namespace ops
149} // namespace tflite
150