1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | #include "tensorflow/core/kernels/nth_element_op.h" |
18 | |
19 | #include <algorithm> |
20 | #include <iostream> |
21 | #include <vector> |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/util/work_sharder.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | typedef Eigen::ThreadPoolDevice CPUDevice; |
32 | |
33 | template <typename Device, typename T> |
34 | class NthElementOp : public OpKernel { |
35 | public: |
36 | explicit NthElementOp(OpKernelConstruction* context) : OpKernel(context) { |
37 | OP_REQUIRES_OK(context, context->GetAttr("reverse" , &reverse_)); |
38 | } |
39 | |
40 | void Compute(OpKernelContext* context) override { |
41 | // The second args is N, which must be a positive scalar. |
42 | const auto& n_in = context->input(1); |
43 | OP_REQUIRES( |
44 | context, TensorShapeUtils::IsScalar(n_in.shape()), |
45 | errors::InvalidArgument("N must be scalar but has rank " , n_in.dims())); |
46 | int n = n_in.scalar<int32>()(); |
47 | OP_REQUIRES(context, n >= 0, |
48 | errors::InvalidArgument("n must be non-negative but is " , n)); |
49 | |
50 | // The first args is input tensor, which must have 1 dimension at least. |
51 | const Tensor& input_in = context->input(0); |
52 | const int num_dims = input_in.dims(); |
53 | OP_REQUIRES(context, num_dims >= 1, |
54 | errors::InvalidArgument( |
55 | "Input must be at least rank 1 but is rank " , num_dims)); |
56 | // The last dimension of input tensor must be greater than N. |
57 | OP_REQUIRES( |
58 | context, input_in.dim_size(num_dims - 1) > n, |
59 | errors::InvalidArgument("Input must have last dimension > n = " , n)); |
60 | |
61 | // std::nth_element only support the nth-smallest selection. |
62 | if (reverse_) { |
63 | n = input_in.dim_size(num_dims - 1) - n - 1; |
64 | } |
65 | |
66 | // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1]. |
67 | TensorShape out_shape; |
68 | for (int i = 0; i < num_dims - 1; ++i) { |
69 | out_shape.AddDim(input_in.dim_size(i)); |
70 | } |
71 | Tensor* output_tensor = nullptr; |
72 | OP_REQUIRES_OK(context, |
73 | context->allocate_output(0, out_shape, &output_tensor)); |
74 | |
75 | functor::NthElementFunctor<Device, T> nthElementFunc; |
76 | nthElementFunc(context, input_in, *output_tensor, n, reverse_); |
77 | } |
78 | |
79 | private: |
80 | bool reverse_; |
81 | }; |
82 | |
83 | namespace functor { |
84 | |
85 | template <typename T> |
86 | struct NthElementFunctor<CPUDevice, T> { |
87 | void operator()(OpKernelContext* context, const Tensor& input_tensor, |
88 | Tensor& output_tensor, int n, bool reverse) { |
89 | const T* input = input_tensor.flat<T>().data(); |
90 | T* output = output_tensor.flat<T>().data(); |
91 | |
92 | // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1], |
93 | // then num_rows = d1*d2...dk-1, last_dim = dk. |
94 | const int num_rows = output_tensor.NumElements(); |
95 | const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1); |
96 | |
97 | // Allocate each row to different shard. |
98 | auto SubNthElement = [&, input, output, last_dim, n](int64_t start, |
99 | int64_t limit) { |
100 | // std::nth_element would rearrange the array, so we need a new buffer. |
101 | std::vector<T> buf(last_dim); |
102 | |
103 | for (int b = start; b < limit; ++b) { |
104 | // Copy from one row of elements to buffer |
105 | const T* input_start = input + b * last_dim; |
106 | const T* input_end = input + (b + 1) * last_dim; |
107 | std::copy(input_start, input_end, buf.begin()); |
108 | |
109 | std::nth_element(buf.begin(), buf.begin() + n, buf.end()); |
110 | // The element placed in the nth position is exactly the element that |
111 | // would occur in this position if the range was fully sorted. |
112 | output[b] = buf[n]; |
113 | } |
114 | }; |
115 | |
116 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
117 | // The average time complexity of partition-based nth_element (BFPRT) is |
118 | // O(n), although the worst time complexity could be O(n^2). Here, 20 is a |
119 | // empirical factor of cost_per_unit. |
120 | Shard(worker_threads.num_threads, worker_threads.workers, num_rows, |
121 | 20 * last_dim, SubNthElement); |
122 | } |
123 | }; |
124 | |
125 | } // namespace functor |
126 | |
127 | #define REGISTER_NTHOP(T) \ |
128 | REGISTER_KERNEL_BUILDER( \ |
129 | Name("NthElement").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
130 | NthElementOp<CPUDevice, T>) |
131 | |
132 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_NTHOP); |
133 | #undef REGISTER_NTHOP |
134 | |
135 | } // end namespace tensorflow |
136 | |