1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/in_topk_op.h" |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | typedef Eigen::ThreadPoolDevice CPUDevice; |
29 | typedef Eigen::GpuDevice GPUDevice; |
30 | |
31 | template <typename Device, typename T, typename TARGET_T> |
32 | class InTopK : public OpKernel { |
33 | public: |
34 | explicit InTopK(OpKernelConstruction* context) : OpKernel(context) { |
35 | if (context->num_inputs() == 2) { |
36 | OP_REQUIRES_OK(context, context->GetAttr("k" , &k_)); |
37 | } |
38 | } |
39 | |
40 | void Compute(OpKernelContext* context) override { |
41 | const auto& predictions_in = context->input(0); |
42 | const auto& targets_in = context->input(1); |
43 | |
44 | int64_t k_value = k_; |
45 | const Tensor* k_tensor = nullptr; |
46 | |
47 | if (context->num_inputs() == 3) { |
48 | const auto& k_in = context->input(2); |
49 | |
50 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), |
51 | errors::InvalidArgument("k must be 0-D, got shape " , |
52 | k_in.shape().DebugString())); |
53 | |
54 | k_tensor = &k_in; |
55 | } |
56 | |
57 | OP_REQUIRES(context, predictions_in.dims() == 2, |
58 | errors::InvalidArgument("predictions must be 2-dimensional" )); |
59 | OP_REQUIRES(context, targets_in.dims() == 1, |
60 | errors::InvalidArgument("targets must be 1-dimensional" )); |
61 | OP_REQUIRES(context, predictions_in.dim_size(0) == targets_in.dim_size(0), |
62 | errors::InvalidArgument("First dimension of predictions " , |
63 | predictions_in.dim_size(0), |
64 | " must match length of targets " , |
65 | targets_in.dim_size(0))); |
66 | |
67 | const auto predictions = predictions_in.matrix<T>(); |
68 | const auto targets = targets_in.vec<TARGET_T>(); |
69 | |
70 | Tensor* t_out = nullptr; |
71 | OP_REQUIRES_OK(context, |
72 | context->allocate_output( |
73 | 0, TensorShape({targets_in.dim_size(0)}), &t_out)); |
74 | auto out = t_out->vec<bool>(); |
75 | |
76 | functor::InTopKFunctor<Device, T, TARGET_T> f; |
77 | functor::TopKArg arg; |
78 | arg.k_value = k_value; |
79 | arg.k_tensor = k_tensor; |
80 | f(context, predictions, targets, arg, out); |
81 | } |
82 | |
83 | private: |
84 | int k_; |
85 | }; |
86 | |
87 | REGISTER_KERNEL_BUILDER(Name("InTopK" ) |
88 | .Device(DEVICE_CPU) |
89 | .HostMemory("predictions" ) |
90 | .HostMemory("targets" ) |
91 | .HostMemory("precision" ) |
92 | .TypeConstraint<int32>("T" ), |
93 | InTopK<CPUDevice, float, int32>); |
94 | REGISTER_KERNEL_BUILDER(Name("InTopK" ) |
95 | .Device(DEVICE_CPU) |
96 | .HostMemory("predictions" ) |
97 | .HostMemory("targets" ) |
98 | .HostMemory("precision" ) |
99 | .TypeConstraint<int64_t>("T" ), |
100 | InTopK<CPUDevice, float, int64>); |
101 | |
102 | REGISTER_KERNEL_BUILDER(Name("InTopKV2" ) |
103 | .Device(DEVICE_CPU) |
104 | .HostMemory("predictions" ) |
105 | .HostMemory("targets" ) |
106 | .HostMemory("k" ) |
107 | .HostMemory("precision" ) |
108 | .TypeConstraint<int32>("T" ), |
109 | InTopK<CPUDevice, float, int32>); |
110 | REGISTER_KERNEL_BUILDER(Name("InTopKV2" ) |
111 | .Device(DEVICE_CPU) |
112 | .HostMemory("predictions" ) |
113 | .HostMemory("targets" ) |
114 | .HostMemory("k" ) |
115 | .HostMemory("precision" ) |
116 | .TypeConstraint<int64_t>("T" ), |
117 | InTopK<CPUDevice, float, int64>); |
118 | |
119 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
120 | |
121 | // Forward declarations of the functor specializations for GPU. |
122 | namespace functor { |
123 | #define DECLARE_GPU_SPEC(T, TARGET_T) \ |
124 | template <> \ |
125 | void InTopKFunctor<GPUDevice, T, TARGET_T>::operator()( \ |
126 | OpKernelContext* context, \ |
127 | typename TTypes<T, 2>::ConstTensor predictions, \ |
128 | typename TTypes<TARGET_T>::ConstVec targets, const TopKArg k, \ |
129 | typename TTypes<bool>::Vec output); \ |
130 | extern template struct InTopKFunctor<GPUDevice, T, TARGET_T>; |
131 | |
132 | DECLARE_GPU_SPEC(float, int32); |
133 | DECLARE_GPU_SPEC(float, int64_t); |
134 | |
135 | #undef DECLARE_GPU_SPEC |
136 | } // namespace functor |
137 | |
138 | REGISTER_KERNEL_BUILDER( |
139 | Name("InTopKV2" ).Device(DEVICE_GPU).TypeConstraint<int32>("T" ), |
140 | InTopK<GPUDevice, float, int32>); |
141 | REGISTER_KERNEL_BUILDER( |
142 | Name("InTopKV2" ).Device(DEVICE_GPU).TypeConstraint<int64_t>("T" ), |
143 | InTopK<GPUDevice, float, int64>); |
144 | |
145 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
146 | |
147 | } // namespace tensorflow |
148 | |