1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_CAST_OP_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/bfloat16.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/tensor_types.h" |
23 | #include "tensorflow/core/framework/types.h" |
24 | #include "tensorflow/core/platform/byte_order.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | |
27 | // Note that the GPU cast functor templates need to be instantiated unlike the |
28 | // CPU ones, and hence their specializations are different than that for CPUs. |
29 | #ifdef SPECIALIZE_FOR_GPUS |
30 | #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ |
31 | template <typename Device> \ |
32 | struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \ |
33 | void operator()(const Device& d, \ |
34 | typename TTypes<OUT_TYPE>::Flat out_tensor, \ |
35 | typename TTypes<IN_OUT>::ConstFlat in_tensor, \ |
36 | bool truncate = false) { \ |
37 | if (truncate) { \ |
38 | out_tensor.device(d) = \ |
39 | in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ |
40 | .template cast<OUT_TYPE>(); \ |
41 | } else { \ |
42 | out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ |
43 | } \ |
44 | } \ |
45 | }; \ |
46 | template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>; |
47 | #else |
48 | #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ |
49 | template <> \ |
50 | struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \ |
51 | void operator()(const DEVICE& d, \ |
52 | typename TTypes<OUT_TYPE>::Flat out_tensor, \ |
53 | typename TTypes<IN_OUT>::ConstFlat in_tensor, \ |
54 | bool truncate = false) { \ |
55 | if (truncate) { \ |
56 | out_tensor.device(d) = \ |
57 | in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ |
58 | .template cast<OUT_TYPE>(); \ |
59 | } else { \ |
60 | out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ |
61 | } \ |
62 | } \ |
63 | }; |
64 | #endif |
65 | |
66 | #define CAST_FUNCTORS(devname) \ |
67 | SPECIALIZE_CAST(devname, float, double) \ |
68 | SPECIALIZE_CAST(devname, float, std::complex<double>) \ |
69 | SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \ |
70 | SPECIALIZE_CAST(devname, std::complex<float>, double) \ |
71 | SPECIALIZE_CAST(devname, Eigen::half, double) \ |
72 | SPECIALIZE_CAST(devname, Eigen::half, float) \ |
73 | SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \ |
74 | SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \ |
75 | SPECIALIZE_CAST(devname, bfloat16, float) \ |
76 | template <typename OUT_TYPE, typename IN_OUT> \ |
77 | struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \ |
78 | void operator()(const devname& d, \ |
79 | typename TTypes<OUT_TYPE>::Flat out_tensor, \ |
80 | typename TTypes<IN_OUT>::ConstFlat in_tensor, \ |
81 | bool truncate = false) { \ |
82 | out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ |
83 | } \ |
84 | }; |
85 | |
86 | #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) |
87 | // If MLIR kernels are enabled, we don't need the specialized cast from float to |
88 | // double or from Eigen::half to double. We still need the specialized cast from |
89 | // Eigen::half to float, because it is used in depthwise_conv_grad_op.cc. We |
90 | // still need the specialized cast from float to double because it is used in |
91 | // resize_bilinear_op.cc. |
92 | #define CAST_FUNCTORS_SUBSET(devname) \ |
93 | SPECIALIZE_CAST(devname, float, double) \ |
94 | SPECIALIZE_CAST(devname, Eigen::half, float) \ |
95 | SPECIALIZE_CAST(devname, bfloat16, float) \ |
96 | template <typename OUT_TYPE, typename IN_OUT> \ |
97 | struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \ |
98 | void operator()(const devname& d, \ |
99 | typename TTypes<OUT_TYPE>::Flat out_tensor, \ |
100 | typename TTypes<IN_OUT>::ConstFlat in_tensor, \ |
101 | bool truncate = false) { \ |
102 | out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ |
103 | } \ |
104 | }; |
105 | #endif |
106 | |
107 | namespace tensorflow { |
108 | |
109 | typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*, |
110 | bool trunc)> |
111 | CastFunctorType; |
112 | |
113 | // Common base class of Cast kernels |
114 | class CastOpBase : public OpKernel { |
115 | public: |
116 | explicit CastOpBase(OpKernelConstruction* ctx); |
117 | |
118 | void Compute(OpKernelContext* ctx) override; |
119 | |
120 | protected: |
121 | DataType src_dtype_; |
122 | DataType dst_dtype_; |
123 | DataType external_src_dtype_; |
124 | DataType external_dst_dtype_; |
125 | bool use_truncation_; |
126 | CastFunctorType work_ = nullptr; |
127 | Status Unimplemented(); |
128 | |
129 | TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); |
130 | }; |
131 | |
132 | // CPU implementation of Cast |
133 | class CpuCastOp : public CastOpBase { |
134 | public: |
135 | explicit CpuCastOp(OpKernelConstruction* ctx); |
136 | |
137 | private: |
138 | Status Prepare(); |
139 | }; |
140 | |
141 | namespace functor { |
142 | |
143 | template <typename I> |
144 | constexpr int MantissaWidth() { |
145 | return std::numeric_limits<I>::digits; |
146 | } |
147 | |
148 | template <> |
149 | constexpr int MantissaWidth<Eigen::half>() { |
150 | // Remember, there's 1 hidden bit |
151 | return 10 + 1; |
152 | } |
153 | |
154 | template <> |
155 | constexpr int MantissaWidth<bfloat16>() { |
156 | // Remember, there's 1 hidden bit |
157 | return 7 + 1; |
158 | } |
159 | |
160 | template <typename Device, typename Tout, typename Tin> |
161 | void Cast(const Device& d, typename TTypes<Tout>::Flat o, |
162 | typename TTypes<Tin>::ConstFlat i) { |
163 | o.device(d) = i.template cast<Tout>(); |
164 | } |
165 | |
166 | template <typename Device, typename Tout, typename Tin> |
167 | struct CastFunctor { |
168 | void operator()(const Device& d, typename TTypes<Tout>::Flat o, |
169 | typename TTypes<Tin>::ConstFlat i, bool truncate = false); |
170 | }; |
171 | |
172 | // Only enable LSBZeroSetterHelper for 64 and 32 bit input data types. |
173 | // Specialize for others if needed in future. |
174 | template <typename I> |
175 | typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC |
176 | EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { |
177 | // Only zero the bits for non-NaNs. |
178 | // For NaNs, let the non-truncation version handle it. |
179 | if (!std::isnan(t)) { |
180 | uint64_t* p = reinterpret_cast<uint64_t*>(&t); |
181 | *p &= (0xFFFFFFFFFFFFFFFF << n); |
182 | } |
183 | } |
184 | |
185 | template <typename I> |
186 | typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC |
187 | EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { |
188 | // Only zero the bits for non-NaNs. |
189 | // For NaNs, let the non-truncation version handle it. |
190 | if (!std::isnan(t)) { |
191 | uint32_t* p = reinterpret_cast<uint32_t*>(&t); |
192 | *p &= (0xFFFFFFFF << n); |
193 | } |
194 | } |
195 | |
196 | // Set n least significant bits to 0 |
197 | template <typename I, typename O> |
198 | struct LSBZeroSetter { |
199 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const { |
200 | constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); |
201 | static_assert( |
202 | bits > 0, |
203 | "The output type must have fewer mantissa bits than the input type\n" ); |
204 | I t = a; |
205 | LSBZeroSetterHelper(t, bits); |
206 | return t; |
207 | } |
208 | }; |
209 | |
210 | template <typename I, typename O> |
211 | struct LSBZeroSetter<std::complex<I>, std::complex<O>> { |
212 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( |
213 | const std::complex<I>& a) const { |
214 | constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); |
215 | static_assert( |
216 | bits > 0, |
217 | "The output type must have fewer mantissa bits than the input type\n" ); |
218 | I re = std::real(a); |
219 | I img = std::imag(a); |
220 | LSBZeroSetterHelper(re, bits); |
221 | LSBZeroSetterHelper(img, bits); |
222 | std::complex<I> toReturn(re, img); |
223 | return toReturn; |
224 | } |
225 | }; |
226 | |
227 | template <typename I, typename O> |
228 | struct LSBZeroSetter<std::complex<I>, O> { |
229 | // Sets the 16 LSBits of the float to 0 |
230 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( |
231 | const std::complex<I>& a) const { |
232 | constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); |
233 | static_assert( |
234 | bits > 0, |
235 | "The output type must have fewer mantissa bits than the input type\n" ); |
236 | I re = std::real(a); |
237 | I img = std::imag(a); |
238 | LSBZeroSetterHelper(re, bits); |
239 | LSBZeroSetterHelper(img, bits); |
240 | std::complex<I> toReturn(re, img); |
241 | return toReturn; |
242 | } |
243 | }; |
244 | |
245 | } // end namespace functor |
246 | } // end namespace tensorflow |
247 | |
248 | namespace Eigen { |
249 | namespace internal { |
250 | |
251 | // Eigen can't convert to/from complex numbers, because it is limited to cases |
252 | // that can be static_casted. But numpy is able to cast to/from complex, which |
253 | // we want to replicate. So we add specializations for complex here. |
254 | template <typename From, typename To> |
255 | struct scalar_cast_op<std::complex<From>, To> { |
256 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To |
257 | operator()(const std::complex<From>& a) const { |
258 | // Replicate numpy behavior of returning just the real part |
259 | return static_cast<To>(a.real()); |
260 | } |
261 | }; |
262 | |
263 | template <typename From, typename To> |
264 | struct scalar_cast_op<From, std::complex<To>> { |
265 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()( |
266 | const From& a) const { |
267 | // Replicate numpy behavior of setting the imaginary part to 0 |
268 | return std::complex<To>(static_cast<To>(a), To(0)); |
269 | } |
270 | }; |
271 | |
272 | template <typename From, typename To> |
273 | struct scalar_cast_op<std::complex<From>, std::complex<To>> { |
274 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()( |
275 | const std::complex<From>& a) const { |
276 | return std::complex<To>(static_cast<To>(a.real()), |
277 | static_cast<To>(a.imag())); |
278 | } |
279 | }; |
280 | |
281 | template <typename From, typename To> |
282 | struct functor_traits_complex_impl { |
283 | enum { Cost = NumTraits<To>::AddCost, PacketAccess = false }; |
284 | }; |
285 | |
286 | template <typename From, typename To> |
287 | struct functor_traits<scalar_cast_op<std::complex<From>, To>> |
288 | : functor_traits_complex_impl<std::complex<From>, To> {}; |
289 | template <typename From, typename To> |
290 | struct functor_traits<scalar_cast_op<From, std::complex<To>>> |
291 | : functor_traits_complex_impl<From, std::complex<To>> {}; |
292 | // Needed to avoid ambiguous partial specialization |
293 | template <typename From, typename To> |
294 | struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>> |
295 | : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {}; |
296 | |
297 | } // namespace internal |
298 | } // namespace Eigen |
299 | |
300 | #endif // TENSORFLOW_CORE_KERNELS_CAST_OP_H_ |
301 | |