1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/transpose_op.h" |
21 | |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/kernels/transpose_functor.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/lib/strings/str_util.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | // inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of |
35 | // integers 0, 1, ..., n - 1 and returns the inverted |
36 | // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). |
37 | // |
38 | // REQUIRES: input is a vector of int32 or int64. |
39 | // REQUIRES: input is a permutation of 0, 1, ..., n-1. |
40 | |
41 | template <typename T> |
42 | class InvertPermutationOp : public OpKernel { |
43 | public: |
44 | explicit InvertPermutationOp(OpKernelConstruction* context) |
45 | : OpKernel(context) {} |
46 | |
47 | void Compute(OpKernelContext* context) override { |
48 | const Tensor& input = context->input(0); |
49 | OP_REQUIRES( |
50 | context, TensorShapeUtils::IsVector(input.shape()), |
51 | errors::InvalidArgument("invert_permutation expects a 1D vector." )); |
52 | auto Tin = input.vec<T>(); |
53 | OP_REQUIRES(context, |
54 | FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()), |
55 | errors::InvalidArgument("permutation of nonnegative int32s " |
56 | "must have <= int32 max elements" )); |
57 | const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above. |
58 | Tensor* output = nullptr; |
59 | OP_REQUIRES_OK(context, |
60 | context->allocate_output(0, input.shape(), &output)); |
61 | auto Tout = output->vec<T>(); |
62 | std::fill_n(Tout.data(), N, -1); |
63 | for (int i = 0; i < N; ++i) { |
64 | const T d = internal::SubtleMustCopy(Tin(i)); |
65 | OP_REQUIRES(context, FastBoundsCheck(d, N), |
66 | errors::InvalidArgument(d, " is not between 0 and " , N)); |
67 | OP_REQUIRES(context, Tout(d) == -1, |
68 | errors::InvalidArgument(d, " is duplicated in the input." )); |
69 | Tout(d) = i; |
70 | } |
71 | } |
72 | }; |
73 | |
74 | REGISTER_KERNEL_BUILDER( |
75 | Name("InvertPermutation" ).Device(DEVICE_CPU).TypeConstraint<int32>("T" ), |
76 | InvertPermutationOp<int32>); |
77 | REGISTER_KERNEL_BUILDER( |
78 | Name("InvertPermutation" ).Device(DEVICE_CPU).TypeConstraint<int64_t>("T" ), |
79 | InvertPermutationOp<int64_t>); |
80 | |
81 | REGISTER_KERNEL_BUILDER(Name("InvertPermutation" ) |
82 | .Device(DEVICE_DEFAULT) |
83 | .TypeConstraint<int32>("T" ) |
84 | .HostMemory("x" ) |
85 | .HostMemory("y" ), |
86 | InvertPermutationOp<int32>); |
87 | REGISTER_KERNEL_BUILDER(Name("InvertPermutation" ) |
88 | .Device(DEVICE_DEFAULT) |
89 | .TypeConstraint<int64_t>("T" ) |
90 | .HostMemory("x" ) |
91 | .HostMemory("y" ), |
92 | InvertPermutationOp<int64_t>); |
93 | |
94 | namespace { |
95 | template <typename Tperm> |
96 | Status PermutationHelper(const Tensor& perm, const int dims, |
97 | std::vector<int32>* permutation) { |
98 | auto Vperm = perm.vec<Tperm>(); |
99 | if (dims != Vperm.size()) { |
100 | return errors::InvalidArgument("transpose expects a vector of size " , dims, |
101 | ". But input(1) is a vector of size " , |
102 | Vperm.size()); |
103 | } |
104 | // using volatile instead of SubtleMustCopy here so that the |
105 | // asynchrony boundary is permutation. |
106 | const volatile Tperm* perm_begin = |
107 | reinterpret_cast<const volatile Tperm*>(Vperm.data()); |
108 | *permutation = std::vector<int32>(perm_begin, perm_begin + dims); |
109 | |
110 | return OkStatus(); |
111 | } |
112 | } // namespace |
113 | |
114 | // output = TransposeOp(T<any> input, T<int32> perm) takes a tensor |
115 | // of type T and rank N, and a permutation of 0, 1, ..., N-1. It |
116 | // shuffles the dimensions of the input tensor according to permutation. |
117 | // |
118 | // Specifically, the returned tensor output meets the following condition: |
119 | // 1) output.dims() == input.dims(); |
120 | // 2) output.dim_size(i) == input.dim_size(perm[i]); |
121 | // 3) output.tensor<T, N>(i_0, i_1, ..., i_N-1) == |
122 | // input.tensor<T, N>(j_0, j_1, ..., j_N-1), |
123 | // where i_s == j_{perm[s]} |
124 | // |
125 | // REQUIRES: perm is a vector of int32. |
126 | // REQUIRES: input.dims() == perm.size(). |
127 | // REQUIRES: perm is a permutation. |
128 | |
129 | void TransposeOp::Compute(OpKernelContext* ctx) { |
130 | const Tensor& input = ctx->input(0); |
131 | const Tensor& perm = ctx->input(1); |
132 | // Preliminary validation of sizes. |
133 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()), |
134 | errors::InvalidArgument("perm must be rank 1, got shape " , |
135 | perm.shape().DebugString())); |
136 | |
137 | // Although Tperm may be an int64 type, an int32 is sufficient to hold |
138 | // dimension range values, so the narrowing here should be safe. |
139 | std::vector<int32> permutation; |
140 | const int dims = input.dims(); |
141 | if (perm.dtype() == DT_INT32) { |
142 | OP_REQUIRES_OK(ctx, PermutationHelper<int32>(perm, dims, &permutation)); |
143 | } else { |
144 | OP_REQUIRES_OK(ctx, PermutationHelper<int64_t>(perm, dims, &permutation)); |
145 | } |
146 | TensorShape shape; |
147 | |
148 | // Check whether permutation is a permutation of integers of [0 .. dims). |
149 | gtl::InlinedVector<bool, 8> bits(dims); |
150 | bool is_identity = true; |
151 | for (int i = 0; i < dims; ++i) { |
152 | const int32_t d = permutation[i]; |
153 | OP_REQUIRES( |
154 | ctx, 0 <= d && d < dims, |
155 | errors::InvalidArgument(d, " is out of range [0 .. " , dims, ")" )); |
156 | bits[d] = true; |
157 | const auto dim_size = input.dim_size(d); |
158 | shape.AddDim(dim_size); |
159 | if (d != i) { |
160 | is_identity = false; |
161 | } |
162 | } |
163 | for (int i = 0; i < dims; ++i) { |
164 | OP_REQUIRES(ctx, bits[i], |
165 | errors::InvalidArgument(i, " is missing from {" , |
166 | absl::StrJoin(permutation, "," ), "}." )); |
167 | } |
168 | |
169 | // 0-D, 1-D, and identity transposes do nothing. |
170 | if (!IsConjugate() && (dims <= 1 || is_identity)) { |
171 | ctx->set_output(0, input); |
172 | return; |
173 | } else if (!IsConjugate() && internal::NonSingletonDimensionsAlign( |
174 | input.shape(), permutation)) { |
175 | Tensor output; |
176 | OP_REQUIRES(ctx, output.CopyFrom(input, shape), |
177 | errors::Unknown("Error reshaping Tensor." )); |
178 | ctx->set_output(0, output); |
179 | return; |
180 | } |
181 | |
182 | Tensor* output = nullptr; |
183 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); |
184 | if (shape.num_elements() > 0) { |
185 | OP_REQUIRES_OK(ctx, DoTranspose(ctx, input, permutation, output)); |
186 | } |
187 | } |
188 | |
189 | Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, |
190 | gtl::ArraySlice<int32> perm, Tensor* out) { |
191 | typedef Eigen::ThreadPoolDevice CPUDevice; |
192 | return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm, |
193 | out); |
194 | } |
195 | |
196 | Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, |
197 | const Tensor& in, |
198 | gtl::ArraySlice<int32> perm, |
199 | Tensor* out) { |
200 | typedef Eigen::ThreadPoolDevice CPUDevice; |
201 | return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in, |
202 | perm, out); |
203 | } |
204 | |
205 | #define REGISTER(T) \ |
206 | REGISTER_KERNEL_BUILDER(Name("Transpose") \ |
207 | .Device(DEVICE_CPU) \ |
208 | .TypeConstraint<T>("T") \ |
209 | .HostMemory("perm"), \ |
210 | TransposeCpuOp); \ |
211 | REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ |
212 | .Device(DEVICE_CPU) \ |
213 | .TypeConstraint<T>("T") \ |
214 | .HostMemory("perm"), \ |
215 | ConjugateTransposeCpuOp); |
216 | |
217 | TF_CALL_ALL_TYPES(REGISTER) |
218 | #undef REGISTER |
219 | |
220 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
221 | Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, |
222 | gtl::ArraySlice<int32> perm, Tensor* out) { |
223 | typedef Eigen::GpuDevice GPUDevice; |
224 | return ::tensorflow::DoTranspose(ctx->eigen_device<GPUDevice>(), in, perm, |
225 | out); |
226 | } |
227 | Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx, |
228 | const Tensor& in, |
229 | gtl::ArraySlice<int32> perm, |
230 | Tensor* out) { |
231 | typedef Eigen::GpuDevice GPUDevice; |
232 | return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<GPUDevice>(), in, |
233 | perm, out); |
234 | } |
235 | |
236 | #define REGISTER(T) \ |
237 | REGISTER_KERNEL_BUILDER(Name("Transpose") \ |
238 | .Device(DEVICE_GPU) \ |
239 | .TypeConstraint<T>("T") \ |
240 | .HostMemory("perm"), \ |
241 | TransposeGpuOp); \ |
242 | REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ |
243 | .Device(DEVICE_GPU) \ |
244 | .TypeConstraint<T>("T") \ |
245 | .HostMemory("perm"), \ |
246 | ConjugateTransposeGpuOp); |
247 | TF_CALL_POD_TYPES(REGISTER); |
248 | #undef REGISTER |
249 | #endif |
250 | |
251 | } // namespace tensorflow |
252 | |