1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ |
18 | // This file is used by cuda code and must remain compilable by nvcc. |
19 | |
20 | #include "tensorflow/core/framework/numeric_types.h" |
21 | #include "tensorflow/core/framework/resource_handle.h" |
22 | #include "tensorflow/core/framework/variant.h" |
23 | #include "tensorflow/core/platform/types.h" |
24 | |
25 | // Two sets of macros: |
26 | // - TF_CALL_float, TF_CALL_double, etc. which call the given macro with |
27 | // the type name as the only parameter - except on platforms for which |
28 | // the type should not be included. |
29 | // - Macros to apply another macro to lists of supported types. These also call |
30 | // into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform |
31 | // as well. |
32 | // If you change the lists of types, please also update the list in types.cc. |
33 | // |
34 | // See example uses of these macros in core/ops. |
35 | // |
36 | // |
37 | // Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple |
38 | // times by passing each invocation a data type supported by TensorFlow. |
39 | // |
40 | // The different variations pass different subsets of the types. |
41 | // TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. |
42 | // The set of types depends on the compilation platform. |
43 | //. |
44 | // This can be used to register a different template instantiation of |
45 | // an OpKernel for different signatures, e.g.: |
46 | /* |
47 | #define REGISTER_PARTITION(type) \ |
48 | REGISTER_KERNEL_BUILDER( \ |
49 | Name("Partition").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
50 | PartitionOp<type>); |
51 | TF_CALL_ALL_TYPES(REGISTER_PARTITION) |
52 | #undef REGISTER_PARTITION |
53 | */ |
54 | |
55 | #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \ |
56 | defined(ANDROID_TEGRA) |
57 | |
58 | // All types are supported, so all macros are invoked. |
59 | // |
60 | // Note: macros are defined in same order as types in types.proto, for |
61 | // readability. |
62 | #define TF_CALL_float(m) m(float) |
63 | #define TF_CALL_double(m) m(double) |
64 | #define TF_CALL_int32(m) m(::tensorflow::int32) |
65 | #define TF_CALL_uint32(m) m(::tensorflow::uint32) |
66 | #define TF_CALL_uint8(m) m(::tensorflow::uint8) |
67 | #define TF_CALL_int16(m) m(::tensorflow::int16) |
68 | |
69 | #define TF_CALL_int8(m) m(::tensorflow::int8) |
70 | #define TF_CALL_string(m) m(::tensorflow::tstring) |
71 | #define TF_CALL_tstring(m) m(::tensorflow::tstring) |
72 | #define TF_CALL_resource(m) m(::tensorflow::ResourceHandle) |
73 | #define TF_CALL_variant(m) m(::tensorflow::Variant) |
74 | #define TF_CALL_complex64(m) m(::tensorflow::complex64) |
75 | #define TF_CALL_int64(m) m(::int64_t) |
76 | #define TF_CALL_uint64(m) m(::tensorflow::uint64) |
77 | #define TF_CALL_bool(m) m(bool) |
78 | |
79 | #define TF_CALL_qint8(m) m(::tensorflow::qint8) |
80 | #define TF_CALL_quint8(m) m(::tensorflow::quint8) |
81 | #define TF_CALL_qint32(m) m(::tensorflow::qint32) |
82 | #define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16) |
83 | #define TF_CALL_qint16(m) m(::tensorflow::qint16) |
84 | |
85 | #define TF_CALL_quint16(m) m(::tensorflow::quint16) |
86 | #define TF_CALL_uint16(m) m(::tensorflow::uint16) |
87 | #define TF_CALL_complex128(m) m(::tensorflow::complex128) |
88 | #define TF_CALL_half(m) m(Eigen::half) |
89 | |
90 | #elif defined(__ANDROID_TYPES_FULL__) |
91 | |
92 | // Only string, half, float, int32, int64, bool, and quantized types |
93 | // supported. |
94 | #define TF_CALL_float(m) m(float) |
95 | #define TF_CALL_double(m) |
96 | #define TF_CALL_int32(m) m(::tensorflow::int32) |
97 | #define TF_CALL_uint32(m) |
98 | #define TF_CALL_uint8(m) |
99 | #define TF_CALL_int16(m) |
100 | |
101 | #define TF_CALL_int8(m) |
102 | #define TF_CALL_string(m) m(::tensorflow::tstring) |
103 | #define TF_CALL_tstring(m) m(::tensorflow::tstring) |
104 | #define TF_CALL_resource(m) |
105 | #define TF_CALL_variant(m) |
106 | #define TF_CALL_complex64(m) |
107 | #define TF_CALL_int64(m) m(::int64_t) |
108 | #define TF_CALL_uint64(m) |
109 | #define TF_CALL_bool(m) m(bool) |
110 | |
111 | #define TF_CALL_qint8(m) m(::tensorflow::qint8) |
112 | #define TF_CALL_quint8(m) m(::tensorflow::quint8) |
113 | #define TF_CALL_qint32(m) m(::tensorflow::qint32) |
114 | #define TF_CALL_bfloat16(m) |
115 | #define TF_CALL_qint16(m) m(::tensorflow::qint16) |
116 | |
117 | #define TF_CALL_quint16(m) m(::tensorflow::quint16) |
118 | #define TF_CALL_uint16(m) |
119 | #define TF_CALL_complex128(m) |
120 | #define TF_CALL_half(m) m(Eigen::half) |
121 | |
122 | #else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) |
123 | |
124 | // Only float, int32, and bool are supported. |
125 | #define TF_CALL_float(m) m(float) |
126 | #define TF_CALL_double(m) |
127 | #define TF_CALL_int32(m) m(::tensorflow::int32) |
128 | #define TF_CALL_uint32(m) |
129 | #define TF_CALL_uint8(m) |
130 | #define TF_CALL_int16(m) |
131 | |
132 | #define TF_CALL_int8(m) |
133 | #define TF_CALL_string(m) |
134 | #define TF_CALL_tstring(m) |
135 | #define TF_CALL_resource(m) |
136 | #define TF_CALL_variant(m) |
137 | #define TF_CALL_complex64(m) |
138 | #define TF_CALL_int64(m) |
139 | #define TF_CALL_uint64(m) |
140 | #define TF_CALL_bool(m) m(bool) |
141 | |
142 | #define TF_CALL_qint8(m) |
143 | #define TF_CALL_quint8(m) |
144 | #define TF_CALL_qint32(m) |
145 | #define TF_CALL_bfloat16(m) |
146 | #define TF_CALL_qint16(m) |
147 | |
148 | #define TF_CALL_quint16(m) |
149 | #define TF_CALL_uint16(m) |
150 | #define TF_CALL_complex128(m) |
151 | #define TF_CALL_half(m) |
152 | |
153 | #endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines |
154 | |
155 | // Defines for sets of types. |
156 | #define TF_CALL_INTEGRAL_TYPES_NO_INT32(m) \ |
157 | TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \ |
158 | TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m) |
159 | |
160 | #define TF_CALL_INTEGRAL_TYPES(m) \ |
161 | TF_CALL_INTEGRAL_TYPES_NO_INT32(m) TF_CALL_int32(m) |
162 | |
163 | #define TF_CALL_FLOAT_TYPES(m) \ |
164 | TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) |
165 | |
166 | #define TF_CALL_REAL_NUMBER_TYPES(m) \ |
167 | TF_CALL_INTEGRAL_TYPES(m) TF_CALL_FLOAT_TYPES(m) |
168 | |
169 | #define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ |
170 | TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) |
171 | |
172 | #define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ |
173 | TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ |
174 | TF_CALL_INTEGRAL_TYPES_NO_INT32(m) |
175 | |
176 | #define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) |
177 | |
178 | // Call "m" for all number types, including complex types |
179 | #define TF_CALL_NUMBER_TYPES(m) \ |
180 | TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_COMPLEX_TYPES(m) |
181 | |
182 | #define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ |
183 | TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) TF_CALL_COMPLEX_TYPES(m) |
184 | |
185 | #define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m) |
186 | |
187 | // Call "m" on all types. |
188 | #define TF_CALL_ALL_TYPES(m) \ |
189 | TF_CALL_POD_TYPES(m) TF_CALL_tstring(m) TF_CALL_resource(m) TF_CALL_variant(m) |
190 | |
191 | // Call "m" on POD and string types. |
192 | #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_tstring(m) |
193 | |
194 | // Call "m" on all number types supported on GPU. |
195 | #define TF_CALL_GPU_NUMBER_TYPES(m) \ |
196 | TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) |
197 | |
198 | // Call "m" on all types supported on GPU. |
199 | #define TF_CALL_GPU_ALL_TYPES(m) \ |
200 | TF_CALL_GPU_NUMBER_TYPES(m) TF_CALL_COMPLEX_TYPES(m) TF_CALL_bool(m) |
201 | |
202 | #define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m) |
203 | |
204 | // Call "m" on all quantized types. |
205 | // TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m) |
206 | #define TF_CALL_QUANTIZED_TYPES(m) \ |
207 | TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m) |
208 | |
209 | // Types used for save and restore ops. |
210 | #define TF_CALL_SAVE_RESTORE_TYPES(m) \ |
211 | TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ |
212 | TF_CALL_COMPLEX_TYPES(m) \ |
213 | TF_CALL_QUANTIZED_TYPES(m) TF_CALL_bool(m) TF_CALL_tstring(m) |
214 | |
215 | #endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ |
216 | |