1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21#define EIGEN_USE_GPU
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include "tensorflow/core/kernels/reverse_sequence_op.h"
25
26#include <memory>
27#include <vector>
28#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/framework/tensor_types.h"
34#include "tensorflow/core/framework/types.h"
35#include "tensorflow/core/platform/logging.h"
36#include "tensorflow/core/platform/macros.h"
37
38namespace tensorflow {
39
40typedef Eigen::ThreadPoolDevice CPUDevice;
41typedef Eigen::GpuDevice GPUDevice;
42
43template <typename Device, typename Tlen>
44void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
45 const Tensor& input = context->input(0);
46 const Tensor& seq_lengths = context->input(1);
47
48 auto seq_lens_t = seq_lengths.vec<Tlen>();
49
50 std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
51
52 // Copy seq_len info down for validity checks
53 context->eigen_device<Device>().memcpyDeviceToHost(
54 seq_lens_vec.data(), seq_lens_t.data(), sizeof(Tlen) * seq_lens_t.size());
55
56 OP_REQUIRES(context, batch_dim != seq_dim,
57 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
58 OP_REQUIRES(context, seq_dim < input.dims(),
59 errors::InvalidArgument("seq_dim must be < input rank", " ( ",
60 seq_dim, " vs. ", input.dims(), ")"));
61 OP_REQUIRES(context, batch_dim < input.dims(),
62 errors::InvalidArgument("batch_dim must be < input rank", " ( ",
63 batch_dim, " vs. ", input.dims(), ")"));
64 OP_REQUIRES(
65 context, seq_lengths.NumElements() == input.dim_size(batch_dim),
66 errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
67 "), ", "(", seq_lengths.NumElements(), " vs. ",
68 input.dim_size(batch_dim), ")"));
69
70 for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
71 OP_REQUIRES(context, seq_lens_vec[d] >= 0,
72 errors::InvalidArgument("seq_lens(", d, ") < 0"));
73 OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim),
74 errors::InvalidArgument("seq_lens(", d, ") > input.dims(",
75 seq_dim, ")"));
76 }
77}
78
79void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
80 const Tensor& input = context->input(0);
81 const Tensor& seq_lengths = context->input(1);
82
83 OP_REQUIRES(context, batch_dim != seq_dim,
84 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
85 OP_REQUIRES(context, seq_dim < input.dims(),
86 errors::InvalidArgument("seq_dim must be < input rank", " ( ",
87 seq_dim, " vs. ", input.dims(), ")"));
88 OP_REQUIRES(context, batch_dim < input.dims(),
89 errors::InvalidArgument("batch_dim must be < input rank", " ( ",
90 batch_dim, " vs. ", input.dims(), ")"));
91
92 OP_REQUIRES(
93 context, seq_lengths.NumElements() == input.dim_size(batch_dim),
94 errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim,
95 "), ", "(", seq_lengths.NumElements(), " vs. ",
96 input.dim_size(batch_dim), ")"));
97}
98
99template <>
100void CheckErrors<GPUDevice, int32>(OpKernelContext* context, int batch_dim,
101 int seq_dim) {
102 CheckErrorsGPU(context, batch_dim, seq_dim);
103}
104
105template <>
106void CheckErrors<GPUDevice, int64_t>(OpKernelContext* context, int batch_dim,
107 int seq_dim) {
108 CheckErrorsGPU(context, batch_dim, seq_dim);
109}
110
111template <typename Device, typename T, typename Tlen>
112class ReverseSequenceOp : public OpKernel {
113 public:
114 explicit ReverseSequenceOp(OpKernelConstruction* context)
115 : OpKernel(context) {
116 OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
117 OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
118 OP_REQUIRES(context, batch_dim_ >= 0,
119 errors::InvalidArgument("Invalid batch_dim ", batch_dim_));
120 OP_REQUIRES(context, seq_dim_ >= 0,
121 errors::InvalidArgument("Invalid seq_dim ", seq_dim_));
122 }
123
124 void Compute(OpKernelContext* context) override {
125 const Tensor& input = context->input(0);
126 const Tensor& seq_lengths = context->input(1);
127
128 // Preliminary validation of sizes.
129 OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lengths.shape()),
130 errors::InvalidArgument("seq_lengths must be 1-dim, not ",
131 seq_lengths.dims()));
132
133 auto seq_lens_t = seq_lengths.vec<Tlen>();
134
135 CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
136 if (!context->status().ok()) return;
137
138 const int input_dims = input.dims();
139
140 Tensor* output = nullptr;
141 OP_REQUIRES_OK(context,
142 context->allocate_output(0, input.shape(), &output));
143
144#define HANDLE_DIM(NDIM) \
145 case NDIM: \
146 functor::ReverseSequence<Device, T, Tlen, NDIM>::Compute( \
147 context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \
148 seq_dim_, seq_lens_t, output->tensor<T, NDIM>()); \
149 break;
150
151 switch (input_dims) {
152 HANDLE_DIM(2);
153 HANDLE_DIM(3);
154 HANDLE_DIM(4);
155 HANDLE_DIM(5);
156
157 default:
158 OP_REQUIRES(context, false,
159 errors::InvalidArgument(
160 "ReverseSequenceOp : Unhandled input dimensions: ",
161 input_dims));
162 }
163 }
164
165 private:
166 int32 batch_dim_;
167 int32 seq_dim_;
168
169 TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
170};
171
172#define REGISTER_REVERSE_SEQUENCE(type, len_type) \
173 REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \
174 .Device(DEVICE_CPU) \
175 .TypeConstraint<type>("T") \
176 .TypeConstraint<len_type>("Tlen"), \
177 ReverseSequenceOp<CPUDevice, type, len_type>);
178
179#define REGISTER_REVERSE_SEQUENCE_LEN(type) \
180 REGISTER_REVERSE_SEQUENCE(type, int32); \
181 REGISTER_REVERSE_SEQUENCE(type, int64_t);
182
183TF_CALL_POD_STRING_TYPES(REGISTER_REVERSE_SEQUENCE_LEN);
184
185#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
186
187// Forward declarations of the functor specializations for GPU.
188namespace functor {
189#define DECLARE_GPU_SPEC(T, Tlen, Dims) \
190 template <> \
191 void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \
192 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
193 int32 batch_dim, int32 seq_dim, \
194 typename TTypes<Tlen>::ConstVec seq_lengths, \
195 typename TTypes<T, Dims>::Tensor output); \
196 extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
197
198#define DECLARE_GPU_SPEC_LEN(T, Dims) \
199 DECLARE_GPU_SPEC(T, int32, Dims); \
200 DECLARE_GPU_SPEC(T, int64_t, Dims);
201
202#define DECLARE_GPU_SPECS(T) \
203 DECLARE_GPU_SPEC_LEN(T, 2); \
204 DECLARE_GPU_SPEC_LEN(T, 3); \
205 DECLARE_GPU_SPEC_LEN(T, 4); \
206 DECLARE_GPU_SPEC_LEN(T, 5);
207
208TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
209TF_CALL_bool(DECLARE_GPU_SPECS);
210
211} // namespace functor
212
213// Registration of the GPU implementations.
214#define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type) \
215 REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \
216 .Device(DEVICE_GPU) \
217 .TypeConstraint<type>("T") \
218 .TypeConstraint<len_type>("Tlen"), \
219 ReverseSequenceOp<GPUDevice, type, len_type>);
220
221#define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \
222 REGISTER_REVERSE_SEQUENCE_GPU(type, int32); \
223 REGISTER_REVERSE_SEQUENCE_GPU(type, int64_t);
224
225TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
226TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
227
228#undef REGISTER_REVERSE_SEQUENCE_GPU
229
230#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
231
232} // namespace tensorflow
233