1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
17#define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
18// This file is used by cuda code and must remain compilable by nvcc.
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21typedef Eigen::ThreadPoolDevice CPUDevice;
22typedef Eigen::GpuDevice GPUDevice;
23
24
25#include "tensorflow/core/framework/numeric_types.h"
26#include "tensorflow/core/platform/types.h"
27
28namespace tensorflow {
29
30// Remap POD types by size to equivalent proxy types. This works
31// since all we are doing is copying data around.
32struct UnusableProxyType;
33template <typename Device, int size>
34struct proxy_type_pod {
35 typedef UnusableProxyType type;
36};
37template <>
38struct proxy_type_pod<CPUDevice, 16> {
39 typedef ::tensorflow::complex128 type;
40};
41template <>
42struct proxy_type_pod<CPUDevice, 8> {
43 typedef ::int64_t type;
44};
45template <>
46struct proxy_type_pod<CPUDevice, 4> {
47 typedef ::tensorflow::int32 type;
48};
49template <>
50struct proxy_type_pod<CPUDevice, 2> {
51 typedef ::tensorflow::int16 type;
52};
53template <>
54struct proxy_type_pod<CPUDevice, 1> {
55 typedef ::tensorflow::int8 type;
56};
57template <>
58struct proxy_type_pod<GPUDevice, 8> {
59 typedef double type;
60};
61template <>
62struct proxy_type_pod<GPUDevice, 4> {
63 typedef float type;
64};
65template <>
66struct proxy_type_pod<GPUDevice, 2> {
67 typedef Eigen::half type;
68};
69template <>
70struct proxy_type_pod<GPUDevice, 1> {
71 typedef ::tensorflow::int8 type;
72};
73
74
75/// If POD we use proxy_type_pod, otherwise this maps to identity.
76template <typename Device, typename T>
77struct proxy_type {
78 typedef typename std::conditional<
79 std::is_arithmetic<T>::value,
80 typename proxy_type_pod<Device, sizeof(T)>::type, T>::type type;
81 static_assert(sizeof(type) == sizeof(T), "proxy_type_pod is not valid");
82};
83
84/// The active proxy types
85#define TF_CALL_CPU_PROXY_TYPES(m) \
86 TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
87 TF_CALL_int8(m) TF_CALL_complex128(m)
88#define TF_CALL_GPU_PROXY_TYPES(m) \
89 TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m) \
90 TF_CALL_int8(m)
91} // namespace tensorflow
92
93#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
94