1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
17
18#include <math.h>
19#include <stdint.h>
20
21#include <algorithm>
22#include <vector>
23
24#include "tensorflow/lite/c/builtin_op_data.h"
25#include "tensorflow/lite/c/common.h"
26#include "tensorflow/lite/kernels/internal/compatibility.h"
27#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
29#include "tensorflow/lite/kernels/internal/tensor.h"
30#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
31#include "tensorflow/lite/kernels/internal/types.h"
32#include "tensorflow/lite/kernels/kernel_util.h"
33
34namespace tflite {
35namespace ops {
36namespace builtin {
37namespace strided_slice {
38
39enum KernelType {
40 kReference,
41 kGenericOptimized,
42};
43
44constexpr int kInputTensor = 0;
45constexpr int kBeginTensor = 1;
46constexpr int kEndTensor = 2;
47constexpr int kStridesTensor = 3;
48constexpr int kOutputTensor = 0;
49
50struct StridedSliceContext {
51 StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
52 params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
53 input = GetInput(context, node, kInputTensor);
54 begin = GetInput(context, node, kBeginTensor);
55 end = GetInput(context, node, kEndTensor);
56 strides = GetInput(context, node, kStridesTensor);
57 output = GetOutput(context, node, kOutputTensor);
58 input_dims = NumDimensions(input);
59 }
60 const TfLiteStridedSliceParams* params;
61 const TfLiteTensor* input;
62 const TfLiteTensor* begin;
63 const TfLiteTensor* end;
64 const TfLiteTensor* strides;
65 TfLiteTensor* output;
66
67 // Equivalent input shape after adding axis according to new_axis_mask.
68 RuntimeShape effective_input_shape;
69 int input_dims;
70};
71
72StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) {
73 StridedSliceParams op_params;
74
75 // The ellipsis_mask and new_axis_mask in op_params are not used. Those masks
76 // are processed here to update begin_mask, end_mask and the index range.
77 op_params.begin_mask = 0;
78 op_params.ellipsis_mask = 0;
79 op_params.end_mask = 0;
80 op_params.new_axis_mask = 0;
81 op_params.shrink_axis_mask = 0;
82
83 // Count indexes where the new_axis_mask is set but the ellipsis_mask is not.
84 const int begin_count = GetTensorShape(op_context->begin).Dims(0);
85 int num_add_axis = 0;
86 for (int i = 0; i < begin_count; ++i) {
87 if (!((1 << i) & op_context->params->ellipsis_mask) &&
88 ((1 << i) & op_context->params->new_axis_mask)) {
89 num_add_axis++;
90 }
91 }
92
93 // Calculate the dims of input after adding new axises.
94 const int effective_dims = op_context->input_dims + num_add_axis;
95
96 // If begin, end and strides are not fully provided, it means Ellipsis should
97 // be expanded to multiple dimensions (Ex: for spec [Ellipsis, 2] on a 3D
98 // input, the Ellipsis should be applied for the first 2 dimensions). Besides,
99 // If the new_axis_mask and the ellipsis_mask are set at the same index, the
100 // new_axis_mask will have no effect.
101 int effective_ellipsis_mask = 0, effective_new_axis_mask = 0;
102 int ellipsis_start_idx = effective_dims, expanded_ellipsis = 0;
103 for (int i = 0; i < effective_dims;) {
104 if ((1 << i) & op_context->params->ellipsis_mask) {
105 ellipsis_start_idx = i;
106 int ellipsis_end_idx = std::max(
107 i + 1,
108 std::min(i + 1 + num_add_axis + op_context->input_dims - begin_count,
109 effective_dims));
110 expanded_ellipsis = ellipsis_end_idx - ellipsis_start_idx - 1;
111
112 // Set bit for effective_ellipsis_mask.
113 for (; i < ellipsis_end_idx; ++i) {
114 effective_ellipsis_mask |= (1 << i);
115 }
116 continue;
117 }
118
119 if ((1 << (i - expanded_ellipsis)) & op_context->params->new_axis_mask) {
120 effective_new_axis_mask |= (1 << i);
121 }
122 ++i;
123 }
124
125 // Calculate effective_input_shape and its corresponding begin, end, strides.
126 const int32_t* begin_data = GetTensorData<int32_t>(op_context->begin);
127 const int32_t* end_data = GetTensorData<int32_t>(op_context->end);
128 const int32_t* strides_data = GetTensorData<int32_t>(op_context->strides);
129 const RuntimeShape input_shape = GetTensorShape(op_context->input);
130 int added_ellipsis = 0, added_axises = 0;
131 op_context->effective_input_shape.Resize(effective_dims);
132
133 for (int i = 0; i < effective_dims; ++i) {
134 if ((1 << i) & effective_ellipsis_mask) {
135 // If ellipsis_mask, set the begin_mask and end_mask at that index.
136 added_ellipsis = std::max(0, i - ellipsis_start_idx);
137 op_params.begin_mask |= (1 << i);
138 op_params.end_mask |= (1 << i);
139 op_params.strides[i] = 1;
140 op_context->effective_input_shape.SetDim(
141 i, input_shape.Dims(i - added_axises));
142 } else if ((1 << i) & effective_new_axis_mask) {
143 // If new_axis_mask is set, it is equivalent to adding a new dim of 1 to
144 // input tensor. Store added shape to effective_input_shape.
145 op_params.start_indices[i] = 0;
146 op_params.stop_indices[i] = 1;
147 op_params.strides[i] = 1;
148 op_context->effective_input_shape.SetDim(i, 1);
149 added_axises++;
150 } else if (i >= begin_count + expanded_ellipsis) {
151 op_params.start_indices[i] = 0;
152 op_params.stop_indices[i] = 0;
153 op_params.strides[i] = 1;
154 op_params.begin_mask |= (1 << i);
155 op_params.end_mask |= (1 << i);
156 op_context->effective_input_shape.SetDim(
157 i, input_shape.Dims(i - added_axises));
158 } else {
159 const int orig_idx = i - added_ellipsis;
160 op_params.start_indices[i] = begin_data[orig_idx];
161 op_params.stop_indices[i] = end_data[orig_idx];
162 op_params.strides[i] = strides_data[orig_idx];
163 if (op_context->params->begin_mask & (1 << orig_idx)) {
164 op_params.begin_mask |= (1 << i);
165 }
166 if (op_context->params->end_mask & (1 << orig_idx)) {
167 op_params.end_mask |= (1 << i);
168 }
169 if (op_context->params->shrink_axis_mask & (1 << orig_idx)) {
170 op_params.shrink_axis_mask |= (1 << i);
171 }
172 op_context->effective_input_shape.SetDim(
173 i, input_shape.Dims(i - added_axises));
174 }
175 }
176 op_params.start_indices_count = effective_dims;
177 op_params.stop_indices_count = effective_dims;
178 op_params.strides_count = effective_dims;
179
180 return op_params;
181}
182
183// Processes the indexing tensors (begin, end and strides) to resize the
184// output tensor. This function is callable from both Prepare() and Eval() as
185// long as the caller ensures the indexing tensors are present.
186TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
187 StridedSliceContext* op_context) {
188 std::vector<int> output_shape_vector;
189 StridedSliceParams op_params = BuildStridedSliceParams(op_context);
190 const RuntimeShape effective_input_shape = op_context->effective_input_shape;
191 TF_LITE_ENSURE_MSG(
192 context, effective_input_shape.DimensionsCount() <= 5,
193 "StridedSlice op only supports up to 5D output including added axis.");
194
195 for (int idx = effective_input_shape.DimensionsCount() - 1; idx >= 0; --idx) {
196 int32_t stride = op_params.strides[idx];
197 TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
198
199 int32_t begin = ::tflite::strided_slice::StartForAxis(
200 op_params, effective_input_shape, idx);
201 int32_t end = ::tflite::strided_slice::StopForAxis(
202 op_params, effective_input_shape, idx, begin);
203
204 // When shrinking an axis, the end position does not matter (and can be
205 // incorrect when negative indexing is used, see Issue #19260). Always use
206 // begin + 1 to generate a length 1 slice, since begin has
207 // already been adjusted for negative indices by GetBeginValueAtIndex.
208 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
209 if (shrink_axis) {
210 end = begin + 1;
211 }
212
213 // This is valid for both positive and negative strides
214 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
215 dim_shape = dim_shape < 0 ? 0 : dim_shape;
216 if (!shrink_axis) {
217 output_shape_vector.push_back(dim_shape);
218 }
219 }
220
221 TfLiteIntArray* output_shape =
222 TfLiteIntArrayCreate(output_shape_vector.size());
223
224 std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
225 output_shape->data);
226
227 TF_LITE_ENSURE_STATUS(
228 context->ResizeTensor(context, op_context->output, output_shape));
229
230 return kTfLiteOk;
231}
232
233TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
234 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
235 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
236
237 StridedSliceContext op_context(context, node);
238
239 // Ensure validity of input tensor and its dimension
240 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
241 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
242 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
243 TF_LITE_ENSURE_EQ(context, NumElements(op_context.begin),
244 NumElements(op_context.end));
245 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
246
247 // Only INT32 begin/end/strides are supported
248 // TODO(b/175642009): add support for INT64
249 TF_LITE_ENSURE_TYPES_EQ(context, op_context.begin->type, kTfLiteInt32);
250 TF_LITE_ENSURE_TYPES_EQ(context, op_context.end->type, kTfLiteInt32);
251 TF_LITE_ENSURE_TYPES_EQ(context, op_context.strides->type, kTfLiteInt32);
252 TF_LITE_ENSURE_MSG(context, op_context.input_dims <= 5,
253 "StridedSlice op only supports 1D-5D input arrays.");
254
255 // Postpone allocation of output if any of the indexing tensors is not
256 // constant
257 if (!(IsConstantTensor(op_context.begin) &&
258 IsConstantTensor(op_context.end) &&
259 IsConstantTensor(op_context.strides))) {
260 SetTensorToDynamic(op_context.output);
261 return kTfLiteOk;
262 }
263 return ResizeOutputTensor(context, &op_context);
264}
265
266template <KernelType kernel_type>
267TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
268 StridedSliceContext op_context(context, node);
269
270 if (IsDynamicTensor(op_context.output)) {
271 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
272 }
273 StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
274
275#define TF_LITE_STRIDED_SLICE(data_type) \
276 { \
277 if (kernel_type == kGenericOptimized) { \
278 optimized_ops::StridedSlice<data_type>( \
279 op_params, op_context.effective_input_shape, op_context.input, \
280 GetTensorShape(op_context.output), op_context.output); \
281 } else { \
282 reference_ops::StridedSlice<data_type>( \
283 op_params, op_context.effective_input_shape, op_context.input, \
284 GetTensorShape(op_context.output), op_context.output); \
285 } \
286 }
287
288 switch (op_context.input->type) {
289 case kTfLiteFloat32:
290 TF_LITE_STRIDED_SLICE(float);
291 break;
292 case kTfLiteInt32:
293 TF_LITE_STRIDED_SLICE(int32_t);
294 break;
295 case kTfLiteInt64:
296 TF_LITE_STRIDED_SLICE(int64_t);
297 break;
298 case kTfLiteUInt8:
299 TF_LITE_STRIDED_SLICE(uint8_t);
300 break;
301 case kTfLiteInt8:
302 TF_LITE_STRIDED_SLICE(int8_t);
303 break;
304 case kTfLiteInt16:
305 TF_LITE_STRIDED_SLICE(int16_t);
306 break;
307 case kTfLiteBool:
308 TF_LITE_STRIDED_SLICE(bool);
309 break;
310 case kTfLiteString:
311 TF_LITE_STRIDED_SLICE(string);
312 break;
313 default:
314 TF_LITE_KERNEL_LOG(context,
315 "Type %s is currently not supported "
316 "by StridedSlice.",
317 TfLiteTypeGetName(op_context.input->type));
318 return kTfLiteError;
319 }
320#undef TF_LITE_STRIDED_SLICE
321 return kTfLiteOk;
322}
323
324} // namespace strided_slice
325
326TfLiteRegistration* Register_STRIDED_SLICE_REF() {
327 static TfLiteRegistration r = {
328 nullptr, nullptr, strided_slice::Prepare,
329 strided_slice::Eval<strided_slice::kReference>};
330 return &r;
331}
332
333TfLiteRegistration* Register_STRIDED_SLICE() {
334 static TfLiteRegistration r = {
335 nullptr, nullptr, strided_slice::Prepare,
336 strided_slice::Eval<strided_slice::kGenericOptimized>};
337 return &r;
338}
339
340} // namespace builtin
341} // namespace ops
342} // namespace tflite
343