1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
17#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/bfloat16.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/tensor_types.h"
23#include "tensorflow/core/framework/types.h"
24#include "tensorflow/core/platform/byte_order.h"
25#include "tensorflow/core/platform/types.h"
26
27// Note that the GPU cast functor templates need to be instantiated unlike the
28// CPU ones, and hence their specializations are different than that for CPUs.
29#ifdef SPECIALIZE_FOR_GPUS
30#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
31 template <typename Device> \
32 struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \
33 void operator()(const Device& d, \
34 typename TTypes<OUT_TYPE>::Flat out_tensor, \
35 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
36 bool truncate = false) { \
37 if (truncate) { \
38 out_tensor.device(d) = \
39 in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
40 .template cast<OUT_TYPE>(); \
41 } else { \
42 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
43 } \
44 } \
45 }; \
46 template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
47#else
48#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
49 template <> \
50 struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \
51 void operator()(const DEVICE& d, \
52 typename TTypes<OUT_TYPE>::Flat out_tensor, \
53 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
54 bool truncate = false) { \
55 if (truncate) { \
56 out_tensor.device(d) = \
57 in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
58 .template cast<OUT_TYPE>(); \
59 } else { \
60 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
61 } \
62 } \
63 };
64#endif
65
66#define CAST_FUNCTORS(devname) \
67 SPECIALIZE_CAST(devname, float, double) \
68 SPECIALIZE_CAST(devname, float, std::complex<double>) \
69 SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
70 SPECIALIZE_CAST(devname, std::complex<float>, double) \
71 SPECIALIZE_CAST(devname, Eigen::half, double) \
72 SPECIALIZE_CAST(devname, Eigen::half, float) \
73 SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
74 SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
75 SPECIALIZE_CAST(devname, bfloat16, float) \
76 template <typename OUT_TYPE, typename IN_OUT> \
77 struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
78 void operator()(const devname& d, \
79 typename TTypes<OUT_TYPE>::Flat out_tensor, \
80 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
81 bool truncate = false) { \
82 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
83 } \
84 };
85
86#if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
87// If MLIR kernels are enabled, we don't need the specialized cast from float to
88// double or from Eigen::half to double. We still need the specialized cast from
89// Eigen::half to float, because it is used in depthwise_conv_grad_op.cc. We
90// still need the specialized cast from float to double because it is used in
91// resize_bilinear_op.cc.
92#define CAST_FUNCTORS_SUBSET(devname) \
93 SPECIALIZE_CAST(devname, float, double) \
94 SPECIALIZE_CAST(devname, Eigen::half, float) \
95 SPECIALIZE_CAST(devname, bfloat16, float) \
96 template <typename OUT_TYPE, typename IN_OUT> \
97 struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
98 void operator()(const devname& d, \
99 typename TTypes<OUT_TYPE>::Flat out_tensor, \
100 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
101 bool truncate = false) { \
102 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
103 } \
104 };
105#endif
106
107namespace tensorflow {
108
109typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
110 bool trunc)>
111 CastFunctorType;
112
113// Common base class of Cast kernels
114class CastOpBase : public OpKernel {
115 public:
116 explicit CastOpBase(OpKernelConstruction* ctx);
117
118 void Compute(OpKernelContext* ctx) override;
119
120 protected:
121 DataType src_dtype_;
122 DataType dst_dtype_;
123 DataType external_src_dtype_;
124 DataType external_dst_dtype_;
125 bool use_truncation_;
126 CastFunctorType work_ = nullptr;
127 Status Unimplemented();
128
129 TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
130};
131
132// CPU implementation of Cast
133class CpuCastOp : public CastOpBase {
134 public:
135 explicit CpuCastOp(OpKernelConstruction* ctx);
136
137 private:
138 Status Prepare();
139};
140
141namespace functor {
142
143template <typename I>
144constexpr int MantissaWidth() {
145 return std::numeric_limits<I>::digits;
146}
147
148template <>
149constexpr int MantissaWidth<Eigen::half>() {
150 // Remember, there's 1 hidden bit
151 return 10 + 1;
152}
153
154template <>
155constexpr int MantissaWidth<bfloat16>() {
156 // Remember, there's 1 hidden bit
157 return 7 + 1;
158}
159
160template <typename Device, typename Tout, typename Tin>
161void Cast(const Device& d, typename TTypes<Tout>::Flat o,
162 typename TTypes<Tin>::ConstFlat i) {
163 o.device(d) = i.template cast<Tout>();
164}
165
166template <typename Device, typename Tout, typename Tin>
167struct CastFunctor {
168 void operator()(const Device& d, typename TTypes<Tout>::Flat o,
169 typename TTypes<Tin>::ConstFlat i, bool truncate = false);
170};
171
172// Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
173// Specialize for others if needed in future.
174template <typename I>
175typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
176 EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
177 // Only zero the bits for non-NaNs.
178 // For NaNs, let the non-truncation version handle it.
179 if (!std::isnan(t)) {
180 uint64_t* p = reinterpret_cast<uint64_t*>(&t);
181 *p &= (0xFFFFFFFFFFFFFFFF << n);
182 }
183}
184
185template <typename I>
186typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
187 EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
188 // Only zero the bits for non-NaNs.
189 // For NaNs, let the non-truncation version handle it.
190 if (!std::isnan(t)) {
191 uint32_t* p = reinterpret_cast<uint32_t*>(&t);
192 *p &= (0xFFFFFFFF << n);
193 }
194}
195
196// Set n least significant bits to 0
197template <typename I, typename O>
198struct LSBZeroSetter {
199 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
200 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
201 static_assert(
202 bits > 0,
203 "The output type must have fewer mantissa bits than the input type\n");
204 I t = a;
205 LSBZeroSetterHelper(t, bits);
206 return t;
207 }
208};
209
210template <typename I, typename O>
211struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
212 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
213 const std::complex<I>& a) const {
214 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
215 static_assert(
216 bits > 0,
217 "The output type must have fewer mantissa bits than the input type\n");
218 I re = std::real(a);
219 I img = std::imag(a);
220 LSBZeroSetterHelper(re, bits);
221 LSBZeroSetterHelper(img, bits);
222 std::complex<I> toReturn(re, img);
223 return toReturn;
224 }
225};
226
227template <typename I, typename O>
228struct LSBZeroSetter<std::complex<I>, O> {
229 // Sets the 16 LSBits of the float to 0
230 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
231 const std::complex<I>& a) const {
232 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
233 static_assert(
234 bits > 0,
235 "The output type must have fewer mantissa bits than the input type\n");
236 I re = std::real(a);
237 I img = std::imag(a);
238 LSBZeroSetterHelper(re, bits);
239 LSBZeroSetterHelper(img, bits);
240 std::complex<I> toReturn(re, img);
241 return toReturn;
242 }
243};
244
245} // end namespace functor
246} // end namespace tensorflow
247
248namespace Eigen {
249namespace internal {
250
251// Eigen can't convert to/from complex numbers, because it is limited to cases
252// that can be static_casted. But numpy is able to cast to/from complex, which
253// we want to replicate. So we add specializations for complex here.
254template <typename From, typename To>
255struct scalar_cast_op<std::complex<From>, To> {
256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
257 operator()(const std::complex<From>& a) const {
258 // Replicate numpy behavior of returning just the real part
259 return static_cast<To>(a.real());
260 }
261};
262
263template <typename From, typename To>
264struct scalar_cast_op<From, std::complex<To>> {
265 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
266 const From& a) const {
267 // Replicate numpy behavior of setting the imaginary part to 0
268 return std::complex<To>(static_cast<To>(a), To(0));
269 }
270};
271
272template <typename From, typename To>
273struct scalar_cast_op<std::complex<From>, std::complex<To>> {
274 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
275 const std::complex<From>& a) const {
276 return std::complex<To>(static_cast<To>(a.real()),
277 static_cast<To>(a.imag()));
278 }
279};
280
281template <typename From, typename To>
282struct functor_traits_complex_impl {
283 enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
284};
285
286template <typename From, typename To>
287struct functor_traits<scalar_cast_op<std::complex<From>, To>>
288 : functor_traits_complex_impl<std::complex<From>, To> {};
289template <typename From, typename To>
290struct functor_traits<scalar_cast_op<From, std::complex<To>>>
291 : functor_traits_complex_impl<From, std::complex<To>> {};
292// Needed to avoid ambiguous partial specialization
293template <typename From, typename To>
294struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
295 : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
296
297} // namespace internal
298} // namespace Eigen
299
300#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
301