1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
17#define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
18
19#define EIGEN_USE_THREADS
20
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/kernels/cast_op.h"
23
24namespace tensorflow {
25
26namespace functor {
27
28CAST_FUNCTORS(Eigen::ThreadPoolDevice);
29
30
31} // namespace functor
32
33#define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \
34 FN(arg0, arg1, bool); \
35 FN(arg0, arg1, uint8); \
36 FN(arg0, arg1, uint16); \
37 FN(arg0, arg1, uint32); \
38 FN(arg0, arg1, uint64); \
39 FN(arg0, arg1, int8); \
40 FN(arg0, arg1, int16); \
41 FN(arg0, arg1, int32); \
42 FN(arg0, arg1, int64_t); \
43 FN(arg0, arg1, float); \
44 FN(arg0, arg1, double); \
45 FN(arg0, arg1, std::complex<float>); \
46 FN(arg0, arg1, std::complex<double>)
47
48#define CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
49 CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \
50 FN(arg0, arg1, Eigen::half);
51
52#define CURRY_TYPES3(FN, arg0, arg1) \
53 CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
54 FN(arg0, arg1, bfloat16);
55
56#define CAST_CASE(DEVICE, IN, OUT) \
57 if (DataTypeToEnum<OUT>::value == dst_dtype) { \
58 return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \
59 bool truncate) { \
60 functor::CastFunctor<DEVICE, OUT, IN> func; \
61 func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \
62 truncate); \
63 }; \
64 }
65
66// The functions below are implemented in the cast_op_impl_*.cc files.
67CastFunctorType GetCpuCastFromBool(DataType dst_dtype);
68
69CastFunctorType GetCpuCastFromUint8(DataType dst_dtype);
70
71CastFunctorType GetCpuCastFromUint16(DataType dst_dtype);
72
73CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
74
75CastFunctorType GetCpuCastFromUint32(DataType dst_dtype);
76
77CastFunctorType GetCpuCastFromUint64(DataType dst_dtype);
78
79CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
80
81CastFunctorType GetCpuCastFromInt16(DataType dst_dtype);
82
83CastFunctorType GetCpuCastFromInt32(DataType dst_dtype);
84
85CastFunctorType GetCpuCastFromInt64(DataType dst_dtype);
86
87CastFunctorType GetCpuCastFromHalf(DataType dst_dtype);
88
89CastFunctorType GetCpuCastFromFloat(DataType dst_dtype);
90
91CastFunctorType GetCpuCastFromDouble(DataType dst_dtype);
92
93CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype);
94
95CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype);
96
97CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype);
98
99#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
100 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
101// Same, for GPU.
102CastFunctorType GetGpuCastFromBool(DataType dst_dtype);
103
104CastFunctorType GetGpuCastFromUint8(DataType dst_dtype);
105
106CastFunctorType GetGpuCastFromUint16(DataType dst_dtype);
107
108CastFunctorType GetGpuCastFromInt8(DataType dst_dtype);
109
110CastFunctorType GetGpuCastFromUint32(DataType dst_dtype);
111
112CastFunctorType GetGpuCastFromUint64(DataType dst_dtype);
113
114CastFunctorType GetGpuCastFromInt16(DataType dst_dtype);
115
116CastFunctorType GetGpuCastFromInt32(DataType dst_dtype);
117
118CastFunctorType GetGpuCastFromInt64(DataType dst_dtype);
119
120CastFunctorType GetGpuCastFromHalf(DataType dst_dtype);
121
122CastFunctorType GetGpuCastFromFloat(DataType dst_dtype);
123
124CastFunctorType GetGpuCastFromDouble(DataType dst_dtype);
125
126CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype);
127
128CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype);
129
130CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype);
131
132#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
133
134
135} // namespace tensorflow
136
137#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
138