1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | #define EIGEN_USE_GPU |
22 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
23 | |
24 | #include "tensorflow/core/kernels/reverse_sequence_op.h" |
25 | |
26 | #include <memory> |
27 | #include <vector> |
28 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
29 | #include "tensorflow/core/framework/op_kernel.h" |
30 | #include "tensorflow/core/framework/register_types.h" |
31 | #include "tensorflow/core/framework/tensor.h" |
32 | #include "tensorflow/core/framework/tensor_shape.h" |
33 | #include "tensorflow/core/framework/tensor_types.h" |
34 | #include "tensorflow/core/framework/types.h" |
35 | #include "tensorflow/core/platform/logging.h" |
36 | #include "tensorflow/core/platform/macros.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | typedef Eigen::ThreadPoolDevice CPUDevice; |
41 | typedef Eigen::GpuDevice GPUDevice; |
42 | |
43 | template <typename Device, typename Tlen> |
44 | void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { |
45 | const Tensor& input = context->input(0); |
46 | const Tensor& seq_lengths = context->input(1); |
47 | |
48 | auto seq_lens_t = seq_lengths.vec<Tlen>(); |
49 | |
50 | std::vector<Tlen> seq_lens_vec(seq_lens_t.size()); |
51 | |
52 | // Copy seq_len info down for validity checks |
53 | context->eigen_device<Device>().memcpyDeviceToHost( |
54 | seq_lens_vec.data(), seq_lens_t.data(), sizeof(Tlen) * seq_lens_t.size()); |
55 | |
56 | OP_REQUIRES(context, batch_dim != seq_dim, |
57 | errors::InvalidArgument("batch_dim == seq_dim == " , seq_dim)); |
58 | OP_REQUIRES(context, seq_dim < input.dims(), |
59 | errors::InvalidArgument("seq_dim must be < input rank" , " ( " , |
60 | seq_dim, " vs. " , input.dims(), ")" )); |
61 | OP_REQUIRES(context, batch_dim < input.dims(), |
62 | errors::InvalidArgument("batch_dim must be < input rank" , " ( " , |
63 | batch_dim, " vs. " , input.dims(), ")" )); |
64 | OP_REQUIRES( |
65 | context, seq_lengths.NumElements() == input.dim_size(batch_dim), |
66 | errors::InvalidArgument("Length of seq_lengths != input.dims(" , batch_dim, |
67 | "), " , "(" , seq_lengths.NumElements(), " vs. " , |
68 | input.dim_size(batch_dim), ")" )); |
69 | |
70 | for (size_t d = 0; d < seq_lens_vec.size(); ++d) { |
71 | OP_REQUIRES(context, seq_lens_vec[d] >= 0, |
72 | errors::InvalidArgument("seq_lens(" , d, ") < 0" )); |
73 | OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim), |
74 | errors::InvalidArgument("seq_lens(" , d, ") > input.dims(" , |
75 | seq_dim, ")" )); |
76 | } |
77 | } |
78 | |
79 | void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) { |
80 | const Tensor& input = context->input(0); |
81 | const Tensor& seq_lengths = context->input(1); |
82 | |
83 | OP_REQUIRES(context, batch_dim != seq_dim, |
84 | errors::InvalidArgument("batch_dim == seq_dim == " , seq_dim)); |
85 | OP_REQUIRES(context, seq_dim < input.dims(), |
86 | errors::InvalidArgument("seq_dim must be < input rank" , " ( " , |
87 | seq_dim, " vs. " , input.dims(), ")" )); |
88 | OP_REQUIRES(context, batch_dim < input.dims(), |
89 | errors::InvalidArgument("batch_dim must be < input rank" , " ( " , |
90 | batch_dim, " vs. " , input.dims(), ")" )); |
91 | |
92 | OP_REQUIRES( |
93 | context, seq_lengths.NumElements() == input.dim_size(batch_dim), |
94 | errors::InvalidArgument("Length of seq_lengths != input.dims(" , batch_dim, |
95 | "), " , "(" , seq_lengths.NumElements(), " vs. " , |
96 | input.dim_size(batch_dim), ")" )); |
97 | } |
98 | |
99 | template <> |
100 | void CheckErrors<GPUDevice, int32>(OpKernelContext* context, int batch_dim, |
101 | int seq_dim) { |
102 | CheckErrorsGPU(context, batch_dim, seq_dim); |
103 | } |
104 | |
105 | template <> |
106 | void CheckErrors<GPUDevice, int64_t>(OpKernelContext* context, int batch_dim, |
107 | int seq_dim) { |
108 | CheckErrorsGPU(context, batch_dim, seq_dim); |
109 | } |
110 | |
111 | template <typename Device, typename T, typename Tlen> |
112 | class ReverseSequenceOp : public OpKernel { |
113 | public: |
114 | explicit ReverseSequenceOp(OpKernelConstruction* context) |
115 | : OpKernel(context) { |
116 | OP_REQUIRES_OK(context, context->GetAttr("batch_dim" , &batch_dim_)); |
117 | OP_REQUIRES_OK(context, context->GetAttr("seq_dim" , &seq_dim_)); |
118 | OP_REQUIRES(context, batch_dim_ >= 0, |
119 | errors::InvalidArgument("Invalid batch_dim " , batch_dim_)); |
120 | OP_REQUIRES(context, seq_dim_ >= 0, |
121 | errors::InvalidArgument("Invalid seq_dim " , seq_dim_)); |
122 | } |
123 | |
124 | void Compute(OpKernelContext* context) override { |
125 | const Tensor& input = context->input(0); |
126 | const Tensor& seq_lengths = context->input(1); |
127 | |
128 | // Preliminary validation of sizes. |
129 | OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lengths.shape()), |
130 | errors::InvalidArgument("seq_lengths must be 1-dim, not " , |
131 | seq_lengths.dims())); |
132 | |
133 | auto seq_lens_t = seq_lengths.vec<Tlen>(); |
134 | |
135 | CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_); |
136 | if (!context->status().ok()) return; |
137 | |
138 | const int input_dims = input.dims(); |
139 | |
140 | Tensor* output = nullptr; |
141 | OP_REQUIRES_OK(context, |
142 | context->allocate_output(0, input.shape(), &output)); |
143 | |
144 | #define HANDLE_DIM(NDIM) \ |
145 | case NDIM: \ |
146 | functor::ReverseSequence<Device, T, Tlen, NDIM>::Compute( \ |
147 | context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \ |
148 | seq_dim_, seq_lens_t, output->tensor<T, NDIM>()); \ |
149 | break; |
150 | |
151 | switch (input_dims) { |
152 | HANDLE_DIM(2); |
153 | HANDLE_DIM(3); |
154 | HANDLE_DIM(4); |
155 | HANDLE_DIM(5); |
156 | |
157 | default: |
158 | OP_REQUIRES(context, false, |
159 | errors::InvalidArgument( |
160 | "ReverseSequenceOp : Unhandled input dimensions: " , |
161 | input_dims)); |
162 | } |
163 | } |
164 | |
165 | private: |
166 | int32 batch_dim_; |
167 | int32 seq_dim_; |
168 | |
169 | TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp); |
170 | }; |
171 | |
172 | #define REGISTER_REVERSE_SEQUENCE(type, len_type) \ |
173 | REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \ |
174 | .Device(DEVICE_CPU) \ |
175 | .TypeConstraint<type>("T") \ |
176 | .TypeConstraint<len_type>("Tlen"), \ |
177 | ReverseSequenceOp<CPUDevice, type, len_type>); |
178 | |
179 | #define REGISTER_REVERSE_SEQUENCE_LEN(type) \ |
180 | REGISTER_REVERSE_SEQUENCE(type, int32); \ |
181 | REGISTER_REVERSE_SEQUENCE(type, int64_t); |
182 | |
183 | TF_CALL_POD_STRING_TYPES(REGISTER_REVERSE_SEQUENCE_LEN); |
184 | |
185 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
186 | |
187 | // Forward declarations of the functor specializations for GPU. |
188 | namespace functor { |
189 | #define DECLARE_GPU_SPEC(T, Tlen, Dims) \ |
190 | template <> \ |
191 | void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \ |
192 | const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ |
193 | int32 batch_dim, int32 seq_dim, \ |
194 | typename TTypes<Tlen>::ConstVec seq_lengths, \ |
195 | typename TTypes<T, Dims>::Tensor output); \ |
196 | extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>; |
197 | |
198 | #define DECLARE_GPU_SPEC_LEN(T, Dims) \ |
199 | DECLARE_GPU_SPEC(T, int32, Dims); \ |
200 | DECLARE_GPU_SPEC(T, int64_t, Dims); |
201 | |
202 | #define DECLARE_GPU_SPECS(T) \ |
203 | DECLARE_GPU_SPEC_LEN(T, 2); \ |
204 | DECLARE_GPU_SPEC_LEN(T, 3); \ |
205 | DECLARE_GPU_SPEC_LEN(T, 4); \ |
206 | DECLARE_GPU_SPEC_LEN(T, 5); |
207 | |
208 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); |
209 | TF_CALL_bool(DECLARE_GPU_SPECS); |
210 | |
211 | } // namespace functor |
212 | |
213 | // Registration of the GPU implementations. |
214 | #define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type) \ |
215 | REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \ |
216 | .Device(DEVICE_GPU) \ |
217 | .TypeConstraint<type>("T") \ |
218 | .TypeConstraint<len_type>("Tlen"), \ |
219 | ReverseSequenceOp<GPUDevice, type, len_type>); |
220 | |
221 | #define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \ |
222 | REGISTER_REVERSE_SEQUENCE_GPU(type, int32); \ |
223 | REGISTER_REVERSE_SEQUENCE_GPU(type, int64_t); |
224 | |
225 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN); |
226 | TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN); |
227 | |
228 | #undef REGISTER_REVERSE_SEQUENCE_GPU |
229 | |
230 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
231 | |
232 | } // namespace tensorflow |
233 | |