1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/strided_slice_op.h"
17
18#include <algorithm>
19#include <array>
20#include <iterator>
21
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/lib/core/status.h"
24
25namespace tensorflow {
26namespace {
27
28/// Constants
29constexpr int32_t kShrinkAxis = -1, kNewAxis = -2;
30
31// Sparse slicing specification
32// if one does foo[3:5, ..., -3], this will have 3 length tensors
33struct StridedSliceSparseSpec {
34 int64_t dims;
35 int32 num_add_axis_after_ellipsis;
36 const Tensor* begin_tensor;
37 const Tensor* end_tensor;
38 const Tensor& strides_tensor;
39 const int32 begin_mask, end_mask;
40 int32 ellipsis_mask;
41 const int32 new_axis_mask, shrink_axis_mask;
42};
43
44// Dense slicing specification
45// all ellipses and newaxis' are expanded out. So if
46// foo[3:5, ..., -3] where foo is 10 dimensional,
47// each inlinedVector will have 10 entries whereas the
48// sparse had 3 length tensors.
49struct StridedSliceDenseSpec {
50 const int64_t dims;
51 int32 begin_mask;
52 int32 end_mask;
53 bool begin_valid;
54 bool end_valid;
55 gtl::InlinedVector<int64_t, 4>& begin;
56 gtl::InlinedVector<int64_t, 4>& end;
57 gtl::InlinedVector<int64_t, 4>& strides;
58 // This vector helps construct the final shape of the slice.
59 // The final tensor is reduced in rank whenever a single index e.g. foo[3]
60 // is called for. The final tensor increases in rank with tf.newaxis
61 // entries. If an index in this array is positive, the size of the dimension
62 // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
63 // it will be 1. A shrunk dimension is skipped.
64 gtl::InlinedVector<int32, 4> final_shape_gather_indices;
65 // This vector has the same size as final_shape_gather_indices, but it
66 // remembers the sparse index that a dimension comes from, instead of dense
67 // index. A -1 in this vector means there the index is not from the sparse
68 // input.
69 gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
70 gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
71 // The dense indexed shrink mask is which processing dimensions
72 // should be shrunk. For example, if foo.shape = (10,10,10,10)
73 // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
74 // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
75 int32 shrink_axis_mask;
76};
77
78} // namespace
79
80template <class T>
81static Status TF_MUST_USE_RESULT BuildDenseSpec(
82 const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
83 if (dense->dims < 0) {
84 return errors::InvalidArgument("Unexpected negative dense.dims");
85 }
86
87 // Build expanded begin, end, strides, begin_mask, end_mask
88 // to remove any ellipsis
89 dense->begin.resize(dense->dims);
90 dense->end.resize(dense->dims);
91 dense->strides.resize(dense->dims);
92 dense->input_shape_gather_indices_sparse.resize(dense->dims);
93 // What indices to get the final shape from.
94 dense->begin_mask = 0;
95 dense->end_mask = 0;
96 dense->shrink_axis_mask = 0;
97 {
98 int full_index = 0;
99
100 const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
101 dense->begin_valid = sparse.begin_tensor != nullptr;
102 dense->end_valid = sparse.end_tensor != nullptr;
103
104 const T* const begin_flat = sparse.begin_tensor != nullptr
105 ? sparse.begin_tensor->vec<T>().data()
106 : nullptr;
107 const T* const end_flat = sparse.end_tensor != nullptr
108 ? sparse.end_tensor->vec<T>().data()
109 : nullptr;
110
111 for (int i = 0; i < sparse.dims; i++) {
112 if ((1 << i) & sparse.ellipsis_mask) {
113 // Expand the ellipsis into the appropriate indices
114 // NOTE: this only works because we guaranteed one ellipsis
115 int32_t next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
116 sparse.num_add_axis_after_ellipsis,
117 dense->dims);
118 for (; full_index < next_index; full_index++) {
119 // new_axis' aren't real axis so you have to skip
120 dense->begin[full_index] = dense->end[full_index] = 0;
121 dense->strides[full_index] = 1;
122 dense->begin_mask |= (1 << full_index);
123 dense->end_mask |= (1 << full_index);
124 dense->final_shape_gather_indices.push_back(full_index);
125 dense->final_shape_gather_indices_sparse.push_back(-1);
126 dense->input_shape_gather_indices_sparse[full_index] = i;
127 }
128 } else if ((1 << i) & sparse.new_axis_mask) {
129 dense->final_shape_gather_indices.push_back(kNewAxis);
130 dense->final_shape_gather_indices_sparse.push_back(-1);
131 } else {
132 if (full_index == dense->begin.size()) {
133 return errors::InvalidArgument("Index out of range using input dim ",
134 full_index, "; input has only ",
135 dense->dims, " dims");
136 }
137
138 // Gather slicing spec into appropriate index
139 if (begin_flat != nullptr) {
140 dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
141 }
142 if (end_flat != nullptr) {
143 dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
144 }
145 dense->strides[full_index] =
146 internal::SubtleMustCopy<T>(strides_flat[i]);
147 if (sparse.begin_mask & (1 << i)) {
148 dense->begin_mask |= (1 << full_index);
149 }
150 if (sparse.end_mask & (1 << i)) {
151 dense->end_mask |= (1 << full_index);
152 }
153 // If shrink, record where to get the dimensionality from (i.e.
154 // new_axis creates a fake 1 size dimension. Also remember shrink
155 // axis (now in dense form) so we can ignore dense->end below.
156 if (sparse.shrink_axis_mask & (1 << i)) {
157 dense->final_shape_gather_indices.push_back(kShrinkAxis);
158 dense->final_shape_gather_indices_sparse.push_back(-1);
159 dense->shrink_axis_mask |= (1 << full_index);
160 } else {
161 dense->final_shape_gather_indices.push_back(full_index);
162 // Remember that where in the sparse shape the dense dim comes
163 // from.
164 dense->final_shape_gather_indices_sparse.push_back(i);
165 }
166 dense->input_shape_gather_indices_sparse[full_index] = i;
167 full_index++;
168 }
169 }
170 }
171 return OkStatus();
172}
173
174Status ValidateStridedSliceOp(
175 const Tensor* begin_tensor, const Tensor* end_tensor,
176 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
177 int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
178 int32_t new_axis_mask, int32_t shrink_axis_mask,
179 PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
180 bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
181 gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
182 gtl::InlinedVector<int64_t, 4>* strides,
183 StridedSliceShapeSpec* shape_spec) {
184 if (input_shape.unknown_rank()) {
185 // Note: If the rank is unknown, "input_shape.dims()" is -1.
186 return errors::InvalidArgument("Unexpected input_shape with unknown rank");
187 }
188
189 const bool begin_is_wrong =
190 begin_tensor != nullptr &&
191 !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
192 begin_tensor->NumElements() == strides_tensor.NumElements() &&
193 begin_tensor->NumElements() < 32 /* using 32 bit masks */);
194 const bool end_is_wrong =
195 end_tensor != nullptr &&
196 !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
197 end_tensor->NumElements() == strides_tensor.NumElements());
198 if (begin_is_wrong || end_is_wrong ||
199 !TensorShapeUtils::IsVector(strides_tensor.shape())) {
200 if (begin_tensor != nullptr && end_tensor != nullptr) {
201 return errors::InvalidArgument(
202 "Expected begin, end, and strides to be 1D equal size tensors, ",
203 "but got shapes ", begin_tensor->shape().DebugString(), ", ",
204 end_tensor->shape().DebugString(), ", and ",
205 strides_tensor.shape().DebugString(), " instead.");
206 } else {
207 return errors::InvalidArgument(
208 "Expected begin, end, and strides to be 1D equal size tensors, ",
209 "but got shape ", strides_tensor.shape().DebugString(),
210 " for strides.");
211 }
212 }
213 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
214 // i.e. there exists only no more than one ellipsis
215 if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
216 return errors::InvalidArgument(
217 "Multiple ellipses in slice spec not allowed");
218 }
219
220 // Step 1: Account for ellipsis and new axis
221 //
222 // Check for ellipses and count how many non-newaxis' there are after
223 // TODO(aselle): Convert this to do a fast log2 followed by iteration
224 // counting ones in next guys
225 bool ellipsis_seen = false;
226
227 StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
228 0,
229 begin_tensor,
230 end_tensor,
231 strides_tensor,
232 begin_mask_spec,
233 end_mask_spec,
234 ellipsis_mask,
235 new_axis_mask,
236 shrink_axis_mask};
237
238 for (int32_t i = 0; i < sparse_spec.dims; i++) {
239 if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
240 sparse_spec.num_add_axis_after_ellipsis++;
241 }
242 if ((1 << i) & ellipsis_mask) {
243 ellipsis_seen = true;
244 }
245 }
246 // If no ellipsis insert one at the end
247 if (!ellipsis_seen) {
248 sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
249 sparse_spec.dims++; // this effects loop iteration below
250 }
251
252 // Step 2: Make a sparse spec into a full index spec
253 //
254 // The sparse spec does not correspond to the number of dimensions
255 // Make a dense spec that corresponds to the number of dimensions
256 //
257 // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
258 // we need to produce the missing begin_mask for the first two
259 // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
260 // we achieve begin_mask=6, end_mask=7
261 StridedSliceDenseSpec dense_spec = {input_shape.dims(),
262 0 /* begin_mask */,
263 0 /* end_mask */,
264 false /* begin_valid */,
265 false /* end_valid */,
266 *begin,
267 *end,
268 *strides};
269
270 if (strides_tensor.dtype() == DT_INT32) {
271 TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
272 } else if (strides_tensor.dtype() == DT_INT64) {
273 TF_RETURN_IF_ERROR(BuildDenseSpec<int64_t>(sparse_spec, &dense_spec));
274 } else if (strides_tensor.dtype() == DT_INT16) {
275 TF_RETURN_IF_ERROR(BuildDenseSpec<int16_t>(sparse_spec, &dense_spec));
276 } else {
277 LOG(FATAL) << "begin must be either int16, int32 or int64";
278 }
279
280 // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
281 // and bounds check!
282 *is_identity = true;
283 *slice_dim0 = true;
284 *is_simple_slice = true;
285 processing_shape->Clear();
286 for (int i = 0; i < input_shape.dims(); ++i) {
287 int64_t& begin_i = (*begin)[i];
288 int64_t& end_i = (*end)[i];
289 int64_t& stride_i = (*strides)[i];
290 int64_t dim_i = input_shape.dim_size(i);
291 if (stride_i == 0) {
292 return errors::InvalidArgument("strides[", i, "] must be non-zero");
293 }
294 bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
295 if (dim_i == -1) {
296 processing_shape->AddDim(shrink_i ? 1 : -1);
297 continue;
298 }
299
300 const std::array<int64_t, 2> masks = {
301 {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
302 const std::array<int64_t, 2> valid_range = {
303 {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
304
305 auto canonical = [stride_i, dim_i, masks, valid_range](int64_t x, int c) {
306 if (masks[c]) {
307 return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
308 } else {
309 int64_t x_fwd =
310 x < 0 ? dim_i + x : x; // make negative indices positive
311 return x_fwd < valid_range[0]
312 ? valid_range[0]
313 : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
314 }
315 };
316 if (shrink_i && stride_i <= 0) {
317 return errors::InvalidArgument(
318 "only stride 1 allowed on non-range indexing.");
319 }
320 (*is_simple_slice) &= stride_i == 1;
321
322 const bool begin_and_end_masked =
323 (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
324 if (dense_spec.begin_valid && dense_spec.end_valid) {
325 if (shrink_i) {
326 // If we are shrinking, the end index is now possibly incorrect. In
327 // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
328 // and canonical puts these to n-1 and 0, which implies a degenerate
329 // interval. Fortunately, it is now safe to re-create end as begin+1.
330 int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
331 begin_i = x_fwd;
332 end_i = begin_i + 1;
333 if (x_fwd < 0 || x_fwd >= dim_i) {
334 return errors::InvalidArgument(
335 "slice index ", begin_i, " of dimension ", i, " out of bounds.");
336 }
337 } else {
338 begin_i = canonical(begin_i, 0);
339 end_i = canonical(end_i, 1);
340 }
341 // Update optimization values
342 bool take_all_in_dimension =
343 stride_i == 1 && begin_i == 0 && end_i == dim_i;
344 (*is_identity) &= take_all_in_dimension;
345 (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
346 } else {
347 (*is_identity) &= stride_i == 1 && begin_and_end_masked;
348 (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
349 }
350 // Compute the processing shape (the intermediate Eigen will produce)
351 int64_t interval_length;
352 bool known_interval = false;
353 if (dense_spec.begin_valid && dense_spec.end_valid) {
354 interval_length = end_i - begin_i;
355 known_interval = true;
356 } else if (shrink_i) {
357 // The dimension is still known as 1 for the processing_shape, but will be
358 // discarded for the final shape.
359 interval_length = 1;
360 known_interval = true;
361 } else if (begin_and_end_masked) {
362 // Even if we don't have values for begin or end, we do know that this
363 // dimension covers the whole interval. If we have shape information for
364 // this dimension, that tells us the interval length.
365 if (dim_i >= 0) {
366 if (stride_i < 0) {
367 interval_length = -dim_i;
368 } else {
369 interval_length = dim_i;
370 }
371 known_interval = true;
372 }
373 }
374 if (known_interval) {
375 int64_t size_i;
376 // Hold zero if the interval is degenerate, otherwise account for
377 // remainder
378 if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
379 size_i = 0;
380 } else {
381 size_i = interval_length / stride_i +
382 (interval_length % stride_i != 0 ? 1 : 0);
383 }
384 processing_shape->AddDim(size_i);
385 } else {
386 processing_shape->AddDim(-1);
387 }
388 }
389
390 // Step 4: Compute the final shape
391 //
392 // new_axis will increase dimension by 1 (with a one-size dimension)
393 // slices like foo[3,...] will reduce dimension by 1.
394 // This cannot be done earlier, because it depends on Step 3.
395 final_shape->Clear();
396 if (shape_spec != nullptr) {
397 shape_spec->output_to_sparse_mapping.clear();
398 shape_spec->output_to_processing_mapping.clear();
399 shape_spec->processing_to_sparse_mapping.assign(
400 dense_spec.input_shape_gather_indices_sparse.begin(),
401 dense_spec.input_shape_gather_indices_sparse.end());
402
403 shape_spec->begin_dense_mask = dense_spec.begin_mask;
404 shape_spec->end_dense_mask = dense_spec.end_mask;
405 shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
406 }
407
408 for (int64_t dense_dim = 0;
409 dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
410 int64_t gather_index = dense_spec.final_shape_gather_indices[dense_dim];
411 int64_t sparse_index =
412 dense_spec.final_shape_gather_indices_sparse[dense_dim];
413 if (gather_index >= 0) {
414 final_shape->AddDim(processing_shape->dim_size(gather_index));
415 if (shape_spec != nullptr) {
416 shape_spec->output_to_sparse_mapping.push_back(sparse_index);
417 shape_spec->output_to_processing_mapping.push_back(gather_index);
418 }
419 } else if (gather_index == kNewAxis) {
420 final_shape->AddDim(1);
421 if (shape_spec != nullptr) {
422 shape_spec->output_to_sparse_mapping.push_back(-1);
423 shape_spec->output_to_processing_mapping.push_back(-1);
424 }
425 }
426 }
427
428 return OkStatus();
429}
430
431Status ValidateStridedSliceOp(
432 const Tensor* begin_tensor, const Tensor* end_tensor,
433 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
434 int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
435 int32_t new_axis_mask, int32_t shrink_axis_mask,
436 TensorShape* processing_shape, TensorShape* final_shape, bool* is_identity,
437 bool* is_simple_slice, bool* slice_dim0,
438 gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
439 gtl::InlinedVector<int64_t, 4>* strides,
440 StridedSliceShapeSpec* shape_spec) {
441 // Validate with PartialTensorShape output
442 PartialTensorShape partial_processing_shape, partial_final_shape;
443 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
444 begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
445 end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
446 &partial_processing_shape, &partial_final_shape, is_identity,
447 is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
448
449 // Verify that the output shapes are fully known
450 if (!partial_processing_shape.AsTensorShape(processing_shape) ||
451 !partial_final_shape.AsTensorShape(final_shape)) {
452 return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
453 partial_processing_shape.DebugString(), " and ",
454 partial_final_shape.DebugString());
455 }
456 return OkStatus();
457}
458
459StridedSliceAssignBCast::StridedSliceAssignBCast(
460 const StridedSliceAssignBCast::Vec& input_shape,
461 const StridedSliceAssignBCast::Vec& output_shape)
462 : valid_(true),
463 broadcasting_required_(false),
464 reshape_(output_shape.size()),
465 bcast_(output_shape.size()),
466 result_shape_(output_shape) {
467 // The input needs to be reshaped to have the same number of dimensions as
468 // the output. This is accomplished by either prepending with ones or removing
469 // leading, as necessary.
470 size_t input_start = 0;
471 size_t prepend_size = 0;
472 if (output_shape.size() < input_shape.size()) {
473 // Numpy allows assigning a larger rank array to smaller as long as
474 // broadcasting would otherwise work and the prefix dimensions are all 1.
475 // Though this behavior is undocumented, we allow it here for consistency.
476 // See https://github.com/numpy/numpy/issues/21744 for details.
477 input_start = input_shape.size() - output_shape.size();
478 for (size_t i = 0; i < input_start; ++i) {
479 if (input_shape[i] != 1) {
480 valid_ = false;
481 return;
482 }
483 }
484 } else {
485 prepend_size = output_shape.size() - input_shape.size();
486 }
487 std::fill_n(reshape_.begin(), prepend_size, 1);
488 std::copy(input_shape.begin() + input_start, input_shape.end(),
489 reshape_.begin() + prepend_size);
490
491 // In order to broadcast, dimensions must either be equal or one.
492 for (size_t i = 0; i < output_shape.size(); ++i) {
493 if (reshape_[i] == output_shape[i]) {
494 bcast_[i] = 1;
495 } else if (reshape_[i] == 1) {
496 bcast_[i] = output_shape[i];
497 broadcasting_required_ = true;
498 } else {
499 valid_ = false;
500 return;
501 }
502 }
503}
504
505bool StridedSliceAssignBCast::RemapDimensions(
506 int64_t num_dims, const StridedSliceAssignBCast::Vec& dimension_map) {
507 // Each element in the map corresponds to the original result shape, so
508 // the sizes must be equal.
509 if (dimension_map.size() != result_shape_.size()) {
510 return false;
511 }
512
513 // Ensure all indices are within-bounds before any modifications are made -
514 // otherwise we could be left in a corrupted state.
515 for (size_t i = 0; i < dimension_map.size(); ++i) {
516 int64_t dim = dimension_map[i];
517 if (dim >= num_dims) {
518 return false;
519 }
520 }
521
522 Vec old_reshape = std::move(reshape_);
523 Vec old_bcast = std::move(bcast_);
524 Vec old_result_shape = std::move(result_shape_);
525 reshape_ = Vec(num_dims);
526 bcast_ = Vec(num_dims);
527 result_shape_ = Vec(num_dims);
528 std::fill_n(reshape_.begin(), num_dims, 1);
529 std::fill_n(bcast_.begin(), num_dims, 1);
530 std::fill_n(result_shape_.begin(), num_dims, 1);
531 for (size_t i = 0; i < dimension_map.size(); ++i) {
532 int64_t dim = dimension_map[i];
533 if (dim >= 0) {
534 reshape_[dim] = old_reshape[i];
535 bcast_[dim] = old_bcast[i];
536 result_shape_[dim] = old_result_shape[i];
537 }
538 }
539
540 return true;
541}
542
543} // namespace tensorflow
544