1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#define EIGEN_USE_THREADS
16
17#include "tensorflow/core/kernels/reshape_util.h"
18
19#include <algorithm>
20#include <numeric>
21#include <unordered_map>
22#include <utility>
23#include <vector>
24
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/op_requires.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor_shape.h"
30#include "tensorflow/core/framework/tensor_util.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/lib/gtl/inlined_vector.h"
33
34namespace tensorflow {
35
36using CPUDevice = Eigen::ThreadPoolDevice;
37using GPUDevice = Eigen::GpuDevice;
38
39namespace functor {
40
41template <>
42struct ReshapeSparseTensorFunctor<CPUDevice> {
43 Status operator()(OpKernelContext *context, const TensorShape &input_shape,
44 const TensorShape &output_shape,
45 typename TTypes<int64_t>::ConstMatrix input_indices,
46 typename TTypes<int64_t>::Matrix output_indices) const {
47 (void)context; // Unused (only used in GPU implementation)
48 const int64_t input_rank = input_shape.dims();
49 const int64_t output_rank = output_shape.dims();
50 const int64_t nnz = input_indices.dimension(0);
51 gtl::InlinedVector<int64_t, 8> input_strides(input_rank);
52 if (input_rank > 0) {
53 input_strides[input_rank - 1] = 1;
54 for (int d = input_rank - 2; d >= 0; --d) {
55 input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
56 }
57 }
58
59 gtl::InlinedVector<int64_t, 8> output_strides(output_rank);
60 if (output_rank > 0) {
61 output_strides[output_rank - 1] = 1;
62 for (int d = output_rank - 2; d >= 0; --d) {
63 output_strides[d] =
64 output_strides[d + 1] * output_shape.dim_size(d + 1);
65 }
66 }
67
68 for (int i = 0; i < nnz; ++i) {
69 int64_t id = 0;
70 for (int j = 0; j < input_rank; ++j) {
71 id += input_indices(i, j) * input_strides[j];
72 }
73 for (int j = 0; j < output_rank; ++j) {
74 output_indices(i, j) = id / output_strides[j];
75 id %= output_strides[j];
76 }
77 }
78 return OkStatus();
79 }
80};
81
82} // namespace functor
83
84template <typename Device>
85void ReshapeSparseTensor(OpKernelContext *context,
86 const Tensor &input_indices_in,
87 const Tensor &input_shape_in,
88 const Tensor &target_shape_in, int output_indices_idx,
89 int output_shape_idx) {
90 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
91 errors::InvalidArgument(
92 "Input indices should be a matrix but received shape ",
93 input_indices_in.shape().DebugString()));
94 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
95 errors::InvalidArgument(
96 "Input shape should be a vector but received shape ",
97 input_shape_in.shape().DebugString()));
98 OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
99 errors::InvalidArgument(
100 "Target shape should be a vector but received shape ",
101 target_shape_in.shape().DebugString()));
102
103 const int64_t output_rank = target_shape_in.NumElements();
104 TensorShape input_shape;
105 OP_REQUIRES_OK(context, TensorShape::BuildTensorShape(
106 input_shape_in.vec<int64_t>(), &input_shape));
107 const int64_t dense_size = input_shape.num_elements();
108 const int64_t nnz = input_indices_in.shape().dim_size(0);
109
110 // Compute the output shape. Determine product of specified dimensions, and
111 // find the index of the unspecified one.
112 TensorShape output_shape;
113 int64_t product = 1;
114 int unknown_index = -1;
115 auto target_shape = target_shape_in.vec<int64_t>();
116 for (int d = 0; d < output_rank; ++d) {
117 const int64_t size = target_shape(d);
118 if (size == -1) {
119 OP_REQUIRES(
120 context, unknown_index == -1,
121 errors::InvalidArgument("only one output dimension may be -1, "
122 "not both ",
123 unknown_index, " and ", d));
124 unknown_index = d;
125 output_shape.AddDim(1);
126 } else {
127 OP_REQUIRES(context, size >= 0,
128 errors::InvalidArgument("size ", d,
129 " must be non-negative, not ", size));
130 product *= size;
131 output_shape.AddDim(size);
132 }
133 }
134 if (unknown_index != -1) {
135 OP_REQUIRES(
136 context, product > 0,
137 errors::InvalidArgument("reshape cannot infer the missing "
138 "input size for an empty tensor unless all "
139 "specified input sizes are non-zero"));
140 const int64_t missing = dense_size / product;
141 OP_REQUIRES(
142 context, product * missing == dense_size,
143 errors::InvalidArgument(
144 "Input to reshape is a SparseTensor with ", dense_size,
145 " dense values, but the requested shape requires a multiple of ",
146 product, ". input_shape=", input_shape.DebugString(),
147 " output_shape=", output_shape.DebugString()));
148 output_shape.set_dim(unknown_index, missing);
149 }
150
151 OP_REQUIRES(
152 context, output_shape.num_elements() == dense_size,
153 errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
154 " dense values, but the requested shape has ",
155 output_shape.num_elements(),
156 ". input_shape=", input_shape.DebugString(),
157 " output_shape=", output_shape.DebugString()));
158
159 // Optimize for reshaping to the same shape.
160 if (input_shape == output_shape) {
161 context->set_output(output_indices_idx, input_indices_in);
162 context->set_output(output_shape_idx, input_shape_in);
163 return;
164 }
165
166 Tensor *result_shape = nullptr;
167 OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
168 TensorShape({output_rank}),
169 &result_shape));
170 auto output_shape_vec = result_shape->vec<int64_t>();
171 for (int j = 0; j < output_shape.dims(); ++j) {
172 output_shape_vec(j) = output_shape.dim_size(j);
173 }
174
175 Tensor *result_indices = nullptr;
176 OP_REQUIRES_OK(context,
177 context->allocate_output(output_indices_idx,
178 TensorShape({nnz, output_rank}),
179 &result_indices));
180 if (nnz > 0) {
181 OP_REQUIRES(
182 context, dense_size > 0 && product > 0,
183 errors::InvalidArgument(
184 "Input tensor has ", nnz, " non zero elements but input shape (",
185 input_shape.DebugString(), ") or output shape (",
186 output_shape.DebugString(), ") is empty"));
187 OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
188 context, input_shape, output_shape,
189 input_indices_in.matrix<int64_t>(),
190 result_indices->matrix<int64_t>()));
191 }
192}
193
194#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
195 template void ReshapeSparseTensor<Device>( \
196 OpKernelContext * context, const Tensor &input_indices_in, \
197 const Tensor &input_shape_in, const Tensor &target_shape_in, \
198 int output_indices_idx, int output_shape_idx)
199EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
200
201#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
202EXPLICITLY_INSTANTIATE_FUNCTION(GPUDevice);
203#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
204#undef EXPLICITLY_INSTANTIATE_FUNCTION
205
206} // namespace tensorflow
207