1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_ |
18 | |
19 | #include <numeric> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_types.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | |
27 | namespace tensorflow { |
28 | // Transpose tensor 'in' into tensor 'out' according to dimension |
29 | // permutation 'perm'. |
30 | // |
31 | // REQUIRES: in.dtype() == out->dtype() |
32 | // REQUIRES: in.dims() == out->dims() |
33 | // REQUIRES: in.dims() == perm.size() |
34 | // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) |
35 | template <typename Device> |
36 | Status DoTranspose(const Device& device, const Tensor& in, |
37 | const gtl::ArraySlice<int32> perm, Tensor* out); |
38 | |
39 | // Conjugate and transpose tensor 'in' into tensor 'out' according to dimension |
40 | // permutation 'perm'. |
41 | // |
42 | // REQUIRES: in.dtype() == out->dtype() |
43 | // REQUIRES: in.dims() == out->dims() |
44 | // REQUIRES: in.dims() == perm.size() |
45 | // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) |
46 | template <typename Device> |
47 | Status DoConjugateTranspose(const Device& device, const Tensor& in, |
48 | const gtl::ArraySlice<int32> perm, Tensor* out); |
49 | |
50 | // Convenience versions of DoTranspose that only swap the last (inner) two |
51 | // dimensions. |
52 | template <typename Device> |
53 | Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out); |
54 | |
55 | // Convenience versions of DoConjugateTranspose that only swap the last (inner) |
56 | // two dimensions. |
57 | template <typename Device> |
58 | Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in, |
59 | Tensor* out); |
60 | |
61 | // Primary device specific functor to be specialized for each device and type. |
62 | template <typename Device, typename T, bool conjugate = false> |
63 | struct Transpose { |
64 | static void run(const Device& d, const Tensor& in, |
65 | const gtl::ArraySlice<int32> perm, Tensor* out); |
66 | }; |
67 | |
68 | // Implementation details. |
69 | namespace internal { |
70 | |
71 | typedef gtl::InlinedVector<int64_t, 8> TransposeDimsVec; |
72 | typedef gtl::InlinedVector<int32, 8> TransposePermsVec; |
73 | |
74 | // Helper function that takes a tensor shape, a permutation, combines the |
75 | // neighboring shapes if their indices in the permutation are consecutive. |
76 | // The function outputs the combined shape and new permutation. |
77 | // Example: Tensor shape {2, 3, 4, 5, 120} and permutation {0, 4, 1, 2, 3} will |
78 | // produce new shape {2, 60, 120} and new permutation {0, 2, 1}. |
79 | inline void ReduceTransposeDimensions(const TensorShape& shape, |
80 | gtl::ArraySlice<int32> perm, |
81 | TransposePermsVec* new_perm, |
82 | TransposeDimsVec* new_dims) { |
83 | CHECK_EQ(shape.dims(), perm.size()); |
84 | if (shape.dims() == 1) { |
85 | // If input dimension is already 1, no need to reduce dimension. |
86 | new_perm->resize(1); |
87 | (*new_perm)[0] = perm[0]; |
88 | (*new_dims)[0] = shape.dim_size(0); |
89 | return; |
90 | } |
91 | TransposePermsVec new_dim_position(shape.dims(), -1); |
92 | TransposeDimsVec combined_dims(shape.dims(), 0); |
93 | int cur_head = perm[0]; |
94 | new_dim_position[cur_head] = 0; |
95 | combined_dims[0] = shape.dim_size(cur_head); |
96 | int dim_idx = 0; |
97 | for (int perm_idx = 1; perm_idx < shape.dims(); ++perm_idx) { |
98 | // If two indices in permutation are consecutive numbers, combine their |
99 | // dimensions. |
100 | if (cur_head + 1 == perm[perm_idx]) { |
101 | cur_head = perm[perm_idx]; |
102 | combined_dims[dim_idx] *= shape.dim_size(cur_head); |
103 | } else { |
104 | // Else start a new dimension. |
105 | cur_head = perm[perm_idx]; |
106 | dim_idx++; |
107 | new_dim_position[cur_head] = dim_idx; |
108 | combined_dims[dim_idx] = shape.dim_size(cur_head); |
109 | } |
110 | } |
111 | // Compact the new permutations and dimension sizes. |
112 | new_perm->resize(dim_idx + 1); |
113 | new_dims->resize(dim_idx + 1); |
114 | dim_idx = 0; |
115 | for (int i = 0; i < new_dim_position.size(); ++i) { |
116 | if (new_dim_position[i] >= 0) { |
117 | int new_perm_idx = new_dim_position[i]; |
118 | (*new_perm)[dim_idx] = new_perm_idx; |
119 | (*new_dims)[dim_idx] = combined_dims[new_perm_idx]; |
120 | dim_idx++; |
121 | } |
122 | } |
123 | } |
124 | |
125 | // If all non-singleton dimensions remain in ascending order, the shuffled |
126 | // singletons can be transposed by a reshape, saving a memory allocation & copy. |
127 | // |permutation| must be a permutation of {0, .., input_shape.dims() - 1}. |
128 | // That is, for all i, 0 <= perm[i] < input_shape.dims(). |
129 | // In practice, this is checked in TransposeOp::Compute prior to calling this |
130 | // function, and the function sits here to facilitate unit testing. |
131 | inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape, |
132 | const std::vector<int32>& permutation) { |
133 | int last_nonsingleton_perm_dim = -1; |
134 | for (int perm_dim : permutation) { |
135 | if (input_shape.dim_size(perm_dim) == 1) { |
136 | continue; |
137 | } |
138 | if (perm_dim < last_nonsingleton_perm_dim) { |
139 | return false; |
140 | } |
141 | last_nonsingleton_perm_dim = perm_dim; |
142 | } |
143 | return true; |
144 | } |
145 | |
146 | // Uses Eigen to transpose. |
147 | template <typename Device, typename T, int NDIMS> |
148 | void TransposeUsingEigen(const Device& d, const Tensor& in, |
149 | const gtl::ArraySlice<int32> perm, bool conjugate, |
150 | Tensor* out) { |
151 | Eigen::array<int, NDIMS> p; |
152 | for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; |
153 | auto x = typename TTypes<T, NDIMS>::ConstTensor( |
154 | reinterpret_cast<const T*>(in.tensor_data().data()), |
155 | in.shape().AsEigenDSizes<NDIMS>()); |
156 | auto y = typename TTypes<T, NDIMS>::Tensor( |
157 | reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())), |
158 | out->shape().AsEigenDSizes<NDIMS>()); |
159 | if (conjugate) { |
160 | y.device(d) = x.conjugate().shuffle(p); |
161 | } else { |
162 | y.device(d) = x.shuffle(p); |
163 | } |
164 | } |
165 | |
166 | template <typename Device> |
167 | Status DoTransposeImpl(const Device& d, const Tensor& in, |
168 | const gtl::ArraySlice<int32> perm, bool conjugate, |
169 | Tensor* out) { |
170 | CHECK_EQ(in.dims(), out->dims()); |
171 | CHECK_EQ(in.dims(), perm.size()); |
172 | CHECK_EQ(in.dtype(), out->dtype()); |
173 | switch (in.dtype()) { |
174 | case DT_BOOL: |
175 | case DT_INT8: |
176 | case DT_QINT8: |
177 | case DT_QUINT8: |
178 | case DT_UINT8: |
179 | Transpose<Device, uint8>::run(d, in, perm, out); |
180 | break; |
181 | |
182 | case DT_BFLOAT16: |
183 | case DT_HALF: |
184 | case DT_INT16: |
185 | case DT_QINT16: |
186 | case DT_QUINT16: |
187 | case DT_UINT16: |
188 | Transpose<Device, uint16>::run(d, in, perm, out); |
189 | break; |
190 | |
191 | case DT_FLOAT: |
192 | case DT_INT32: |
193 | case DT_QINT32: |
194 | case DT_UINT32: |
195 | Transpose<Device, uint32>::run(d, in, perm, out); |
196 | break; |
197 | |
198 | case DT_DOUBLE: |
199 | case DT_INT64: |
200 | case DT_UINT64: |
201 | Transpose<Device, uint64>::run(d, in, perm, out); |
202 | break; |
203 | |
204 | case DT_COMPLEX64: |
205 | if (conjugate) { |
206 | #if defined(__ANDROID__) and !defined(__clang__) |
207 | // Workaround for GCC compiler bug in Android toolchain. |
208 | return errors::Unimplemented( |
209 | "Conjugate transpose of complex64 not supported for GCC on " |
210 | "Android." ); |
211 | #else |
212 | Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out); |
213 | #endif |
214 | } else { |
215 | Transpose<Device, uint64>::run(d, in, perm, out); |
216 | } |
217 | break; |
218 | |
219 | case DT_COMPLEX128: |
220 | if (conjugate) { |
221 | Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm, |
222 | out); |
223 | } else { |
224 | Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm, |
225 | out); |
226 | } |
227 | break; |
228 | |
229 | case DT_STRING: |
230 | Transpose<Device, tstring>::run(d, in, perm, out); |
231 | break; |
232 | |
233 | default: |
234 | return errors::Unimplemented("Unsupported dtype on CPU: " , in.dtype()); |
235 | } |
236 | return OkStatus(); |
237 | } |
238 | |
239 | template <typename Device> |
240 | inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in, |
241 | bool conjugate, Tensor* out) { |
242 | const int ndims = in.dims(); |
243 | if (ndims == 0) return OkStatus(); |
244 | TransposePermsVec perm(ndims); |
245 | std::iota(perm.begin(), perm.end(), 0); |
246 | std::swap(perm[ndims - 2], perm[ndims - 1]); |
247 | return DoTransposeImpl(device, in, perm, conjugate, out); |
248 | } |
249 | |
250 | } // namespace internal |
251 | } // namespace tensorflow |
252 | |
253 | #endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_ |
254 | |