1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_ |
18 | |
19 | // Functor definition for StridedSliceOp, must be compilable by nvcc. |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/register_types_traits.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/variant.h" |
28 | #include "tensorflow/core/framework/variant_encode_decode.h" |
29 | #include "tensorflow/core/kernels/dense_update_functor.h" |
30 | #include "tensorflow/core/kernels/ops_util.h" |
31 | #include "tensorflow/core/kernels/slice_op.h" |
32 | #include "tensorflow/core/kernels/strided_slice_op.h" |
33 | #include "tensorflow/core/lib/core/status.h" |
34 | #include "tensorflow/core/lib/gtl/array_slice.h" |
35 | #include "tensorflow/core/platform/mem.h" |
36 | |
37 | namespace tensorflow { |
38 | |
39 | template <typename Device, typename T, int NDIM> |
40 | void HandleStridedSliceCase(OpKernelContext* context, |
41 | const gtl::ArraySlice<int64_t>& begin, |
42 | const gtl::ArraySlice<int64_t>& end, |
43 | const gtl::ArraySlice<int64_t>& strides, |
44 | const TensorShape& processing_shape, |
45 | bool is_simple_slice, Tensor* result); |
46 | |
47 | template <typename Device, typename T, int NDIM> |
48 | void HandleStridedSliceGradCase(OpKernelContext* context, |
49 | const gtl::ArraySlice<int64_t>& begin, |
50 | const gtl::ArraySlice<int64_t>& end, |
51 | const gtl::ArraySlice<int64_t>& strides, |
52 | const TensorShape& processing_shape, |
53 | bool is_simple_slice, Tensor* result); |
54 | |
55 | template <typename Device, typename T, int NDIM> |
56 | class HandleStridedSliceAssignCase { |
57 | public: |
58 | void operator()(OpKernelContext* context, |
59 | const gtl::ArraySlice<int64_t>& begin, |
60 | const gtl::ArraySlice<int64_t>& end, |
61 | const gtl::ArraySlice<int64_t>& strides, |
62 | const StridedSliceAssignBCast& bcast, Tensor* result); |
63 | }; |
64 | } // namespace tensorflow |
65 | |
66 | // The actual implementation. This is designed so multiple |
67 | // translation units can include this file in the form |
68 | // |
69 | // #define STRIDED_SLICE_INSTANTIATE_DIM 1 |
70 | // #include <thisfile> |
71 | // #undef STRIDED_SLICE_INSTANTIATE_DIM |
72 | // |
73 | #ifdef STRIDED_SLICE_INSTANTIATE_DIM |
74 | |
75 | namespace tensorflow { |
76 | |
77 | template <typename Device, typename T, int NDIM> |
78 | void HandleStridedSliceCase(OpKernelContext* context, |
79 | const gtl::ArraySlice<int64_t>& begin, |
80 | const gtl::ArraySlice<int64_t>& end, |
81 | const gtl::ArraySlice<int64_t>& strides, |
82 | const TensorShape& processing_shape, |
83 | bool is_simple_slice, Tensor* result) { |
84 | typedef typename proxy_type<Device, T>::type Proxy; |
85 | |
86 | gtl::InlinedVector<int64_t, 4> processing_dims = processing_shape.dim_sizes(); |
87 | if (is_simple_slice) { |
88 | Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; |
89 | Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di; |
90 | for (int i = 0; i < NDIM; ++i) { |
91 | begin_di[i] = begin[i]; |
92 | sizes_di[i] = end[i] - begin[i]; |
93 | } |
94 | functor::Slice<Device, Proxy, NDIM>()( |
95 | context->eigen_device<Device>(), |
96 | result->bit_casted_shaped<Proxy, NDIM>(processing_dims), |
97 | context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di); |
98 | } else { |
99 | Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; |
100 | Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; |
101 | Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di; |
102 | for (int i = 0; i < NDIM; ++i) { |
103 | begin_di[i] = begin[i]; |
104 | end_di[i] = end[i]; |
105 | strides_di[i] = strides[i]; |
106 | } |
107 | functor::StridedSlice<Device, Proxy, NDIM>()( |
108 | context->eigen_device<Device>(), |
109 | result->bit_casted_shaped<Proxy, NDIM>(processing_dims), |
110 | context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, end_di, |
111 | strides_di); |
112 | } |
113 | } |
114 | |
115 | template <typename Device, typename T, int NDIM> |
116 | void HandleStridedSliceGradCase(OpKernelContext* context, |
117 | const gtl::ArraySlice<int64_t>& begin, |
118 | const gtl::ArraySlice<int64_t>& end, |
119 | const gtl::ArraySlice<int64_t>& strides, |
120 | const TensorShape& processing_shape, |
121 | bool is_simple_slice, Tensor* result) { |
122 | gtl::InlinedVector<int64_t, 4> processing_dims = processing_shape.dim_sizes(); |
123 | |
124 | Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; |
125 | Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; |
126 | Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di; |
127 | for (int i = 0; i < NDIM; ++i) { |
128 | begin_di[i] = begin[i]; |
129 | end_di[i] = end[i]; |
130 | strides_di[i] = strides[i]; |
131 | } |
132 | |
133 | typedef typename proxy_type<Device, T>::type Proxy; |
134 | functor::StridedSliceGrad<Device, Proxy, NDIM>()( |
135 | context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(), |
136 | context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims), |
137 | begin_di, end_di, strides_di); |
138 | } |
139 | |
140 | template <typename Device, typename T, int NDIM> |
141 | void HandleStridedSliceAssignCase<Device, T, NDIM>::operator()( |
142 | OpKernelContext* context, const gtl::ArraySlice<int64_t>& begin, |
143 | const gtl::ArraySlice<int64_t>& end, |
144 | const gtl::ArraySlice<int64_t>& strides, |
145 | const StridedSliceAssignBCast& bcast, Tensor* result) { |
146 | typedef typename proxy_type<Device, T>::type Proxy; |
147 | Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; |
148 | Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; |
149 | Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di; |
150 | for (int i = 0; i < NDIM; ++i) { |
151 | begin_di[i] = begin[i]; |
152 | end_di[i] = end[i]; |
153 | strides_di[i] = strides[i]; |
154 | } |
155 | |
156 | constexpr int kRhsInput = 4; |
157 | const Tensor& input = context->input(kRhsInput); |
158 | functor::StridedSliceAssign<Device, Proxy, NDIM>()( |
159 | context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(), |
160 | input.bit_casted_shaped<Proxy, NDIM>(bcast.reshape()), begin_di, end_di, |
161 | strides_di, bcast); |
162 | } |
163 | |
164 | template <typename Device, typename T> |
165 | class HandleStridedSliceAssignCase<Device, T, 0> { |
166 | public: |
167 | enum { NDIM_PROXY = 1 }; |
168 | void operator()(OpKernelContext* context, |
169 | const gtl::ArraySlice<int64_t>& begin, |
170 | const gtl::ArraySlice<int64_t>& end, |
171 | const gtl::ArraySlice<int64_t>& strides, |
172 | const StridedSliceAssignBCast& bcast, Tensor* result) { |
173 | gtl::InlinedVector<int64_t, 1> processing_dims(1); |
174 | processing_dims[0] = 1; |
175 | |
176 | typedef typename proxy_type<Device, T>::type Proxy; |
177 | functor::StridedSliceAssignScalar<Device, Proxy>()( |
178 | context->eigen_device<Device>(), |
179 | result->bit_casted_shaped<Proxy, 1>(processing_dims), |
180 | context->input(4).bit_casted_shaped<Proxy, 1>(processing_dims)); |
181 | } |
182 | }; |
183 | |
184 | // NOTE(aselle): according to bsteiner, we need this because otherwise |
185 | // nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu |
186 | // handles instantiates externally. It is important that this is done |
187 | // before the HandleXXCase's are instantiated to avoid duplicate |
188 | // specialization errors. |
189 | |
190 | #define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) \ |
191 | namespace functor { \ |
192 | template <> \ |
193 | void StridedSlice<GPUDevice, T, NDIM>::operator()( \ |
194 | const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ |
195 | typename TTypes<T, NDIM>::ConstTensor input, \ |
196 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \ |
197 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \ |
198 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \ |
199 | extern template struct StridedSlice<GPUDevice, T, NDIM>; \ |
200 | template <> \ |
201 | void Slice<GPUDevice, T, NDIM>::operator()( \ |
202 | const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ |
203 | typename TTypes<T, NDIM>::ConstTensor input, \ |
204 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ |
205 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ |
206 | extern template struct Slice<GPUDevice, T, NDIM>; \ |
207 | template <> \ |
208 | void StridedSliceGrad<GPUDevice, T, NDIM>::operator()( \ |
209 | const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ |
210 | typename TTypes<T, NDIM>::ConstTensor input, \ |
211 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \ |
212 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \ |
213 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \ |
214 | extern template struct StridedSliceGrad<GPUDevice, T, NDIM>; \ |
215 | template <> \ |
216 | void StridedSliceAssign<GPUDevice, T, NDIM>::operator()( \ |
217 | const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ |
218 | typename TTypes<T, NDIM>::ConstTensor input, \ |
219 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \ |
220 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \ |
221 | const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides, \ |
222 | const StridedSliceAssignBCast& bcast); \ |
223 | extern template struct StridedSliceAssign<GPUDevice, T, NDIM>; \ |
224 | } // namespace functor |
225 | #define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) \ |
226 | namespace functor { \ |
227 | template <> \ |
228 | void StridedSliceAssignScalar<GPUDevice, T>::operator()( \ |
229 | const GPUDevice& d, typename TTypes<T, 1>::Tensor output, \ |
230 | typename TTypes<T, 1>::ConstTensor input); \ |
231 | extern template struct StridedSliceAssignScalar<GPUDevice, T>; \ |
232 | } // namespace functor |
233 | |
234 | // Dimension 0 only instantiates some functors. So we only need |
235 | // to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY |
236 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
237 | #if STRIDED_SLICE_INSTANTIATE_DIM == 0 |
238 | #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) |
239 | #else |
240 | #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) |
241 | #endif |
242 | #else |
243 | #define PREVENT_INSTANTIATE(T, NDIM) |
244 | #endif |
245 | |
246 | #define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) \ |
247 | template void HandleStridedSliceCase<DEVICE, T, DIM>( \ |
248 | OpKernelContext * context, const gtl::ArraySlice<int64_t>& begin, \ |
249 | const gtl::ArraySlice<int64_t>& end, \ |
250 | const gtl::ArraySlice<int64_t>& strides, \ |
251 | const TensorShape& processing_shape, bool is_simple_slice, \ |
252 | Tensor* result); \ |
253 | template void HandleStridedSliceGradCase<DEVICE, T, DIM>( \ |
254 | OpKernelContext * context, const gtl::ArraySlice<int64_t>& begin, \ |
255 | const gtl::ArraySlice<int64_t>& end, \ |
256 | const gtl::ArraySlice<int64_t>& strides, \ |
257 | const TensorShape& processing_shape, bool is_simple_slice, \ |
258 | Tensor* result); |
259 | |
260 | #define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \ |
261 | template class HandleStridedSliceAssignCase<DEVICE, T, DIM>; |
262 | |
263 | // Only some kernels need to be instantiated on dim 0. |
264 | #if STRIDED_SLICE_INSTANTIATE_DIM == 0 |
265 | #define INSTANTIATE(DEVICE, T, DIM) \ |
266 | INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) |
267 | #else |
268 | #define INSTANTIATE(DEVICE, T, DIM) \ |
269 | INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \ |
270 | INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) |
271 | #endif |
272 | |
273 | #define DECLARE_FOR_N_CPU(T) \ |
274 | INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) |
275 | |
276 | #define PREVENT_FOR_N_GPU(T) \ |
277 | PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM) |
278 | |
279 | #define DECLARE_FOR_N_GPU(T) \ |
280 | INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) |
281 | |
282 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
283 | TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU); |
284 | TF_CALL_COMPLEX_TYPES(PREVENT_FOR_N_GPU); |
285 | |
286 | TF_CALL_INTEGRAL_TYPES(DECLARE_FOR_N_GPU); |
287 | TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU); |
288 | #endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
289 | |
290 | TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); |
291 | TF_CALL_QUANTIZED_TYPES(DECLARE_FOR_N_CPU); |
292 | |
293 | #undef INSTANTIATE |
294 | #undef DECLARE_FOR_N_CPU |
295 | #undef DECLARE_FOR_N_GPU |
296 | |
297 | } // end namespace tensorflow |
298 | |
299 | #endif // END STRIDED_SLICE_INSTANTIATE_DIM |
300 | #endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_ |
301 | |