1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
17#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
18
19// Functor definition for StridedSliceOp, must be compilable by nvcc.
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/register_types_traits.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/variant.h"
28#include "tensorflow/core/framework/variant_encode_decode.h"
29#include "tensorflow/core/kernels/dense_update_functor.h"
30#include "tensorflow/core/kernels/ops_util.h"
31#include "tensorflow/core/kernels/slice_op.h"
32#include "tensorflow/core/kernels/strided_slice_op.h"
33#include "tensorflow/core/lib/core/status.h"
34#include "tensorflow/core/lib/gtl/array_slice.h"
35#include "tensorflow/core/platform/mem.h"
36
37namespace tensorflow {
38
39template <typename Device, typename T, int NDIM>
40void HandleStridedSliceCase(OpKernelContext* context,
41 const gtl::ArraySlice<int64_t>& begin,
42 const gtl::ArraySlice<int64_t>& end,
43 const gtl::ArraySlice<int64_t>& strides,
44 const TensorShape& processing_shape,
45 bool is_simple_slice, Tensor* result);
46
47template <typename Device, typename T, int NDIM>
48void HandleStridedSliceGradCase(OpKernelContext* context,
49 const gtl::ArraySlice<int64_t>& begin,
50 const gtl::ArraySlice<int64_t>& end,
51 const gtl::ArraySlice<int64_t>& strides,
52 const TensorShape& processing_shape,
53 bool is_simple_slice, Tensor* result);
54
55template <typename Device, typename T, int NDIM>
56class HandleStridedSliceAssignCase {
57 public:
58 void operator()(OpKernelContext* context,
59 const gtl::ArraySlice<int64_t>& begin,
60 const gtl::ArraySlice<int64_t>& end,
61 const gtl::ArraySlice<int64_t>& strides,
62 const StridedSliceAssignBCast& bcast, Tensor* result);
63};
64} // namespace tensorflow
65
66// The actual implementation. This is designed so multiple
67// translation units can include this file in the form
68//
69// #define STRIDED_SLICE_INSTANTIATE_DIM 1
70// #include <thisfile>
71// #undef STRIDED_SLICE_INSTANTIATE_DIM
72//
73#ifdef STRIDED_SLICE_INSTANTIATE_DIM
74
75namespace tensorflow {
76
77template <typename Device, typename T, int NDIM>
78void HandleStridedSliceCase(OpKernelContext* context,
79 const gtl::ArraySlice<int64_t>& begin,
80 const gtl::ArraySlice<int64_t>& end,
81 const gtl::ArraySlice<int64_t>& strides,
82 const TensorShape& processing_shape,
83 bool is_simple_slice, Tensor* result) {
84 typedef typename proxy_type<Device, T>::type Proxy;
85
86 gtl::InlinedVector<int64_t, 4> processing_dims = processing_shape.dim_sizes();
87 if (is_simple_slice) {
88 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
89 Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di;
90 for (int i = 0; i < NDIM; ++i) {
91 begin_di[i] = begin[i];
92 sizes_di[i] = end[i] - begin[i];
93 }
94 functor::Slice<Device, Proxy, NDIM>()(
95 context->eigen_device<Device>(),
96 result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
97 context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di);
98 } else {
99 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
100 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
101 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
102 for (int i = 0; i < NDIM; ++i) {
103 begin_di[i] = begin[i];
104 end_di[i] = end[i];
105 strides_di[i] = strides[i];
106 }
107 functor::StridedSlice<Device, Proxy, NDIM>()(
108 context->eigen_device<Device>(),
109 result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
110 context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, end_di,
111 strides_di);
112 }
113}
114
115template <typename Device, typename T, int NDIM>
116void HandleStridedSliceGradCase(OpKernelContext* context,
117 const gtl::ArraySlice<int64_t>& begin,
118 const gtl::ArraySlice<int64_t>& end,
119 const gtl::ArraySlice<int64_t>& strides,
120 const TensorShape& processing_shape,
121 bool is_simple_slice, Tensor* result) {
122 gtl::InlinedVector<int64_t, 4> processing_dims = processing_shape.dim_sizes();
123
124 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
125 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
126 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
127 for (int i = 0; i < NDIM; ++i) {
128 begin_di[i] = begin[i];
129 end_di[i] = end[i];
130 strides_di[i] = strides[i];
131 }
132
133 typedef typename proxy_type<Device, T>::type Proxy;
134 functor::StridedSliceGrad<Device, Proxy, NDIM>()(
135 context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
136 context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
137 begin_di, end_di, strides_di);
138}
139
140template <typename Device, typename T, int NDIM>
141void HandleStridedSliceAssignCase<Device, T, NDIM>::operator()(
142 OpKernelContext* context, const gtl::ArraySlice<int64_t>& begin,
143 const gtl::ArraySlice<int64_t>& end,
144 const gtl::ArraySlice<int64_t>& strides,
145 const StridedSliceAssignBCast& bcast, Tensor* result) {
146 typedef typename proxy_type<Device, T>::type Proxy;
147 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
148 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
149 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
150 for (int i = 0; i < NDIM; ++i) {
151 begin_di[i] = begin[i];
152 end_di[i] = end[i];
153 strides_di[i] = strides[i];
154 }
155
156 constexpr int kRhsInput = 4;
157 const Tensor& input = context->input(kRhsInput);
158 functor::StridedSliceAssign<Device, Proxy, NDIM>()(
159 context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
160 input.bit_casted_shaped<Proxy, NDIM>(bcast.reshape()), begin_di, end_di,
161 strides_di, bcast);
162}
163
164template <typename Device, typename T>
165class HandleStridedSliceAssignCase<Device, T, 0> {
166 public:
167 enum { NDIM_PROXY = 1 };
168 void operator()(OpKernelContext* context,
169 const gtl::ArraySlice<int64_t>& begin,
170 const gtl::ArraySlice<int64_t>& end,
171 const gtl::ArraySlice<int64_t>& strides,
172 const StridedSliceAssignBCast& bcast, Tensor* result) {
173 gtl::InlinedVector<int64_t, 1> processing_dims(1);
174 processing_dims[0] = 1;
175
176 typedef typename proxy_type<Device, T>::type Proxy;
177 functor::StridedSliceAssignScalar<Device, Proxy>()(
178 context->eigen_device<Device>(),
179 result->bit_casted_shaped<Proxy, 1>(processing_dims),
180 context->input(4).bit_casted_shaped<Proxy, 1>(processing_dims));
181 }
182};
183
184// NOTE(aselle): according to bsteiner, we need this because otherwise
185// nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu
186// handles instantiates externally. It is important that this is done
187// before the HandleXXCase's are instantiated to avoid duplicate
188// specialization errors.
189
190#define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) \
191 namespace functor { \
192 template <> \
193 void StridedSlice<GPUDevice, T, NDIM>::operator()( \
194 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
195 typename TTypes<T, NDIM>::ConstTensor input, \
196 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \
197 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \
198 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \
199 extern template struct StridedSlice<GPUDevice, T, NDIM>; \
200 template <> \
201 void Slice<GPUDevice, T, NDIM>::operator()( \
202 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
203 typename TTypes<T, NDIM>::ConstTensor input, \
204 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
205 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
206 extern template struct Slice<GPUDevice, T, NDIM>; \
207 template <> \
208 void StridedSliceGrad<GPUDevice, T, NDIM>::operator()( \
209 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
210 typename TTypes<T, NDIM>::ConstTensor input, \
211 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \
212 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \
213 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \
214 extern template struct StridedSliceGrad<GPUDevice, T, NDIM>; \
215 template <> \
216 void StridedSliceAssign<GPUDevice, T, NDIM>::operator()( \
217 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
218 typename TTypes<T, NDIM>::ConstTensor input, \
219 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \
220 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \
221 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides, \
222 const StridedSliceAssignBCast& bcast); \
223 extern template struct StridedSliceAssign<GPUDevice, T, NDIM>; \
224 } // namespace functor
225#define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) \
226 namespace functor { \
227 template <> \
228 void StridedSliceAssignScalar<GPUDevice, T>::operator()( \
229 const GPUDevice& d, typename TTypes<T, 1>::Tensor output, \
230 typename TTypes<T, 1>::ConstTensor input); \
231 extern template struct StridedSliceAssignScalar<GPUDevice, T>; \
232 } // namespace functor
233
234// Dimension 0 only instantiates some functors. So we only need
235// to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY
236#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
237#if STRIDED_SLICE_INSTANTIATE_DIM == 0
238#define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)
239#else
240#define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)
241#endif
242#else
243#define PREVENT_INSTANTIATE(T, NDIM)
244#endif
245
246#define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) \
247 template void HandleStridedSliceCase<DEVICE, T, DIM>( \
248 OpKernelContext * context, const gtl::ArraySlice<int64_t>& begin, \
249 const gtl::ArraySlice<int64_t>& end, \
250 const gtl::ArraySlice<int64_t>& strides, \
251 const TensorShape& processing_shape, bool is_simple_slice, \
252 Tensor* result); \
253 template void HandleStridedSliceGradCase<DEVICE, T, DIM>( \
254 OpKernelContext * context, const gtl::ArraySlice<int64_t>& begin, \
255 const gtl::ArraySlice<int64_t>& end, \
256 const gtl::ArraySlice<int64_t>& strides, \
257 const TensorShape& processing_shape, bool is_simple_slice, \
258 Tensor* result);
259
260#define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
261 template class HandleStridedSliceAssignCase<DEVICE, T, DIM>;
262
263// Only some kernels need to be instantiated on dim 0.
264#if STRIDED_SLICE_INSTANTIATE_DIM == 0
265#define INSTANTIATE(DEVICE, T, DIM) \
266 INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM)
267#else
268#define INSTANTIATE(DEVICE, T, DIM) \
269 INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
270 INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)
271#endif
272
273#define DECLARE_FOR_N_CPU(T) \
274 INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
275
276#define PREVENT_FOR_N_GPU(T) \
277 PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
278
279#define DECLARE_FOR_N_GPU(T) \
280 INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
281
282#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
283TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU);
284TF_CALL_COMPLEX_TYPES(PREVENT_FOR_N_GPU);
285
286TF_CALL_INTEGRAL_TYPES(DECLARE_FOR_N_GPU);
287TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
288#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
289
290TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
291TF_CALL_QUANTIZED_TYPES(DECLARE_FOR_N_CPU);
292
293#undef INSTANTIATE
294#undef DECLARE_FOR_N_CPU
295#undef DECLARE_FOR_N_GPU
296
297} // end namespace tensorflow
298
299#endif // END STRIDED_SLICE_INSTANTIATE_DIM
300#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
301