1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc
17
18#define EIGEN_USE_THREADS
19
20#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
21 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
22#define EIGEN_USE_GPU
23#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24
25#include "tensorflow/core/kernels/one_hot_op.h"
26
27#include <memory>
28#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/framework/tensor_types.h"
34#include "tensorflow/core/framework/types.h"
35#include "tensorflow/core/platform/logging.h"
36#include "tensorflow/core/platform/macros.h"
37#include "tensorflow/core/util/overflow.h"
38
39namespace tensorflow {
40
41typedef Eigen::ThreadPoolDevice CPUDevice;
42typedef Eigen::GpuDevice GPUDevice;
43
44template <typename Device, typename T, typename TI>
45class OneHotOp : public OpKernel {
46 public:
47 explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
48 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
49 }
50
51 void Compute(OpKernelContext* ctx) override {
52 const Tensor& indices = ctx->input(0);
53 const Tensor& depth = ctx->input(1);
54 const Tensor& on_value = ctx->input(2);
55 const Tensor& off_value = ctx->input(3);
56 const TensorShape& indices_shape = indices.shape();
57
58 const int indices_dims = indices_shape.dims();
59 const int output_dims = indices_dims + 1;
60
61 // Preliminary validation of sizes.
62 OP_REQUIRES(
63 ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims),
64 errors::InvalidArgument("Expected axis to be -1 or between [0, ",
65 output_dims, "). But received: ", axis_));
66 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth.shape()),
67 errors::InvalidArgument("depth must be a scalar, but got: ",
68 depth.shape().DebugString()));
69 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value.shape()),
70 errors::InvalidArgument("on_value must be a scalar, but got: ",
71 on_value.shape().DebugString()));
72 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value.shape()),
73 errors::InvalidArgument("off_value must be a scalar, but got: ",
74 off_value.shape().DebugString()));
75
76 const int axis = (axis_ == -1) ? indices_dims : axis_;
77
78 // The one-hot dimension.
79 const int32_t depth_v = depth.scalar<int32>()();
80 OP_REQUIRES(
81 ctx, depth_v >= 0,
82 errors::InvalidArgument("depth must be non-negative, got: ", depth_v));
83 OP_REQUIRES(
84 ctx,
85 MultiplyWithoutOverflow(indices_shape.num_elements(), depth_v) >= 0,
86 errors::InvalidArgument("OneHot result would have shape ",
87 indices_shape.DebugString(), " + [", depth_v,
88 "], which exceeds 2**63 - 1 elements"));
89
90 TensorShape output_shape = indices_shape;
91 output_shape.InsertDim(axis, depth_v);
92
93 auto on_value_t = on_value.scalar<T>();
94 auto off_value_t = off_value.scalar<T>();
95
96 Tensor* output;
97 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
98
99 if (output_shape.num_elements() > 0) {
100 // prefix_dim_size == # of elements before the axis
101 // depth_v == # of elements per axis
102 // suffix_dim_size == # of elements after the axis
103 int64_t prefix_dim_size = 1;
104 for (int i = 0; i < axis; ++i) {
105 prefix_dim_size *= indices_shape.dim_size(i);
106 }
107 int64_t suffix_dim_size = indices_shape.num_elements() / prefix_dim_size;
108
109 // Split indices into matrix of size prefix_dim_size x suffix_dim_size
110 auto indices_t =
111 indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
112 // Split output into 3-Tensor of size:
113 // prefix_dim_size x depth x suffix_dim_size.
114 auto output_t =
115 output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
116
117 functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(),
118 indices_t, on_value_t,
119 off_value_t, &output_t);
120 }
121 }
122
123 private:
124 int32 axis_;
125
126 TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
127};
128
129#define REGISTER_ONE_HOT_INDEX(type, index_type) \
130 REGISTER_KERNEL_BUILDER(Name("OneHot") \
131 .Device(DEVICE_CPU) \
132 .TypeConstraint<index_type>("TI") \
133 .TypeConstraint<type>("T") \
134 .HostMemory("depth"), \
135 OneHotOp<CPUDevice, type, index_type>);
136
137#define REGISTER_ONE_HOT(type) \
138 REGISTER_ONE_HOT_INDEX(type, uint8); \
139 REGISTER_ONE_HOT_INDEX(type, int8); \
140 REGISTER_ONE_HOT_INDEX(type, int32); \
141 REGISTER_ONE_HOT_INDEX(type, int64_t)
142
143TF_CALL_ALL_TYPES(REGISTER_ONE_HOT);
144
145#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
146 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
147
148// Forward declarations of the functor specializations for GPU.
149namespace functor {
150#define DECLARE_GPU_SPEC_INDEX(T, TI) \
151 template <> \
152 void OneHot<GPUDevice, T, TI>::Compute( \
153 const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
154 const typename TTypes<T>::ConstScalar& on_value, \
155 const typename TTypes<T>::ConstScalar& off_value, \
156 typename TTypes<T, 3>::Tensor* output); \
157 extern template struct OneHot<GPUDevice, T, TI>;
158
159#define DECLARE_GPU_SPEC(T) \
160 DECLARE_GPU_SPEC_INDEX(T, uint8); \
161 DECLARE_GPU_SPEC_INDEX(T, int8); \
162 DECLARE_GPU_SPEC_INDEX(T, int32); \
163 DECLARE_GPU_SPEC_INDEX(T, int64_t);
164
165TF_CALL_int8(DECLARE_GPU_SPEC);
166TF_CALL_int32(DECLARE_GPU_SPEC);
167TF_CALL_int64(DECLARE_GPU_SPEC);
168TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
169
170#undef DECLARE_GPU_SPEC_INDEX
171#undef DECLARE_GPU_SPEC
172
173} // namespace functor
174
175// Registration of the GPU implementations.
176#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \
177 REGISTER_KERNEL_BUILDER(Name("OneHot") \
178 .Device(DEVICE_GPU) \
179 .TypeConstraint<index_type>("TI") \
180 .TypeConstraint<type>("T") \
181 .HostMemory("depth"), \
182 OneHotOp<GPUDevice, type, index_type>);
183
184#define REGISTER_ONE_HOT_GPU(type) \
185 REGISTER_ONE_HOT_GPU_INDEX(type, uint8); \
186 REGISTER_ONE_HOT_GPU_INDEX(type, int8); \
187 REGISTER_ONE_HOT_GPU_INDEX(type, int32); \
188 REGISTER_ONE_HOT_GPU_INDEX(type, int64_t);
189
190TF_CALL_int8(REGISTER_ONE_HOT_GPU);
191TF_CALL_int32(REGISTER_ONE_HOT_GPU);
192TF_CALL_int64(REGISTER_ONE_HOT_GPU);
193TF_CALL_GPU_ALL_TYPES(REGISTER_ONE_HOT_GPU);
194
195#undef REGISTER_ONE_HOT_GPU_INDEX
196#undef REGISTER_ONE_HOT_GPU
197
198#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
199
200} // namespace tensorflow
201