1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ |
18 | |
19 | #define EIGEN_USE_THREADS |
20 | |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/kernels/cast_op.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | namespace functor { |
27 | |
28 | CAST_FUNCTORS(Eigen::ThreadPoolDevice); |
29 | |
30 | |
31 | } // namespace functor |
32 | |
33 | #define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ |
34 | FN(arg0, arg1, bool); \ |
35 | FN(arg0, arg1, uint8); \ |
36 | FN(arg0, arg1, uint16); \ |
37 | FN(arg0, arg1, uint32); \ |
38 | FN(arg0, arg1, uint64); \ |
39 | FN(arg0, arg1, int8); \ |
40 | FN(arg0, arg1, int16); \ |
41 | FN(arg0, arg1, int32); \ |
42 | FN(arg0, arg1, int64_t); \ |
43 | FN(arg0, arg1, float); \ |
44 | FN(arg0, arg1, double); \ |
45 | FN(arg0, arg1, std::complex<float>); \ |
46 | FN(arg0, arg1, std::complex<double>) |
47 | |
48 | #define CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ |
49 | CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ |
50 | FN(arg0, arg1, Eigen::half); |
51 | |
52 | #define CURRY_TYPES3(FN, arg0, arg1) \ |
53 | CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ |
54 | FN(arg0, arg1, bfloat16); |
55 | |
56 | #define CAST_CASE(DEVICE, IN, OUT) \ |
57 | if (DataTypeToEnum<OUT>::value == dst_dtype) { \ |
58 | return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \ |
59 | bool truncate) { \ |
60 | functor::CastFunctor<DEVICE, OUT, IN> func; \ |
61 | func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \ |
62 | truncate); \ |
63 | }; \ |
64 | } |
65 | |
66 | // The functions below are implemented in the cast_op_impl_*.cc files. |
67 | CastFunctorType GetCpuCastFromBool(DataType dst_dtype); |
68 | |
69 | CastFunctorType GetCpuCastFromUint8(DataType dst_dtype); |
70 | |
71 | CastFunctorType GetCpuCastFromUint16(DataType dst_dtype); |
72 | |
73 | CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); |
74 | |
75 | CastFunctorType GetCpuCastFromUint32(DataType dst_dtype); |
76 | |
77 | CastFunctorType GetCpuCastFromUint64(DataType dst_dtype); |
78 | |
79 | CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); |
80 | |
81 | CastFunctorType GetCpuCastFromInt16(DataType dst_dtype); |
82 | |
83 | CastFunctorType GetCpuCastFromInt32(DataType dst_dtype); |
84 | |
85 | CastFunctorType GetCpuCastFromInt64(DataType dst_dtype); |
86 | |
87 | CastFunctorType GetCpuCastFromHalf(DataType dst_dtype); |
88 | |
89 | CastFunctorType GetCpuCastFromFloat(DataType dst_dtype); |
90 | |
91 | CastFunctorType GetCpuCastFromDouble(DataType dst_dtype); |
92 | |
93 | CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype); |
94 | |
95 | CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype); |
96 | |
97 | CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype); |
98 | |
99 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
100 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
101 | // Same, for GPU. |
102 | CastFunctorType GetGpuCastFromBool(DataType dst_dtype); |
103 | |
104 | CastFunctorType GetGpuCastFromUint8(DataType dst_dtype); |
105 | |
106 | CastFunctorType GetGpuCastFromUint16(DataType dst_dtype); |
107 | |
108 | CastFunctorType GetGpuCastFromInt8(DataType dst_dtype); |
109 | |
110 | CastFunctorType GetGpuCastFromUint32(DataType dst_dtype); |
111 | |
112 | CastFunctorType GetGpuCastFromUint64(DataType dst_dtype); |
113 | |
114 | CastFunctorType GetGpuCastFromInt16(DataType dst_dtype); |
115 | |
116 | CastFunctorType GetGpuCastFromInt32(DataType dst_dtype); |
117 | |
118 | CastFunctorType GetGpuCastFromInt64(DataType dst_dtype); |
119 | |
120 | CastFunctorType GetGpuCastFromHalf(DataType dst_dtype); |
121 | |
122 | CastFunctorType GetGpuCastFromFloat(DataType dst_dtype); |
123 | |
124 | CastFunctorType GetGpuCastFromDouble(DataType dst_dtype); |
125 | |
126 | CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype); |
127 | |
128 | CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype); |
129 | |
130 | CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype); |
131 | |
132 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
133 | |
134 | |
135 | } // namespace tensorflow |
136 | |
137 | #endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ |
138 | |