1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc
17#define EIGEN_USE_THREADS
18
19#include "tensorflow/core/kernels/reverse_op.h"
20#include <memory>
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/type_traits.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/util/work_sharder.h"
32
33namespace tensorflow {
34
35typedef Eigen::ThreadPoolDevice CPUDevice;
36typedef Eigen::GpuDevice GPUDevice;
37
38namespace {
39
40// Reverse rows (middle dimension) of a three dimensional tensor.
41// NUM_CHANNELS can be <= 0 to compute it dynamically from <input>
42// Otherwise, it must equal input.dim_size(2) and is used as a compile-time
43// constant.
44template <typename T, int NUM_CHANNELS>
45void ReverseRows(OpKernelContext* context, const Tensor& input,
46 Tensor* result) {
47 auto work = [&input, result](int64_t start, int64_t end) {
48 const int64_t inner_size =
49 NUM_CHANNELS > 0 ? NUM_CHANNELS : input.dim_size(2);
50 const int64_t middle_size = input.dim_size(1);
51 const int64_t row_size = inner_size * middle_size;
52 DCHECK_EQ(input.dim_size(2), inner_size);
53
54 const T* in_ptr = input.bit_casted_tensor<T, 3>().data();
55 T* out_ptr = result->bit_casted_tensor<T, 3>().data();
56
57 in_ptr += start * row_size;
58 out_ptr += start * row_size;
59
60 for (int outer_dim = start; outer_dim < end; ++outer_dim) {
61 out_ptr += row_size;
62 int remaining = middle_size;
63 while (remaining > 0) {
64 out_ptr -= inner_size;
65 memcpy(out_ptr, in_ptr, inner_size * sizeof(T));
66 in_ptr += inner_size;
67 --remaining;
68 }
69
70 out_ptr += row_size;
71 }
72 };
73
74 // Shard across outer dimension.
75 const int64_t N = input.dim_size(0);
76 const int64_t cost_per_unit = input.NumElements() / N;
77 auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
78 Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit,
79 std::move(work));
80}
81
82template <typename T>
83struct data_type_can_memcpy {
84 static constexpr bool value =
85 std::is_same<T, uint8>::value || std::is_same<T, int8>::value ||
86 std::is_same<T, bool>::value || std::is_same<T, uint16>::value ||
87 std::is_same<T, int16>::value || std::is_same<T, Eigen::half>::value ||
88 std::is_same<T, int32>::value || std::is_same<T, float>::value ||
89 std::is_same<T, int64_t>::value || std::is_same<T, double>::value ||
90 std::is_same<T, complex64>::value || std::is_same<T, complex128>::value;
91};
92
93template <typename T, int NUM_CHANNELS>
94typename std::enable_if<data_type_can_memcpy<T>::value>::type
95DoHandleReverseCase(OpKernelContext* context, const Tensor& input,
96 Tensor* result) {
97 if (sizeof(T) == 1) {
98 static_assert(sizeof(uint8) == 1, "uint8 must be 1 byte.");
99 ReverseRows<uint8, NUM_CHANNELS>(context, input, result);
100 } else if (sizeof(T) == 2) {
101 static_assert(sizeof(uint16) == 2, "uint16 must be 2 bytes");
102 ReverseRows<uint16, NUM_CHANNELS>(context, input, result);
103 } else if (sizeof(T) == 4) {
104 static_assert(sizeof(uint32) == 4, "uint32 must be 4 bytes");
105 ReverseRows<uint32, NUM_CHANNELS>(context, input, result);
106 } else if (sizeof(T) == 8) {
107 static_assert(sizeof(uint64) == 8, "uint64 must be 8 bytes");
108 ReverseRows<uint64, NUM_CHANNELS>(context, input, result);
109 } else if (sizeof(T) == 16) {
110 static_assert(sizeof(complex128) == 16, "complex128 must be 16 bytes");
111 ReverseRows<complex128, NUM_CHANNELS>(context, input, result);
112 } else {
113 context->CtxFailure(errors::InvalidArgument(DataTypeString(input.dtype()),
114 " has unexpected size of ",
115 sizeof(T), " bytes"));
116 }
117}
118
119template <typename T, int NUM_CHANNELS>
120typename std::enable_if<!data_type_can_memcpy<T>::value>::type
121DoHandleReverseCase(OpKernelContext* context, const Tensor& input,
122 Tensor* result) {}
123
124} // namespace
125
126template <typename Device, typename T, int NDIMS>
127void HandleReverseCase(OpKernelContext* context,
128 typename TTypes<bool, 1>::ConstTensor dims,
129 Tensor* result) {
130 const Tensor& input = context->input(0);
131
132 // Use optimized reverse if possible.
133 if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value &&
134 data_type_can_memcpy<T>::value && (!dims(0) && dims(1) && !dims(2))) {
135 if (input.dim_size(2) == 3) {
136 DoHandleReverseCase<T, 3>(context, input, result);
137 } else {
138 DoHandleReverseCase<T, -1>(context, input, result);
139 }
140 return;
141 }
142 typename Eigen::array<bool, NDIMS> axes_di;
143 for (int i = 0; i < NDIMS; i++) {
144 axes_di[i] = dims(i);
145 }
146 functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
147 input.tensor<T, NDIMS>(), axes_di,
148 result->tensor<T, NDIMS>());
149}
150
151template <typename Device, typename T>
152class ReverseOp : public OpKernel {
153 public:
154 explicit ReverseOp(OpKernelConstruction* context) : OpKernel(context) {}
155
156 void Compute(OpKernelContext* context) override {
157 const Tensor& input = context->input(0);
158 // If input is provided, check to make sure the first dimension is valid.
159 if (input.dims() > 0) {
160 OP_REQUIRES(
161 context, input.dim_size(0) != 0,
162 errors::InvalidArgument("Invalid input first dimension. Found 0."));
163 }
164 const Tensor& dims = context->input(1);
165
166 if (TensorShapeUtils::IsScalar(input.shape())) {
167 context->set_output(0, input);
168 } else {
169 const int input_dims = input.dims();
170 OP_REQUIRES(context, TensorShapeUtils::IsVector(dims.shape()),
171 errors::InvalidArgument("'dims' must be 1-dimension, not ",
172 dims.dims()));
173
174 OP_REQUIRES(
175 context, input_dims == dims.dim_size(0),
176 errors::InvalidArgument(
177 "'dims' must have the same number of values as 'input' has "
178 "dimensions. 'input' has ",
179 input_dims, "'dims' has ", dims.dim_size(0), " values"));
180 OP_REQUIRES(context, input_dims <= 8,
181 errors::Unimplemented(
182 "reverse is not implemented for tensors of rank > 8."));
183
184 Tensor* output = nullptr;
185 OP_REQUIRES_OK(context,
186 context->allocate_output(0, input.shape(), &output));
187
188#define HANDLE_REVERSE(NDIMS) \
189 case NDIMS: \
190 HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
191 return;
192
193 switch (input_dims) {
194 HANDLE_REVERSE(0);
195 HANDLE_REVERSE(1);
196 HANDLE_REVERSE(2);
197 HANDLE_REVERSE(3);
198 HANDLE_REVERSE(4);
199 HANDLE_REVERSE(5);
200 HANDLE_REVERSE(6);
201 HANDLE_REVERSE(7);
202 HANDLE_REVERSE(8);
203 }
204#undef HANDLE_REVERSE
205 }
206 }
207};
208
209template <typename Device, typename T, int NDIMS>
210void HandleReverseV2Case(OpKernelContext* context,
211 const gtl::ArraySlice<bool> axes, Tensor* result) {
212 const Tensor& input = context->input(0);
213
214 // Use optimized reverse if possible.
215 if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value &&
216 data_type_can_memcpy<T>::value && (!axes[0] && axes[1] && !axes[2])) {
217 if (input.dim_size(2) == 3) {
218 DoHandleReverseCase<T, 3>(context, input, result);
219 } else {
220 DoHandleReverseCase<T, -1>(context, input, result);
221 }
222 return;
223 }
224
225 typename Eigen::array<bool, NDIMS> axes_di;
226 for (int i = 0; i < NDIMS; i++) {
227 axes_di[i] = axes[i];
228 }
229 functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
230 input.tensor<T, NDIMS>(), axes_di,
231 result->tensor<T, NDIMS>());
232}
233
234template <typename Device, typename T, typename Tidx>
235class ReverseV2Op : public OpKernel {
236 public:
237 explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {}
238
239 void Compute(OpKernelContext* context) override {
240 const Tensor& input = context->input(0);
241 const Tensor& sparse_dims = context->input(1);
242
243 if (TensorShapeUtils::IsScalar(input.shape()) || input.NumElements() == 0) {
244 context->set_output(0, input);
245 } else {
246 const int input_dims = input.dims();
247 const TensorShape& sparse_dims_shape = sparse_dims.shape();
248 const auto& axes_sparse_flat = sparse_dims.flat<Tidx>();
249
250 OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape),
251 errors::InvalidArgument("'dims' must be 1-dimension, not ",
252 sparse_dims.dims()));
253 gtl::InlinedVector<bool, 8> axes_dense(input_dims, false);
254 for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) {
255 Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy));
256 Tidx canonical_axis = axis < 0 ? input_dims + axis : axis;
257 OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims,
258 errors::InvalidArgument("'axis'[", dummy, "] = ", axis,
259 " is out of valid range [", 0, ", ",
260 input_dims - 1));
261 OP_REQUIRES(context, !axes_dense[canonical_axis],
262 errors::InvalidArgument("axis ", canonical_axis,
263 " specified more than once."));
264 axes_dense[canonical_axis] = true;
265 }
266
267 OP_REQUIRES(context, input_dims <= 8,
268 errors::Unimplemented(
269 "reverse is not implemented for tensors of rank > 8."));
270
271 Tensor* output = nullptr;
272 OP_REQUIRES_OK(context,
273 context->allocate_output(0, input.shape(), &output));
274
275 // TODO(cwhipkey): we can do dimension folding to reduce, e.g., a reverse
276 // of a single dimension to the dims=3 or dims=2 case, regardless of the
277 // number of dimensions in the tensor. This would let some ops use faster
278 // lower-dimension code (and use optimized versions).
279
280#define HANDLE_REVERSE(NDIMS) \
281 case NDIMS: \
282 HandleReverseV2Case<Device, T, NDIMS>(context, axes_dense, output); \
283 return;
284
285 switch (input_dims) {
286 HANDLE_REVERSE(0);
287 HANDLE_REVERSE(1);
288 HANDLE_REVERSE(2);
289 HANDLE_REVERSE(3);
290 HANDLE_REVERSE(4);
291 HANDLE_REVERSE(5);
292 HANDLE_REVERSE(6);
293 HANDLE_REVERSE(7);
294 HANDLE_REVERSE(8);
295 }
296#undef HANDLE_REVERSE
297 }
298 }
299};
300
301#define REGISTER_KERNELS(T) \
302 REGISTER_KERNEL_BUILDER(Name("Reverse") \
303 .Device(DEVICE_CPU) \
304 .TypeConstraint<T>("T") \
305 .HostMemory("dims"), \
306 ReverseOp<CPUDevice, T>) \
307 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
308 .Device(DEVICE_CPU) \
309 .TypeConstraint<T>("T") \
310 .TypeConstraint<int32>("Tidx") \
311 .HostMemory("axis"), \
312 ReverseV2Op<CPUDevice, T, int32>) \
313 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
314 .Device(DEVICE_CPU) \
315 .TypeConstraint<T>("T") \
316 .TypeConstraint<int64_t>("Tidx") \
317 .HostMemory("axis"), \
318 ReverseV2Op<CPUDevice, T, int64>)
319TF_CALL_POD_TYPES(REGISTER_KERNELS);
320TF_CALL_tstring(REGISTER_KERNELS);
321#undef REGISTER_KERNELS
322
323#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
324
325// Forward declarations of the function specializations for GPU (to prevent
326// building the GPU versions here, they will be built compiling _gpu.cu.cc).
327namespace functor {
328#define DECLARE_GPU_SPEC_DIM(T, DIM) \
329 template <> \
330 void Reverse<GPUDevice, T, DIM>::operator()( \
331 const GPUDevice& d, typename TTypes<T, DIM>::ConstTensor input, \
332 const Eigen::array<bool, DIM>& reverse_dims, \
333 typename TTypes<T, DIM>::Tensor output); \
334 extern template struct Reverse<GPUDevice, T, DIM>;
335#define DECLARE_GPU_SPEC(T) \
336 DECLARE_GPU_SPEC_DIM(T, 0) \
337 DECLARE_GPU_SPEC_DIM(T, 1) \
338 DECLARE_GPU_SPEC_DIM(T, 2) \
339 DECLARE_GPU_SPEC_DIM(T, 3) \
340 DECLARE_GPU_SPEC_DIM(T, 4) \
341 DECLARE_GPU_SPEC_DIM(T, 5) \
342 DECLARE_GPU_SPEC_DIM(T, 6) \
343 DECLARE_GPU_SPEC_DIM(T, 7) \
344 DECLARE_GPU_SPEC_DIM(T, 8)
345
346TF_CALL_uint8(DECLARE_GPU_SPEC);
347TF_CALL_int8(DECLARE_GPU_SPEC);
348TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
349#undef DECLARE_GPU_SPEC
350#undef DECLARE_GPU_SPEC_DIM
351} // namespace functor
352
353// Registration of the GPU implementations.
354#define REGISTER_GPU_KERNELS(T) \
355 REGISTER_KERNEL_BUILDER(Name("Reverse") \
356 .Device(DEVICE_GPU) \
357 .TypeConstraint<T>("T") \
358 .HostMemory("dims"), \
359 ReverseOp<GPUDevice, T>) \
360 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
361 .Device(DEVICE_GPU) \
362 .TypeConstraint<T>("T") \
363 .TypeConstraint<int32>("Tidx") \
364 .HostMemory("axis"), \
365 ReverseV2Op<GPUDevice, T, int32>) \
366 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
367 .Device(DEVICE_GPU) \
368 .TypeConstraint<T>("T") \
369 .TypeConstraint<int64_t>("Tidx") \
370 .HostMemory("axis"), \
371 ReverseV2Op<GPUDevice, T, int64>)
372TF_CALL_uint8(REGISTER_GPU_KERNELS);
373TF_CALL_int8(REGISTER_GPU_KERNELS);
374TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
375#undef REGISTER_GPU_KERNEL
376
377// A special GPU kernel for int32.
378// TODO(b/25387198): Also enable int32 in device memory. This kernel
379// registration requires all int32 inputs and outputs to be in host memory.
380REGISTER_KERNEL_BUILDER(Name("Reverse")
381 .Device(DEVICE_GPU)
382 .TypeConstraint<int32>("T")
383 .HostMemory("tensor")
384 .HostMemory("dims")
385 .HostMemory("output"),
386 ReverseOp<CPUDevice, int32>);
387REGISTER_KERNEL_BUILDER(Name("ReverseV2")
388 .Device(DEVICE_GPU)
389 .TypeConstraint<int32>("T")
390 .TypeConstraint<int32>("Tidx")
391 .HostMemory("tensor")
392 .HostMemory("axis")
393 .HostMemory("output"),
394 ReverseV2Op<CPUDevice, int32, int32>);
395REGISTER_KERNEL_BUILDER(Name("ReverseV2")
396 .Device(DEVICE_GPU)
397 .TypeConstraint<int32>("T")
398 .TypeConstraint<int64_t>("Tidx")
399 .HostMemory("tensor")
400 .HostMemory("axis")
401 .HostMemory("output"),
402 ReverseV2Op<CPUDevice, int32, int64>);
403#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
404
405} // namespace tensorflow
406