1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include <algorithm>
19#include <string>
20#include <vector>
21
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/context_util.h"
24#include "tensorflow/lite/kernels/internal/compatibility.h"
25#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
27#include "tensorflow/lite/kernels/internal/tensor.h"
28#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29#include "tensorflow/lite/kernels/internal/types.h"
30#include "tensorflow/lite/kernels/kernel_util.h"
31#include "tensorflow/lite/string_type.h"
32
33namespace tflite {
34namespace ops {
35namespace builtin {
36namespace slice {
37
38enum KernelType {
39 kReference,
40 kGenericOptimized,
41};
42
43constexpr int kInputTensor = 0;
44constexpr int kBeginTensor = 1;
45constexpr int kSizeTensor = 2;
46constexpr int kOutputTensor = 0;
47
48// This Op only supports 1-5D cases and since we use the optimized ops 5D
49// implementation, the 1-4D tensors are mapped to 5D.
50const int kMaxDim = 5;
51
52template <typename T>
53TfLiteStatus CalculateOutputShapeVector(TfLiteContext* context,
54 const TfLiteTensor* input,
55 const TfLiteTensor* begin,
56 const TfLiteTensor* size,
57 std::vector<int>* output_shape_vector) {
58 for (int idx = 0; idx < NumDimensions(input); ++idx) {
59 T size_value = GetTensorData<T>(size)[idx];
60 if (size_value < 0) {
61 if (size_value != -1) {
62 TF_LITE_KERNEL_LOG(context, "Invalid size.");
63 return kTfLiteError;
64 }
65 size_value = SizeOfDimension(input, idx) - GetTensorData<T>(begin)[idx];
66 } else {
67 if (SizeOfDimension(input, idx) <
68 GetTensorData<T>(begin)[idx] + size_value) {
69 TF_LITE_KERNEL_LOG(context, "Invalid begin and size.");
70 return kTfLiteError;
71 }
72 }
73 output_shape_vector->push_back(static_cast<int>(size_value));
74 }
75 return kTfLiteOk;
76}
77
78template <typename T>
79void GetBeginAndSizeVectors(int dimensions, const TfLiteTensor* begin,
80 const TfLiteTensor* size, std::vector<int>* begins,
81 std::vector<int>* sizes) {
82 for (int idx = 0; idx < dimensions; ++idx) {
83 begins->push_back(GetTensorData<T>(begin)[idx]);
84 sizes->push_back(GetTensorData<T>(size)[idx]);
85 }
86}
87
88TfLiteStatus ResizeOutputShape(TfLiteContext* context,
89 const TfLiteTensor* input,
90 const TfLiteTensor* begin,
91 const TfLiteTensor* size, TfLiteTensor* output) {
92 std::vector<int> output_shape_vector;
93
94 if (begin->type == kTfLiteInt32) {
95 TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int32_t>(
96 context, input, begin, size, &output_shape_vector));
97 } else if (begin->type == kTfLiteInt64) {
98 TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int64_t>(
99 context, input, begin, size, &output_shape_vector));
100 } else {
101 TF_LITE_KERNEL_LOG(context, "Type %d is currently not supported by Slice.",
102 begin->type);
103 return kTfLiteError;
104 }
105
106 TfLiteIntArray* output_shape =
107 TfLiteIntArrayCreate(output_shape_vector.size());
108 std::copy(output_shape_vector.begin(), output_shape_vector.end(),
109 output_shape->data);
110 return context->ResizeTensor(context, output, output_shape);
111}
112
113bool ShapeHasRank(const TfLiteIntArray* shape) {
114 // Note that we consider scalar as false here because there is
115 // no differentiation between scalar and dynamic properly supported.
116 if (shape == nullptr || shape->size == 0) return false;
117 return true;
118}
119
120TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
121 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
122 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
123
124 const TfLiteTensor* input;
125 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
126 const TfLiteTensor* begin;
127 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
128 const TfLiteTensor* size;
129 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
130 TfLiteTensor* output;
131 TF_LITE_ENSURE_OK(context,
132 GetOutputSafe(context, node, kOutputTensor, &output));
133
134 // Ensure validity of input tensor and its dimension.
135 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
136 TF_LITE_ENSURE(context,
137 begin->type == kTfLiteInt32 || begin->type == kTfLiteInt64);
138 TF_LITE_ENSURE(context,
139 size->type == kTfLiteInt32 || size->type == kTfLiteInt64);
140 TF_LITE_ENSURE_EQ(context, NumDimensions(begin), 1);
141 TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
142 TF_LITE_ENSURE_EQ(context, NumElements(begin), NumElements(size));
143 TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim,
144 "Slice op only supports 1D-5D input arrays.");
145
146 // If the shape of output is fully specified then resize even if
147 // the input shape is not staticly defined.
148 if (!HasUnspecifiedDimension(output) && ShapeHasRank(output->dims)) {
149 return kTfLiteOk;
150 }
151 // Postpone allocation of output if any of the indexing tensors is not
152 // constant, or the input tensor has dynamic dimension.
153 if (!(IsConstantTensor(begin) && IsConstantTensor(size)) ||
154 HasUnspecifiedDimension(input)) {
155 SetTensorToDynamic(output);
156 return kTfLiteOk;
157 }
158
159 return ResizeOutputShape(context, input, begin, size, output);
160}
161
162template <KernelType kernel_type>
163TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
164 const TfLiteTensor* input;
165 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
166 const TfLiteTensor* begin;
167 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
168 const TfLiteTensor* size;
169 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
170 TfLiteTensor* output;
171 TF_LITE_ENSURE_OK(context,
172 GetOutputSafe(context, node, kOutputTensor, &output));
173
174 if (IsDynamicTensor(output)) {
175 TF_LITE_ENSURE_OK(context,
176 ResizeOutputShape(context, input, begin, size, output));
177 }
178
179 std::vector<int> begins;
180 begins.reserve(kMaxDim);
181 std::vector<int> sizes;
182 sizes.reserve(kMaxDim);
183
184 for (int i = NumDimensions(input); i < kMaxDim; ++i) {
185 begins.push_back(0);
186 sizes.push_back(1);
187 }
188
189 if (begin->type == kTfLiteInt32) {
190 GetBeginAndSizeVectors<int32_t>(NumDimensions(input), begin, size, &begins,
191 &sizes);
192 } else if (begin->type == kTfLiteInt64) {
193 GetBeginAndSizeVectors<int64_t>(NumDimensions(input), begin, size, &begins,
194 &sizes);
195 } else {
196 TF_LITE_KERNEL_LOG(context, "Type %d is currently not supported by Slice.",
197 begin->type);
198 return kTfLiteError;
199 }
200
201 // The Slice op implementation only accepts 5-D sizes. That constraint is, for
202 // the present, maintained here.
203 //
204 // The dimensions in the kernel used to be in reverse-order, and TFLite
205 // arranged the begins and sizes vectors accordingly. This macro incorporates
206 // the needed reversing.
207#define TF_LITE_SLICE(data_type) \
208 { \
209 TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
210 TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
211 tflite::SliceParams op_params; \
212 op_params.begin_count = kMaxDim; \
213 op_params.size_count = kMaxDim; \
214 for (int i = 0; i < kMaxDim; ++i) { \
215 op_params.begin[i] = begins[i]; \
216 op_params.size[i] = sizes[i]; \
217 } \
218 \
219 if (kernel_type == kGenericOptimized) { \
220 optimized_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
221 GetTensorShape(output), output); \
222 } else { \
223 reference_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
224 GetTensorShape(output), output); \
225 } \
226 }
227
228 switch (input->type) {
229 case kTfLiteFloat32:
230 TF_LITE_SLICE(float);
231 break;
232 case kTfLiteInt32:
233 TF_LITE_SLICE(int32_t);
234 break;
235 case kTfLiteInt64:
236 TF_LITE_SLICE(int64_t);
237 break;
238 case kTfLiteInt8:
239 TF_LITE_SLICE(int8_t);
240 break;
241 case kTfLiteInt16:
242 TF_LITE_SLICE(int16_t);
243 break;
244 case kTfLiteUInt8:
245 TF_LITE_SLICE(uint8_t);
246 break;
247 case kTfLiteBool:
248 TF_LITE_SLICE(bool);
249 break;
250 case kTfLiteString:
251 TF_LITE_SLICE(string);
252 break;
253 default:
254 TF_LITE_KERNEL_LOG(
255 context, "Type %d is currently not supported by Slice.", input->type);
256 return kTfLiteError;
257 }
258#undef TF_LITE_SLICE
259 return kTfLiteOk;
260}
261
262} // namespace slice
263
264TfLiteRegistration* Register_SLICE_REF() {
265 static TfLiteRegistration r = {nullptr, nullptr, slice::Prepare,
266 slice::Eval<slice::kReference>};
267 return &r;
268}
269
270TfLiteRegistration* Register_SLICE() {
271 static TfLiteRegistration r = {nullptr, nullptr, slice::Prepare,
272 slice::Eval<slice::kGenericOptimized>};
273 return &r;
274}
275
276} // namespace builtin
277} // namespace ops
278} // namespace tflite
279