1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
21 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
22 | #define EIGEN_USE_GPU |
23 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
24 | |
25 | #include "tensorflow/core/kernels/argmax_op.h" |
26 | |
27 | #include <memory> |
28 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
29 | #include "tensorflow/core/framework/bounds_check.h" |
30 | #include "tensorflow/core/framework/op_kernel.h" |
31 | #include "tensorflow/core/framework/register_types.h" |
32 | #include "tensorflow/core/framework/tensor.h" |
33 | #include "tensorflow/core/framework/tensor_shape.h" |
34 | #include "tensorflow/core/framework/tensor_types.h" |
35 | #include "tensorflow/core/framework/types.h" |
36 | #include "tensorflow/core/platform/logging.h" |
37 | #include "tensorflow/core/platform/macros.h" |
38 | |
39 | namespace tensorflow { |
40 | |
41 | typedef Eigen::ThreadPoolDevice CPUDevice; |
42 | typedef Eigen::GpuDevice GPUDevice; |
43 | |
44 | template <typename Device, typename T, typename Tout, typename ArgFunctor> |
45 | class ArgOp : public OpKernel { |
46 | public: |
47 | explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {} |
48 | |
49 | void Compute(OpKernelContext* context) override { |
50 | const Tensor& input = context->input(0); |
51 | const Tensor& dimension = context->input(1); |
52 | |
53 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()), |
54 | errors::InvalidArgument( |
55 | "dim must be a scalar, but received tensor of shape: " , |
56 | dimension.shape().DebugString())); |
57 | |
58 | const int32_t dim = internal::SubtleMustCopy(dimension.scalar<int32>()()); |
59 | const int input_dims = input.dims(); |
60 | |
61 | int axis = dim < 0 ? dim + input_dims : dim; |
62 | |
63 | OP_REQUIRES(context, FastBoundsCheck(axis, input_dims), |
64 | errors::InvalidArgument("Expected dimension in the range [" , |
65 | -input_dims, ", " , input_dims, |
66 | "), but got " , dim)); |
67 | OP_REQUIRES( |
68 | context, input.dim_size(axis) > 0, |
69 | errors::InvalidArgument("Reduction axis " , dim, " is empty in shape " , |
70 | input.shape().DebugString())); |
71 | |
72 | TensorShape output_shape; |
73 | const TensorShape& input_shape = input.shape(); |
74 | for (int d = 0; d < input_dims - 1; ++d) { |
75 | output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); |
76 | } |
77 | Tensor* output = nullptr; |
78 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); |
79 | |
80 | if (output_shape.num_elements() == 0) { |
81 | return; |
82 | } |
83 | |
84 | #define HANDLE_DIM(NDIM) \ |
85 | case NDIM: \ |
86 | ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \ |
87 | input.tensor<T, NDIM>(), axis, \ |
88 | output->tensor<Tout, NDIM - 1>()); \ |
89 | break; |
90 | |
91 | switch (input_dims) { |
92 | HANDLE_DIM(1); |
93 | HANDLE_DIM(2); |
94 | HANDLE_DIM(3); |
95 | HANDLE_DIM(4); |
96 | HANDLE_DIM(5); |
97 | HANDLE_DIM(6); |
98 | HANDLE_DIM(7); |
99 | |
100 | default: |
101 | OP_REQUIRES(context, false, |
102 | errors::InvalidArgument("Argmax and Argmin only support up " |
103 | "to 7 input dimensions, but got " , |
104 | input_dims, ". Inputs shape: " , |
105 | input.shape().DebugString())); |
106 | } |
107 | } |
108 | #undef HANDLE_DIM |
109 | |
110 | private: |
111 | TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); |
112 | }; |
113 | |
114 | template <typename Device, typename T, typename Tout> |
115 | class ArgMaxOp |
116 | : public ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> > { |
117 | public: |
118 | explicit ArgMaxOp(OpKernelConstruction* context) |
119 | : ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> >(context) {} |
120 | }; |
121 | |
122 | template <typename Device, typename T, typename Tout> |
123 | class ArgMinOp |
124 | : public ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> > { |
125 | public: |
126 | explicit ArgMinOp(OpKernelConstruction* context) |
127 | : ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> >(context) {} |
128 | }; |
129 | |
130 | #define REGISTER_ARGMAX(type) \ |
131 | REGISTER_KERNEL_BUILDER(Name("ArgMax") \ |
132 | .Device(DEVICE_CPU) \ |
133 | .TypeConstraint<type>("T") \ |
134 | .TypeConstraint<int64_t>("output_type") \ |
135 | .HostMemory("dimension"), \ |
136 | ArgMaxOp<CPUDevice, type, int64>); \ |
137 | REGISTER_KERNEL_BUILDER(Name("ArgMin") \ |
138 | .Device(DEVICE_CPU) \ |
139 | .TypeConstraint<type>("T") \ |
140 | .TypeConstraint<int64_t>("output_type") \ |
141 | .HostMemory("dimension"), \ |
142 | ArgMinOp<CPUDevice, type, int64>); \ |
143 | REGISTER_KERNEL_BUILDER(Name("ArgMax") \ |
144 | .Device(DEVICE_CPU) \ |
145 | .TypeConstraint<type>("T") \ |
146 | .TypeConstraint<int32>("output_type") \ |
147 | .HostMemory("dimension"), \ |
148 | ArgMaxOp<CPUDevice, type, int32>); \ |
149 | REGISTER_KERNEL_BUILDER(Name("ArgMin") \ |
150 | .Device(DEVICE_CPU) \ |
151 | .TypeConstraint<type>("T") \ |
152 | .TypeConstraint<int32>("output_type") \ |
153 | .HostMemory("dimension"), \ |
154 | ArgMinOp<CPUDevice, type, int32>); \ |
155 | REGISTER_KERNEL_BUILDER(Name("ArgMax") \ |
156 | .Device(DEVICE_CPU) \ |
157 | .TypeConstraint<type>("T") \ |
158 | .TypeConstraint<int16>("output_type") \ |
159 | .HostMemory("dimension"), \ |
160 | ArgMaxOp<CPUDevice, type, int16>); \ |
161 | REGISTER_KERNEL_BUILDER(Name("ArgMax") \ |
162 | .Device(DEVICE_CPU) \ |
163 | .TypeConstraint<type>("T") \ |
164 | .TypeConstraint<uint16>("output_type") \ |
165 | .HostMemory("dimension"), \ |
166 | ArgMaxOp<CPUDevice, type, uint16>); |
167 | |
168 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX); |
169 | TF_CALL_bool(REGISTER_ARGMAX); |
170 | |
171 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
172 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
173 | |
174 | // Forward declarations of the functor specializations for GPU. |
175 | namespace functor { |
176 | |
177 | #define DECLARE_GPU_SPEC(T, Tout, Dims) \ |
178 | template <> \ |
179 | void ArgMax<GPUDevice, T, Tout>::Reduce##Dims( \ |
180 | const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ |
181 | const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); \ |
182 | template <> \ |
183 | void ArgMin<GPUDevice, T, Tout>::Reduce##Dims( \ |
184 | const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ |
185 | const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); |
186 | |
187 | #define DECLARE_GPU_SPECS(T) \ |
188 | DECLARE_GPU_SPEC(T, int64_t, 1); \ |
189 | DECLARE_GPU_SPEC(T, int64_t, 2); \ |
190 | DECLARE_GPU_SPEC(T, int64_t, 3); \ |
191 | DECLARE_GPU_SPEC(T, int64_t, 4); \ |
192 | DECLARE_GPU_SPEC(T, int64_t, 5); \ |
193 | DECLARE_GPU_SPEC(T, int64_t, 6); \ |
194 | DECLARE_GPU_SPEC(T, int64_t, 7); \ |
195 | DECLARE_GPU_SPEC(T, int32, 1); \ |
196 | DECLARE_GPU_SPEC(T, int32, 2); \ |
197 | DECLARE_GPU_SPEC(T, int32, 3); \ |
198 | DECLARE_GPU_SPEC(T, int32, 4); \ |
199 | DECLARE_GPU_SPEC(T, int32, 5); \ |
200 | DECLARE_GPU_SPEC(T, int32, 6); \ |
201 | DECLARE_GPU_SPEC(T, int32, 7); |
202 | |
203 | #define DECLARE_GPU_CLASS(T) \ |
204 | extern template struct ArgMax<GPUDevice, T, int64_t>; \ |
205 | extern template struct ArgMin<GPUDevice, T, int64_t>; \ |
206 | extern template struct ArgMax<GPUDevice, T, int32>; \ |
207 | extern template struct ArgMin<GPUDevice, T, int32>; |
208 | |
209 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); |
210 | TF_CALL_bool(DECLARE_GPU_SPECS); |
211 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); |
212 | TF_CALL_bool(DECLARE_GPU_CLASS); |
213 | |
214 | #undef DECLARE_GPU_SPECS |
215 | #undef DECLARE_GPU_CLASS |
216 | |
217 | } // namespace functor |
218 | |
219 | // Registration of the GPU implementations. |
220 | #define REGISTER_ARGMAX_GPU(type) \ |
221 | REGISTER_KERNEL_BUILDER(Name("ArgMax") \ |
222 | .Device(DEVICE_GPU) \ |
223 | .TypeConstraint<type>("T") \ |
224 | .TypeConstraint<int64_t>("output_type") \ |
225 | .TypeConstraint<int32>("Tidx") \ |
226 | .HostMemory("dimension"), \ |
227 | ArgMaxOp<GPUDevice, type, int64>); \ |
228 | REGISTER_KERNEL_BUILDER(Name("ArgMin") \ |
229 | .Device(DEVICE_GPU) \ |
230 | .TypeConstraint<type>("T") \ |
231 | .TypeConstraint<int64_t>("output_type") \ |
232 | .TypeConstraint<int32>("Tidx") \ |
233 | .HostMemory("dimension"), \ |
234 | ArgMinOp<GPUDevice, type, int64>); \ |
235 | REGISTER_KERNEL_BUILDER(Name("ArgMax") \ |
236 | .Device(DEVICE_GPU) \ |
237 | .TypeConstraint<type>("T") \ |
238 | .TypeConstraint<int32>("output_type") \ |
239 | .TypeConstraint<int32>("Tidx") \ |
240 | .HostMemory("dimension"), \ |
241 | ArgMaxOp<GPUDevice, type, int32>); \ |
242 | REGISTER_KERNEL_BUILDER(Name("ArgMin") \ |
243 | .Device(DEVICE_GPU) \ |
244 | .TypeConstraint<type>("T") \ |
245 | .TypeConstraint<int32>("output_type") \ |
246 | .TypeConstraint<int32>("Tidx") \ |
247 | .HostMemory("dimension"), \ |
248 | ArgMinOp<GPUDevice, type, int32>); |
249 | |
250 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU); |
251 | TF_CALL_bool(REGISTER_ARGMAX_GPU); |
252 | |
253 | #undef REGISTER_ARGMAX_GPU |
254 | |
255 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
256 | |
257 | } // namespace tensorflow |
258 | |