1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17#include "tensorflow/core/kernels/nth_element_op.h"
18
19#include <algorithm>
20#include <iostream>
21#include <vector>
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/util/work_sharder.h"
28
29namespace tensorflow {
30
31typedef Eigen::ThreadPoolDevice CPUDevice;
32
33template <typename Device, typename T>
34class NthElementOp : public OpKernel {
35 public:
36 explicit NthElementOp(OpKernelConstruction* context) : OpKernel(context) {
37 OP_REQUIRES_OK(context, context->GetAttr("reverse", &reverse_));
38 }
39
40 void Compute(OpKernelContext* context) override {
41 // The second args is N, which must be a positive scalar.
42 const auto& n_in = context->input(1);
43 OP_REQUIRES(
44 context, TensorShapeUtils::IsScalar(n_in.shape()),
45 errors::InvalidArgument("N must be scalar but has rank ", n_in.dims()));
46 int n = n_in.scalar<int32>()();
47 OP_REQUIRES(context, n >= 0,
48 errors::InvalidArgument("n must be non-negative but is ", n));
49
50 // The first args is input tensor, which must have 1 dimension at least.
51 const Tensor& input_in = context->input(0);
52 const int num_dims = input_in.dims();
53 OP_REQUIRES(context, num_dims >= 1,
54 errors::InvalidArgument(
55 "Input must be at least rank 1 but is rank ", num_dims));
56 // The last dimension of input tensor must be greater than N.
57 OP_REQUIRES(
58 context, input_in.dim_size(num_dims - 1) > n,
59 errors::InvalidArgument("Input must have last dimension > n = ", n));
60
61 // std::nth_element only support the nth-smallest selection.
62 if (reverse_) {
63 n = input_in.dim_size(num_dims - 1) - n - 1;
64 }
65
66 // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1].
67 TensorShape out_shape;
68 for (int i = 0; i < num_dims - 1; ++i) {
69 out_shape.AddDim(input_in.dim_size(i));
70 }
71 Tensor* output_tensor = nullptr;
72 OP_REQUIRES_OK(context,
73 context->allocate_output(0, out_shape, &output_tensor));
74
75 functor::NthElementFunctor<Device, T> nthElementFunc;
76 nthElementFunc(context, input_in, *output_tensor, n, reverse_);
77 }
78
79 private:
80 bool reverse_;
81};
82
83namespace functor {
84
85template <typename T>
86struct NthElementFunctor<CPUDevice, T> {
87 void operator()(OpKernelContext* context, const Tensor& input_tensor,
88 Tensor& output_tensor, int n, bool reverse) {
89 const T* input = input_tensor.flat<T>().data();
90 T* output = output_tensor.flat<T>().data();
91
92 // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1],
93 // then num_rows = d1*d2...dk-1, last_dim = dk.
94 const int num_rows = output_tensor.NumElements();
95 const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1);
96
97 // Allocate each row to different shard.
98 auto SubNthElement = [&, input, output, last_dim, n](int64_t start,
99 int64_t limit) {
100 // std::nth_element would rearrange the array, so we need a new buffer.
101 std::vector<T> buf(last_dim);
102
103 for (int b = start; b < limit; ++b) {
104 // Copy from one row of elements to buffer
105 const T* input_start = input + b * last_dim;
106 const T* input_end = input + (b + 1) * last_dim;
107 std::copy(input_start, input_end, buf.begin());
108
109 std::nth_element(buf.begin(), buf.begin() + n, buf.end());
110 // The element placed in the nth position is exactly the element that
111 // would occur in this position if the range was fully sorted.
112 output[b] = buf[n];
113 }
114 };
115
116 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
117 // The average time complexity of partition-based nth_element (BFPRT) is
118 // O(n), although the worst time complexity could be O(n^2). Here, 20 is a
119 // empirical factor of cost_per_unit.
120 Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
121 20 * last_dim, SubNthElement);
122 }
123};
124
125} // namespace functor
126
127#define REGISTER_NTHOP(T) \
128 REGISTER_KERNEL_BUILDER( \
129 Name("NthElement").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
130 NthElementOp<CPUDevice, T>)
131
132TF_CALL_REAL_NUMBER_TYPES(REGISTER_NTHOP);
133#undef REGISTER_NTHOP
134
135} // end namespace tensorflow
136