1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
17#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
18
19// Functor definition for StridedSliceOp, must be compilable by nvcc.
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/resource_handle.h"
23#include "tensorflow/core/framework/tensor_types.h"
24#include "tensorflow/core/framework/variant_encode_decode.h"
25#include "tensorflow/core/platform/types.h"
26#include "tensorflow/core/util/strided_slice_op.h"
27
28namespace tensorflow {
29namespace functor {
30
31template <typename Device, typename T, int NDIMS>
32struct StridedSlice {
33 void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
34 typename TTypes<T, NDIMS>::ConstTensor input,
35 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& start_indices,
36 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& stop_indices,
37 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& strides) {
38 MaybeWith32BitIndexing<Device>(
39 [&](auto output32, auto input32, const auto& start_indices32,
40 const auto& stop_indices32, const auto& strides32) {
41 output32.device(d) =
42 input32.stridedSlice(start_indices32, stop_indices32, strides32);
43 },
44 output, input, start_indices, stop_indices, strides);
45 }
46};
47
48template <typename T, int NDIMS, typename Device>
49struct InitOutput {
50 static void run(const Device& d, typename TTypes<T, NDIMS>::Tensor output) {
51 output.device(d) = output.constant(T(0));
52 }
53};
54
55template <int NDIMS, typename Device>
56struct InitOutput<ResourceHandle, NDIMS, Device> {
57 static void run(const Device& d,
58 typename TTypes<ResourceHandle, NDIMS>::Tensor output) {
59 output.device(d) = output.constant(ResourceHandle());
60 }
61};
62
63template <int NDIMS, typename Device>
64struct InitOutput<tstring, NDIMS, Device> {
65 static void run(const Device& d,
66 typename TTypes<tstring, NDIMS>::Tensor output) {
67 output.device(d) = output.constant(tstring());
68 }
69};
70
71template <typename Device, typename T, int NDIMS>
72struct StridedSliceGrad {
73 void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
74 typename TTypes<T, NDIMS>::ConstTensor input,
75 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& start_indices,
76 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& stop_indices,
77 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& strides) {
78 InitOutput<T, NDIMS, Device>::run(d, output);
79 MaybeWith32BitIndexing<Device>(
80 [&](auto output32, const auto& start_indices32,
81 const auto& stop_indices32, const auto& strides32) {
82 output32.stridedSlice(start_indices32, stop_indices32, strides32)
83 .device(d) = input;
84 },
85 output, start_indices, stop_indices, strides);
86 }
87};
88
89template <typename Device, typename T, int NDIMS>
90struct StridedSliceAssign {
91 void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
92 typename TTypes<T, NDIMS>::ConstTensor input,
93 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& start_indices,
94 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& stop_indices,
95 const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& strides,
96 const StridedSliceAssignBCast& bcast) {
97 MaybeWith32BitIndexing<Device>(
98 [&](auto output32, auto input32, const auto& start_indices32,
99 const auto& stop_indices32, const auto& strides32) {
100 if (bcast.IsBroadcastingRequired()) {
101 output32.stridedSlice(start_indices32, stop_indices32, strides32)
102 .device(d) = input32.broadcast(bcast.bcast());
103 } else {
104 output32.stridedSlice(start_indices32, stop_indices32, strides32)
105 .device(d) = input32;
106 }
107 },
108 output, input, start_indices, stop_indices, strides);
109 }
110};
111
112template <typename Device, typename T>
113struct StridedSliceAssignScalar {
114 void operator()(const Device& d, typename TTypes<T, 1>::Tensor output,
115 typename TTypes<T, 1>::ConstTensor input) {
116 output.device(d) = input;
117 }
118};
119
120} // namespace functor
121} // namespace tensorflow
122
123#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
124