1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/transpose_op.h"
21
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/kernels/transpose_functor.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/lib/strings/str_util.h"
30#include "tensorflow/core/platform/logging.h"
31
32namespace tensorflow {
33
34// inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of
35// integers 0, 1, ..., n - 1 and returns the inverted
36// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
37//
38// REQUIRES: input is a vector of int32 or int64.
39// REQUIRES: input is a permutation of 0, 1, ..., n-1.
40
41template <typename T>
42class InvertPermutationOp : public OpKernel {
43 public:
44 explicit InvertPermutationOp(OpKernelConstruction* context)
45 : OpKernel(context) {}
46
47 void Compute(OpKernelContext* context) override {
48 const Tensor& input = context->input(0);
49 OP_REQUIRES(
50 context, TensorShapeUtils::IsVector(input.shape()),
51 errors::InvalidArgument("invert_permutation expects a 1D vector."));
52 auto Tin = input.vec<T>();
53 OP_REQUIRES(context,
54 FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()),
55 errors::InvalidArgument("permutation of nonnegative int32s "
56 "must have <= int32 max elements"));
57 const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above.
58 Tensor* output = nullptr;
59 OP_REQUIRES_OK(context,
60 context->allocate_output(0, input.shape(), &output));
61 auto Tout = output->vec<T>();
62 std::fill_n(Tout.data(), N, -1);
63 for (int i = 0; i < N; ++i) {
64 const T d = internal::SubtleMustCopy(Tin(i));
65 OP_REQUIRES(context, FastBoundsCheck(d, N),
66 errors::InvalidArgument(d, " is not between 0 and ", N));
67 OP_REQUIRES(context, Tout(d) == -1,
68 errors::InvalidArgument(d, " is duplicated in the input."));
69 Tout(d) = i;
70 }
71 }
72};
73
74REGISTER_KERNEL_BUILDER(
75 Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int32>("T"),
76 InvertPermutationOp<int32>);
77REGISTER_KERNEL_BUILDER(
78 Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int64_t>("T"),
79 InvertPermutationOp<int64_t>);
80
81REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
82 .Device(DEVICE_DEFAULT)
83 .TypeConstraint<int32>("T")
84 .HostMemory("x")
85 .HostMemory("y"),
86 InvertPermutationOp<int32>);
87REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
88 .Device(DEVICE_DEFAULT)
89 .TypeConstraint<int64_t>("T")
90 .HostMemory("x")
91 .HostMemory("y"),
92 InvertPermutationOp<int64_t>);
93
94namespace {
95template <typename Tperm>
96Status PermutationHelper(const Tensor& perm, const int dims,
97 std::vector<int32>* permutation) {
98 auto Vperm = perm.vec<Tperm>();
99 if (dims != Vperm.size()) {
100 return errors::InvalidArgument("transpose expects a vector of size ", dims,
101 ". But input(1) is a vector of size ",
102 Vperm.size());
103 }
104 // using volatile instead of SubtleMustCopy here so that the
105 // asynchrony boundary is permutation.
106 const volatile Tperm* perm_begin =
107 reinterpret_cast<const volatile Tperm*>(Vperm.data());
108 *permutation = std::vector<int32>(perm_begin, perm_begin + dims);
109
110 return OkStatus();
111}
112} // namespace
113
114// output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
115// of type T and rank N, and a permutation of 0, 1, ..., N-1. It
116// shuffles the dimensions of the input tensor according to permutation.
117//
118// Specifically, the returned tensor output meets the following condition:
119// 1) output.dims() == input.dims();
120// 2) output.dim_size(i) == input.dim_size(perm[i]);
121// 3) output.tensor<T, N>(i_0, i_1, ..., i_N-1) ==
122// input.tensor<T, N>(j_0, j_1, ..., j_N-1),
123// where i_s == j_{perm[s]}
124//
125// REQUIRES: perm is a vector of int32.
126// REQUIRES: input.dims() == perm.size().
127// REQUIRES: perm is a permutation.
128
129void TransposeOp::Compute(OpKernelContext* ctx) {
130 const Tensor& input = ctx->input(0);
131 const Tensor& perm = ctx->input(1);
132 // Preliminary validation of sizes.
133 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()),
134 errors::InvalidArgument("perm must be rank 1, got shape ",
135 perm.shape().DebugString()));
136
137 // Although Tperm may be an int64 type, an int32 is sufficient to hold
138 // dimension range values, so the narrowing here should be safe.
139 std::vector<int32> permutation;
140 const int dims = input.dims();
141 if (perm.dtype() == DT_INT32) {
142 OP_REQUIRES_OK(ctx, PermutationHelper<int32>(perm, dims, &permutation));
143 } else {
144 OP_REQUIRES_OK(ctx, PermutationHelper<int64_t>(perm, dims, &permutation));
145 }
146 TensorShape shape;
147
148 // Check whether permutation is a permutation of integers of [0 .. dims).
149 gtl::InlinedVector<bool, 8> bits(dims);
150 bool is_identity = true;
151 for (int i = 0; i < dims; ++i) {
152 const int32_t d = permutation[i];
153 OP_REQUIRES(
154 ctx, 0 <= d && d < dims,
155 errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
156 bits[d] = true;
157 const auto dim_size = input.dim_size(d);
158 shape.AddDim(dim_size);
159 if (d != i) {
160 is_identity = false;
161 }
162 }
163 for (int i = 0; i < dims; ++i) {
164 OP_REQUIRES(ctx, bits[i],
165 errors::InvalidArgument(i, " is missing from {",
166 absl::StrJoin(permutation, ","), "}."));
167 }
168
169 // 0-D, 1-D, and identity transposes do nothing.
170 if (!IsConjugate() && (dims <= 1 || is_identity)) {
171 ctx->set_output(0, input);
172 return;
173 } else if (!IsConjugate() && internal::NonSingletonDimensionsAlign(
174 input.shape(), permutation)) {
175 Tensor output;
176 OP_REQUIRES(ctx, output.CopyFrom(input, shape),
177 errors::Unknown("Error reshaping Tensor."));
178 ctx->set_output(0, output);
179 return;
180 }
181
182 Tensor* output = nullptr;
183 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
184 if (shape.num_elements() > 0) {
185 OP_REQUIRES_OK(ctx, DoTranspose(ctx, input, permutation, output));
186 }
187}
188
189Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
190 gtl::ArraySlice<int32> perm, Tensor* out) {
191 typedef Eigen::ThreadPoolDevice CPUDevice;
192 return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
193 out);
194}
195
196Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
197 const Tensor& in,
198 gtl::ArraySlice<int32> perm,
199 Tensor* out) {
200 typedef Eigen::ThreadPoolDevice CPUDevice;
201 return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
202 perm, out);
203}
204
205#define REGISTER(T) \
206 REGISTER_KERNEL_BUILDER(Name("Transpose") \
207 .Device(DEVICE_CPU) \
208 .TypeConstraint<T>("T") \
209 .HostMemory("perm"), \
210 TransposeCpuOp); \
211 REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
212 .Device(DEVICE_CPU) \
213 .TypeConstraint<T>("T") \
214 .HostMemory("perm"), \
215 ConjugateTransposeCpuOp);
216
217TF_CALL_ALL_TYPES(REGISTER)
218#undef REGISTER
219
220#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
221Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
222 gtl::ArraySlice<int32> perm, Tensor* out) {
223 typedef Eigen::GpuDevice GPUDevice;
224 return ::tensorflow::DoTranspose(ctx->eigen_device<GPUDevice>(), in, perm,
225 out);
226}
227Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx,
228 const Tensor& in,
229 gtl::ArraySlice<int32> perm,
230 Tensor* out) {
231 typedef Eigen::GpuDevice GPUDevice;
232 return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<GPUDevice>(), in,
233 perm, out);
234}
235
236#define REGISTER(T) \
237 REGISTER_KERNEL_BUILDER(Name("Transpose") \
238 .Device(DEVICE_GPU) \
239 .TypeConstraint<T>("T") \
240 .HostMemory("perm"), \
241 TransposeGpuOp); \
242 REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
243 .Device(DEVICE_GPU) \
244 .TypeConstraint<T>("T") \
245 .HostMemory("perm"), \
246 ConjugateTransposeGpuOp);
247TF_CALL_POD_TYPES(REGISTER);
248#undef REGISTER
249#endif
250
251} // namespace tensorflow
252