1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include <complex> |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/attr_value.pb.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/kernels/ops_util.h" |
24 | #include "tensorflow/core/kernels/transpose_functor.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/lib/gtl/array_slice.h" |
27 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
28 | |
29 | typedef Eigen::ThreadPoolDevice CPUDevice; |
30 | |
31 | namespace tensorflow { |
32 | namespace { |
33 | |
34 | template <typename T, bool conjugate> |
35 | void TransposeSimple(const CPUDevice& device, const Tensor& in, |
36 | const gtl::ArraySlice<int32> perm, Tensor* out) { |
37 | const int ndims = in.dims(); |
38 | gtl::InlinedVector<int64_t, 8> in_strides = |
39 | ComputeStride<int64_t>(in.shape()); |
40 | gtl::InlinedVector<int64_t, 8> out_strides = |
41 | ComputeStride<int64_t>(out->shape()); |
42 | const T* p = reinterpret_cast<const T*>(in.tensor_data().data()); |
43 | T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data()))); |
44 | auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64_t begin, |
45 | int64_t end) { |
46 | for (int64_t o_idx = begin; o_idx < end; ++o_idx) { |
47 | int64_t i_idx = 0; |
48 | int64_t t = o_idx; |
49 | for (int i = 0; i < ndims; ++i) { |
50 | const int64_t ratio = t / out_strides[i]; |
51 | t -= ratio * out_strides[i]; |
52 | i_idx += ratio * in_strides[perm[i]]; |
53 | } |
54 | if (conjugate) { |
55 | q[o_idx] = Eigen::numext::conj(p[i_idx]); |
56 | } else { |
57 | q[o_idx] = p[i_idx]; |
58 | } |
59 | } |
60 | }; |
61 | double cycles_per_element = |
62 | (conjugate ? 1 : 0) + |
63 | ndims * (Eigen::TensorOpCost::DivCost<int64_t>() + |
64 | 2 * Eigen::TensorOpCost::MulCost<int64_t>() + |
65 | 2 * Eigen::TensorOpCost::AddCost<int64_t>()); |
66 | Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T), |
67 | /*bytes_stored=*/sizeof(T), cycles_per_element); |
68 | device.parallelFor(in.NumElements(), cost, std::move(transpose_fn)); |
69 | } |
70 | |
71 | } // namespace |
72 | |
73 | template <typename T, bool conjugate> |
74 | struct Transpose<CPUDevice, T, conjugate> { |
75 | static void run(const CPUDevice& d, const Tensor& in, |
76 | const gtl::ArraySlice<int32> perm, Tensor* out) { |
77 | switch (in.dims()) { |
78 | case 2: |
79 | internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, conjugate, |
80 | out); |
81 | break; |
82 | case 3: |
83 | internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, conjugate, |
84 | out); |
85 | break; |
86 | case 4: |
87 | internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, conjugate, |
88 | out); |
89 | break; |
90 | case 5: |
91 | internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate, |
92 | out); |
93 | break; |
94 | case 6: |
95 | internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate, |
96 | out); |
97 | break; |
98 | case 7: |
99 | internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate, |
100 | out); |
101 | break; |
102 | case 8: |
103 | internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate, |
104 | out); |
105 | break; |
106 | default: |
107 | TransposeSimple<T, conjugate>(d, in, perm, out); |
108 | break; |
109 | } |
110 | } |
111 | }; |
112 | |
113 | #define INSTANTIATE(DEVICE) \ |
114 | template <> \ |
115 | Status DoTranspose(const DEVICE& device, const Tensor& in, \ |
116 | const gtl::ArraySlice<int32> perm, Tensor* out) { \ |
117 | return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \ |
118 | out); \ |
119 | } \ |
120 | template <> \ |
121 | Status DoConjugateTranspose(const DEVICE& device, const Tensor& in, \ |
122 | const gtl::ArraySlice<int32> perm, \ |
123 | Tensor* out) { \ |
124 | return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, \ |
125 | out); \ |
126 | } \ |
127 | template <> \ |
128 | Status DoMatrixTranspose(const DEVICE& device, const Tensor& in, \ |
129 | Tensor* out) { \ |
130 | return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \ |
131 | out); \ |
132 | } \ |
133 | template <> \ |
134 | Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \ |
135 | Tensor* out) { \ |
136 | return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, \ |
137 | out); \ |
138 | } |
139 | |
140 | INSTANTIATE(CPUDevice) |
141 | |
142 | |
143 | } // namespace tensorflow |
144 | |