1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
17#define TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
18
19#include <numeric>
20#include <string>
21#include <vector>
22
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_types.h"
25#include "tensorflow/core/platform/logging.h"
26
27namespace tensorflow {
28// Transpose tensor 'in' into tensor 'out' according to dimension
29// permutation 'perm'.
30//
31// REQUIRES: in.dtype() == out->dtype()
32// REQUIRES: in.dims() == out->dims()
33// REQUIRES: in.dims() == perm.size()
34// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
35template <typename Device>
36Status DoTranspose(const Device& device, const Tensor& in,
37 const gtl::ArraySlice<int32> perm, Tensor* out);
38
39// Conjugate and transpose tensor 'in' into tensor 'out' according to dimension
40// permutation 'perm'.
41//
42// REQUIRES: in.dtype() == out->dtype()
43// REQUIRES: in.dims() == out->dims()
44// REQUIRES: in.dims() == perm.size()
45// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
46template <typename Device>
47Status DoConjugateTranspose(const Device& device, const Tensor& in,
48 const gtl::ArraySlice<int32> perm, Tensor* out);
49
50// Convenience versions of DoTranspose that only swap the last (inner) two
51// dimensions.
52template <typename Device>
53Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out);
54
55// Convenience versions of DoConjugateTranspose that only swap the last (inner)
56// two dimensions.
57template <typename Device>
58Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in,
59 Tensor* out);
60
61// Primary device specific functor to be specialized for each device and type.
62template <typename Device, typename T, bool conjugate = false>
63struct Transpose {
64 static void run(const Device& d, const Tensor& in,
65 const gtl::ArraySlice<int32> perm, Tensor* out);
66};
67
68// Implementation details.
69namespace internal {
70
71typedef gtl::InlinedVector<int64_t, 8> TransposeDimsVec;
72typedef gtl::InlinedVector<int32, 8> TransposePermsVec;
73
74// Helper function that takes a tensor shape, a permutation, combines the
75// neighboring shapes if their indices in the permutation are consecutive.
76// The function outputs the combined shape and new permutation.
77// Example: Tensor shape {2, 3, 4, 5, 120} and permutation {0, 4, 1, 2, 3} will
78// produce new shape {2, 60, 120} and new permutation {0, 2, 1}.
79inline void ReduceTransposeDimensions(const TensorShape& shape,
80 gtl::ArraySlice<int32> perm,
81 TransposePermsVec* new_perm,
82 TransposeDimsVec* new_dims) {
83 CHECK_EQ(shape.dims(), perm.size());
84 if (shape.dims() == 1) {
85 // If input dimension is already 1, no need to reduce dimension.
86 new_perm->resize(1);
87 (*new_perm)[0] = perm[0];
88 (*new_dims)[0] = shape.dim_size(0);
89 return;
90 }
91 TransposePermsVec new_dim_position(shape.dims(), -1);
92 TransposeDimsVec combined_dims(shape.dims(), 0);
93 int cur_head = perm[0];
94 new_dim_position[cur_head] = 0;
95 combined_dims[0] = shape.dim_size(cur_head);
96 int dim_idx = 0;
97 for (int perm_idx = 1; perm_idx < shape.dims(); ++perm_idx) {
98 // If two indices in permutation are consecutive numbers, combine their
99 // dimensions.
100 if (cur_head + 1 == perm[perm_idx]) {
101 cur_head = perm[perm_idx];
102 combined_dims[dim_idx] *= shape.dim_size(cur_head);
103 } else {
104 // Else start a new dimension.
105 cur_head = perm[perm_idx];
106 dim_idx++;
107 new_dim_position[cur_head] = dim_idx;
108 combined_dims[dim_idx] = shape.dim_size(cur_head);
109 }
110 }
111 // Compact the new permutations and dimension sizes.
112 new_perm->resize(dim_idx + 1);
113 new_dims->resize(dim_idx + 1);
114 dim_idx = 0;
115 for (int i = 0; i < new_dim_position.size(); ++i) {
116 if (new_dim_position[i] >= 0) {
117 int new_perm_idx = new_dim_position[i];
118 (*new_perm)[dim_idx] = new_perm_idx;
119 (*new_dims)[dim_idx] = combined_dims[new_perm_idx];
120 dim_idx++;
121 }
122 }
123}
124
125// If all non-singleton dimensions remain in ascending order, the shuffled
126// singletons can be transposed by a reshape, saving a memory allocation & copy.
127// |permutation| must be a permutation of {0, .., input_shape.dims() - 1}.
128// That is, for all i, 0 <= perm[i] < input_shape.dims().
129// In practice, this is checked in TransposeOp::Compute prior to calling this
130// function, and the function sits here to facilitate unit testing.
131inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape,
132 const std::vector<int32>& permutation) {
133 int last_nonsingleton_perm_dim = -1;
134 for (int perm_dim : permutation) {
135 if (input_shape.dim_size(perm_dim) == 1) {
136 continue;
137 }
138 if (perm_dim < last_nonsingleton_perm_dim) {
139 return false;
140 }
141 last_nonsingleton_perm_dim = perm_dim;
142 }
143 return true;
144}
145
146// Uses Eigen to transpose.
147template <typename Device, typename T, int NDIMS>
148void TransposeUsingEigen(const Device& d, const Tensor& in,
149 const gtl::ArraySlice<int32> perm, bool conjugate,
150 Tensor* out) {
151 Eigen::array<int, NDIMS> p;
152 for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
153 auto x = typename TTypes<T, NDIMS>::ConstTensor(
154 reinterpret_cast<const T*>(in.tensor_data().data()),
155 in.shape().AsEigenDSizes<NDIMS>());
156 auto y = typename TTypes<T, NDIMS>::Tensor(
157 reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
158 out->shape().AsEigenDSizes<NDIMS>());
159 if (conjugate) {
160 y.device(d) = x.conjugate().shuffle(p);
161 } else {
162 y.device(d) = x.shuffle(p);
163 }
164}
165
166template <typename Device>
167Status DoTransposeImpl(const Device& d, const Tensor& in,
168 const gtl::ArraySlice<int32> perm, bool conjugate,
169 Tensor* out) {
170 CHECK_EQ(in.dims(), out->dims());
171 CHECK_EQ(in.dims(), perm.size());
172 CHECK_EQ(in.dtype(), out->dtype());
173 switch (in.dtype()) {
174 case DT_BOOL:
175 case DT_INT8:
176 case DT_QINT8:
177 case DT_QUINT8:
178 case DT_UINT8:
179 Transpose<Device, uint8>::run(d, in, perm, out);
180 break;
181
182 case DT_BFLOAT16:
183 case DT_HALF:
184 case DT_INT16:
185 case DT_QINT16:
186 case DT_QUINT16:
187 case DT_UINT16:
188 Transpose<Device, uint16>::run(d, in, perm, out);
189 break;
190
191 case DT_FLOAT:
192 case DT_INT32:
193 case DT_QINT32:
194 case DT_UINT32:
195 Transpose<Device, uint32>::run(d, in, perm, out);
196 break;
197
198 case DT_DOUBLE:
199 case DT_INT64:
200 case DT_UINT64:
201 Transpose<Device, uint64>::run(d, in, perm, out);
202 break;
203
204 case DT_COMPLEX64:
205 if (conjugate) {
206#if defined(__ANDROID__) and !defined(__clang__)
207 // Workaround for GCC compiler bug in Android toolchain.
208 return errors::Unimplemented(
209 "Conjugate transpose of complex64 not supported for GCC on "
210 "Android.");
211#else
212 Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
213#endif
214 } else {
215 Transpose<Device, uint64>::run(d, in, perm, out);
216 }
217 break;
218
219 case DT_COMPLEX128:
220 if (conjugate) {
221 Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
222 out);
223 } else {
224 Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
225 out);
226 }
227 break;
228
229 case DT_STRING:
230 Transpose<Device, tstring>::run(d, in, perm, out);
231 break;
232
233 default:
234 return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
235 }
236 return OkStatus();
237}
238
239template <typename Device>
240inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in,
241 bool conjugate, Tensor* out) {
242 const int ndims = in.dims();
243 if (ndims == 0) return OkStatus();
244 TransposePermsVec perm(ndims);
245 std::iota(perm.begin(), perm.end(), 0);
246 std::swap(perm[ndims - 2], perm[ndims - 1]);
247 return DoTransposeImpl(device, in, perm, conjugate, out);
248}
249
250} // namespace internal
251} // namespace tensorflow
252
253#endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
254