1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/cast_op.h" |
21 | |
22 | #include "tensorflow/core/common_runtime/device.h" |
23 | #include "tensorflow/core/framework/op.h" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/platform/macros.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | #include "tensorflow/core/util/work_sharder.h" |
30 | |
31 | #include "tensorflow/core/kernels/cast_op_impl.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | typedef Eigen::ThreadPoolDevice CPUDevice; |
36 | typedef Eigen::GpuDevice GPUDevice; |
37 | |
38 | #define CURRY_TYPES2(FN, arg0) \ |
39 | FN(arg0, bool); \ |
40 | FN(arg0, uint8); \ |
41 | FN(arg0, uint16); \ |
42 | FN(arg0, uint32); \ |
43 | FN(arg0, uint64); \ |
44 | FN(arg0, int8); \ |
45 | FN(arg0, int16); \ |
46 | FN(arg0, int32); \ |
47 | FN(arg0, int64_t); \ |
48 | FN(arg0, Eigen::half); \ |
49 | FN(arg0, float); \ |
50 | FN(arg0, double); \ |
51 | FN(arg0, std::complex<float>); \ |
52 | FN(arg0, std::complex<double>) |
53 | |
54 | CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { |
55 | OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT" , &external_src_dtype_)); |
56 | |
57 | OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT" , &external_dst_dtype_)); |
58 | |
59 | OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate" , &use_truncation_)); |
60 | |
61 | // Quantized data types use the same underlying format as their non quantized |
62 | // version so we use the non quantized implementation for casting. |
63 | if (external_dst_dtype_ == DT_QUINT8) { |
64 | dst_dtype_ = DT_UINT8; |
65 | } else if (external_dst_dtype_ == DT_QINT8) { |
66 | dst_dtype_ = DT_INT8; |
67 | } else if (external_dst_dtype_ == DT_QINT32) { |
68 | dst_dtype_ = DT_INT32; |
69 | } else if (external_dst_dtype_ == DT_QINT16) { |
70 | dst_dtype_ = DT_INT16; |
71 | } else if (external_dst_dtype_ == DT_QUINT16) { |
72 | dst_dtype_ = DT_UINT16; |
73 | } else { |
74 | dst_dtype_ = external_dst_dtype_; |
75 | } |
76 | |
77 | if (external_src_dtype_ == DT_QUINT8) { |
78 | src_dtype_ = DT_UINT8; |
79 | } else if (external_src_dtype_ == DT_QINT8) { |
80 | src_dtype_ = DT_INT8; |
81 | } else if (external_src_dtype_ == DT_QINT32) { |
82 | src_dtype_ = DT_INT32; |
83 | } else if (external_src_dtype_ == DT_QINT16) { |
84 | src_dtype_ = DT_INT16; |
85 | } else if (external_src_dtype_ == DT_QUINT16) { |
86 | src_dtype_ = DT_UINT16; |
87 | } else { |
88 | src_dtype_ = external_src_dtype_; |
89 | } |
90 | } |
91 | |
92 | void CastOpBase::Compute(OpKernelContext* ctx) { |
93 | const Tensor& inp = ctx->input(0); |
94 | if (work_ == nullptr) { |
95 | ctx->set_output(0, inp); |
96 | } else if (external_src_dtype_ != src_dtype_ || |
97 | external_dst_dtype_ != dst_dtype_) { |
98 | Tensor in; |
99 | // If the type is a quantized type we need to do a bitcast since the |
100 | // src_dtype_ is different from external_src_type_. |
101 | OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape())); |
102 | Tensor* out = nullptr; |
103 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); |
104 | out->set_dtype(dst_dtype_); |
105 | work_(ctx, in, out, use_truncation_); |
106 | out->set_dtype(external_dst_dtype_); |
107 | } else { |
108 | Tensor* out = nullptr; |
109 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); |
110 | work_(ctx, inp, out, use_truncation_); |
111 | } |
112 | } |
113 | |
114 | Status CastOpBase::Unimplemented() { |
115 | return errors::Unimplemented("Cast " , DataTypeString(external_src_dtype_), |
116 | " to " , DataTypeString(external_dst_dtype_), |
117 | " is not supported" ); |
118 | } |
119 | |
120 | CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { |
121 | OP_REQUIRES_OK(ctx, Prepare()); |
122 | } |
123 | |
124 | Status CpuCastOp::Prepare() { |
125 | if (external_src_dtype_ == external_dst_dtype_) { |
126 | work_ = nullptr; // Identity |
127 | return OkStatus(); |
128 | } |
129 | if (src_dtype_ == DT_BOOL) { |
130 | work_ = GetCpuCastFromBool(dst_dtype_); |
131 | } else if (src_dtype_ == DT_UINT8) { |
132 | work_ = GetCpuCastFromUint8(dst_dtype_); |
133 | } else if (src_dtype_ == DT_UINT16) { |
134 | work_ = GetCpuCastFromUint16(dst_dtype_); |
135 | } else if (src_dtype_ == DT_UINT32) { |
136 | work_ = GetCpuCastFromUint32(dst_dtype_); |
137 | } else if (src_dtype_ == DT_UINT64) { |
138 | work_ = GetCpuCastFromUint64(dst_dtype_); |
139 | } else if (src_dtype_ == DT_INT8) { |
140 | work_ = GetCpuCastFromInt8(dst_dtype_); |
141 | } else if (src_dtype_ == DT_INT16) { |
142 | work_ = GetCpuCastFromInt16(dst_dtype_); |
143 | } else if (src_dtype_ == DT_INT32) { |
144 | work_ = GetCpuCastFromInt32(dst_dtype_); |
145 | } else if (src_dtype_ == DT_INT64) { |
146 | work_ = GetCpuCastFromInt64(dst_dtype_); |
147 | } else if (src_dtype_ == DT_HALF) { |
148 | work_ = GetCpuCastFromHalf(dst_dtype_); |
149 | } else if (src_dtype_ == DT_FLOAT) { |
150 | work_ = GetCpuCastFromFloat(dst_dtype_); |
151 | } else if (src_dtype_ == DT_DOUBLE) { |
152 | work_ = GetCpuCastFromDouble(dst_dtype_); |
153 | } else if (src_dtype_ == DT_COMPLEX64) { |
154 | work_ = GetCpuCastFromComplex64(dst_dtype_); |
155 | } else if (src_dtype_ == DT_COMPLEX128) { |
156 | work_ = GetCpuCastFromComplex128(dst_dtype_); |
157 | } else if (src_dtype_ == DT_BFLOAT16) { |
158 | work_ = GetCpuCastFromBfloat(dst_dtype_); |
159 | } |
160 | |
161 | // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a |
162 | // bottleneck, we could probably implement specialized support for |
163 | // vectorized versions (not the least based on F16C for Haswell |
164 | // or newer). |
165 | |
166 | return work_ == nullptr ? Unimplemented() : OkStatus(); |
167 | } |
168 | |
169 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
170 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
171 | class GpuCastOp : public CastOpBase { |
172 | public: |
173 | explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { |
174 | OP_REQUIRES_OK(ctx, Prepare()); |
175 | } |
176 | |
177 | private: |
178 | Status Prepare() { |
179 | if (external_src_dtype_ == external_dst_dtype_) { |
180 | work_ = nullptr; // Identity |
181 | return OkStatus(); |
182 | } |
183 | if (src_dtype_ == DT_BOOL) { |
184 | work_ = GetGpuCastFromBool(dst_dtype_); |
185 | } else if (src_dtype_ == DT_UINT8) { |
186 | work_ = GetGpuCastFromUint8(dst_dtype_); |
187 | } else if (src_dtype_ == DT_UINT16) { |
188 | work_ = GetGpuCastFromUint16(dst_dtype_); |
189 | } else if (src_dtype_ == DT_UINT32) { |
190 | work_ = GetGpuCastFromUint32(dst_dtype_); |
191 | } else if (src_dtype_ == DT_UINT64) { |
192 | work_ = GetGpuCastFromUint64(dst_dtype_); |
193 | } else if (src_dtype_ == DT_INT8) { |
194 | work_ = GetGpuCastFromInt8(dst_dtype_); |
195 | } else if (src_dtype_ == DT_INT16) { |
196 | work_ = GetGpuCastFromInt16(dst_dtype_); |
197 | } else if (src_dtype_ == DT_INT32) { |
198 | work_ = GetGpuCastFromInt32(dst_dtype_); |
199 | } else if (src_dtype_ == DT_INT64) { |
200 | work_ = GetGpuCastFromInt64(dst_dtype_); |
201 | } else if (src_dtype_ == DT_HALF) { |
202 | work_ = GetGpuCastFromHalf(dst_dtype_); |
203 | } else if (src_dtype_ == DT_FLOAT) { |
204 | work_ = GetGpuCastFromFloat(dst_dtype_); |
205 | } else if (src_dtype_ == DT_DOUBLE) { |
206 | work_ = GetGpuCastFromDouble(dst_dtype_); |
207 | } else if (src_dtype_ == DT_COMPLEX64) { |
208 | work_ = GetGpuCastFromComplex64(dst_dtype_); |
209 | } else if (src_dtype_ == DT_COMPLEX128) { |
210 | work_ = GetGpuCastFromComplex128(dst_dtype_); |
211 | } else if (src_dtype_ == DT_BFLOAT16) { |
212 | work_ = GetGpuCastFromBfloat(dst_dtype_); |
213 | } |
214 | |
215 | return work_ == nullptr ? Unimplemented() : OkStatus(); |
216 | } |
217 | }; |
218 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
219 | |
220 | #undef CAST_CASE |
221 | |
222 | REGISTER_KERNEL_BUILDER(Name("Cast" ).Device(DEVICE_CPU), CpuCastOp); |
223 | |
224 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
225 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
226 | #define REGISTER_CAST_GPU(srctype, dsttype) \ |
227 | REGISTER_KERNEL_BUILDER(Name("Cast") \ |
228 | .TypeConstraint<srctype>("SrcT") \ |
229 | .TypeConstraint<dsttype>("DstT") \ |
230 | .Device(DEVICE_GPU), \ |
231 | GpuCastOp) |
232 | |
233 | #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) |
234 | CURRY_TYPES2(REGISTER_CAST_GPU, bool); |
235 | CURRY_TYPES2(REGISTER_CAST_GPU, int8); |
236 | CURRY_TYPES2(REGISTER_CAST_GPU, int16); |
237 | CURRY_TYPES2(REGISTER_CAST_GPU, int32); |
238 | CURRY_TYPES2(REGISTER_CAST_GPU, int64); |
239 | CURRY_TYPES2(REGISTER_CAST_GPU, uint8); |
240 | CURRY_TYPES2(REGISTER_CAST_GPU, uint16); |
241 | CURRY_TYPES2(REGISTER_CAST_GPU, uint32); |
242 | CURRY_TYPES2(REGISTER_CAST_GPU, uint64); |
243 | CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half); |
244 | CURRY_TYPES2(REGISTER_CAST_GPU, float); |
245 | CURRY_TYPES2(REGISTER_CAST_GPU, double); |
246 | CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>); |
247 | CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>); |
248 | #endif |
249 | |
250 | REGISTER_CAST_GPU(float, bfloat16); |
251 | REGISTER_CAST_GPU(bfloat16, float); |
252 | |
253 | #undef REGISTER_CAST_GPU |
254 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
255 | |
256 | |
257 | #undef CURRY_TYPES2 |
258 | |
259 | // HostCast differs from Cast in that its input and output are in host memory. |
260 | REGISTER_KERNEL_BUILDER(Name("_HostCast" ).Device(DEVICE_CPU), CpuCastOp); |
261 | REGISTER_KERNEL_BUILDER( |
262 | Name("_HostCast" ).Device(DEVICE_DEFAULT).HostMemory("x" ).HostMemory("y" ), |
263 | CpuCastOp); |
264 | } // end namespace tensorflow |
265 | |