1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <math.h>
17#include <stdint.h>
18#include <stdlib.h>
19
20#include <functional>
21#include <type_traits>
22
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25#include "tensorflow/lite/kernels/internal/tensor.h"
26#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27#include "tensorflow/lite/kernels/kernel_util.h"
28
29namespace tflite {
30namespace ops {
31namespace builtin {
32namespace range {
33namespace {
34
35constexpr int kStartTensor = 0;
36constexpr int kLimitTensor = 1;
37constexpr int kDeltaTensor = 2;
38constexpr int kOutputTensor = 0;
39
40template <typename T>
41TfLiteStatus GetSize(TfLiteContext* context, T start, T limit, T delta,
42 int* size) {
43 TF_LITE_ENSURE(context, !std::equal_to<T>()(delta, 0));
44 TF_LITE_ENSURE(
45 context, (start >= limit && delta < 0) || (start <= limit && delta > 0));
46 *size =
47 (std::is_integral<T>::value
48 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
49 : std::ceil(std::abs((limit - start) / delta)));
50 return kTfLiteOk;
51}
52
53TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* start,
54 const TfLiteTensor* limit, const TfLiteTensor* delta,
55 TfLiteTensor* output) {
56 // The output will always be a 1-d array.
57 int size = 0;
58 switch (start->type) {
59 case kTfLiteInt32: {
60 TF_LITE_ENSURE_OK(context,
61 GetSize(context, *GetTensorData<int32_t>(start),
62 *GetTensorData<int32_t>(limit),
63 *GetTensorData<int32_t>(delta), &size));
64 break;
65 }
66 case kTfLiteFloat32: {
67 TF_LITE_ENSURE_OK(context, GetSize(context, *GetTensorData<float>(start),
68 *GetTensorData<float>(limit),
69 *GetTensorData<float>(delta), &size));
70 break;
71 }
72 default: {
73 TF_LITE_KERNEL_LOG(context, "Unknown data type: %d", start->type);
74 return kTfLiteError;
75 }
76 }
77 TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(1);
78 output_shape_array->data[0] = size;
79 return context->ResizeTensor(context, output, output_shape_array);
80}
81
82TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
83 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
84 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
85
86 const TfLiteTensor* start;
87 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
88 const TfLiteTensor* limit;
89 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
90 const TfLiteTensor* delta;
91 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
92 // Make sure all the inputs are scalars.
93 TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0);
94 TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0);
95 TF_LITE_ENSURE_EQ(context, NumDimensions(delta), 0);
96
97 // Currently only supports int32 and float.
98 // TODO(b/117912892): Support quantization as well.
99 const auto dtype = start->type;
100 if (dtype != kTfLiteFloat32 && dtype != kTfLiteInt32) {
101 TF_LITE_KERNEL_LOG(context, "Unknown index output data type: %s",
102 TfLiteTypeGetName(dtype));
103 return kTfLiteError;
104 }
105
106 TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype);
107 TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype);
108
109 TfLiteTensor* output;
110 TF_LITE_ENSURE_OK(context,
111 GetOutputSafe(context, node, kOutputTensor, &output));
112 output->type = dtype;
113
114 if (IsConstantTensor(start) && IsConstantTensor(limit) &&
115 IsConstantTensor(delta)) {
116 return ResizeOutput(context, start, limit, delta, output);
117 }
118
119 SetTensorToDynamic(output);
120 return kTfLiteOk;
121}
122
123template <typename T>
124void EvalImpl(const TfLiteTensor* start, const TfLiteTensor* delta,
125 TfLiteTensor* output) {
126 const T start_value = *GetTensorData<T>(start);
127 const T delta_value = *GetTensorData<T>(delta);
128 T* output_data = GetTensorData<T>(output);
129 const int num_elements = NumElements(output);
130 T value = start_value;
131 for (int i = 0; i < num_elements; ++i) {
132 output_data[i] = value;
133 value += delta_value;
134 }
135}
136
137TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
138 const TfLiteTensor* start;
139 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
140 const TfLiteTensor* limit;
141 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
142 const TfLiteTensor* delta;
143 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
144
145 TfLiteTensor* output;
146 TF_LITE_ENSURE_OK(context,
147 GetOutputSafe(context, node, kOutputTensor, &output));
148
149 if (IsDynamicTensor(output)) {
150 TF_LITE_ENSURE_OK(context,
151 ResizeOutput(context, start, limit, delta, output));
152 }
153
154 switch (output->type) {
155 case kTfLiteInt32: {
156 EvalImpl<int32_t>(start, delta, output);
157 break;
158 }
159 case kTfLiteFloat32: {
160 EvalImpl<float>(start, delta, output);
161 break;
162 }
163 default: {
164 TF_LITE_KERNEL_LOG(context, "Unsupported data type: %d", output->type);
165 return kTfLiteError;
166 }
167 }
168 return kTfLiteOk;
169}
170
171} // namespace
172} // namespace range
173
174TfLiteRegistration* Register_RANGE() {
175 static TfLiteRegistration r = {nullptr, nullptr, range::Prepare, range::Eval};
176 return &r;
177}
178
179} // namespace builtin
180} // namespace ops
181} // namespace tflite
182