1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/searchsorted_op.h"
19
20#include "tensorflow/core/framework/bounds_check.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25#include "tensorflow/core/kernels/fill_functor.h"
26#include "tensorflow/core/lib/core/bits.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/platform/threadpool.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace tensorflow {
32typedef Eigen::ThreadPoolDevice CPUDevice;
33typedef Eigen::GpuDevice GPUDevice;
34
35namespace functor {
36template <typename T, typename OutType>
37struct UpperBoundFunctor<CPUDevice, T, OutType> {
38 static Status Compute(OpKernelContext* context,
39 const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
40 const typename TTypes<T, 1>::ConstTensor& values,
41 int batch_size, int num_inputs, int num_values,
42 typename TTypes<OutType, 1>::Tensor* output) {
43 auto work_fn = [&](int64_t first, int64_t last) {
44 for (int b = 0; b < batch_size; ++b) {
45 const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
46 OutType* output_ptr = output->data() + b * num_values;
47 for (int i = first; i < last; ++i) {
48 output_ptr[i] = std::upper_bound(sorted_inputs_ptr,
49 sorted_inputs_ptr + num_inputs,
50 values(i + b * num_values)) -
51 sorted_inputs_ptr;
52 }
53 }
54 };
55 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
56 thread::ThreadPool* thread_pool = worker_threads.workers;
57 const float kCostMultiplier = 1.f; // Can be tuned to minimize overhead
58 int64_t cost_per_unit =
59 kCostMultiplier * batch_size * Log2Ceiling(num_inputs);
60 thread_pool->ParallelFor(num_values, cost_per_unit, work_fn);
61 return OkStatus();
62 }
63};
64
65template <typename T, typename OutType>
66struct LowerBoundFunctor<CPUDevice, T, OutType> {
67 static Status Compute(OpKernelContext* context,
68 const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
69 const typename TTypes<T, 1>::ConstTensor& values,
70 int batch_size, int num_inputs, int num_values,
71 typename TTypes<OutType, 1>::Tensor* output) {
72 auto work_fn = [&](int64_t first, int64_t last) {
73 for (int b = 0; b < batch_size; ++b) {
74 const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
75 OutType* output_ptr = output->data() + b * num_values;
76 for (int i = first; i < last; ++i) {
77 output_ptr[i] = std::lower_bound(sorted_inputs_ptr,
78 sorted_inputs_ptr + num_inputs,
79 values(i + b * num_values)) -
80 sorted_inputs_ptr;
81 }
82 }
83 };
84 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
85 thread::ThreadPool* thread_pool = worker_threads.workers;
86 const float kCostMultiplier = 1.f; // Can be tuned to minimize overhead
87 int64_t cost_per_unit =
88 kCostMultiplier * batch_size * Log2Ceiling(num_inputs);
89 thread_pool->ParallelFor(num_values, cost_per_unit, work_fn);
90 return OkStatus();
91 }
92};
93} // namespace functor
94
95template <typename Device, typename T, typename OutType>
96class UpperBoundOp : public OpKernel {
97 public:
98 explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
99
100 void Compute(OpKernelContext* ctx) override {
101 const Tensor& sorted_inputs_t = ctx->input(0);
102 const Tensor& values_t = ctx->input(1);
103
104 // inputs must be at least a matrix
105 OP_REQUIRES(
106 ctx, sorted_inputs_t.shape().dims() >= 2,
107 errors::InvalidArgument("sorted input argument must be a matrix"));
108 // must have same batch dim_size for both
109 OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
110 Status(error::INVALID_ARGUMENT,
111 "Leading dim_size of both tensors must match."));
112
113 // this is required because we do indexing in int32 on the GPU
114 OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
115 Status(error::INVALID_ARGUMENT,
116 "values tensor size must less than INT_MAX"));
117
118 Tensor* output_t;
119 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
120
121 if (output_t->dtype() == DT_INT32) {
122 OP_REQUIRES(ctx,
123 FastBoundsCheck(sorted_inputs_t.dim_size(1),
124 std::numeric_limits<int>::max()),
125 errors::InvalidArgument("trailing dim_size must less than "
126 "INT_MAX for int32 output type, was ",
127 sorted_inputs_t.dim_size(1)));
128 }
129
130 auto output = output_t->template flat<OutType>();
131 const auto sorted_inputs = sorted_inputs_t.template flat<T>();
132 const auto values = values_t.template flat<T>();
133
134 // For empty inputs, all values will be placed at the zeroth position.
135 if (sorted_inputs.size() == 0) {
136 functor::SetZeroFunctor<Device, OutType> set_zero;
137 set_zero(ctx->eigen_device<Device>(), output);
138 return;
139 }
140
141 OP_REQUIRES_OK(
142 ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
143 ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
144 sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
145 }
146};
147
148template <typename Device, typename T, typename OutType>
149class LowerBoundOp : public OpKernel {
150 public:
151 explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
152
153 void Compute(OpKernelContext* ctx) override {
154 const Tensor& sorted_inputs_t = ctx->input(0);
155 const Tensor& values_t = ctx->input(1);
156
157 // inputs must be at least a matrix
158 OP_REQUIRES(
159 ctx, sorted_inputs_t.shape().dims() >= 2,
160 errors::InvalidArgument("sorted input argument must be a matrix"));
161 // must have same batch dim_size for both
162 OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
163 Status(error::INVALID_ARGUMENT,
164 "Leading dim_size of both tensors must match."));
165
166 // this is required because we do indexing in int32 on the GPU
167 OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
168 Status(error::INVALID_ARGUMENT,
169 "values tensor size must less than INT_MAX"));
170
171 Tensor* output_t;
172 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
173
174 if (output_t->dtype() == DT_INT32) {
175 OP_REQUIRES(ctx,
176 FastBoundsCheck(sorted_inputs_t.dim_size(1),
177 std::numeric_limits<int>::max()),
178 errors::InvalidArgument("trailing dim_size must less than "
179 "INT_MAX for int32 output type, was ",
180 sorted_inputs_t.dim_size(1)));
181 }
182
183 auto output = output_t->template flat<OutType>();
184 const auto sorted_inputs = sorted_inputs_t.template flat<T>();
185 const auto values = values_t.template flat<T>();
186
187 // For empty inputs, all values will be placed at the zeroth position.
188 if (sorted_inputs.size() == 0) {
189 functor::SetZeroFunctor<Device, OutType> set_zero;
190 set_zero(ctx->eigen_device<Device>(), output);
191 return;
192 }
193
194 OP_REQUIRES_OK(
195 ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
196 ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
197 sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
198 }
199};
200
201#define REGISTER_KERNELS(type) \
202 REGISTER_KERNEL_BUILDER(Name("UpperBound") \
203 .Device(DEVICE_CPU) \
204 .TypeConstraint<type>("T") \
205 .TypeConstraint<int32>("out_type"), \
206 UpperBoundOp<CPUDevice, type, int32>);
207
208TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
209#undef REGISTER_KERNELS
210
211#define REGISTER_KERNELS(type) \
212 REGISTER_KERNEL_BUILDER(Name("UpperBound") \
213 .Device(DEVICE_CPU) \
214 .TypeConstraint<type>("T") \
215 .TypeConstraint<int64_t>("out_type"), \
216 UpperBoundOp<CPUDevice, type, int64>);
217
218TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
219#undef REGISTER_KERNELS
220
221#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
222
223#define REGISTER_KERNELS(type) \
224 REGISTER_KERNEL_BUILDER(Name("UpperBound") \
225 .Device(DEVICE_GPU) \
226 .TypeConstraint<type>("T") \
227 .TypeConstraint<int32>("out_type"), \
228 UpperBoundOp<GPUDevice, type, int32>);
229
230TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
231#undef REGISTER_KERNELS
232
233#define REGISTER_KERNELS(type) \
234 REGISTER_KERNEL_BUILDER(Name("UpperBound") \
235 .Device(DEVICE_GPU) \
236 .TypeConstraint<type>("T") \
237 .TypeConstraint<int64_t>("out_type"), \
238 UpperBoundOp<GPUDevice, type, int64>);
239
240TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
241#undef REGISTER_KERNELS
242
243#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
244
245#define REGISTER_KERNELS(type) \
246 REGISTER_KERNEL_BUILDER(Name("LowerBound") \
247 .Device(DEVICE_CPU) \
248 .TypeConstraint<type>("T") \
249 .TypeConstraint<int32>("out_type"), \
250 LowerBoundOp<CPUDevice, type, int32>);
251
252TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
253#undef REGISTER_KERNELS
254
255#define REGISTER_KERNELS(type) \
256 REGISTER_KERNEL_BUILDER(Name("LowerBound") \
257 .Device(DEVICE_CPU) \
258 .TypeConstraint<type>("T") \
259 .TypeConstraint<int64_t>("out_type"), \
260 LowerBoundOp<CPUDevice, type, int64>);
261
262TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
263#undef REGISTER_KERNELS
264
265#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266
267#define REGISTER_KERNELS(type) \
268 REGISTER_KERNEL_BUILDER(Name("LowerBound") \
269 .Device(DEVICE_GPU) \
270 .TypeConstraint<type>("T") \
271 .TypeConstraint<int32>("out_type"), \
272 LowerBoundOp<GPUDevice, type, int32>);
273
274TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
275#undef REGISTER_KERNELS
276
277#define REGISTER_KERNELS(type) \
278 REGISTER_KERNEL_BUILDER(Name("LowerBound") \
279 .Device(DEVICE_GPU) \
280 .TypeConstraint<type>("T") \
281 .TypeConstraint<int64_t>("out_type"), \
282 LowerBoundOp<GPUDevice, type, int64>);
283
284TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
285#undef REGISTER_KERNELS
286
287#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
288} // namespace tensorflow
289