1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ |
18 | // This file is used by cuda code and must remain compilable by nvcc. |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | typedef Eigen::ThreadPoolDevice CPUDevice; |
22 | typedef Eigen::GpuDevice GPUDevice; |
23 | |
24 | |
25 | #include "tensorflow/core/framework/numeric_types.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | // Remap POD types by size to equivalent proxy types. This works |
31 | // since all we are doing is copying data around. |
32 | struct UnusableProxyType; |
33 | template <typename Device, int size> |
34 | struct proxy_type_pod { |
35 | typedef UnusableProxyType type; |
36 | }; |
37 | template <> |
38 | struct proxy_type_pod<CPUDevice, 16> { |
39 | typedef ::tensorflow::complex128 type; |
40 | }; |
41 | template <> |
42 | struct proxy_type_pod<CPUDevice, 8> { |
43 | typedef ::int64_t type; |
44 | }; |
45 | template <> |
46 | struct proxy_type_pod<CPUDevice, 4> { |
47 | typedef ::tensorflow::int32 type; |
48 | }; |
49 | template <> |
50 | struct proxy_type_pod<CPUDevice, 2> { |
51 | typedef ::tensorflow::int16 type; |
52 | }; |
53 | template <> |
54 | struct proxy_type_pod<CPUDevice, 1> { |
55 | typedef ::tensorflow::int8 type; |
56 | }; |
57 | template <> |
58 | struct proxy_type_pod<GPUDevice, 8> { |
59 | typedef double type; |
60 | }; |
61 | template <> |
62 | struct proxy_type_pod<GPUDevice, 4> { |
63 | typedef float type; |
64 | }; |
65 | template <> |
66 | struct proxy_type_pod<GPUDevice, 2> { |
67 | typedef Eigen::half type; |
68 | }; |
69 | template <> |
70 | struct proxy_type_pod<GPUDevice, 1> { |
71 | typedef ::tensorflow::int8 type; |
72 | }; |
73 | |
74 | |
75 | /// If POD we use proxy_type_pod, otherwise this maps to identity. |
76 | template <typename Device, typename T> |
77 | struct proxy_type { |
78 | typedef typename std::conditional< |
79 | std::is_arithmetic<T>::value, |
80 | typename proxy_type_pod<Device, sizeof(T)>::type, T>::type type; |
81 | static_assert(sizeof(type) == sizeof(T), "proxy_type_pod is not valid" ); |
82 | }; |
83 | |
84 | /// The active proxy types |
85 | #define TF_CALL_CPU_PROXY_TYPES(m) \ |
86 | TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ |
87 | TF_CALL_int8(m) TF_CALL_complex128(m) |
88 | #define TF_CALL_GPU_PROXY_TYPES(m) \ |
89 | TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m) \ |
90 | TF_CALL_int8(m) |
91 | } // namespace tensorflow |
92 | |
93 | #endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ |
94 | |