1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21#define EIGEN_USE_GPU
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include "tensorflow/core/kernels/strided_slice_op.h"
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/framework/bounds_check.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/op_requires.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/kernels/dense_update_functor.h"
33#include "tensorflow/core/kernels/inplace_ops_functor.h"
34#include "tensorflow/core/kernels/ops_util.h"
35#include "tensorflow/core/kernels/slice_op.h"
36#include "tensorflow/core/kernels/strided_slice_op_impl.h"
37#include "tensorflow/core/kernels/training_op_helpers.h"
38#include "tensorflow/core/kernels/variable_ops.h"
39#include "tensorflow/core/lib/core/refcount.h"
40#include "tensorflow/core/lib/core/status.h"
41#include "tensorflow/core/lib/gtl/array_slice.h"
42#include "tensorflow/core/platform/errors.h"
43#include "tensorflow/core/platform/prefetch.h"
44#include "tensorflow/core/platform/status.h"
45#include "tensorflow/core/util/strided_slice_op.h"
46
47namespace tensorflow {
48namespace {
49
50template <typename T>
51struct MemCpyFunctor {
52 // Returns true if the copy was made with memcpy, false otherwise.
53 bool Copy(const Tensor& input, const gtl::InlinedVector<int64_t, 4>& begin,
54 const gtl::InlinedVector<int64_t, 4>& end, Tensor* result) {
55 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
56 auto in = input.tensor<T, 2>();
57 auto output = result->tensor<T, 2>();
58 // TODO(agarwal): Consider multi-threading if size[0] is large
59 for (int row_in = begin[0], row_out = 0; row_in < end[0];
60 ++row_in, ++row_out) {
61 if (row_in + 1 < end[0]) {
62 port::prefetch<port::PREFETCH_HINT_T0>(&output(row_in + 1, 0));
63 port::prefetch<port::PREFETCH_HINT_T0>(&in(row_in + 1, begin[1]));
64 }
65 memcpy(&output(row_out, 0), &in(row_in, begin[1]),
66 (end[1] - begin[1]) * sizeof(T));
67 }
68 return true;
69 }
70 return false;
71 }
72};
73
74template <>
75struct MemCpyFunctor<ResourceHandle> {
76 bool Copy(const Tensor& input, const gtl::InlinedVector<int64_t, 4>& begin,
77 const gtl::InlinedVector<int64_t, 4>& end, Tensor* result) {
78 return false;
79 }
80};
81
82} // namespace
83
84template <typename Device, typename T>
85class StridedSliceOp : public OpKernel {
86 public:
87 explicit StridedSliceOp(OpKernelConstruction* context) : OpKernel(context) {
88 OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
89 OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
90 OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
91 OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
92 OP_REQUIRES_OK(context,
93 context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
94 }
95
96 void Compute(OpKernelContext* context) override {
97 TensorShape processing_shape, final_shape;
98 bool is_identity = true;
99 bool slice_dim0 = true;
100 bool is_simple_slice = true;
101 gtl::InlinedVector<int64_t, 4> begin;
102 gtl::InlinedVector<int64_t, 4> end;
103 gtl::InlinedVector<int64_t, 4> strides;
104
105 OP_REQUIRES_OK(
106 context, ValidateStridedSliceOp(
107 &context->input(1), &context->input(2), context->input(3),
108 context->input(0).shape(), begin_mask, end_mask,
109 ellipsis_mask, new_axis_mask, shrink_axis_mask,
110 &processing_shape, &final_shape, &is_identity,
111 &is_simple_slice, &slice_dim0, &begin, &end, &strides));
112 const Tensor& input = context->input(0);
113
114 // Optimization #1, slice is a no-op plus reshape
115 if (is_identity) {
116 VLOG(1) << "Strided slice identity ";
117 Tensor tmp;
118 OP_REQUIRES(context, tmp.CopyFrom(input, final_shape),
119 errors::Internal("Copy failed"));
120 context->set_output(0, tmp);
121 return;
122 }
123
124 // Optimization #2, slice is memory contiguous (only occurs in dim 0)
125 if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], end[0])) {
126 OP_REQUIRES(context, input.dims() >= 1,
127 errors::InvalidArgument(
128 "Input must have rank at least 1, got: ", input.dims()));
129 // Otherwise, is_identity should be true.
130 VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString();
131 // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
132 Tensor slice = input.Slice(std::min(begin[0], end[0]), end[0]);
133 Tensor tmp;
134 OP_REQUIRES(context, tmp.CopyFrom(slice, final_shape),
135 errors::Internal("Copy failed"));
136 context->set_output(0, tmp);
137 return;
138 }
139
140 Tensor* result = nullptr;
141 OP_REQUIRES_OK(context, context->allocate_output(0, final_shape, &result));
142 const int input_dims = input.dims();
143 const int processing_dims = processing_shape.dims();
144
145 if (processing_shape.num_elements() > 0) {
146 // Optimization #3, slice has stride 1 in all dimensions
147 // Optimization #3A, slice has only two dimensions
148 // TODO(aselle): Here we are restricting to processing_shape and
149 // final_shape being 2D. This isn't strictly necessary, but I don't
150 // want to blow up code gen size, because to shape<> you need static
151 // NDIM and T
152 if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
153 input_dims == 2 && processing_shape.dims() == 2 &&
154 final_shape.dims() == 2 && new_axis_mask == 0) {
155 MemCpyFunctor<T> functor;
156 if (functor.Copy(input, begin, end, result)) {
157 return;
158 }
159 }
160
161#define HANDLE_DIM(NDIM) \
162 if (processing_dims == NDIM) { \
163 HandleStridedSliceCase<Device, T, NDIM>(context, begin, end, strides, \
164 processing_shape, is_simple_slice, \
165 result); \
166 return; \
167 }
168
169 HANDLE_DIM(1);
170 HANDLE_DIM(2);
171 HANDLE_DIM(3);
172 HANDLE_DIM(4);
173 HANDLE_DIM(5);
174 HANDLE_DIM(6);
175 HANDLE_DIM(7);
176 HANDLE_DIM(8);
177
178#undef HANDLE_DIM
179
180 OP_REQUIRES(
181 context, false,
182 errors::Unimplemented("Unhandled input dimensions ", input_dims));
183 }
184 }
185
186 private:
187 int32 begin_mask, end_mask;
188 int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
189};
190
191template <typename Device, typename T>
192class StridedSliceGradOp : public OpKernel {
193 public:
194 explicit StridedSliceGradOp(OpKernelConstruction* context)
195 : OpKernel(context) {
196 OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
197 OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
198 OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
199 OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
200 OP_REQUIRES_OK(context,
201 context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
202 }
203
204 void Compute(OpKernelContext* context) override {
205 TensorShape processing_shape, final_shape;
206 bool is_identity = true;
207 bool slice_dim0 = true;
208 bool is_simple_slice = true;
209 gtl::InlinedVector<int64_t, 4> begin;
210 gtl::InlinedVector<int64_t, 4> end;
211 gtl::InlinedVector<int64_t, 4> strides;
212
213 TensorShape input_shape;
214 const Tensor& input_shape_tensor = context->input(0);
215 OP_REQUIRES(
216 context, input_shape_tensor.dims() == 1,
217 errors::InvalidArgument("shape must be 1-D, got shape.shape = ",
218 input_shape_tensor.shape().DebugString()));
219 if (input_shape_tensor.dtype() == DT_INT32) {
220 OP_REQUIRES_OK(
221 context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int32>(),
222 &input_shape));
223 } else if (input_shape_tensor.dtype() == DT_INT64) {
224 OP_REQUIRES_OK(context,
225 TensorShapeUtils::MakeShape(
226 input_shape_tensor.vec<int64_t>(), &input_shape));
227 } else {
228 LOG(FATAL) << "shape must have type int32 or int64.";
229 }
230
231 OP_REQUIRES_OK(
232 context,
233 ValidateStridedSliceOp(
234 &context->input(1), &context->input(2), context->input(3),
235 input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
236 shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
237 &is_simple_slice, &slice_dim0, &begin, &end, &strides));
238
239 // Check to make sure dy is consistent with the original slice
240 TensorShape dy_shape = context->input(4).shape();
241 OP_REQUIRES(
242 context, final_shape == dy_shape,
243 errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(),
244 " instead of ", final_shape.DebugString()));
245
246 if (!context->status().ok()) return;
247
248 // const int input_dims = input.dims();
249 const int processing_dims = processing_shape.dims();
250 Tensor* result = nullptr;
251 OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &result));
252
253 if (processing_shape.dims() == 0) {
254 auto in = context->input(4);
255 OP_REQUIRES(context, result->CopyFrom(in, processing_shape),
256 errors::Internal("Copy failed"));
257 return;
258 }
259
260#define HANDLE_DIM(NDIM) \
261 if (processing_dims == NDIM) { \
262 HandleStridedSliceGradCase<Device, T, NDIM>(context, begin, end, strides, \
263 processing_shape, \
264 is_simple_slice, result); \
265 return; \
266 }
267
268 HANDLE_DIM(1);
269 HANDLE_DIM(2);
270 HANDLE_DIM(3);
271 HANDLE_DIM(4);
272 HANDLE_DIM(5);
273 HANDLE_DIM(6);
274 HANDLE_DIM(7);
275 HANDLE_DIM(8);
276
277#undef HANDLE_DIM
278 }
279
280 private:
281 int32 begin_mask, end_mask;
282 int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
283};
284
285template <typename Device, typename T, bool isTensor>
286class StridedSliceAssignOp : public OpKernel {
287 public:
288 explicit StridedSliceAssignOp(OpKernelConstruction* context)
289 : OpKernel(context) {
290 OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
291 OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
292 OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
293 OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
294 OP_REQUIRES_OK(context,
295 context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
296 }
297
298 void Compute(OpKernelContext* context) override {
299 TensorShape processing_shape; // Without reshaping input.
300 TensorShape final_shape; // After reshaping input.
301 bool is_identity = true;
302 bool slice_dim0 = true;
303 bool is_simple_slice = true;
304 gtl::InlinedVector<int64_t, 4> begin;
305 gtl::InlinedVector<int64_t, 4> end;
306 gtl::InlinedVector<int64_t, 4> strides;
307
308 Tensor* old_lhs = nullptr;
309 Tensor tmp;
310 if (isTensor) {
311 const Tensor& input = context->input(0);
312
313 int forwarded_input;
314 OP_REQUIRES_OK(context,
315 context->forward_input_or_allocate_output(
316 {0}, 0, input.shape(), &old_lhs, &forwarded_input));
317 if (forwarded_input < 0) {
318 OP_REQUIRES_OK(context,
319 tensorflow::functor::DoCopy(
320 context->eigen_device<Device>(), input, old_lhs));
321 }
322 } else {
323 if (context->input_dtype(0) == DT_RESOURCE) {
324 core::RefCountPtr<Var> v;
325 OP_REQUIRES_OK(
326 context, LookupResource(context, HandleFromInput(context, 0), &v));
327 OP_REQUIRES_OK(context,
328 EnsureSparseVariableAccess<Device, T>(context, v.get()));
329 mutex_lock ml(*v->mu());
330 old_lhs = v->tensor();
331 OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
332 errors::InvalidArgument(
333 "l-value dtype ", DataTypeString(old_lhs->dtype()),
334 " does not match r-value dtype ",
335 DataTypeString(DataTypeToEnum<T>::value)));
336 } else {
337 context->forward_ref_input_to_ref_output(0, 0);
338 tmp = context->mutable_input(0, true);
339 old_lhs = &tmp;
340 }
341 }
342
343 StridedSliceShapeSpec shape_spec;
344 OP_REQUIRES_OK(
345 context, ValidateStridedSliceOp(
346 &context->input(1), &context->input(2), context->input(3),
347 old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
348 new_axis_mask, shrink_axis_mask, &processing_shape,
349 &final_shape, &is_identity, &is_simple_slice, &slice_dim0,
350 &begin, &end, &strides, &shape_spec));
351
352 if (processing_shape.num_elements() > 0) {
353 const Tensor& input = context->input(4);
354 TensorShape input_shape = input.shape();
355 TensorShape original_shape = old_lhs->shape();
356 const int processing_dims = processing_shape.dims();
357
358 StridedSliceAssignBCast bcast(input_shape.dim_sizes(),
359 final_shape.dim_sizes());
360 OP_REQUIRES(context, bcast.IsValid(),
361 errors::InvalidArgument("Cannot broadcast input shape ",
362 input_shape.DebugString(),
363 " into final shape ",
364 final_shape.DebugString()));
365
366 // The assignment RHS and broadcast spec need to be remapped to
367 // the same number of dimensions as the unstrided LHS (i.e. processing
368 // dimensions). This adds back any shrink axes and removes new axes.
369 bool remap_valid = bcast.RemapDimensions(
370 processing_dims, shape_spec.output_to_processing_mapping);
371 // Sanity check. The following should never fail.
372 DCHECK(remap_valid) << "Failed to remap output shape "
373 << final_shape.DebugString()
374 << " to processing shape "
375 << processing_shape.DebugString();
376
377 // 0-dimensional case implies the left and right are exactly the same
378 // scalar shape
379
380// Handle general dimensions. The do{}while(false) construct is a common
381// approach to avoid pedantic extra semicolon warnings.
382#define HANDLE_DIM(NDIM) \
383 do { \
384 if (processing_dims == NDIM) { \
385 HandleStridedSliceAssignCase<Device, T, NDIM>()( \
386 context, begin, end, strides, bcast, old_lhs); \
387 return; \
388 } \
389 } while (false)
390 HANDLE_DIM(0);
391 HANDLE_DIM(1);
392 HANDLE_DIM(2);
393 HANDLE_DIM(3);
394 HANDLE_DIM(4);
395 HANDLE_DIM(5);
396 HANDLE_DIM(6);
397 HANDLE_DIM(7);
398 HANDLE_DIM(8);
399#undef HANDLE_DIM
400
401 OP_REQUIRES(context, false,
402 errors::Unimplemented("Unhandled input dimensions ",
403 processing_dims));
404 }
405 }
406
407 private:
408 int32 begin_mask, end_mask;
409 int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
410};
411
412#define REGISTER_STRIDED_SLICE(type) \
413 REGISTER_KERNEL_BUILDER(Name("StridedSlice") \
414 .Device(DEVICE_CPU) \
415 .TypeConstraint<type>("T") \
416 .HostMemory("begin") \
417 .HostMemory("end") \
418 .HostMemory("strides"), \
419 StridedSliceOp<CPUDevice, type>) \
420 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \
421 .Device(DEVICE_CPU) \
422 .TypeConstraint<type>("T") \
423 .HostMemory("shape") \
424 .HostMemory("begin") \
425 .HostMemory("end") \
426 .HostMemory("strides"), \
427 StridedSliceGradOp<CPUDevice, type>) \
428 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \
429 .Device(DEVICE_CPU) \
430 .TypeConstraint<type>("T") \
431 .HostMemory("begin") \
432 .HostMemory("end") \
433 .HostMemory("strides"), \
434 StridedSliceAssignOp<CPUDevice, type, false>) \
435 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \
436 .Device(DEVICE_CPU) \
437 .TypeConstraint<type>("T") \
438 .HostMemory("ref") \
439 .HostMemory("begin") \
440 .HostMemory("end") \
441 .HostMemory("strides"), \
442 StridedSliceAssignOp<CPUDevice, type, false>) \
443 REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
444 .Device(DEVICE_CPU) \
445 .TypeConstraint<type>("T") \
446 .HostMemory("begin") \
447 .HostMemory("end") \
448 .HostMemory("strides"), \
449 StridedSliceAssignOp<CPUDevice, type, true>)
450
451TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
452TF_CALL_QUANTIZED_TYPES(REGISTER_STRIDED_SLICE);
453
454#undef REGISTER_STRIDED_SLICE
455
456#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
457
458#define REGISTER_GPU(type) \
459 REGISTER_KERNEL_BUILDER(Name("StridedSlice") \
460 .Device(DEVICE_GPU) \
461 .TypeConstraint<type>("T") \
462 .HostMemory("begin") \
463 .HostMemory("end") \
464 .HostMemory("strides"), \
465 StridedSliceOp<GPUDevice, type>) \
466 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \
467 .Device(DEVICE_GPU) \
468 .TypeConstraint<type>("T") \
469 .HostMemory("shape") \
470 .HostMemory("begin") \
471 .HostMemory("end") \
472 .HostMemory("strides"), \
473 StridedSliceGradOp<GPUDevice, type>) \
474 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \
475 .Device(DEVICE_GPU) \
476 .TypeConstraint<type>("T") \
477 .HostMemory("begin") \
478 .HostMemory("end") \
479 .HostMemory("strides"), \
480 StridedSliceAssignOp<GPUDevice, type, false>) \
481 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \
482 .Device(DEVICE_GPU) \
483 .TypeConstraint<type>("T") \
484 .HostMemory("ref") \
485 .HostMemory("begin") \
486 .HostMemory("end") \
487 .HostMemory("strides"), \
488 StridedSliceAssignOp<GPUDevice, type, false>) \
489 REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
490 .Device(DEVICE_GPU) \
491 .TypeConstraint<type>("T") \
492 .HostMemory("begin") \
493 .HostMemory("end") \
494 .HostMemory("strides"), \
495 StridedSliceAssignOp<GPUDevice, type, true>)
496
497TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU);
498TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
499
500// A special GPU kernel for int32.
501// TODO(b/25387198): Also enable int32 in device memory. This kernel
502// registration requires all int32 inputs and outputs to be in host memory.
503REGISTER_KERNEL_BUILDER(Name("StridedSlice")
504 .Device(DEVICE_GPU)
505 .TypeConstraint<int32>("T")
506 .HostMemory("input")
507 .HostMemory("begin")
508 .HostMemory("end")
509 .HostMemory("strides")
510 .HostMemory("output"),
511 StridedSliceOp<CPUDevice, int32>);
512REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")
513 .Device(DEVICE_GPU)
514 .TypeConstraint<int32>("T")
515 .HostMemory("shape")
516 .HostMemory("begin")
517 .HostMemory("end")
518 .HostMemory("strides")
519 .HostMemory("dy")
520 .HostMemory("output"),
521 StridedSliceGradOp<CPUDevice, int32>);
522REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")
523 .Device(DEVICE_GPU)
524 .TypeConstraint<int32>("T")
525 .HostMemory("ref")
526 .HostMemory("begin")
527 .HostMemory("end")
528 .HostMemory("strides")
529 .HostMemory("value"),
530 StridedSliceAssignOp<CPUDevice, int32, false>);
531REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
532 .Device(DEVICE_GPU)
533 .TypeConstraint<int32>("T")
534 .HostMemory("ref")
535 .HostMemory("begin")
536 .HostMemory("end")
537 .HostMemory("strides")
538 .HostMemory("value"),
539 StridedSliceAssignOp<CPUDevice, int32, false>);
540REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
541 .Device(DEVICE_GPU)
542 .TypeConstraint<int32>("T")
543 .HostMemory("input")
544 .HostMemory("begin")
545 .HostMemory("end")
546 .HostMemory("strides")
547 .HostMemory("value")
548 .HostMemory("output"),
549 StridedSliceAssignOp<CPUDevice, int32, true>);
550#undef REGISTER_GPU
551
552#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
553
554} // namespace tensorflow
555