1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <math.h> |
17 | #include <stdint.h> |
18 | #include <stdlib.h> |
19 | |
20 | #include <functional> |
21 | #include <type_traits> |
22 | |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
25 | #include "tensorflow/lite/kernels/internal/tensor.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
27 | #include "tensorflow/lite/kernels/kernel_util.h" |
28 | |
29 | namespace tflite { |
30 | namespace ops { |
31 | namespace builtin { |
32 | namespace range { |
33 | namespace { |
34 | |
35 | constexpr int kStartTensor = 0; |
36 | constexpr int kLimitTensor = 1; |
37 | constexpr int kDeltaTensor = 2; |
38 | constexpr int kOutputTensor = 0; |
39 | |
40 | template <typename T> |
41 | TfLiteStatus GetSize(TfLiteContext* context, T start, T limit, T delta, |
42 | int* size) { |
43 | TF_LITE_ENSURE(context, !std::equal_to<T>()(delta, 0)); |
44 | TF_LITE_ENSURE( |
45 | context, (start >= limit && delta < 0) || (start <= limit && delta > 0)); |
46 | *size = |
47 | (std::is_integral<T>::value |
48 | ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) |
49 | : std::ceil(std::abs((limit - start) / delta))); |
50 | return kTfLiteOk; |
51 | } |
52 | |
53 | TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* start, |
54 | const TfLiteTensor* limit, const TfLiteTensor* delta, |
55 | TfLiteTensor* output) { |
56 | // The output will always be a 1-d array. |
57 | int size = 0; |
58 | switch (start->type) { |
59 | case kTfLiteInt32: { |
60 | TF_LITE_ENSURE_OK(context, |
61 | GetSize(context, *GetTensorData<int32_t>(start), |
62 | *GetTensorData<int32_t>(limit), |
63 | *GetTensorData<int32_t>(delta), &size)); |
64 | break; |
65 | } |
66 | case kTfLiteFloat32: { |
67 | TF_LITE_ENSURE_OK(context, GetSize(context, *GetTensorData<float>(start), |
68 | *GetTensorData<float>(limit), |
69 | *GetTensorData<float>(delta), &size)); |
70 | break; |
71 | } |
72 | default: { |
73 | TF_LITE_KERNEL_LOG(context, "Unknown data type: %d" , start->type); |
74 | return kTfLiteError; |
75 | } |
76 | } |
77 | TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(1); |
78 | output_shape_array->data[0] = size; |
79 | return context->ResizeTensor(context, output, output_shape_array); |
80 | } |
81 | |
82 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
83 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); |
84 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
85 | |
86 | const TfLiteTensor* start; |
87 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start)); |
88 | const TfLiteTensor* limit; |
89 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit)); |
90 | const TfLiteTensor* delta; |
91 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta)); |
92 | // Make sure all the inputs are scalars. |
93 | TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0); |
94 | TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0); |
95 | TF_LITE_ENSURE_EQ(context, NumDimensions(delta), 0); |
96 | |
97 | // Currently only supports int32 and float. |
98 | // TODO(b/117912892): Support quantization as well. |
99 | const auto dtype = start->type; |
100 | if (dtype != kTfLiteFloat32 && dtype != kTfLiteInt32) { |
101 | TF_LITE_KERNEL_LOG(context, "Unknown index output data type: %s" , |
102 | TfLiteTypeGetName(dtype)); |
103 | return kTfLiteError; |
104 | } |
105 | |
106 | TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype); |
107 | TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype); |
108 | |
109 | TfLiteTensor* output; |
110 | TF_LITE_ENSURE_OK(context, |
111 | GetOutputSafe(context, node, kOutputTensor, &output)); |
112 | output->type = dtype; |
113 | |
114 | if (IsConstantTensor(start) && IsConstantTensor(limit) && |
115 | IsConstantTensor(delta)) { |
116 | return ResizeOutput(context, start, limit, delta, output); |
117 | } |
118 | |
119 | SetTensorToDynamic(output); |
120 | return kTfLiteOk; |
121 | } |
122 | |
123 | template <typename T> |
124 | void EvalImpl(const TfLiteTensor* start, const TfLiteTensor* delta, |
125 | TfLiteTensor* output) { |
126 | const T start_value = *GetTensorData<T>(start); |
127 | const T delta_value = *GetTensorData<T>(delta); |
128 | T* output_data = GetTensorData<T>(output); |
129 | const int num_elements = NumElements(output); |
130 | T value = start_value; |
131 | for (int i = 0; i < num_elements; ++i) { |
132 | output_data[i] = value; |
133 | value += delta_value; |
134 | } |
135 | } |
136 | |
137 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
138 | const TfLiteTensor* start; |
139 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start)); |
140 | const TfLiteTensor* limit; |
141 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit)); |
142 | const TfLiteTensor* delta; |
143 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta)); |
144 | |
145 | TfLiteTensor* output; |
146 | TF_LITE_ENSURE_OK(context, |
147 | GetOutputSafe(context, node, kOutputTensor, &output)); |
148 | |
149 | if (IsDynamicTensor(output)) { |
150 | TF_LITE_ENSURE_OK(context, |
151 | ResizeOutput(context, start, limit, delta, output)); |
152 | } |
153 | |
154 | switch (output->type) { |
155 | case kTfLiteInt32: { |
156 | EvalImpl<int32_t>(start, delta, output); |
157 | break; |
158 | } |
159 | case kTfLiteFloat32: { |
160 | EvalImpl<float>(start, delta, output); |
161 | break; |
162 | } |
163 | default: { |
164 | TF_LITE_KERNEL_LOG(context, "Unsupported data type: %d" , output->type); |
165 | return kTfLiteError; |
166 | } |
167 | } |
168 | return kTfLiteOk; |
169 | } |
170 | |
171 | } // namespace |
172 | } // namespace range |
173 | |
174 | TfLiteRegistration* Register_RANGE() { |
175 | static TfLiteRegistration r = {nullptr, nullptr, range::Prepare, range::Eval}; |
176 | return &r; |
177 | } |
178 | |
179 | } // namespace builtin |
180 | } // namespace ops |
181 | } // namespace tflite |
182 | |