1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/platform/logging.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | // Helper to define Tensor types given that the scalar is of type T. |
25 | template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex> |
26 | struct TTypes { |
27 | // Rank-<NDIMS> tensor of scalar type T. |
28 | typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>, |
29 | Eigen::Aligned> |
30 | Tensor; |
31 | typedef Eigen::TensorMap< |
32 | Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned> |
33 | ConstTensor; |
34 | |
35 | // Unaligned Rank-<NDIMS> tensor of scalar type T. |
36 | typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType> > |
37 | UnalignedTensor; |
38 | typedef Eigen::TensorMap< |
39 | Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType> > |
40 | UnalignedConstTensor; |
41 | |
42 | typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>, |
43 | Eigen::Aligned> |
44 | Tensor32Bit; |
45 | |
46 | // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. |
47 | typedef Eigen::TensorMap< |
48 | Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>, |
49 | Eigen::Aligned> |
50 | Scalar; |
51 | typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>, |
52 | Eigen::RowMajor, IndexType>, |
53 | Eigen::Aligned> |
54 | ConstScalar; |
55 | |
56 | // Unaligned Scalar tensor of scalar type T. |
57 | typedef Eigen::TensorMap< |
58 | Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> > |
59 | UnalignedScalar; |
60 | typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>, |
61 | Eigen::RowMajor, IndexType> > |
62 | UnalignedConstScalar; |
63 | |
64 | // Rank-1 tensor (vector) of scalar type T. |
65 | typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>, |
66 | Eigen::Aligned> |
67 | Flat; |
68 | typedef Eigen::TensorMap< |
69 | Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned> |
70 | ConstFlat; |
71 | typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>, |
72 | Eigen::Aligned> |
73 | Vec; |
74 | typedef Eigen::TensorMap< |
75 | Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned> |
76 | ConstVec; |
77 | |
78 | // Unaligned Rank-1 tensor (vector) of scalar type T. |
79 | typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> > |
80 | UnalignedFlat; |
81 | typedef Eigen::TensorMap< |
82 | Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> > |
83 | UnalignedConstFlat; |
84 | typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> > |
85 | UnalignedVec; |
86 | typedef Eigen::TensorMap< |
87 | Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> > |
88 | UnalignedConstVec; |
89 | |
90 | // Rank-2 tensor (matrix) of scalar type T. |
91 | typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>, |
92 | Eigen::Aligned> |
93 | Matrix; |
94 | typedef Eigen::TensorMap< |
95 | Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned> |
96 | ConstMatrix; |
97 | |
98 | // Unaligned Rank-2 tensor (matrix) of scalar type T. |
99 | typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType> > |
100 | UnalignedMatrix; |
101 | typedef Eigen::TensorMap< |
102 | Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType> > |
103 | UnalignedConstMatrix; |
104 | }; |
105 | |
106 | typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32; |
107 | |
108 | template <typename Index, int NumDims> |
109 | bool SafeFor32BitIndexing(const Eigen::DSizes<Index, NumDims>& in) { |
110 | for (int i = 0; i < NumDims; ++i) { |
111 | if (in[i] > std::numeric_limits<Index32>::max()) return false; |
112 | } |
113 | return true; |
114 | } |
115 | |
116 | template <typename Index, size_t NumDims> |
117 | bool SafeFor32BitIndexing(const Eigen::array<Index, NumDims>& in) { |
118 | for (size_t i = 0; i < NumDims; ++i) { |
119 | if (in[i] > std::numeric_limits<Index32>::max()) return false; |
120 | } |
121 | return true; |
122 | } |
123 | |
124 | template <typename TensorType, |
125 | typename Enable = typename TTypes< |
126 | typename TensorType::Scalar, TensorType::NumIndices>::Tensor32Bit> |
127 | bool SafeFor32BitIndexing(TensorType in) { |
128 | return in.size() <= std::numeric_limits<Index32>::max(); |
129 | } |
130 | |
131 | template <typename Index, int NumDims> |
132 | Eigen::DSizes<Index32, NumDims> To32Bit( |
133 | const Eigen::DSizes<Index, NumDims>& in) { |
134 | DCHECK(SafeFor32BitIndexing(in)); |
135 | Eigen::DSizes<Index32, NumDims> out; |
136 | for (int i = 0; i < NumDims; ++i) { |
137 | out[i] = static_cast<Index32>(in[i]); |
138 | } |
139 | return out; |
140 | } |
141 | |
142 | template <typename Index, size_t NumDims> |
143 | Eigen::array<Index32, NumDims> To32Bit(const Eigen::array<Index, NumDims>& in) { |
144 | DCHECK(SafeFor32BitIndexing(in)); |
145 | Eigen::array<Index32, NumDims> out; |
146 | for (size_t i = 0; i < NumDims; ++i) { |
147 | out[i] = static_cast<Index32>(in[i]); |
148 | } |
149 | return out; |
150 | } |
151 | |
152 | template <typename TensorType> |
153 | typename TTypes<typename TensorType::Scalar, |
154 | TensorType::NumIndices>::Tensor32Bit |
155 | To32Bit(TensorType in) { |
156 | typedef typename TTypes<typename TensorType::Scalar, |
157 | TensorType::NumIndices>::Tensor32Bit RetType; |
158 | DCHECK(SafeFor32BitIndexing(in)); |
159 | return RetType(in.data(), To32Bit(in.dimensions())); |
160 | } |
161 | |
162 | namespace internal { |
163 | |
164 | template <typename Device> |
165 | struct MaybeWith32BitIndexingImpl { |
166 | template <typename Func, typename... Args> |
167 | void operator()(Func func, Args&&... args) const { |
168 | func(std::forward<Args>(args)...); |
169 | } |
170 | }; |
171 | |
172 | template <> |
173 | struct MaybeWith32BitIndexingImpl<Eigen::GpuDevice> { |
174 | template <typename Func, typename... Args> |
175 | void operator()(Func func, Args&&... args) const { |
176 | auto all = [](const auto&... bool_vals) { |
177 | for (bool b : {bool_vals...}) { |
178 | if (!b) return false; |
179 | } |
180 | return true; |
181 | }; |
182 | if (all(SafeFor32BitIndexing(std::forward<Args>(args))...)) { |
183 | func(To32Bit(std::forward<Args>(args))...); |
184 | } else { |
185 | func(std::forward<Args>(args)...); |
186 | } |
187 | } |
188 | }; |
189 | |
190 | } // namespace internal |
191 | |
192 | template <typename Device, typename Func, typename... Args> |
193 | void MaybeWith32BitIndexing(Func func, Args&&... args) { |
194 | return internal::MaybeWith32BitIndexingImpl<Device>()( |
195 | func, std::forward<Args>(args)...); |
196 | } |
197 | |
198 | } // namespace tensorflow |
199 | #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_ |
200 | |