1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/lite/kernels/internal/reference/strided_slice.h" |
17 | |
18 | #include <math.h> |
19 | #include <stdint.h> |
20 | |
21 | #include <algorithm> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/lite/c/builtin_op_data.h" |
25 | #include "tensorflow/lite/c/common.h" |
26 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
27 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
28 | #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" |
29 | #include "tensorflow/lite/kernels/internal/tensor.h" |
30 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
31 | #include "tensorflow/lite/kernels/internal/types.h" |
32 | #include "tensorflow/lite/kernels/kernel_util.h" |
33 | |
34 | namespace tflite { |
35 | namespace ops { |
36 | namespace builtin { |
37 | namespace strided_slice { |
38 | |
39 | enum KernelType { |
40 | kReference, |
41 | kGenericOptimized, |
42 | }; |
43 | |
44 | constexpr int kInputTensor = 0; |
45 | constexpr int kBeginTensor = 1; |
46 | constexpr int kEndTensor = 2; |
47 | constexpr int kStridesTensor = 3; |
48 | constexpr int kOutputTensor = 0; |
49 | |
50 | struct StridedSliceContext { |
51 | StridedSliceContext(TfLiteContext* context, TfLiteNode* node) { |
52 | params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data); |
53 | input = GetInput(context, node, kInputTensor); |
54 | begin = GetInput(context, node, kBeginTensor); |
55 | end = GetInput(context, node, kEndTensor); |
56 | strides = GetInput(context, node, kStridesTensor); |
57 | output = GetOutput(context, node, kOutputTensor); |
58 | input_dims = NumDimensions(input); |
59 | } |
60 | const TfLiteStridedSliceParams* params; |
61 | const TfLiteTensor* input; |
62 | const TfLiteTensor* begin; |
63 | const TfLiteTensor* end; |
64 | const TfLiteTensor* strides; |
65 | TfLiteTensor* output; |
66 | |
67 | // Equivalent input shape after adding axis according to new_axis_mask. |
68 | RuntimeShape effective_input_shape; |
69 | int input_dims; |
70 | }; |
71 | |
72 | StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) { |
73 | StridedSliceParams op_params; |
74 | |
75 | // The ellipsis_mask and new_axis_mask in op_params are not used. Those masks |
76 | // are processed here to update begin_mask, end_mask and the index range. |
77 | op_params.begin_mask = 0; |
78 | op_params.ellipsis_mask = 0; |
79 | op_params.end_mask = 0; |
80 | op_params.new_axis_mask = 0; |
81 | op_params.shrink_axis_mask = 0; |
82 | |
83 | // Count indexes where the new_axis_mask is set but the ellipsis_mask is not. |
84 | const int begin_count = GetTensorShape(op_context->begin).Dims(0); |
85 | int num_add_axis = 0; |
86 | for (int i = 0; i < begin_count; ++i) { |
87 | if (!((1 << i) & op_context->params->ellipsis_mask) && |
88 | ((1 << i) & op_context->params->new_axis_mask)) { |
89 | num_add_axis++; |
90 | } |
91 | } |
92 | |
93 | // Calculate the dims of input after adding new axises. |
94 | const int effective_dims = op_context->input_dims + num_add_axis; |
95 | |
96 | // If begin, end and strides are not fully provided, it means Ellipsis should |
97 | // be expanded to multiple dimensions (Ex: for spec [Ellipsis, 2] on a 3D |
98 | // input, the Ellipsis should be applied for the first 2 dimensions). Besides, |
99 | // If the new_axis_mask and the ellipsis_mask are set at the same index, the |
100 | // new_axis_mask will have no effect. |
101 | int effective_ellipsis_mask = 0, effective_new_axis_mask = 0; |
102 | int ellipsis_start_idx = effective_dims, expanded_ellipsis = 0; |
103 | for (int i = 0; i < effective_dims;) { |
104 | if ((1 << i) & op_context->params->ellipsis_mask) { |
105 | ellipsis_start_idx = i; |
106 | int ellipsis_end_idx = std::max( |
107 | i + 1, |
108 | std::min(i + 1 + num_add_axis + op_context->input_dims - begin_count, |
109 | effective_dims)); |
110 | expanded_ellipsis = ellipsis_end_idx - ellipsis_start_idx - 1; |
111 | |
112 | // Set bit for effective_ellipsis_mask. |
113 | for (; i < ellipsis_end_idx; ++i) { |
114 | effective_ellipsis_mask |= (1 << i); |
115 | } |
116 | continue; |
117 | } |
118 | |
119 | if ((1 << (i - expanded_ellipsis)) & op_context->params->new_axis_mask) { |
120 | effective_new_axis_mask |= (1 << i); |
121 | } |
122 | ++i; |
123 | } |
124 | |
125 | // Calculate effective_input_shape and its corresponding begin, end, strides. |
126 | const int32_t* begin_data = GetTensorData<int32_t>(op_context->begin); |
127 | const int32_t* end_data = GetTensorData<int32_t>(op_context->end); |
128 | const int32_t* strides_data = GetTensorData<int32_t>(op_context->strides); |
129 | const RuntimeShape input_shape = GetTensorShape(op_context->input); |
130 | int added_ellipsis = 0, added_axises = 0; |
131 | op_context->effective_input_shape.Resize(effective_dims); |
132 | |
133 | for (int i = 0; i < effective_dims; ++i) { |
134 | if ((1 << i) & effective_ellipsis_mask) { |
135 | // If ellipsis_mask, set the begin_mask and end_mask at that index. |
136 | added_ellipsis = std::max(0, i - ellipsis_start_idx); |
137 | op_params.begin_mask |= (1 << i); |
138 | op_params.end_mask |= (1 << i); |
139 | op_params.strides[i] = 1; |
140 | op_context->effective_input_shape.SetDim( |
141 | i, input_shape.Dims(i - added_axises)); |
142 | } else if ((1 << i) & effective_new_axis_mask) { |
143 | // If new_axis_mask is set, it is equivalent to adding a new dim of 1 to |
144 | // input tensor. Store added shape to effective_input_shape. |
145 | op_params.start_indices[i] = 0; |
146 | op_params.stop_indices[i] = 1; |
147 | op_params.strides[i] = 1; |
148 | op_context->effective_input_shape.SetDim(i, 1); |
149 | added_axises++; |
150 | } else if (i >= begin_count + expanded_ellipsis) { |
151 | op_params.start_indices[i] = 0; |
152 | op_params.stop_indices[i] = 0; |
153 | op_params.strides[i] = 1; |
154 | op_params.begin_mask |= (1 << i); |
155 | op_params.end_mask |= (1 << i); |
156 | op_context->effective_input_shape.SetDim( |
157 | i, input_shape.Dims(i - added_axises)); |
158 | } else { |
159 | const int orig_idx = i - added_ellipsis; |
160 | op_params.start_indices[i] = begin_data[orig_idx]; |
161 | op_params.stop_indices[i] = end_data[orig_idx]; |
162 | op_params.strides[i] = strides_data[orig_idx]; |
163 | if (op_context->params->begin_mask & (1 << orig_idx)) { |
164 | op_params.begin_mask |= (1 << i); |
165 | } |
166 | if (op_context->params->end_mask & (1 << orig_idx)) { |
167 | op_params.end_mask |= (1 << i); |
168 | } |
169 | if (op_context->params->shrink_axis_mask & (1 << orig_idx)) { |
170 | op_params.shrink_axis_mask |= (1 << i); |
171 | } |
172 | op_context->effective_input_shape.SetDim( |
173 | i, input_shape.Dims(i - added_axises)); |
174 | } |
175 | } |
176 | op_params.start_indices_count = effective_dims; |
177 | op_params.stop_indices_count = effective_dims; |
178 | op_params.strides_count = effective_dims; |
179 | |
180 | return op_params; |
181 | } |
182 | |
183 | // Processes the indexing tensors (begin, end and strides) to resize the |
184 | // output tensor. This function is callable from both Prepare() and Eval() as |
185 | // long as the caller ensures the indexing tensors are present. |
186 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
187 | StridedSliceContext* op_context) { |
188 | std::vector<int> output_shape_vector; |
189 | StridedSliceParams op_params = BuildStridedSliceParams(op_context); |
190 | const RuntimeShape effective_input_shape = op_context->effective_input_shape; |
191 | TF_LITE_ENSURE_MSG( |
192 | context, effective_input_shape.DimensionsCount() <= 5, |
193 | "StridedSlice op only supports up to 5D output including added axis." ); |
194 | |
195 | for (int idx = effective_input_shape.DimensionsCount() - 1; idx >= 0; --idx) { |
196 | int32_t stride = op_params.strides[idx]; |
197 | TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero" ); |
198 | |
199 | int32_t begin = ::tflite::strided_slice::StartForAxis( |
200 | op_params, effective_input_shape, idx); |
201 | int32_t end = ::tflite::strided_slice::StopForAxis( |
202 | op_params, effective_input_shape, idx, begin); |
203 | |
204 | // When shrinking an axis, the end position does not matter (and can be |
205 | // incorrect when negative indexing is used, see Issue #19260). Always use |
206 | // begin + 1 to generate a length 1 slice, since begin has |
207 | // already been adjusted for negative indices by GetBeginValueAtIndex. |
208 | const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx); |
209 | if (shrink_axis) { |
210 | end = begin + 1; |
211 | } |
212 | |
213 | // This is valid for both positive and negative strides |
214 | int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride)); |
215 | dim_shape = dim_shape < 0 ? 0 : dim_shape; |
216 | if (!shrink_axis) { |
217 | output_shape_vector.push_back(dim_shape); |
218 | } |
219 | } |
220 | |
221 | TfLiteIntArray* output_shape = |
222 | TfLiteIntArrayCreate(output_shape_vector.size()); |
223 | |
224 | std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(), |
225 | output_shape->data); |
226 | |
227 | TF_LITE_ENSURE_STATUS( |
228 | context->ResizeTensor(context, op_context->output, output_shape)); |
229 | |
230 | return kTfLiteOk; |
231 | } |
232 | |
233 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
234 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); |
235 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
236 | |
237 | StridedSliceContext op_context(context, node); |
238 | |
239 | // Ensure validity of input tensor and its dimension |
240 | TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); |
241 | TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); |
242 | TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); |
243 | TF_LITE_ENSURE_EQ(context, NumElements(op_context.begin), |
244 | NumElements(op_context.end)); |
245 | TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); |
246 | |
247 | // Only INT32 begin/end/strides are supported |
248 | // TODO(b/175642009): add support for INT64 |
249 | TF_LITE_ENSURE_TYPES_EQ(context, op_context.begin->type, kTfLiteInt32); |
250 | TF_LITE_ENSURE_TYPES_EQ(context, op_context.end->type, kTfLiteInt32); |
251 | TF_LITE_ENSURE_TYPES_EQ(context, op_context.strides->type, kTfLiteInt32); |
252 | TF_LITE_ENSURE_MSG(context, op_context.input_dims <= 5, |
253 | "StridedSlice op only supports 1D-5D input arrays." ); |
254 | |
255 | // Postpone allocation of output if any of the indexing tensors is not |
256 | // constant |
257 | if (!(IsConstantTensor(op_context.begin) && |
258 | IsConstantTensor(op_context.end) && |
259 | IsConstantTensor(op_context.strides))) { |
260 | SetTensorToDynamic(op_context.output); |
261 | return kTfLiteOk; |
262 | } |
263 | return ResizeOutputTensor(context, &op_context); |
264 | } |
265 | |
266 | template <KernelType kernel_type> |
267 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
268 | StridedSliceContext op_context(context, node); |
269 | |
270 | if (IsDynamicTensor(op_context.output)) { |
271 | TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); |
272 | } |
273 | StridedSliceParams op_params = BuildStridedSliceParams(&op_context); |
274 | |
275 | #define TF_LITE_STRIDED_SLICE(data_type) \ |
276 | { \ |
277 | if (kernel_type == kGenericOptimized) { \ |
278 | optimized_ops::StridedSlice<data_type>( \ |
279 | op_params, op_context.effective_input_shape, op_context.input, \ |
280 | GetTensorShape(op_context.output), op_context.output); \ |
281 | } else { \ |
282 | reference_ops::StridedSlice<data_type>( \ |
283 | op_params, op_context.effective_input_shape, op_context.input, \ |
284 | GetTensorShape(op_context.output), op_context.output); \ |
285 | } \ |
286 | } |
287 | |
288 | switch (op_context.input->type) { |
289 | case kTfLiteFloat32: |
290 | TF_LITE_STRIDED_SLICE(float); |
291 | break; |
292 | case kTfLiteInt32: |
293 | TF_LITE_STRIDED_SLICE(int32_t); |
294 | break; |
295 | case kTfLiteInt64: |
296 | TF_LITE_STRIDED_SLICE(int64_t); |
297 | break; |
298 | case kTfLiteUInt8: |
299 | TF_LITE_STRIDED_SLICE(uint8_t); |
300 | break; |
301 | case kTfLiteInt8: |
302 | TF_LITE_STRIDED_SLICE(int8_t); |
303 | break; |
304 | case kTfLiteInt16: |
305 | TF_LITE_STRIDED_SLICE(int16_t); |
306 | break; |
307 | case kTfLiteBool: |
308 | TF_LITE_STRIDED_SLICE(bool); |
309 | break; |
310 | case kTfLiteString: |
311 | TF_LITE_STRIDED_SLICE(string); |
312 | break; |
313 | default: |
314 | TF_LITE_KERNEL_LOG(context, |
315 | "Type %s is currently not supported " |
316 | "by StridedSlice." , |
317 | TfLiteTypeGetName(op_context.input->type)); |
318 | return kTfLiteError; |
319 | } |
320 | #undef TF_LITE_STRIDED_SLICE |
321 | return kTfLiteOk; |
322 | } |
323 | |
324 | } // namespace strided_slice |
325 | |
326 | TfLiteRegistration* Register_STRIDED_SLICE_REF() { |
327 | static TfLiteRegistration r = { |
328 | nullptr, nullptr, strided_slice::Prepare, |
329 | strided_slice::Eval<strided_slice::kReference>}; |
330 | return &r; |
331 | } |
332 | |
333 | TfLiteRegistration* Register_STRIDED_SLICE() { |
334 | static TfLiteRegistration r = { |
335 | nullptr, nullptr, strided_slice::Prepare, |
336 | strided_slice::Eval<strided_slice::kGenericOptimized>}; |
337 | return &r; |
338 | } |
339 | |
340 | } // namespace builtin |
341 | } // namespace ops |
342 | } // namespace tflite |
343 | |