1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/topk_op.h"
21
22#include <algorithm>
23#include <numeric>
24#include <vector>
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor_shape.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/lib/gtl/top_n.h"
32#include "tensorflow/core/util/work_sharder.h"
33
34namespace tensorflow {
35
36typedef Eigen::ThreadPoolDevice CPUDevice;
37typedef Eigen::GpuDevice GPUDevice;
38
39template <typename Device, typename T>
40class TopK : public OpKernel {
41 public:
42 explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
43 OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
44 if (num_inputs() < 2) { // k is an attr (TopK).
45 OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
46 } else { // k is an input (TopKV2), so we won't know it until Compute.
47 k_ = -1;
48 }
49 }
50
51 void Compute(OpKernelContext* context) override {
52 int k = k_;
53 if (num_inputs() >= 2) {
54 const auto& k_in = context->input(1);
55 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
56 errors::InvalidArgument("k must be scalar, got shape ",
57 k_in.shape().DebugString()));
58 k = k_in.scalar<int32>()();
59 }
60 OP_REQUIRES(context, k >= 0,
61 errors::InvalidArgument("Need k >= 0, got ", k));
62 const auto& input_in = context->input(0);
63 OP_REQUIRES(context, input_in.dims() >= 1,
64 errors::InvalidArgument("input must be >= 1-D, got shape ",
65 input_in.shape().DebugString()));
66 OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
67 errors::InvalidArgument(
68 "input must have at least k columns. Had ",
69 input_in.dim_size(input_in.dims() - 1), ", needed ", k));
70
71 const auto& input = input_in.flat_inner_dims<T>();
72
73 const int64_t num_rows = input.dimension(0); // generally batch_size
74 const int64_t num_cols = input.dimension(1);
75 OP_REQUIRES(
76 context, num_rows <= std::numeric_limits<int32>::max(),
77 errors::InvalidArgument(
78 "First dimension of flattened input must be <= INT_MAX, got ",
79 num_rows));
80 OP_REQUIRES(
81 context, num_cols <= std::numeric_limits<int32>::max(),
82 errors::InvalidArgument(
83 "Second dimension of flattened input must be <= INT_MAX, got ",
84 num_cols));
85
86 TensorShape output_shape = input_in.shape();
87 output_shape.set_dim(input_in.dims() - 1, k);
88 Tensor* values_out = nullptr;
89 OP_REQUIRES_OK(context,
90 context->allocate_output(0, output_shape, &values_out));
91 Tensor* indices_out = nullptr;
92 OP_REQUIRES_OK(context,
93 context->allocate_output(1, output_shape, &indices_out));
94
95 // Nothing to do for top-nothing or over nothing.
96 if (k == 0 || num_rows == 0) return;
97
98 auto values = values_out->flat_inner_dims<T>();
99 auto indices = indices_out->flat_inner_dims<int32>();
100 Status s = functor::TopKFunctor<Device, T>::Compute(
101 context, sorted_, k, input, num_rows, num_cols, values, indices);
102 OP_REQUIRES_OK(context, s);
103 }
104
105 private:
106 int k_;
107 bool sorted_;
108};
109
110namespace functor {
111
112template <typename T>
113struct TopKFunctor<CPUDevice, T> {
114 static EIGEN_ALWAYS_INLINE Status Compute(
115 OpKernelContext* context, bool sorted, int k,
116 const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows,
117 const int64_t num_cols, typename TTypes<T, 2>::Tensor values,
118 typename TTypes<int, 2>::Tensor indices) {
119 const CPUDevice& d = context->eigen_device<CPUDevice>();
120
121 // Special case for k == 1.
122 if (k == 1) {
123 typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols;
124 typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one;
125 rows_by_one.set(0, num_rows);
126
127 values.device(d) =
128 input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
129 // Get the indices of the maximum values.
130 for (int r = 0; r < num_rows; ++r) {
131 indices(r, 0) = 0;
132 for (int c = 0; c < num_cols; ++c) {
133 if (values(r, 0) == input(r, c)) {
134 indices(r, 0) = c;
135 break;
136 }
137 }
138 values(r, 0) = input(r, indices(r, 0));
139 }
140
141 return OkStatus();
142 }
143
144 auto SortIndices = [&](int64_t start_batch, int64_t limit_batch) {
145 for (int32_t b = start_batch; b < limit_batch; ++b) {
146 const T* input_data = &input(b, 0);
147 const auto stable_comp = [input_data](const int32_t a,
148 const int32_t b) {
149 if (input_data[b] < input_data[a]) {
150 return true;
151 } else if (input_data[b] > input_data[a]) {
152 return false;
153 } else {
154 return a < b;
155 }
156 };
157 const auto comp = [input_data](const int32_t a, const int32_t b) {
158 return input_data[b] < input_data[a];
159 };
160 // TODO(ebrevdo): For large k < num_cols, instead of using
161 // TopN, it may be faster to create a temporary vector of
162 // values 0..num_cols - 1 and then use std::partial_sort_copy
163 // of this into indices. Choosing the appropriate minimum k or
164 // ratio of k/num_cols will require some experimentation.
165 if (k == num_cols) {
166 auto* begin = &indices(b, 0);
167 auto* end = &indices(b, k);
168 // Set the initial array of indices 0 ... k - 1.
169 std::iota(begin, end, 0);
170 // We want an in-place sort, but we can cheat because we're sorting
171 // indices that started out sorted. First, do a std::sort, which
172 // is notably faster than std::stable_sort.
173 std::sort(begin, end, comp);
174 // Then, for runs of adjacent elements that were equal, sort the
175 // indices in those runs in increasing order.
176 for (auto* run_begin = begin; run_begin != end;) {
177 auto* run_end = run_begin + 1;
178 if (run_end == end) break;
179 if (input_data[*run_begin] == input_data[*run_end]) {
180 while (++run_end != end) {
181 if (input_data[*run_begin] != input_data[*run_end]) break;
182 }
183 std::sort(run_begin, run_end);
184 }
185 run_begin = run_end;
186 }
187 } else {
188 // Use the TopN heap object to sort.
189 gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp);
190 filter.reserve(num_cols);
191 for (int32_t c = 0; c < num_cols; ++c) {
192 filter.push(c);
193 }
194
195 int32_t i = 0;
196 if (sorted) {
197 std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
198 for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
199 ++top_k_it, ++i) {
200 indices(b, i) = *top_k_it;
201 }
202 } else {
203 for (auto top_k_it = filter.unsorted_begin();
204 top_k_it != filter.unsorted_end(); ++top_k_it, ++i) {
205 indices(b, i) = *top_k_it;
206 }
207 }
208 }
209 // Now that the indices are sorted, copy the values over in
210 // sorted order.
211 std::transform(
212 &indices(b, 0), &indices(b, k), &values(b, 0),
213 [b, &input](const int32_t loc) { return input(b, loc); });
214 } // for (int32 b = ...
215 };
216
217 // Guesstimate of cost; 4*N*log(K) where N == num_cols.
218 // If K == N, assume the cost is N*log(K + 1).
219 const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
220 Eigen::TensorOpCost::AddCost<T>();
221 const double base_cost =
222 cmp_cost *
223 static_cast<double>(num_cols *
224 Eigen::numext::log2(static_cast<float>(k + 1)));
225 const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
226 const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
227 const double total_cost = sort_cost + copy_cost;
228 const int64_t final_cost = (total_cost >= static_cast<double>(kint64max))
229 ? kint64max
230 : static_cast<int64_t>(total_cost);
231 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
232 Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
233 final_cost, SortIndices);
234
235 return OkStatus();
236 }
237};
238
239} // namespace functor
240
241#define REGISTER_KERNELS_NAME(name, type) \
242 REGISTER_KERNEL_BUILDER( \
243 Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
244 TopK<CPUDevice, type>)
245
246#define REGISTER_KERNELS(type) \
247 REGISTER_KERNELS_NAME(TopK, type); \
248 REGISTER_KERNELS_NAME(TopKV2, type)
249
250TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
251#undef REGISTER_KERNELS_NAME
252#undef REGISTER_KERNELS
253
254#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
255
256namespace functor {
257#define DECLARE_GPU_SPEC(T) \
258 template <> \
259 Status TopKFunctor<GPUDevice, T>::Compute( \
260 OpKernelContext* context, bool sorted, int k, \
261 const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows, \
262 const int64_t num_cols, typename TTypes<T, 2>::Tensor values, \
263 typename TTypes<int, 2>::Tensor indices); \
264 extern template struct functor::TopKFunctor<GPUDevice, T>;
265
266TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
267TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
268
269#undef DECLARE_GPU_SPEC
270
271} // namespace functor
272
273#define REGISTER_KERNELS(type) \
274 REGISTER_KERNEL_BUILDER( \
275 Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
276 TopK<GPUDevice, type>) \
277 REGISTER_KERNEL_BUILDER(Name("TopKV2") \
278 .Device(DEVICE_GPU) \
279 .TypeConstraint<type>("T") \
280 .HostMemory("k"), \
281 TopK<GPUDevice, type>)
282
283TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
284TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
285#undef REGISTER_KERNELS
286
287#endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
288
289} // end namespace tensorflow
290