1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/tensor_to_hash_bucket_op.h"
17
18#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19#include "tensorflow/core/framework/register_types.h"
20
21namespace tensorflow {
22
23typedef Eigen::ThreadPoolDevice CPUDevice;
24typedef Eigen::GpuDevice GPUDevice;
25
26template <typename Device, typename T>
27class TensorToHashBucketOp : public OpKernel {
28 public:
29 explicit TensorToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
30 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
31
32 DataType dtype;
33 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype));
34 OP_REQUIRES(ctx,
35 dtype == DT_INT8 || dtype == DT_UINT8 || dtype == DT_INT16 ||
36 dtype == DT_UINT16 || dtype == DT_INT32 ||
37 dtype == DT_UINT32 || dtype == DT_INT64 ||
38 dtype == DT_UINT64,
39 errors::InvalidArgument("TensorToHashBucketOp doesn't support "
40 "datatype ",
41 DataTypeString(dtype)));
42 }
43
44 void Compute(OpKernelContext* context) override {
45 const Tensor* input_tensor;
46 OP_REQUIRES_OK(context, context->input("input", &input_tensor));
47 const auto& input_flat = input_tensor->flat<T>();
48
49 Tensor* output_tensor = nullptr;
50 OP_REQUIRES_OK(context,
51 context->allocate_output("output", input_tensor->shape(),
52 &output_tensor));
53 auto output_flat = output_tensor->flat<int64_t>();
54
55 functor::LaunchTensorToHashBucket<Device, T>()(
56 context, num_buckets_, input_flat.data(), input_tensor->NumElements(),
57 output_flat.data());
58 }
59
60 private:
61 int64_t num_buckets_;
62
63 TF_DISALLOW_COPY_AND_ASSIGN(TensorToHashBucketOp);
64};
65
66#define REGISTER_CPU_KERNELS(type) \
67 REGISTER_KERNEL_BUILDER(Name("_TensorToHashBucketFast") \
68 .Device(DEVICE_CPU) \
69 .TypeConstraint<type>("T"), \
70 TensorToHashBucketOp<CPUDevice, type>);
71
72TF_CALL_INTEGRAL_TYPES(REGISTER_CPU_KERNELS);
73
74#undef REGISTER_CPU_KERNELS
75
76#if GOOGLE_CUDA
77
78#define REGISTER_GPU_KERNELS(type) \
79 REGISTER_KERNEL_BUILDER(Name("_TensorToHashBucketFast") \
80 .Device(DEVICE_GPU) \
81 .TypeConstraint<type>("T"), \
82 TensorToHashBucketOp<GPUDevice, type>);
83
84TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS);
85
86#undef REGISTER_GPU_KERNELS
87
88#endif // GOOGLE_CUDA
89
90} // namespace tensorflow
91