1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/cast_op.h"
21
22#include "tensorflow/core/common_runtime/device.h"
23#include "tensorflow/core/framework/op.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/platform/macros.h"
28#include "tensorflow/core/platform/types.h"
29#include "tensorflow/core/util/work_sharder.h"
30
31#include "tensorflow/core/kernels/cast_op_impl.h"
32
33namespace tensorflow {
34
35typedef Eigen::ThreadPoolDevice CPUDevice;
36typedef Eigen::GpuDevice GPUDevice;
37
38#define CURRY_TYPES2(FN, arg0) \
39 FN(arg0, bool); \
40 FN(arg0, uint8); \
41 FN(arg0, uint16); \
42 FN(arg0, uint32); \
43 FN(arg0, uint64); \
44 FN(arg0, int8); \
45 FN(arg0, int16); \
46 FN(arg0, int32); \
47 FN(arg0, int64_t); \
48 FN(arg0, Eigen::half); \
49 FN(arg0, float); \
50 FN(arg0, double); \
51 FN(arg0, std::complex<float>); \
52 FN(arg0, std::complex<double>)
53
54CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
55 OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
56
57 OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
58
59 OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
60
61 // Quantized data types use the same underlying format as their non quantized
62 // version so we use the non quantized implementation for casting.
63 if (external_dst_dtype_ == DT_QUINT8) {
64 dst_dtype_ = DT_UINT8;
65 } else if (external_dst_dtype_ == DT_QINT8) {
66 dst_dtype_ = DT_INT8;
67 } else if (external_dst_dtype_ == DT_QINT32) {
68 dst_dtype_ = DT_INT32;
69 } else if (external_dst_dtype_ == DT_QINT16) {
70 dst_dtype_ = DT_INT16;
71 } else if (external_dst_dtype_ == DT_QUINT16) {
72 dst_dtype_ = DT_UINT16;
73 } else {
74 dst_dtype_ = external_dst_dtype_;
75 }
76
77 if (external_src_dtype_ == DT_QUINT8) {
78 src_dtype_ = DT_UINT8;
79 } else if (external_src_dtype_ == DT_QINT8) {
80 src_dtype_ = DT_INT8;
81 } else if (external_src_dtype_ == DT_QINT32) {
82 src_dtype_ = DT_INT32;
83 } else if (external_src_dtype_ == DT_QINT16) {
84 src_dtype_ = DT_INT16;
85 } else if (external_src_dtype_ == DT_QUINT16) {
86 src_dtype_ = DT_UINT16;
87 } else {
88 src_dtype_ = external_src_dtype_;
89 }
90}
91
92void CastOpBase::Compute(OpKernelContext* ctx) {
93 const Tensor& inp = ctx->input(0);
94 if (work_ == nullptr) {
95 ctx->set_output(0, inp);
96 } else if (external_src_dtype_ != src_dtype_ ||
97 external_dst_dtype_ != dst_dtype_) {
98 Tensor in;
99 // If the type is a quantized type we need to do a bitcast since the
100 // src_dtype_ is different from external_src_type_.
101 OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
102 Tensor* out = nullptr;
103 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
104 out->set_dtype(dst_dtype_);
105 work_(ctx, in, out, use_truncation_);
106 out->set_dtype(external_dst_dtype_);
107 } else {
108 Tensor* out = nullptr;
109 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
110 work_(ctx, inp, out, use_truncation_);
111 }
112}
113
114Status CastOpBase::Unimplemented() {
115 return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
116 " to ", DataTypeString(external_dst_dtype_),
117 " is not supported");
118}
119
120CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
121 OP_REQUIRES_OK(ctx, Prepare());
122}
123
124Status CpuCastOp::Prepare() {
125 if (external_src_dtype_ == external_dst_dtype_) {
126 work_ = nullptr; // Identity
127 return OkStatus();
128 }
129 if (src_dtype_ == DT_BOOL) {
130 work_ = GetCpuCastFromBool(dst_dtype_);
131 } else if (src_dtype_ == DT_UINT8) {
132 work_ = GetCpuCastFromUint8(dst_dtype_);
133 } else if (src_dtype_ == DT_UINT16) {
134 work_ = GetCpuCastFromUint16(dst_dtype_);
135 } else if (src_dtype_ == DT_UINT32) {
136 work_ = GetCpuCastFromUint32(dst_dtype_);
137 } else if (src_dtype_ == DT_UINT64) {
138 work_ = GetCpuCastFromUint64(dst_dtype_);
139 } else if (src_dtype_ == DT_INT8) {
140 work_ = GetCpuCastFromInt8(dst_dtype_);
141 } else if (src_dtype_ == DT_INT16) {
142 work_ = GetCpuCastFromInt16(dst_dtype_);
143 } else if (src_dtype_ == DT_INT32) {
144 work_ = GetCpuCastFromInt32(dst_dtype_);
145 } else if (src_dtype_ == DT_INT64) {
146 work_ = GetCpuCastFromInt64(dst_dtype_);
147 } else if (src_dtype_ == DT_HALF) {
148 work_ = GetCpuCastFromHalf(dst_dtype_);
149 } else if (src_dtype_ == DT_FLOAT) {
150 work_ = GetCpuCastFromFloat(dst_dtype_);
151 } else if (src_dtype_ == DT_DOUBLE) {
152 work_ = GetCpuCastFromDouble(dst_dtype_);
153 } else if (src_dtype_ == DT_COMPLEX64) {
154 work_ = GetCpuCastFromComplex64(dst_dtype_);
155 } else if (src_dtype_ == DT_COMPLEX128) {
156 work_ = GetCpuCastFromComplex128(dst_dtype_);
157 } else if (src_dtype_ == DT_BFLOAT16) {
158 work_ = GetCpuCastFromBfloat(dst_dtype_);
159 }
160
161 // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
162 // bottleneck, we could probably implement specialized support for
163 // vectorized versions (not the least based on F16C for Haswell
164 // or newer).
165
166 return work_ == nullptr ? Unimplemented() : OkStatus();
167}
168
169#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
170 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
171class GpuCastOp : public CastOpBase {
172 public:
173 explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
174 OP_REQUIRES_OK(ctx, Prepare());
175 }
176
177 private:
178 Status Prepare() {
179 if (external_src_dtype_ == external_dst_dtype_) {
180 work_ = nullptr; // Identity
181 return OkStatus();
182 }
183 if (src_dtype_ == DT_BOOL) {
184 work_ = GetGpuCastFromBool(dst_dtype_);
185 } else if (src_dtype_ == DT_UINT8) {
186 work_ = GetGpuCastFromUint8(dst_dtype_);
187 } else if (src_dtype_ == DT_UINT16) {
188 work_ = GetGpuCastFromUint16(dst_dtype_);
189 } else if (src_dtype_ == DT_UINT32) {
190 work_ = GetGpuCastFromUint32(dst_dtype_);
191 } else if (src_dtype_ == DT_UINT64) {
192 work_ = GetGpuCastFromUint64(dst_dtype_);
193 } else if (src_dtype_ == DT_INT8) {
194 work_ = GetGpuCastFromInt8(dst_dtype_);
195 } else if (src_dtype_ == DT_INT16) {
196 work_ = GetGpuCastFromInt16(dst_dtype_);
197 } else if (src_dtype_ == DT_INT32) {
198 work_ = GetGpuCastFromInt32(dst_dtype_);
199 } else if (src_dtype_ == DT_INT64) {
200 work_ = GetGpuCastFromInt64(dst_dtype_);
201 } else if (src_dtype_ == DT_HALF) {
202 work_ = GetGpuCastFromHalf(dst_dtype_);
203 } else if (src_dtype_ == DT_FLOAT) {
204 work_ = GetGpuCastFromFloat(dst_dtype_);
205 } else if (src_dtype_ == DT_DOUBLE) {
206 work_ = GetGpuCastFromDouble(dst_dtype_);
207 } else if (src_dtype_ == DT_COMPLEX64) {
208 work_ = GetGpuCastFromComplex64(dst_dtype_);
209 } else if (src_dtype_ == DT_COMPLEX128) {
210 work_ = GetGpuCastFromComplex128(dst_dtype_);
211 } else if (src_dtype_ == DT_BFLOAT16) {
212 work_ = GetGpuCastFromBfloat(dst_dtype_);
213 }
214
215 return work_ == nullptr ? Unimplemented() : OkStatus();
216 }
217};
218#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
219
220#undef CAST_CASE
221
222REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
223
224#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
225 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
226#define REGISTER_CAST_GPU(srctype, dsttype) \
227 REGISTER_KERNEL_BUILDER(Name("Cast") \
228 .TypeConstraint<srctype>("SrcT") \
229 .TypeConstraint<dsttype>("DstT") \
230 .Device(DEVICE_GPU), \
231 GpuCastOp)
232
233#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
234CURRY_TYPES2(REGISTER_CAST_GPU, bool);
235CURRY_TYPES2(REGISTER_CAST_GPU, int8);
236CURRY_TYPES2(REGISTER_CAST_GPU, int16);
237CURRY_TYPES2(REGISTER_CAST_GPU, int32);
238CURRY_TYPES2(REGISTER_CAST_GPU, int64);
239CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
240CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
241CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
242CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
243CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
244CURRY_TYPES2(REGISTER_CAST_GPU, float);
245CURRY_TYPES2(REGISTER_CAST_GPU, double);
246CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
247CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
248#endif
249
250REGISTER_CAST_GPU(float, bfloat16);
251REGISTER_CAST_GPU(bfloat16, float);
252
253#undef REGISTER_CAST_GPU
254#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
255
256
257#undef CURRY_TYPES2
258
259// HostCast differs from Cast in that its input and output are in host memory.
260REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
261REGISTER_KERNEL_BUILDER(
262 Name("_HostCast").Device(DEVICE_DEFAULT).HostMemory("x").HostMemory("y"),
263 CpuCastOp);
264} // end namespace tensorflow
265