1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
21 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
22#define EIGEN_USE_GPU
23#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24
25#include "tensorflow/core/kernels/argmax_op.h"
26
27#include <memory>
28#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29#include "tensorflow/core/framework/bounds_check.h"
30#include "tensorflow/core/framework/op_kernel.h"
31#include "tensorflow/core/framework/register_types.h"
32#include "tensorflow/core/framework/tensor.h"
33#include "tensorflow/core/framework/tensor_shape.h"
34#include "tensorflow/core/framework/tensor_types.h"
35#include "tensorflow/core/framework/types.h"
36#include "tensorflow/core/platform/logging.h"
37#include "tensorflow/core/platform/macros.h"
38
39namespace tensorflow {
40
41typedef Eigen::ThreadPoolDevice CPUDevice;
42typedef Eigen::GpuDevice GPUDevice;
43
44template <typename Device, typename T, typename Tout, typename ArgFunctor>
45class ArgOp : public OpKernel {
46 public:
47 explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {}
48
49 void Compute(OpKernelContext* context) override {
50 const Tensor& input = context->input(0);
51 const Tensor& dimension = context->input(1);
52
53 OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()),
54 errors::InvalidArgument(
55 "dim must be a scalar, but received tensor of shape: ",
56 dimension.shape().DebugString()));
57
58 const int32_t dim = internal::SubtleMustCopy(dimension.scalar<int32>()());
59 const int input_dims = input.dims();
60
61 int axis = dim < 0 ? dim + input_dims : dim;
62
63 OP_REQUIRES(context, FastBoundsCheck(axis, input_dims),
64 errors::InvalidArgument("Expected dimension in the range [",
65 -input_dims, ", ", input_dims,
66 "), but got ", dim));
67 OP_REQUIRES(
68 context, input.dim_size(axis) > 0,
69 errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ",
70 input.shape().DebugString()));
71
72 TensorShape output_shape;
73 const TensorShape& input_shape = input.shape();
74 for (int d = 0; d < input_dims - 1; ++d) {
75 output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
76 }
77 Tensor* output = nullptr;
78 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
79
80 if (output_shape.num_elements() == 0) {
81 return;
82 }
83
84#define HANDLE_DIM(NDIM) \
85 case NDIM: \
86 ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \
87 input.tensor<T, NDIM>(), axis, \
88 output->tensor<Tout, NDIM - 1>()); \
89 break;
90
91 switch (input_dims) {
92 HANDLE_DIM(1);
93 HANDLE_DIM(2);
94 HANDLE_DIM(3);
95 HANDLE_DIM(4);
96 HANDLE_DIM(5);
97 HANDLE_DIM(6);
98 HANDLE_DIM(7);
99
100 default:
101 OP_REQUIRES(context, false,
102 errors::InvalidArgument("Argmax and Argmin only support up "
103 "to 7 input dimensions, but got ",
104 input_dims, ". Inputs shape: ",
105 input.shape().DebugString()));
106 }
107 }
108#undef HANDLE_DIM
109
110 private:
111 TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
112};
113
114template <typename Device, typename T, typename Tout>
115class ArgMaxOp
116 : public ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> > {
117 public:
118 explicit ArgMaxOp(OpKernelConstruction* context)
119 : ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> >(context) {}
120};
121
122template <typename Device, typename T, typename Tout>
123class ArgMinOp
124 : public ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> > {
125 public:
126 explicit ArgMinOp(OpKernelConstruction* context)
127 : ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> >(context) {}
128};
129
130#define REGISTER_ARGMAX(type) \
131 REGISTER_KERNEL_BUILDER(Name("ArgMax") \
132 .Device(DEVICE_CPU) \
133 .TypeConstraint<type>("T") \
134 .TypeConstraint<int64_t>("output_type") \
135 .HostMemory("dimension"), \
136 ArgMaxOp<CPUDevice, type, int64>); \
137 REGISTER_KERNEL_BUILDER(Name("ArgMin") \
138 .Device(DEVICE_CPU) \
139 .TypeConstraint<type>("T") \
140 .TypeConstraint<int64_t>("output_type") \
141 .HostMemory("dimension"), \
142 ArgMinOp<CPUDevice, type, int64>); \
143 REGISTER_KERNEL_BUILDER(Name("ArgMax") \
144 .Device(DEVICE_CPU) \
145 .TypeConstraint<type>("T") \
146 .TypeConstraint<int32>("output_type") \
147 .HostMemory("dimension"), \
148 ArgMaxOp<CPUDevice, type, int32>); \
149 REGISTER_KERNEL_BUILDER(Name("ArgMin") \
150 .Device(DEVICE_CPU) \
151 .TypeConstraint<type>("T") \
152 .TypeConstraint<int32>("output_type") \
153 .HostMemory("dimension"), \
154 ArgMinOp<CPUDevice, type, int32>); \
155 REGISTER_KERNEL_BUILDER(Name("ArgMax") \
156 .Device(DEVICE_CPU) \
157 .TypeConstraint<type>("T") \
158 .TypeConstraint<int16>("output_type") \
159 .HostMemory("dimension"), \
160 ArgMaxOp<CPUDevice, type, int16>); \
161 REGISTER_KERNEL_BUILDER(Name("ArgMax") \
162 .Device(DEVICE_CPU) \
163 .TypeConstraint<type>("T") \
164 .TypeConstraint<uint16>("output_type") \
165 .HostMemory("dimension"), \
166 ArgMaxOp<CPUDevice, type, uint16>);
167
168TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX);
169TF_CALL_bool(REGISTER_ARGMAX);
170
171#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
172 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
173
174// Forward declarations of the functor specializations for GPU.
175namespace functor {
176
177#define DECLARE_GPU_SPEC(T, Tout, Dims) \
178 template <> \
179 void ArgMax<GPUDevice, T, Tout>::Reduce##Dims( \
180 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
181 const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); \
182 template <> \
183 void ArgMin<GPUDevice, T, Tout>::Reduce##Dims( \
184 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
185 const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output);
186
187#define DECLARE_GPU_SPECS(T) \
188 DECLARE_GPU_SPEC(T, int64_t, 1); \
189 DECLARE_GPU_SPEC(T, int64_t, 2); \
190 DECLARE_GPU_SPEC(T, int64_t, 3); \
191 DECLARE_GPU_SPEC(T, int64_t, 4); \
192 DECLARE_GPU_SPEC(T, int64_t, 5); \
193 DECLARE_GPU_SPEC(T, int64_t, 6); \
194 DECLARE_GPU_SPEC(T, int64_t, 7); \
195 DECLARE_GPU_SPEC(T, int32, 1); \
196 DECLARE_GPU_SPEC(T, int32, 2); \
197 DECLARE_GPU_SPEC(T, int32, 3); \
198 DECLARE_GPU_SPEC(T, int32, 4); \
199 DECLARE_GPU_SPEC(T, int32, 5); \
200 DECLARE_GPU_SPEC(T, int32, 6); \
201 DECLARE_GPU_SPEC(T, int32, 7);
202
203#define DECLARE_GPU_CLASS(T) \
204 extern template struct ArgMax<GPUDevice, T, int64_t>; \
205 extern template struct ArgMin<GPUDevice, T, int64_t>; \
206 extern template struct ArgMax<GPUDevice, T, int32>; \
207 extern template struct ArgMin<GPUDevice, T, int32>;
208
209TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
210TF_CALL_bool(DECLARE_GPU_SPECS);
211TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
212TF_CALL_bool(DECLARE_GPU_CLASS);
213
214#undef DECLARE_GPU_SPECS
215#undef DECLARE_GPU_CLASS
216
217} // namespace functor
218
219// Registration of the GPU implementations.
220#define REGISTER_ARGMAX_GPU(type) \
221 REGISTER_KERNEL_BUILDER(Name("ArgMax") \
222 .Device(DEVICE_GPU) \
223 .TypeConstraint<type>("T") \
224 .TypeConstraint<int64_t>("output_type") \
225 .TypeConstraint<int32>("Tidx") \
226 .HostMemory("dimension"), \
227 ArgMaxOp<GPUDevice, type, int64>); \
228 REGISTER_KERNEL_BUILDER(Name("ArgMin") \
229 .Device(DEVICE_GPU) \
230 .TypeConstraint<type>("T") \
231 .TypeConstraint<int64_t>("output_type") \
232 .TypeConstraint<int32>("Tidx") \
233 .HostMemory("dimension"), \
234 ArgMinOp<GPUDevice, type, int64>); \
235 REGISTER_KERNEL_BUILDER(Name("ArgMax") \
236 .Device(DEVICE_GPU) \
237 .TypeConstraint<type>("T") \
238 .TypeConstraint<int32>("output_type") \
239 .TypeConstraint<int32>("Tidx") \
240 .HostMemory("dimension"), \
241 ArgMaxOp<GPUDevice, type, int32>); \
242 REGISTER_KERNEL_BUILDER(Name("ArgMin") \
243 .Device(DEVICE_GPU) \
244 .TypeConstraint<type>("T") \
245 .TypeConstraint<int32>("output_type") \
246 .TypeConstraint<int32>("Tidx") \
247 .HostMemory("dimension"), \
248 ArgMinOp<GPUDevice, type, int32>);
249
250TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU);
251TF_CALL_bool(REGISTER_ARGMAX_GPU);
252
253#undef REGISTER_ARGMAX_GPU
254
255#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
256
257} // namespace tensorflow
258