1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #include "tensorflow/core/kernels/bucketize_op.h" |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/framework/register_types.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/tensor_shape.h" |
23 | #include "tensorflow/core/platform/logging.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | using CPUDevice = Eigen::ThreadPoolDevice; |
29 | using GPUDevice = Eigen::GpuDevice; |
30 | |
31 | namespace functor { |
32 | |
33 | template <typename T> |
34 | struct BucketizeFunctor<CPUDevice, T> { |
35 | // PRECONDITION: boundaries_vector must be sorted. |
36 | static Status Compute(OpKernelContext* context, |
37 | const typename TTypes<T, 1>::ConstTensor& input, |
38 | const std::vector<float>& boundaries_vector, |
39 | typename TTypes<int32, 1>::Tensor& output) { |
40 | const int N = input.size(); |
41 | for (int i = 0; i < N; i++) { |
42 | auto first_bigger_it = std::upper_bound( |
43 | boundaries_vector.begin(), boundaries_vector.end(), input(i)); |
44 | output(i) = first_bigger_it - boundaries_vector.begin(); |
45 | } |
46 | |
47 | return OkStatus(); |
48 | } |
49 | }; |
50 | |
51 | } // namespace functor |
52 | |
53 | template <typename Device, typename T> |
54 | class BucketizeOp : public OpKernel { |
55 | public: |
56 | explicit BucketizeOp(OpKernelConstruction* context) : OpKernel(context) { |
57 | OP_REQUIRES_OK(context, context->GetAttr("boundaries" , &boundaries_)); |
58 | OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), |
59 | errors::InvalidArgument("Expected sorted boundaries" )); |
60 | } |
61 | |
62 | void Compute(OpKernelContext* context) override { |
63 | const Tensor& input_tensor = context->input(0); |
64 | const auto input = input_tensor.flat<T>(); |
65 | |
66 | Tensor* output_tensor = nullptr; |
67 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), |
68 | &output_tensor)); |
69 | auto output = output_tensor->template flat<int32>(); |
70 | if (input.size() > 0) { |
71 | OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute( |
72 | context, input, boundaries_, output)); |
73 | } |
74 | } |
75 | |
76 | private: |
77 | std::vector<float> boundaries_; |
78 | }; |
79 | |
80 | #define REGISTER_KERNEL(T) \ |
81 | REGISTER_KERNEL_BUILDER( \ |
82 | Name("Bucketize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
83 | BucketizeOp<CPUDevice, T>); |
84 | |
85 | REGISTER_KERNEL(int32); |
86 | REGISTER_KERNEL(int64_t); |
87 | REGISTER_KERNEL(float); |
88 | REGISTER_KERNEL(double); |
89 | #undef REGISTER_KERNEL |
90 | |
91 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
92 | #define REGISTER_KERNEL(T) \ |
93 | REGISTER_KERNEL_BUILDER( \ |
94 | Name("Bucketize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
95 | BucketizeOp<GPUDevice, T>); |
96 | |
97 | REGISTER_KERNEL(int32); |
98 | REGISTER_KERNEL(int64_t); |
99 | REGISTER_KERNEL(float); |
100 | REGISTER_KERNEL(double); |
101 | #undef REGISTER_KERNEL |
102 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
103 | |
104 | } // namespace tensorflow |
105 | |