1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #define EIGEN_USE_THREADS |
16 | |
17 | #include "tensorflow/core/kernels/reshape_util.h" |
18 | |
19 | #include <algorithm> |
20 | #include <numeric> |
21 | #include <unordered_map> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/op_requires.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/tensor_shape.h" |
30 | #include "tensorflow/core/framework/tensor_util.h" |
31 | #include "tensorflow/core/framework/types.h" |
32 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | using CPUDevice = Eigen::ThreadPoolDevice; |
37 | using GPUDevice = Eigen::GpuDevice; |
38 | |
39 | namespace functor { |
40 | |
41 | template <> |
42 | struct ReshapeSparseTensorFunctor<CPUDevice> { |
43 | Status operator()(OpKernelContext *context, const TensorShape &input_shape, |
44 | const TensorShape &output_shape, |
45 | typename TTypes<int64_t>::ConstMatrix input_indices, |
46 | typename TTypes<int64_t>::Matrix output_indices) const { |
47 | (void)context; // Unused (only used in GPU implementation) |
48 | const int64_t input_rank = input_shape.dims(); |
49 | const int64_t output_rank = output_shape.dims(); |
50 | const int64_t nnz = input_indices.dimension(0); |
51 | gtl::InlinedVector<int64_t, 8> input_strides(input_rank); |
52 | if (input_rank > 0) { |
53 | input_strides[input_rank - 1] = 1; |
54 | for (int d = input_rank - 2; d >= 0; --d) { |
55 | input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); |
56 | } |
57 | } |
58 | |
59 | gtl::InlinedVector<int64_t, 8> output_strides(output_rank); |
60 | if (output_rank > 0) { |
61 | output_strides[output_rank - 1] = 1; |
62 | for (int d = output_rank - 2; d >= 0; --d) { |
63 | output_strides[d] = |
64 | output_strides[d + 1] * output_shape.dim_size(d + 1); |
65 | } |
66 | } |
67 | |
68 | for (int i = 0; i < nnz; ++i) { |
69 | int64_t id = 0; |
70 | for (int j = 0; j < input_rank; ++j) { |
71 | id += input_indices(i, j) * input_strides[j]; |
72 | } |
73 | for (int j = 0; j < output_rank; ++j) { |
74 | output_indices(i, j) = id / output_strides[j]; |
75 | id %= output_strides[j]; |
76 | } |
77 | } |
78 | return OkStatus(); |
79 | } |
80 | }; |
81 | |
82 | } // namespace functor |
83 | |
84 | template <typename Device> |
85 | void ReshapeSparseTensor(OpKernelContext *context, |
86 | const Tensor &input_indices_in, |
87 | const Tensor &input_shape_in, |
88 | const Tensor &target_shape_in, int output_indices_idx, |
89 | int output_shape_idx) { |
90 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()), |
91 | errors::InvalidArgument( |
92 | "Input indices should be a matrix but received shape " , |
93 | input_indices_in.shape().DebugString())); |
94 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), |
95 | errors::InvalidArgument( |
96 | "Input shape should be a vector but received shape " , |
97 | input_shape_in.shape().DebugString())); |
98 | OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()), |
99 | errors::InvalidArgument( |
100 | "Target shape should be a vector but received shape " , |
101 | target_shape_in.shape().DebugString())); |
102 | |
103 | const int64_t output_rank = target_shape_in.NumElements(); |
104 | TensorShape input_shape; |
105 | OP_REQUIRES_OK(context, TensorShape::BuildTensorShape( |
106 | input_shape_in.vec<int64_t>(), &input_shape)); |
107 | const int64_t dense_size = input_shape.num_elements(); |
108 | const int64_t nnz = input_indices_in.shape().dim_size(0); |
109 | |
110 | // Compute the output shape. Determine product of specified dimensions, and |
111 | // find the index of the unspecified one. |
112 | TensorShape output_shape; |
113 | int64_t product = 1; |
114 | int unknown_index = -1; |
115 | auto target_shape = target_shape_in.vec<int64_t>(); |
116 | for (int d = 0; d < output_rank; ++d) { |
117 | const int64_t size = target_shape(d); |
118 | if (size == -1) { |
119 | OP_REQUIRES( |
120 | context, unknown_index == -1, |
121 | errors::InvalidArgument("only one output dimension may be -1, " |
122 | "not both " , |
123 | unknown_index, " and " , d)); |
124 | unknown_index = d; |
125 | output_shape.AddDim(1); |
126 | } else { |
127 | OP_REQUIRES(context, size >= 0, |
128 | errors::InvalidArgument("size " , d, |
129 | " must be non-negative, not " , size)); |
130 | product *= size; |
131 | output_shape.AddDim(size); |
132 | } |
133 | } |
134 | if (unknown_index != -1) { |
135 | OP_REQUIRES( |
136 | context, product > 0, |
137 | errors::InvalidArgument("reshape cannot infer the missing " |
138 | "input size for an empty tensor unless all " |
139 | "specified input sizes are non-zero" )); |
140 | const int64_t missing = dense_size / product; |
141 | OP_REQUIRES( |
142 | context, product * missing == dense_size, |
143 | errors::InvalidArgument( |
144 | "Input to reshape is a SparseTensor with " , dense_size, |
145 | " dense values, but the requested shape requires a multiple of " , |
146 | product, ". input_shape=" , input_shape.DebugString(), |
147 | " output_shape=" , output_shape.DebugString())); |
148 | output_shape.set_dim(unknown_index, missing); |
149 | } |
150 | |
151 | OP_REQUIRES( |
152 | context, output_shape.num_elements() == dense_size, |
153 | errors::InvalidArgument("Input to reshape is a tensor with " , dense_size, |
154 | " dense values, but the requested shape has " , |
155 | output_shape.num_elements(), |
156 | ". input_shape=" , input_shape.DebugString(), |
157 | " output_shape=" , output_shape.DebugString())); |
158 | |
159 | // Optimize for reshaping to the same shape. |
160 | if (input_shape == output_shape) { |
161 | context->set_output(output_indices_idx, input_indices_in); |
162 | context->set_output(output_shape_idx, input_shape_in); |
163 | return; |
164 | } |
165 | |
166 | Tensor *result_shape = nullptr; |
167 | OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx, |
168 | TensorShape({output_rank}), |
169 | &result_shape)); |
170 | auto output_shape_vec = result_shape->vec<int64_t>(); |
171 | for (int j = 0; j < output_shape.dims(); ++j) { |
172 | output_shape_vec(j) = output_shape.dim_size(j); |
173 | } |
174 | |
175 | Tensor *result_indices = nullptr; |
176 | OP_REQUIRES_OK(context, |
177 | context->allocate_output(output_indices_idx, |
178 | TensorShape({nnz, output_rank}), |
179 | &result_indices)); |
180 | if (nnz > 0) { |
181 | OP_REQUIRES( |
182 | context, dense_size > 0 && product > 0, |
183 | errors::InvalidArgument( |
184 | "Input tensor has " , nnz, " non zero elements but input shape (" , |
185 | input_shape.DebugString(), ") or output shape (" , |
186 | output_shape.DebugString(), ") is empty" )); |
187 | OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()( |
188 | context, input_shape, output_shape, |
189 | input_indices_in.matrix<int64_t>(), |
190 | result_indices->matrix<int64_t>())); |
191 | } |
192 | } |
193 | |
194 | #define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \ |
195 | template void ReshapeSparseTensor<Device>( \ |
196 | OpKernelContext * context, const Tensor &input_indices_in, \ |
197 | const Tensor &input_shape_in, const Tensor &target_shape_in, \ |
198 | int output_indices_idx, int output_shape_idx) |
199 | EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice); |
200 | |
201 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
202 | EXPLICITLY_INSTANTIATE_FUNCTION(GPUDevice); |
203 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
204 | #undef EXPLICITLY_INSTANTIATE_FUNCTION |
205 | |
206 | } // namespace tensorflow |
207 | |