1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/searchsorted_op.h" |
19 | |
20 | #include "tensorflow/core/framework/bounds_check.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | #include "tensorflow/core/kernels/fill_functor.h" |
26 | #include "tensorflow/core/lib/core/bits.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/platform/threadpool.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace tensorflow { |
32 | typedef Eigen::ThreadPoolDevice CPUDevice; |
33 | typedef Eigen::GpuDevice GPUDevice; |
34 | |
35 | namespace functor { |
36 | template <typename T, typename OutType> |
37 | struct UpperBoundFunctor<CPUDevice, T, OutType> { |
38 | static Status Compute(OpKernelContext* context, |
39 | const typename TTypes<T, 1>::ConstTensor& sorted_inputs, |
40 | const typename TTypes<T, 1>::ConstTensor& values, |
41 | int batch_size, int num_inputs, int num_values, |
42 | typename TTypes<OutType, 1>::Tensor* output) { |
43 | auto work_fn = [&](int64_t first, int64_t last) { |
44 | for (int b = 0; b < batch_size; ++b) { |
45 | const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs; |
46 | OutType* output_ptr = output->data() + b * num_values; |
47 | for (int i = first; i < last; ++i) { |
48 | output_ptr[i] = std::upper_bound(sorted_inputs_ptr, |
49 | sorted_inputs_ptr + num_inputs, |
50 | values(i + b * num_values)) - |
51 | sorted_inputs_ptr; |
52 | } |
53 | } |
54 | }; |
55 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
56 | thread::ThreadPool* thread_pool = worker_threads.workers; |
57 | const float kCostMultiplier = 1.f; // Can be tuned to minimize overhead |
58 | int64_t cost_per_unit = |
59 | kCostMultiplier * batch_size * Log2Ceiling(num_inputs); |
60 | thread_pool->ParallelFor(num_values, cost_per_unit, work_fn); |
61 | return OkStatus(); |
62 | } |
63 | }; |
64 | |
65 | template <typename T, typename OutType> |
66 | struct LowerBoundFunctor<CPUDevice, T, OutType> { |
67 | static Status Compute(OpKernelContext* context, |
68 | const typename TTypes<T, 1>::ConstTensor& sorted_inputs, |
69 | const typename TTypes<T, 1>::ConstTensor& values, |
70 | int batch_size, int num_inputs, int num_values, |
71 | typename TTypes<OutType, 1>::Tensor* output) { |
72 | auto work_fn = [&](int64_t first, int64_t last) { |
73 | for (int b = 0; b < batch_size; ++b) { |
74 | const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs; |
75 | OutType* output_ptr = output->data() + b * num_values; |
76 | for (int i = first; i < last; ++i) { |
77 | output_ptr[i] = std::lower_bound(sorted_inputs_ptr, |
78 | sorted_inputs_ptr + num_inputs, |
79 | values(i + b * num_values)) - |
80 | sorted_inputs_ptr; |
81 | } |
82 | } |
83 | }; |
84 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
85 | thread::ThreadPool* thread_pool = worker_threads.workers; |
86 | const float kCostMultiplier = 1.f; // Can be tuned to minimize overhead |
87 | int64_t cost_per_unit = |
88 | kCostMultiplier * batch_size * Log2Ceiling(num_inputs); |
89 | thread_pool->ParallelFor(num_values, cost_per_unit, work_fn); |
90 | return OkStatus(); |
91 | } |
92 | }; |
93 | } // namespace functor |
94 | |
95 | template <typename Device, typename T, typename OutType> |
96 | class UpperBoundOp : public OpKernel { |
97 | public: |
98 | explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
99 | |
100 | void Compute(OpKernelContext* ctx) override { |
101 | const Tensor& sorted_inputs_t = ctx->input(0); |
102 | const Tensor& values_t = ctx->input(1); |
103 | |
104 | // inputs must be at least a matrix |
105 | OP_REQUIRES( |
106 | ctx, sorted_inputs_t.shape().dims() >= 2, |
107 | errors::InvalidArgument("sorted input argument must be a matrix" )); |
108 | // must have same batch dim_size for both |
109 | OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0), |
110 | Status(error::INVALID_ARGUMENT, |
111 | "Leading dim_size of both tensors must match." )); |
112 | |
113 | // this is required because we do indexing in int32 on the GPU |
114 | OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(), |
115 | Status(error::INVALID_ARGUMENT, |
116 | "values tensor size must less than INT_MAX" )); |
117 | |
118 | Tensor* output_t; |
119 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t)); |
120 | |
121 | if (output_t->dtype() == DT_INT32) { |
122 | OP_REQUIRES(ctx, |
123 | FastBoundsCheck(sorted_inputs_t.dim_size(1), |
124 | std::numeric_limits<int>::max()), |
125 | errors::InvalidArgument("trailing dim_size must less than " |
126 | "INT_MAX for int32 output type, was " , |
127 | sorted_inputs_t.dim_size(1))); |
128 | } |
129 | |
130 | auto output = output_t->template flat<OutType>(); |
131 | const auto sorted_inputs = sorted_inputs_t.template flat<T>(); |
132 | const auto values = values_t.template flat<T>(); |
133 | |
134 | // For empty inputs, all values will be placed at the zeroth position. |
135 | if (sorted_inputs.size() == 0) { |
136 | functor::SetZeroFunctor<Device, OutType> set_zero; |
137 | set_zero(ctx->eigen_device<Device>(), output); |
138 | return; |
139 | } |
140 | |
141 | OP_REQUIRES_OK( |
142 | ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute( |
143 | ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0), |
144 | sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output)); |
145 | } |
146 | }; |
147 | |
148 | template <typename Device, typename T, typename OutType> |
149 | class LowerBoundOp : public OpKernel { |
150 | public: |
151 | explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
152 | |
153 | void Compute(OpKernelContext* ctx) override { |
154 | const Tensor& sorted_inputs_t = ctx->input(0); |
155 | const Tensor& values_t = ctx->input(1); |
156 | |
157 | // inputs must be at least a matrix |
158 | OP_REQUIRES( |
159 | ctx, sorted_inputs_t.shape().dims() >= 2, |
160 | errors::InvalidArgument("sorted input argument must be a matrix" )); |
161 | // must have same batch dim_size for both |
162 | OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0), |
163 | Status(error::INVALID_ARGUMENT, |
164 | "Leading dim_size of both tensors must match." )); |
165 | |
166 | // this is required because we do indexing in int32 on the GPU |
167 | OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(), |
168 | Status(error::INVALID_ARGUMENT, |
169 | "values tensor size must less than INT_MAX" )); |
170 | |
171 | Tensor* output_t; |
172 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t)); |
173 | |
174 | if (output_t->dtype() == DT_INT32) { |
175 | OP_REQUIRES(ctx, |
176 | FastBoundsCheck(sorted_inputs_t.dim_size(1), |
177 | std::numeric_limits<int>::max()), |
178 | errors::InvalidArgument("trailing dim_size must less than " |
179 | "INT_MAX for int32 output type, was " , |
180 | sorted_inputs_t.dim_size(1))); |
181 | } |
182 | |
183 | auto output = output_t->template flat<OutType>(); |
184 | const auto sorted_inputs = sorted_inputs_t.template flat<T>(); |
185 | const auto values = values_t.template flat<T>(); |
186 | |
187 | // For empty inputs, all values will be placed at the zeroth position. |
188 | if (sorted_inputs.size() == 0) { |
189 | functor::SetZeroFunctor<Device, OutType> set_zero; |
190 | set_zero(ctx->eigen_device<Device>(), output); |
191 | return; |
192 | } |
193 | |
194 | OP_REQUIRES_OK( |
195 | ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute( |
196 | ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0), |
197 | sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output)); |
198 | } |
199 | }; |
200 | |
201 | #define REGISTER_KERNELS(type) \ |
202 | REGISTER_KERNEL_BUILDER(Name("UpperBound") \ |
203 | .Device(DEVICE_CPU) \ |
204 | .TypeConstraint<type>("T") \ |
205 | .TypeConstraint<int32>("out_type"), \ |
206 | UpperBoundOp<CPUDevice, type, int32>); |
207 | |
208 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
209 | #undef REGISTER_KERNELS |
210 | |
211 | #define REGISTER_KERNELS(type) \ |
212 | REGISTER_KERNEL_BUILDER(Name("UpperBound") \ |
213 | .Device(DEVICE_CPU) \ |
214 | .TypeConstraint<type>("T") \ |
215 | .TypeConstraint<int64_t>("out_type"), \ |
216 | UpperBoundOp<CPUDevice, type, int64>); |
217 | |
218 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
219 | #undef REGISTER_KERNELS |
220 | |
221 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
222 | |
223 | #define REGISTER_KERNELS(type) \ |
224 | REGISTER_KERNEL_BUILDER(Name("UpperBound") \ |
225 | .Device(DEVICE_GPU) \ |
226 | .TypeConstraint<type>("T") \ |
227 | .TypeConstraint<int32>("out_type"), \ |
228 | UpperBoundOp<GPUDevice, type, int32>); |
229 | |
230 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
231 | #undef REGISTER_KERNELS |
232 | |
233 | #define REGISTER_KERNELS(type) \ |
234 | REGISTER_KERNEL_BUILDER(Name("UpperBound") \ |
235 | .Device(DEVICE_GPU) \ |
236 | .TypeConstraint<type>("T") \ |
237 | .TypeConstraint<int64_t>("out_type"), \ |
238 | UpperBoundOp<GPUDevice, type, int64>); |
239 | |
240 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
241 | #undef REGISTER_KERNELS |
242 | |
243 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
244 | |
245 | #define REGISTER_KERNELS(type) \ |
246 | REGISTER_KERNEL_BUILDER(Name("LowerBound") \ |
247 | .Device(DEVICE_CPU) \ |
248 | .TypeConstraint<type>("T") \ |
249 | .TypeConstraint<int32>("out_type"), \ |
250 | LowerBoundOp<CPUDevice, type, int32>); |
251 | |
252 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
253 | #undef REGISTER_KERNELS |
254 | |
255 | #define REGISTER_KERNELS(type) \ |
256 | REGISTER_KERNEL_BUILDER(Name("LowerBound") \ |
257 | .Device(DEVICE_CPU) \ |
258 | .TypeConstraint<type>("T") \ |
259 | .TypeConstraint<int64_t>("out_type"), \ |
260 | LowerBoundOp<CPUDevice, type, int64>); |
261 | |
262 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
263 | #undef REGISTER_KERNELS |
264 | |
265 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
266 | |
267 | #define REGISTER_KERNELS(type) \ |
268 | REGISTER_KERNEL_BUILDER(Name("LowerBound") \ |
269 | .Device(DEVICE_GPU) \ |
270 | .TypeConstraint<type>("T") \ |
271 | .TypeConstraint<int32>("out_type"), \ |
272 | LowerBoundOp<GPUDevice, type, int32>); |
273 | |
274 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
275 | #undef REGISTER_KERNELS |
276 | |
277 | #define REGISTER_KERNELS(type) \ |
278 | REGISTER_KERNEL_BUILDER(Name("LowerBound") \ |
279 | .Device(DEVICE_GPU) \ |
280 | .TypeConstraint<type>("T") \ |
281 | .TypeConstraint<int64_t>("out_type"), \ |
282 | LowerBoundOp<GPUDevice, type, int64>); |
283 | |
284 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); |
285 | #undef REGISTER_KERNELS |
286 | |
287 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
288 | } // namespace tensorflow |
289 | |