1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21#define EIGEN_USE_GPU
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include "tensorflow/core/kernels/slice_op.h"
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/kernels/ops_util.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/lib/gtl/array_slice.h"
33#include "tensorflow/core/platform/prefetch.h"
34
35namespace tensorflow {
36
37namespace {
38
39void IntTensorToInt64Vec(const Tensor& tensor,
40 gtl::InlinedVector<int64_t, 4>* out) {
41 out->resize(tensor.NumElements());
42 int64_t* out_ptr = out->data();
43 if (tensor.dtype() == DT_INT32) {
44 const int32* tensor_ptr = tensor.flat<int32>().data();
45 for (int64_t i = 0; i < tensor.NumElements(); ++i) {
46 out_ptr[i] = tensor_ptr[i];
47 }
48 } else if (tensor.dtype() == DT_INT64) {
49 const int64_t* tensor_ptr = tensor.flat<int64_t>().data();
50 for (int64_t i = 0; i < tensor.NumElements(); ++i) {
51 out_ptr[i] = tensor_ptr[i];
52 }
53 } else {
54 LOG(FATAL) << "begin must be either int32 or int64";
55 }
56}
57
58typedef Eigen::ThreadPoolDevice CPUDevice;
59typedef Eigen::GpuDevice GPUDevice;
60
61// Shared code that is not dependent on the type of T. We do this to reduce
62// code size by not duplicating all this for all T (float, double, int32, etc.)
63void SharedSliceValidation(OpKernelContext* context, const Tensor& input,
64 TensorShape* output_shape, bool* is_identity,
65 bool* slice_dim0,
66 gtl::InlinedVector<int64_t, 4>* begin,
67 gtl::InlinedVector<int64_t, 4>* size) {
68 const Tensor& begin_tensor = context->input(1);
69 const Tensor& size_tensor = context->input(2);
70
71 OP_REQUIRES(
72 context,
73 TensorShapeUtils::IsVector(begin_tensor.shape()) &&
74 TensorShapeUtils::IsVector(size_tensor.shape()) &&
75 begin_tensor.NumElements() == input.dims() &&
76 size_tensor.NumElements() == input.dims(),
77 errors::InvalidArgument(
78 "Expected begin and size arguments to be 1-D tensors of size ",
79 input.dims(), ", but got shapes ", begin_tensor.shape().DebugString(),
80 " and ", size_tensor.shape().DebugString(), " instead."));
81
82 const int input_dims = input.dims();
83 IntTensorToInt64Vec(begin_tensor, begin);
84 IntTensorToInt64Vec(size_tensor, size);
85 for (int i = 0; i < input_dims; ++i) {
86 if ((*size)[i] == -1) {
87 // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
88 (*size)[i] = input.dim_size(i) - (*begin)[i];
89 }
90 }
91
92 *is_identity = true;
93 *slice_dim0 = true;
94 for (int i = 0; i < input_dims; ++i) {
95 int64_t b = (*begin)[i];
96 int64_t s = (*size)[i];
97 if (input.dim_size(i) == 0) {
98 OP_REQUIRES(
99 context, b == 0 && s == 0,
100 errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
101 ") and size[", i, "] == 0 ", "(got ", s,
102 ") when ", "input.dim_size(", i, ") == 0"));
103 } else {
104 OP_REQUIRES(context, 0 <= b && b <= input.dim_size(i),
105 errors::InvalidArgument("Expected begin[", i, "] in [0, ",
106 input.dim_size(i), "], but got ", b));
107 OP_REQUIRES(
108 context, 0 <= s && b + s <= input.dim_size(i),
109 errors::InvalidArgument("Expected size[", i, "] in [0, ",
110 input.dim_size(i) - b, "], but ", "got ", s));
111 }
112 output_shape->AddDim(s);
113 const bool take_all = (b == 0) && (s == input.dim_size(i));
114 (*is_identity) &= take_all;
115 (*slice_dim0) &= (i == 0) || take_all;
116 }
117}
118
119// Extracted out code in SliceOp::Compute so that MklSliceOp can reuse this
120// generic code
121template <typename T>
122static void SharedSliceCommonCases(OpKernelContext* context,
123 const Tensor& input,
124 gtl::InlinedVector<int64, 4>* begin,
125 gtl::InlinedVector<int64, 4>* size,
126 Tensor** result, bool* done) {
127 bool is_identity = true;
128 bool slice_dim0 = true;
129 TensorShape output_shape;
130 *done = false;
131
132 SharedSliceValidation(context, input, &output_shape, &is_identity,
133 &slice_dim0, begin, size);
134 if (!context->status().ok()) return;
135 if (is_identity) {
136 VLOG(1) << "Slice identity";
137 context->set_output(0, input);
138 *done = true;
139 return;
140 }
141
142 if (slice_dim0 &&
143 IsDim0SliceAligned<T>(input.shape(), (*begin)[0], (*size)[0])) {
144 VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
145 CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
146 context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0]));
147 *done = true;
148 return;
149 }
150
151 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, result));
152}
153
154template <typename Device, typename T>
155class SliceOp : public OpKernel {
156 public:
157 explicit SliceOp(OpKernelConstruction* context) : OpKernel(context) {}
158
159 void Compute(OpKernelContext* context) override {
160 gtl::InlinedVector<int64_t, 4> begin;
161 gtl::InlinedVector<int64_t, 4> size;
162 const Tensor& input = context->input(0);
163 Tensor* result = nullptr;
164 bool done = false;
165 SharedSliceCommonCases<T>(context, input, &begin, &size, &result, &done);
166 if (!context->status().ok() || done == true) return;
167
168 const int input_dims = input.dims();
169
170 if (result->NumElements() > 0) {
171 if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
172 DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
173 auto input_t = input.tensor<T, 2>();
174 auto output_t = result->tensor<T, 2>();
175
176 const int64_t row_begin = begin[0];
177 const int64_t col_begin = begin[1];
178 const int64_t row_size = size[0];
179 const int64_t col_size = size[1];
180
181 // TODO(agarwal): Consider multi-threading this loop for cases where
182 // row_size is very large.
183 for (int i = 0; i < row_size; ++i) {
184 const int64_t row = row_begin + i;
185 if (i + 1 < size[0]) {
186 port::prefetch<port::PREFETCH_HINT_T0>(&output_t(i + 1, 0));
187 port::prefetch<port::PREFETCH_HINT_T0>(
188 &input_t(row + 1, col_begin));
189 }
190 memcpy(&output_t(i, 0), &input_t(row, col_begin),
191 col_size * sizeof(T));
192 }
193 return;
194 }
195#define HANDLE_DIM(NDIM) \
196 if (input_dims == NDIM) { \
197 HandleCase<NDIM>(context, begin, size, input, result); \
198 return; \
199 }
200
201 HANDLE_DIM(1);
202 HANDLE_DIM(2);
203 HANDLE_DIM(3);
204 HANDLE_DIM(4);
205 HANDLE_DIM(5);
206 HANDLE_DIM(6);
207 HANDLE_DIM(7);
208 HANDLE_DIM(8);
209
210#undef HANDLE_DIM
211
212 OP_REQUIRES(
213 context, false,
214 errors::Unimplemented("SliceOp : Unhandled input dimensions"));
215 }
216 }
217
218 private:
219 template <int NDIM>
220 void HandleCase(OpKernelContext* context, gtl::ArraySlice<int64_t> begin,
221 gtl::ArraySlice<int64_t> size, const Tensor& input,
222 Tensor* result) {
223 Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
224 Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
225 for (int i = 0; i < NDIM; ++i) {
226 indices[i] = begin[i];
227 sizes[i] = size[i];
228 }
229
230 functor::Slice<Device, T, NDIM>()(context->eigen_device<Device>(),
231 result->tensor<T, NDIM>(),
232 input.tensor<T, NDIM>(), indices, sizes);
233 }
234};
235
236} // namespace
237
238// Forward declarations of the functor specializations for declared in the
239// sharded source files.
240namespace functor {
241#define DECLARE_CPU_SPEC(T, NDIM) \
242 template <> \
243 void Slice<CPUDevice, T, NDIM>::operator()( \
244 const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
245 typename TTypes<T, NDIM>::ConstTensor input, \
246 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
247 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
248 extern template struct Slice<CPUDevice, T, NDIM>;
249
250#define DECLARE_FOR_N(T) \
251 DECLARE_CPU_SPEC(T, 1); \
252 DECLARE_CPU_SPEC(T, 2); \
253 DECLARE_CPU_SPEC(T, 3); \
254 DECLARE_CPU_SPEC(T, 4); \
255 DECLARE_CPU_SPEC(T, 5); \
256 DECLARE_CPU_SPEC(T, 6); \
257 DECLARE_CPU_SPEC(T, 7); \
258 DECLARE_CPU_SPEC(T, 8);
259
260TF_CALL_ALL_TYPES(DECLARE_FOR_N);
261
262#undef DECLARE_FOR_N
263#undef DECLARE_CPU_SPEC
264} // namespace functor
265
266#define REGISTER_SLICE(type) \
267 REGISTER_KERNEL_BUILDER(Name("Slice") \
268 .Device(DEVICE_CPU) \
269 .TypeConstraint<type>("T") \
270 .HostMemory("begin") \
271 .HostMemory("size"), \
272 SliceOp<CPUDevice, type>)
273
274TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
275TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
276#undef REGISTER_SLICE
277
278#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
279// Forward declarations of the functor specializations for GPU.
280namespace functor {
281#define DECLARE_GPU_SPEC(T, NDIM) \
282 template <> \
283 void Slice<GPUDevice, T, NDIM>::operator()( \
284 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
285 typename TTypes<T, NDIM>::ConstTensor input, \
286 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
287 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
288 extern template struct Slice<GPUDevice, T, NDIM>;
289
290#define DECLARE_FOR_N(T) \
291 DECLARE_GPU_SPEC(T, 1); \
292 DECLARE_GPU_SPEC(T, 2); \
293 DECLARE_GPU_SPEC(T, 3); \
294 DECLARE_GPU_SPEC(T, 4); \
295 DECLARE_GPU_SPEC(T, 5); \
296 DECLARE_GPU_SPEC(T, 6); \
297 DECLARE_GPU_SPEC(T, 7); \
298 DECLARE_GPU_SPEC(T, 8);
299
300TF_CALL_bfloat16(DECLARE_FOR_N);
301TF_CALL_int8(DECLARE_FOR_N);
302TF_CALL_int32(DECLARE_FOR_N);
303TF_CALL_int64(DECLARE_FOR_N);
304TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N);
305
306#undef DECLARE_FOR_N
307#undef DECLARE_GPU_SPEC
308} // namespace functor
309
310#define REGISTER_GPU(type) \
311 REGISTER_KERNEL_BUILDER(Name("Slice") \
312 .Device(DEVICE_GPU) \
313 .TypeConstraint<type>("T") \
314 .HostMemory("begin") \
315 .HostMemory("size"), \
316 SliceOp<GPUDevice, type>)
317
318TF_CALL_bfloat16(REGISTER_GPU);
319TF_CALL_int8(REGISTER_GPU);
320TF_CALL_int64(REGISTER_GPU);
321TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
322
323#undef REGISTER_GPU
324
325#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
326
327// A special DEVICE_DEFAULT kernel for int32.
328// TODO(b/25387198): Also enable int32 in device memory. This kernel
329// registration requires all int32 inputs and outputs to be in host memory.
330REGISTER_KERNEL_BUILDER(Name("Slice")
331 .Device(DEVICE_DEFAULT)
332 .TypeConstraint<int32>("T")
333 .HostMemory("input")
334 .HostMemory("begin")
335 .HostMemory("size")
336 .HostMemory("output"),
337 SliceOp<CPUDevice, int32>);
338
339} // namespace tensorflow
340