1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
17#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/platform/logging.h"
21
22namespace tensorflow {
23
24// Helper to define Tensor types given that the scalar is of type T.
25template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
26struct TTypes {
27 // Rank-<NDIMS> tensor of scalar type T.
28 typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
29 Eigen::Aligned>
30 Tensor;
31 typedef Eigen::TensorMap<
32 Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
33 ConstTensor;
34
35 // Unaligned Rank-<NDIMS> tensor of scalar type T.
36 typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType> >
37 UnalignedTensor;
38 typedef Eigen::TensorMap<
39 Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType> >
40 UnalignedConstTensor;
41
42 typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
43 Eigen::Aligned>
44 Tensor32Bit;
45
46 // Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
47 typedef Eigen::TensorMap<
48 Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
49 Eigen::Aligned>
50 Scalar;
51 typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
52 Eigen::RowMajor, IndexType>,
53 Eigen::Aligned>
54 ConstScalar;
55
56 // Unaligned Scalar tensor of scalar type T.
57 typedef Eigen::TensorMap<
58 Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> >
59 UnalignedScalar;
60 typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
61 Eigen::RowMajor, IndexType> >
62 UnalignedConstScalar;
63
64 // Rank-1 tensor (vector) of scalar type T.
65 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
66 Eigen::Aligned>
67 Flat;
68 typedef Eigen::TensorMap<
69 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
70 ConstFlat;
71 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
72 Eigen::Aligned>
73 Vec;
74 typedef Eigen::TensorMap<
75 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
76 ConstVec;
77
78 // Unaligned Rank-1 tensor (vector) of scalar type T.
79 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
80 UnalignedFlat;
81 typedef Eigen::TensorMap<
82 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
83 UnalignedConstFlat;
84 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
85 UnalignedVec;
86 typedef Eigen::TensorMap<
87 Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
88 UnalignedConstVec;
89
90 // Rank-2 tensor (matrix) of scalar type T.
91 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
92 Eigen::Aligned>
93 Matrix;
94 typedef Eigen::TensorMap<
95 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
96 ConstMatrix;
97
98 // Unaligned Rank-2 tensor (matrix) of scalar type T.
99 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType> >
100 UnalignedMatrix;
101 typedef Eigen::TensorMap<
102 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType> >
103 UnalignedConstMatrix;
104};
105
106typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
107
108template <typename Index, int NumDims>
109bool SafeFor32BitIndexing(const Eigen::DSizes<Index, NumDims>& in) {
110 for (int i = 0; i < NumDims; ++i) {
111 if (in[i] > std::numeric_limits<Index32>::max()) return false;
112 }
113 return true;
114}
115
116template <typename Index, size_t NumDims>
117bool SafeFor32BitIndexing(const Eigen::array<Index, NumDims>& in) {
118 for (size_t i = 0; i < NumDims; ++i) {
119 if (in[i] > std::numeric_limits<Index32>::max()) return false;
120 }
121 return true;
122}
123
124template <typename TensorType,
125 typename Enable = typename TTypes<
126 typename TensorType::Scalar, TensorType::NumIndices>::Tensor32Bit>
127bool SafeFor32BitIndexing(TensorType in) {
128 return in.size() <= std::numeric_limits<Index32>::max();
129}
130
131template <typename Index, int NumDims>
132Eigen::DSizes<Index32, NumDims> To32Bit(
133 const Eigen::DSizes<Index, NumDims>& in) {
134 DCHECK(SafeFor32BitIndexing(in));
135 Eigen::DSizes<Index32, NumDims> out;
136 for (int i = 0; i < NumDims; ++i) {
137 out[i] = static_cast<Index32>(in[i]);
138 }
139 return out;
140}
141
142template <typename Index, size_t NumDims>
143Eigen::array<Index32, NumDims> To32Bit(const Eigen::array<Index, NumDims>& in) {
144 DCHECK(SafeFor32BitIndexing(in));
145 Eigen::array<Index32, NumDims> out;
146 for (size_t i = 0; i < NumDims; ++i) {
147 out[i] = static_cast<Index32>(in[i]);
148 }
149 return out;
150}
151
152template <typename TensorType>
153typename TTypes<typename TensorType::Scalar,
154 TensorType::NumIndices>::Tensor32Bit
155To32Bit(TensorType in) {
156 typedef typename TTypes<typename TensorType::Scalar,
157 TensorType::NumIndices>::Tensor32Bit RetType;
158 DCHECK(SafeFor32BitIndexing(in));
159 return RetType(in.data(), To32Bit(in.dimensions()));
160}
161
162namespace internal {
163
164template <typename Device>
165struct MaybeWith32BitIndexingImpl {
166 template <typename Func, typename... Args>
167 void operator()(Func func, Args&&... args) const {
168 func(std::forward<Args>(args)...);
169 }
170};
171
172template <>
173struct MaybeWith32BitIndexingImpl<Eigen::GpuDevice> {
174 template <typename Func, typename... Args>
175 void operator()(Func func, Args&&... args) const {
176 auto all = [](const auto&... bool_vals) {
177 for (bool b : {bool_vals...}) {
178 if (!b) return false;
179 }
180 return true;
181 };
182 if (all(SafeFor32BitIndexing(std::forward<Args>(args))...)) {
183 func(To32Bit(std::forward<Args>(args))...);
184 } else {
185 func(std::forward<Args>(args)...);
186 }
187 }
188};
189
190} // namespace internal
191
192template <typename Device, typename Func, typename... Args>
193void MaybeWith32BitIndexing(Func func, Args&&... args) {
194 return internal::MaybeWith32BitIndexingImpl<Device>()(
195 func, std::forward<Args>(args)...);
196}
197
198} // namespace tensorflow
199#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
200