1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
17#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
18
19#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20#define EIGEN_USE_GPU
21#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/core/framework/bounds_check.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_types.h"
28
29namespace tensorflow {
30namespace functor {
31
32typedef Eigen::ThreadPoolDevice CPUDevice;
33typedef Eigen::GpuDevice GPUDevice;
34
35// InTopK argument can be passed either via mode attribute (InTopK op), or as an
36// input tensor (InTopKV2 op).
37struct TopKArg {
38 int64_t k_value = -1;
39 const Tensor* k_tensor = nullptr;
40};
41
42template <typename Device, typename T, typename TargetT>
43struct InTopKFunctor {
44 template <int ndims>
45 using Dims = Eigen::DSizes<Eigen::Index, ndims>;
46
47 void operator()(OpKernelContext* context,
48 typename TTypes<T, 2>::ConstTensor predictions,
49 typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
50 typename TTypes<bool>::Vec output) {}
51};
52
53template <typename T, typename TargetT>
54struct InTopKFunctor<CPUDevice, T, TargetT> {
55 void operator()(OpKernelContext* context,
56 typename TTypes<T, 2>::ConstTensor predictions,
57 typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
58 typename TTypes<bool>::Vec output) {
59 const Eigen::Index num_targets = predictions.dimension(0);
60 const Eigen::Index num_classes = predictions.dimension(1);
61
62 int64_t k_val = k.k_value;
63 if (k.k_tensor != nullptr) {
64 if (k.k_tensor->dtype() == DT_INT32) {
65 k_val = k.k_tensor->scalar<int32>()();
66 } else {
67 k_val = k.k_tensor->scalar<int64_t>()();
68 }
69 }
70
71 for (int batch_idx = 0; batch_idx < num_targets; batch_idx++) {
72 auto target = internal::SubtleMustCopy(targets(batch_idx));
73
74 bool cannot_say = !FastBoundsCheck(target, num_classes) ||
75 !std::isfinite(predictions(batch_idx, target));
76
77 int more_probable_classes = 0;
78 if (!cannot_say) {
79 const T target_prediction = predictions(batch_idx, target);
80
81 for (int class_idx = 0; class_idx < num_classes; ++class_idx) {
82 T pred = predictions(batch_idx, class_idx);
83 if (!std::isfinite(pred)) {
84 cannot_say = true;
85 break;
86 } else if (pred > target_prediction) {
87 ++more_probable_classes;
88 if (more_probable_classes > k_val) break;
89 }
90 }
91 }
92 output(batch_idx) = cannot_say ? false : (more_probable_classes < k_val);
93 }
94 }
95};
96
97} // namespace functor
98} // namespace tensorflow
99
100#endif // TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
101