1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/in_topk_op.h"
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25
26namespace tensorflow {
27
28typedef Eigen::ThreadPoolDevice CPUDevice;
29typedef Eigen::GpuDevice GPUDevice;
30
31template <typename Device, typename T, typename TARGET_T>
32class InTopK : public OpKernel {
33 public:
34 explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
35 if (context->num_inputs() == 2) {
36 OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
37 }
38 }
39
40 void Compute(OpKernelContext* context) override {
41 const auto& predictions_in = context->input(0);
42 const auto& targets_in = context->input(1);
43
44 int64_t k_value = k_;
45 const Tensor* k_tensor = nullptr;
46
47 if (context->num_inputs() == 3) {
48 const auto& k_in = context->input(2);
49
50 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
51 errors::InvalidArgument("k must be 0-D, got shape ",
52 k_in.shape().DebugString()));
53
54 k_tensor = &k_in;
55 }
56
57 OP_REQUIRES(context, predictions_in.dims() == 2,
58 errors::InvalidArgument("predictions must be 2-dimensional"));
59 OP_REQUIRES(context, targets_in.dims() == 1,
60 errors::InvalidArgument("targets must be 1-dimensional"));
61 OP_REQUIRES(context, predictions_in.dim_size(0) == targets_in.dim_size(0),
62 errors::InvalidArgument("First dimension of predictions ",
63 predictions_in.dim_size(0),
64 " must match length of targets ",
65 targets_in.dim_size(0)));
66
67 const auto predictions = predictions_in.matrix<T>();
68 const auto targets = targets_in.vec<TARGET_T>();
69
70 Tensor* t_out = nullptr;
71 OP_REQUIRES_OK(context,
72 context->allocate_output(
73 0, TensorShape({targets_in.dim_size(0)}), &t_out));
74 auto out = t_out->vec<bool>();
75
76 functor::InTopKFunctor<Device, T, TARGET_T> f;
77 functor::TopKArg arg;
78 arg.k_value = k_value;
79 arg.k_tensor = k_tensor;
80 f(context, predictions, targets, arg, out);
81 }
82
83 private:
84 int k_;
85};
86
87REGISTER_KERNEL_BUILDER(Name("InTopK")
88 .Device(DEVICE_CPU)
89 .HostMemory("predictions")
90 .HostMemory("targets")
91 .HostMemory("precision")
92 .TypeConstraint<int32>("T"),
93 InTopK<CPUDevice, float, int32>);
94REGISTER_KERNEL_BUILDER(Name("InTopK")
95 .Device(DEVICE_CPU)
96 .HostMemory("predictions")
97 .HostMemory("targets")
98 .HostMemory("precision")
99 .TypeConstraint<int64_t>("T"),
100 InTopK<CPUDevice, float, int64>);
101
102REGISTER_KERNEL_BUILDER(Name("InTopKV2")
103 .Device(DEVICE_CPU)
104 .HostMemory("predictions")
105 .HostMemory("targets")
106 .HostMemory("k")
107 .HostMemory("precision")
108 .TypeConstraint<int32>("T"),
109 InTopK<CPUDevice, float, int32>);
110REGISTER_KERNEL_BUILDER(Name("InTopKV2")
111 .Device(DEVICE_CPU)
112 .HostMemory("predictions")
113 .HostMemory("targets")
114 .HostMemory("k")
115 .HostMemory("precision")
116 .TypeConstraint<int64_t>("T"),
117 InTopK<CPUDevice, float, int64>);
118
119#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
120
121// Forward declarations of the functor specializations for GPU.
122namespace functor {
123#define DECLARE_GPU_SPEC(T, TARGET_T) \
124 template <> \
125 void InTopKFunctor<GPUDevice, T, TARGET_T>::operator()( \
126 OpKernelContext* context, \
127 typename TTypes<T, 2>::ConstTensor predictions, \
128 typename TTypes<TARGET_T>::ConstVec targets, const TopKArg k, \
129 typename TTypes<bool>::Vec output); \
130 extern template struct InTopKFunctor<GPUDevice, T, TARGET_T>;
131
132DECLARE_GPU_SPEC(float, int32);
133DECLARE_GPU_SPEC(float, int64_t);
134
135#undef DECLARE_GPU_SPEC
136} // namespace functor
137
138REGISTER_KERNEL_BUILDER(
139 Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
140 InTopK<GPUDevice, float, int32>);
141REGISTER_KERNEL_BUILDER(
142 Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64_t>("T"),
143 InTopK<GPUDevice, float, int64>);
144
145#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
146
147} // namespace tensorflow
148