1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/topk_op.h" |
21 | |
22 | #include <algorithm> |
23 | #include <numeric> |
24 | #include <vector> |
25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/tensor_shape.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/lib/gtl/top_n.h" |
32 | #include "tensorflow/core/util/work_sharder.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | typedef Eigen::ThreadPoolDevice CPUDevice; |
37 | typedef Eigen::GpuDevice GPUDevice; |
38 | |
39 | template <typename Device, typename T> |
40 | class TopK : public OpKernel { |
41 | public: |
42 | explicit TopK(OpKernelConstruction* context) : OpKernel(context) { |
43 | OP_REQUIRES_OK(context, context->GetAttr("sorted" , &sorted_)); |
44 | if (num_inputs() < 2) { // k is an attr (TopK). |
45 | OP_REQUIRES_OK(context, context->GetAttr("k" , &k_)); |
46 | } else { // k is an input (TopKV2), so we won't know it until Compute. |
47 | k_ = -1; |
48 | } |
49 | } |
50 | |
51 | void Compute(OpKernelContext* context) override { |
52 | int k = k_; |
53 | if (num_inputs() >= 2) { |
54 | const auto& k_in = context->input(1); |
55 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), |
56 | errors::InvalidArgument("k must be scalar, got shape " , |
57 | k_in.shape().DebugString())); |
58 | k = k_in.scalar<int32>()(); |
59 | } |
60 | OP_REQUIRES(context, k >= 0, |
61 | errors::InvalidArgument("Need k >= 0, got " , k)); |
62 | const auto& input_in = context->input(0); |
63 | OP_REQUIRES(context, input_in.dims() >= 1, |
64 | errors::InvalidArgument("input must be >= 1-D, got shape " , |
65 | input_in.shape().DebugString())); |
66 | OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k, |
67 | errors::InvalidArgument( |
68 | "input must have at least k columns. Had " , |
69 | input_in.dim_size(input_in.dims() - 1), ", needed " , k)); |
70 | |
71 | const auto& input = input_in.flat_inner_dims<T>(); |
72 | |
73 | const int64_t num_rows = input.dimension(0); // generally batch_size |
74 | const int64_t num_cols = input.dimension(1); |
75 | OP_REQUIRES( |
76 | context, num_rows <= std::numeric_limits<int32>::max(), |
77 | errors::InvalidArgument( |
78 | "First dimension of flattened input must be <= INT_MAX, got " , |
79 | num_rows)); |
80 | OP_REQUIRES( |
81 | context, num_cols <= std::numeric_limits<int32>::max(), |
82 | errors::InvalidArgument( |
83 | "Second dimension of flattened input must be <= INT_MAX, got " , |
84 | num_cols)); |
85 | |
86 | TensorShape output_shape = input_in.shape(); |
87 | output_shape.set_dim(input_in.dims() - 1, k); |
88 | Tensor* values_out = nullptr; |
89 | OP_REQUIRES_OK(context, |
90 | context->allocate_output(0, output_shape, &values_out)); |
91 | Tensor* indices_out = nullptr; |
92 | OP_REQUIRES_OK(context, |
93 | context->allocate_output(1, output_shape, &indices_out)); |
94 | |
95 | // Nothing to do for top-nothing or over nothing. |
96 | if (k == 0 || num_rows == 0) return; |
97 | |
98 | auto values = values_out->flat_inner_dims<T>(); |
99 | auto indices = indices_out->flat_inner_dims<int32>(); |
100 | Status s = functor::TopKFunctor<Device, T>::Compute( |
101 | context, sorted_, k, input, num_rows, num_cols, values, indices); |
102 | OP_REQUIRES_OK(context, s); |
103 | } |
104 | |
105 | private: |
106 | int k_; |
107 | bool sorted_; |
108 | }; |
109 | |
110 | namespace functor { |
111 | |
112 | template <typename T> |
113 | struct TopKFunctor<CPUDevice, T> { |
114 | static EIGEN_ALWAYS_INLINE Status Compute( |
115 | OpKernelContext* context, bool sorted, int k, |
116 | const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows, |
117 | const int64_t num_cols, typename TTypes<T, 2>::Tensor values, |
118 | typename TTypes<int, 2>::Tensor indices) { |
119 | const CPUDevice& d = context->eigen_device<CPUDevice>(); |
120 | |
121 | // Special case for k == 1. |
122 | if (k == 1) { |
123 | typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols; |
124 | typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one; |
125 | rows_by_one.set(0, num_rows); |
126 | |
127 | values.device(d) = |
128 | input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one); |
129 | // Get the indices of the maximum values. |
130 | for (int r = 0; r < num_rows; ++r) { |
131 | indices(r, 0) = 0; |
132 | for (int c = 0; c < num_cols; ++c) { |
133 | if (values(r, 0) == input(r, c)) { |
134 | indices(r, 0) = c; |
135 | break; |
136 | } |
137 | } |
138 | values(r, 0) = input(r, indices(r, 0)); |
139 | } |
140 | |
141 | return OkStatus(); |
142 | } |
143 | |
144 | auto SortIndices = [&](int64_t start_batch, int64_t limit_batch) { |
145 | for (int32_t b = start_batch; b < limit_batch; ++b) { |
146 | const T* input_data = &input(b, 0); |
147 | const auto stable_comp = [input_data](const int32_t a, |
148 | const int32_t b) { |
149 | if (input_data[b] < input_data[a]) { |
150 | return true; |
151 | } else if (input_data[b] > input_data[a]) { |
152 | return false; |
153 | } else { |
154 | return a < b; |
155 | } |
156 | }; |
157 | const auto comp = [input_data](const int32_t a, const int32_t b) { |
158 | return input_data[b] < input_data[a]; |
159 | }; |
160 | // TODO(ebrevdo): For large k < num_cols, instead of using |
161 | // TopN, it may be faster to create a temporary vector of |
162 | // values 0..num_cols - 1 and then use std::partial_sort_copy |
163 | // of this into indices. Choosing the appropriate minimum k or |
164 | // ratio of k/num_cols will require some experimentation. |
165 | if (k == num_cols) { |
166 | auto* begin = &indices(b, 0); |
167 | auto* end = &indices(b, k); |
168 | // Set the initial array of indices 0 ... k - 1. |
169 | std::iota(begin, end, 0); |
170 | // We want an in-place sort, but we can cheat because we're sorting |
171 | // indices that started out sorted. First, do a std::sort, which |
172 | // is notably faster than std::stable_sort. |
173 | std::sort(begin, end, comp); |
174 | // Then, for runs of adjacent elements that were equal, sort the |
175 | // indices in those runs in increasing order. |
176 | for (auto* run_begin = begin; run_begin != end;) { |
177 | auto* run_end = run_begin + 1; |
178 | if (run_end == end) break; |
179 | if (input_data[*run_begin] == input_data[*run_end]) { |
180 | while (++run_end != end) { |
181 | if (input_data[*run_begin] != input_data[*run_end]) break; |
182 | } |
183 | std::sort(run_begin, run_end); |
184 | } |
185 | run_begin = run_end; |
186 | } |
187 | } else { |
188 | // Use the TopN heap object to sort. |
189 | gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp); |
190 | filter.reserve(num_cols); |
191 | for (int32_t c = 0; c < num_cols; ++c) { |
192 | filter.push(c); |
193 | } |
194 | |
195 | int32_t i = 0; |
196 | if (sorted) { |
197 | std::unique_ptr<std::vector<int32>> top_k(filter.Extract()); |
198 | for (auto top_k_it = top_k->begin(); top_k_it != top_k->end(); |
199 | ++top_k_it, ++i) { |
200 | indices(b, i) = *top_k_it; |
201 | } |
202 | } else { |
203 | for (auto top_k_it = filter.unsorted_begin(); |
204 | top_k_it != filter.unsorted_end(); ++top_k_it, ++i) { |
205 | indices(b, i) = *top_k_it; |
206 | } |
207 | } |
208 | } |
209 | // Now that the indices are sorted, copy the values over in |
210 | // sorted order. |
211 | std::transform( |
212 | &indices(b, 0), &indices(b, k), &values(b, 0), |
213 | [b, &input](const int32_t loc) { return input(b, loc); }); |
214 | } // for (int32 b = ... |
215 | }; |
216 | |
217 | // Guesstimate of cost; 4*N*log(K) where N == num_cols. |
218 | // If K == N, assume the cost is N*log(K + 1). |
219 | const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() + |
220 | Eigen::TensorOpCost::AddCost<T>(); |
221 | const double base_cost = |
222 | cmp_cost * |
223 | static_cast<double>(num_cols * |
224 | Eigen::numext::log2(static_cast<float>(k + 1))); |
225 | const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost; |
226 | const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>(); |
227 | const double total_cost = sort_cost + copy_cost; |
228 | const int64_t final_cost = (total_cost >= static_cast<double>(kint64max)) |
229 | ? kint64max |
230 | : static_cast<int64_t>(total_cost); |
231 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
232 | Shard(worker_threads.num_threads, worker_threads.workers, num_rows, |
233 | final_cost, SortIndices); |
234 | |
235 | return OkStatus(); |
236 | } |
237 | }; |
238 | |
239 | } // namespace functor |
240 | |
241 | #define REGISTER_KERNELS_NAME(name, type) \ |
242 | REGISTER_KERNEL_BUILDER( \ |
243 | Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
244 | TopK<CPUDevice, type>) |
245 | |
246 | #define REGISTER_KERNELS(type) \ |
247 | REGISTER_KERNELS_NAME(TopK, type); \ |
248 | REGISTER_KERNELS_NAME(TopKV2, type) |
249 | |
250 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
251 | #undef REGISTER_KERNELS_NAME |
252 | #undef REGISTER_KERNELS |
253 | |
254 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
255 | |
256 | namespace functor { |
257 | #define DECLARE_GPU_SPEC(T) \ |
258 | template <> \ |
259 | Status TopKFunctor<GPUDevice, T>::Compute( \ |
260 | OpKernelContext* context, bool sorted, int k, \ |
261 | const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows, \ |
262 | const int64_t num_cols, typename TTypes<T, 2>::Tensor values, \ |
263 | typename TTypes<int, 2>::Tensor indices); \ |
264 | extern template struct functor::TopKFunctor<GPUDevice, T>; |
265 | |
266 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); |
267 | TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); |
268 | |
269 | #undef DECLARE_GPU_SPEC |
270 | |
271 | } // namespace functor |
272 | |
273 | #define REGISTER_KERNELS(type) \ |
274 | REGISTER_KERNEL_BUILDER( \ |
275 | Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
276 | TopK<GPUDevice, type>) \ |
277 | REGISTER_KERNEL_BUILDER(Name("TopKV2") \ |
278 | .Device(DEVICE_GPU) \ |
279 | .TypeConstraint<type>("T") \ |
280 | .HostMemory("k"), \ |
281 | TopK<GPUDevice, type>) |
282 | |
283 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); |
284 | TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); |
285 | #undef REGISTER_KERNELS |
286 | |
287 | #endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
288 | |
289 | } // end namespace tensorflow |
290 | |