1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include <algorithm> |
19 | #include <numeric> |
20 | #include <utility> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor.pb.h" |
27 | #include "tensorflow/core/framework/tensor_util.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/framework/variant.h" |
30 | #include "tensorflow/core/framework/variant_encode_decode.h" |
31 | #include "tensorflow/core/kernels/reshape_util.h" |
32 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
33 | #include "tensorflow/core/lib/gtl/optional.h" |
34 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
35 | |
36 | namespace tensorflow { |
37 | |
38 | using CPUDevice = Eigen::ThreadPoolDevice; |
39 | |
40 | namespace { |
41 | |
42 | using sparse::SparseTensor; |
43 | |
44 | class DeserializeSparseOp : public OpKernel { |
45 | public: |
46 | explicit DeserializeSparseOp(OpKernelConstruction* context) |
47 | : OpKernel(context) { |
48 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_)); |
49 | } |
50 | |
51 | void Compute(OpKernelContext* context) override { |
52 | const Tensor& serialized_sparse = context->input(0); |
53 | const int ndims = serialized_sparse.shape().dims(); |
54 | |
55 | OP_REQUIRES( |
56 | context, ndims > 0, |
57 | errors::InvalidArgument("Serialized sparse should have non-zero rank " , |
58 | serialized_sparse.shape().DebugString())); |
59 | |
60 | OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3, |
61 | errors::InvalidArgument( |
62 | "Serialized sparse should have 3 as the last dimension " , |
63 | serialized_sparse.shape().DebugString())); |
64 | |
65 | int num_sparse_tensors = 1; |
66 | for (int i = 0; i < ndims - 1; ++i) { |
67 | num_sparse_tensors *= serialized_sparse.shape().dim_size(i); |
68 | } |
69 | |
70 | OP_REQUIRES( |
71 | context, num_sparse_tensors > 0, |
72 | errors::InvalidArgument( |
73 | "Serialized sparse should have at least 1 serialized tensor, " |
74 | "but has a zero dimension " , |
75 | serialized_sparse.shape().DebugString())); |
76 | |
77 | if (num_sparse_tensors == 1 && ndims == 1) { |
78 | // Special case with a single sparse tensor. We can avoid data |
79 | // motion in the Concat and Reshape. |
80 | const auto& serialized_sparse_t = serialized_sparse.vec<tstring>(); |
81 | |
82 | Tensor output_indices; |
83 | Tensor output_values; |
84 | Tensor output_shape; |
85 | OP_REQUIRES_OK(context, |
86 | this->GetAndValidateSparseTensor( |
87 | serialized_sparse_t(0), serialized_sparse_t(1), |
88 | serialized_sparse_t(2), dtype_, 0 /* index */, |
89 | &output_indices, &output_values, &output_shape)); |
90 | context->set_output(0, output_indices); |
91 | context->set_output(1, output_values); |
92 | context->set_output(2, output_shape); |
93 | return; |
94 | } |
95 | |
96 | std::vector<Tensor> indices; |
97 | std::vector<Tensor> values; |
98 | TensorShape shape; |
99 | indices.reserve(num_sparse_tensors); |
100 | values.reserve(num_sparse_tensors); |
101 | |
102 | const auto& serialized_sparse_t = |
103 | serialized_sparse.flat_inner_dims<tstring, 2>(); |
104 | for (int i = 0; i < num_sparse_tensors; ++i) { |
105 | Tensor output_indices; |
106 | Tensor output_values; |
107 | Tensor output_shape; |
108 | OP_REQUIRES_OK(context, |
109 | this->GetAndValidateSparseTensor( |
110 | serialized_sparse_t(i, 0), serialized_sparse_t(i, 1), |
111 | serialized_sparse_t(i, 2), dtype_, i, &output_indices, |
112 | &output_values, &output_shape)); |
113 | int64_t num_entries = output_indices.dim_size(0); |
114 | int rank = output_indices.dim_size(1); |
115 | |
116 | // Now we expand each SparseTensors' indices and shape by |
117 | // prefixing a dimension |
118 | Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank})); |
119 | const auto& output_indices_t = output_indices.matrix<int64_t>(); |
120 | auto expanded_indices_t = expanded_indices.matrix<int64_t>(); |
121 | expanded_indices_t.chip<1>(0).setZero(); |
122 | if (rank > 0) { |
123 | Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1); |
124 | Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank); |
125 | expanded_indices_t.slice(indices_start, indices_sizes) = |
126 | output_indices_t; |
127 | } |
128 | Tensor expanded_shape(DT_INT64, TensorShape({1 + rank})); |
129 | const auto& output_shape_t = output_shape.vec<int64_t>(); |
130 | auto expanded_shape_t = expanded_shape.vec<int64_t>(); |
131 | expanded_shape_t(0) = 1; |
132 | std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); |
133 | |
134 | TensorShape expanded_tensor_shape(expanded_shape.vec<int64_t>()); |
135 | |
136 | indices.push_back(expanded_indices); |
137 | values.push_back(output_values); |
138 | if (i == 0) { |
139 | shape = expanded_tensor_shape; |
140 | } else { |
141 | OP_REQUIRES( |
142 | context, shape.dims() == expanded_tensor_shape.dims(), |
143 | errors::InvalidArgument( |
144 | "Inconsistent shape across SparseTensors: rank prior to " |
145 | "SparseTensor[" , |
146 | i, "] was: " , shape.dims() - 1, " but rank of SparseTensor[" , i, |
147 | "] is: " , expanded_tensor_shape.dims() - 1)); |
148 | for (int j = 1; j < shape.dims(); ++j) { |
149 | // NOTE(mrry): For compatibility with the implementations of |
150 | // DeserializeManySparse, and many ops that generate |
151 | // SparseTensors to batch that do not have a fixed |
152 | // dense_shape (e.g. `tf.parse_single_example()`), we |
153 | // compute the maximum in each dimension to find the |
154 | // smallest dense_shape that bounds all of the input |
155 | // SparseTensors. |
156 | shape.set_dim(j, std::max(shape.dim_size(j), |
157 | expanded_tensor_shape.dim_size(j))); |
158 | } |
159 | } |
160 | } |
161 | |
162 | // Dimension 0 is the primary dimension. |
163 | int rank = shape.dims(); |
164 | gtl::InlinedVector<int64_t, 8> std_order(rank); |
165 | std::iota(std_order.begin(), std_order.end(), 0); |
166 | |
167 | std::vector<SparseTensor> tensors; |
168 | tensors.reserve(num_sparse_tensors); |
169 | for (int i = 0; i < num_sparse_tensors; ++i) { |
170 | SparseTensor tensor; |
171 | OP_REQUIRES_OK(context, SparseTensor::Create(indices[i], values[i], shape, |
172 | std_order, &tensor)); |
173 | tensors.push_back(std::move(tensor)); |
174 | } |
175 | |
176 | gtl::optional<SparseTensor> maybe_output; |
177 | #define HANDLE_TYPE(T) \ |
178 | case DataTypeToEnum<T>::value: { \ |
179 | maybe_output = SparseTensor::Concat<T>(tensors); \ |
180 | break; \ |
181 | } |
182 | |
183 | switch (dtype_) { |
184 | TF_CALL_ALL_TYPES(HANDLE_TYPE); |
185 | TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
186 | #undef HANDLE_TYPE |
187 | default: |
188 | OP_REQUIRES(context, false, |
189 | errors::Unimplemented( |
190 | "DeserializeSparse Unhandled data type: " , dtype_)); |
191 | } |
192 | DCHECK(maybe_output); |
193 | SparseTensor& output = maybe_output.value(); |
194 | |
195 | // Compute the input shape for the reshape operation. |
196 | Tensor input_shape(DT_INT64, TensorShape({output.dims()})); |
197 | std::copy_n(output.shape().data(), output.dims(), |
198 | input_shape.vec<int64_t>().data()); |
199 | |
200 | // Compute the target shape for the reshape operation. |
201 | Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2})); |
202 | for (int i = 0; i < ndims - 1; ++i) { |
203 | target_shape.vec<int64_t>()(i) = serialized_sparse.shape().dim_size(i); |
204 | } |
205 | for (int i = 0; i < output.dims() - 1; ++i) { |
206 | target_shape.vec<int64_t>()(i + ndims - 1) = output.shape().data()[i + 1]; |
207 | } |
208 | |
209 | ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape, |
210 | target_shape, 0 /* output indices index */, |
211 | 2 /* output shape index */); |
212 | context->set_output(1, output.values()); |
213 | } |
214 | |
215 | private: |
216 | Status Deserialize(const tstring& serialized, Tensor* result) { |
217 | TensorProto proto; |
218 | if (!ParseProtoUnlimited(&proto, serialized)) { |
219 | return errors::InvalidArgument("Could not parse serialized proto" ); |
220 | } |
221 | Tensor tensor; |
222 | if (!tensor.FromProto(proto)) { |
223 | return errors::InvalidArgument("Could not construct tensor from proto" ); |
224 | } |
225 | *result = tensor; |
226 | return OkStatus(); |
227 | } |
228 | |
229 | Status GetAndValidateSparseTensor( |
230 | const tstring& serialized_indices, const tstring& serialized_values, |
231 | const tstring& serialized_shape, DataType values_dtype, int index, |
232 | Tensor* output_indices, Tensor* output_values, Tensor* output_shape) { |
233 | // Deserialize and validate the indices. |
234 | TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices)); |
235 | if (!TensorShapeUtils::IsMatrix(output_indices->shape())) { |
236 | return errors::InvalidArgument( |
237 | "Expected serialized_sparse[" , index, |
238 | ", 0] to represent an index matrix but received shape " , |
239 | output_indices->shape().DebugString()); |
240 | } |
241 | int64_t num_entries = output_indices->dim_size(0); |
242 | int rank = output_indices->dim_size(1); |
243 | |
244 | // Deserialize and validate the values. |
245 | TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values)); |
246 | if (!TensorShapeUtils::IsVector(output_values->shape())) { |
247 | return errors::InvalidArgument( |
248 | "Expected serialized_sparse[" , index, |
249 | ", 1] to represent a values vector but received shape " , |
250 | output_values->shape().DebugString()); |
251 | } |
252 | if (values_dtype != output_values->dtype()) { |
253 | return errors::InvalidArgument( |
254 | "Requested SparseTensor of type " , DataTypeString(values_dtype), |
255 | " but SparseTensor[" , index, |
256 | "].values.dtype() == " , DataTypeString(output_values->dtype())); |
257 | } |
258 | if (num_entries != output_values->dim_size(0)) { |
259 | return errors::InvalidArgument( |
260 | "Expected row counts of SparseTensor[" , index, |
261 | "].indices and SparseTensor[" , index, |
262 | "].values to match but they do not: " , num_entries, " vs. " , |
263 | output_values->dim_size(0)); |
264 | } |
265 | |
266 | // Deserialize and validate the shape. |
267 | TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape)); |
268 | if (!TensorShapeUtils::IsVector(output_shape->shape())) { |
269 | return errors::InvalidArgument( |
270 | "Expected serialized_sparse[" , index, |
271 | ", 1] to be a shape vector but its shape is " , |
272 | output_shape->shape().DebugString()); |
273 | } |
274 | if (rank != output_shape->dim_size(0)) { |
275 | return errors::InvalidArgument("Expected column counts of SparseTensor[" , |
276 | index, |
277 | "].indices to match size of SparseTensor[" , |
278 | index, "].shape but they do not: " , rank, |
279 | " vs. " , output_shape->dim_size(0)); |
280 | } |
281 | return OkStatus(); |
282 | } |
283 | |
284 | DataType dtype_; |
285 | }; |
286 | |
287 | REGISTER_KERNEL_BUILDER(Name("DeserializeSparse" ) |
288 | .Device(DEVICE_CPU) |
289 | .TypeConstraint<tstring>("Tserialized" ), |
290 | DeserializeSparseOp) |
291 | |
292 | REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse" ).Device(DEVICE_CPU), |
293 | DeserializeSparseOp) |
294 | |
295 | } // namespace |
296 | |
297 | } // namespace tensorflow |
298 | |