1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include <algorithm>
19#include <numeric>
20#include <utility>
21#include <vector>
22
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor.pb.h"
27#include "tensorflow/core/framework/tensor_util.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/framework/variant.h"
30#include "tensorflow/core/framework/variant_encode_decode.h"
31#include "tensorflow/core/kernels/reshape_util.h"
32#include "tensorflow/core/lib/gtl/inlined_vector.h"
33#include "tensorflow/core/lib/gtl/optional.h"
34#include "tensorflow/core/util/sparse/sparse_tensor.h"
35
36namespace tensorflow {
37
38using CPUDevice = Eigen::ThreadPoolDevice;
39
40namespace {
41
42using sparse::SparseTensor;
43
44class DeserializeSparseOp : public OpKernel {
45 public:
46 explicit DeserializeSparseOp(OpKernelConstruction* context)
47 : OpKernel(context) {
48 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
49 }
50
51 void Compute(OpKernelContext* context) override {
52 const Tensor& serialized_sparse = context->input(0);
53 const int ndims = serialized_sparse.shape().dims();
54
55 OP_REQUIRES(
56 context, ndims > 0,
57 errors::InvalidArgument("Serialized sparse should have non-zero rank ",
58 serialized_sparse.shape().DebugString()));
59
60 OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3,
61 errors::InvalidArgument(
62 "Serialized sparse should have 3 as the last dimension ",
63 serialized_sparse.shape().DebugString()));
64
65 int num_sparse_tensors = 1;
66 for (int i = 0; i < ndims - 1; ++i) {
67 num_sparse_tensors *= serialized_sparse.shape().dim_size(i);
68 }
69
70 OP_REQUIRES(
71 context, num_sparse_tensors > 0,
72 errors::InvalidArgument(
73 "Serialized sparse should have at least 1 serialized tensor, "
74 "but has a zero dimension ",
75 serialized_sparse.shape().DebugString()));
76
77 if (num_sparse_tensors == 1 && ndims == 1) {
78 // Special case with a single sparse tensor. We can avoid data
79 // motion in the Concat and Reshape.
80 const auto& serialized_sparse_t = serialized_sparse.vec<tstring>();
81
82 Tensor output_indices;
83 Tensor output_values;
84 Tensor output_shape;
85 OP_REQUIRES_OK(context,
86 this->GetAndValidateSparseTensor(
87 serialized_sparse_t(0), serialized_sparse_t(1),
88 serialized_sparse_t(2), dtype_, 0 /* index */,
89 &output_indices, &output_values, &output_shape));
90 context->set_output(0, output_indices);
91 context->set_output(1, output_values);
92 context->set_output(2, output_shape);
93 return;
94 }
95
96 std::vector<Tensor> indices;
97 std::vector<Tensor> values;
98 TensorShape shape;
99 indices.reserve(num_sparse_tensors);
100 values.reserve(num_sparse_tensors);
101
102 const auto& serialized_sparse_t =
103 serialized_sparse.flat_inner_dims<tstring, 2>();
104 for (int i = 0; i < num_sparse_tensors; ++i) {
105 Tensor output_indices;
106 Tensor output_values;
107 Tensor output_shape;
108 OP_REQUIRES_OK(context,
109 this->GetAndValidateSparseTensor(
110 serialized_sparse_t(i, 0), serialized_sparse_t(i, 1),
111 serialized_sparse_t(i, 2), dtype_, i, &output_indices,
112 &output_values, &output_shape));
113 int64_t num_entries = output_indices.dim_size(0);
114 int rank = output_indices.dim_size(1);
115
116 // Now we expand each SparseTensors' indices and shape by
117 // prefixing a dimension
118 Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank}));
119 const auto& output_indices_t = output_indices.matrix<int64_t>();
120 auto expanded_indices_t = expanded_indices.matrix<int64_t>();
121 expanded_indices_t.chip<1>(0).setZero();
122 if (rank > 0) {
123 Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
124 Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
125 expanded_indices_t.slice(indices_start, indices_sizes) =
126 output_indices_t;
127 }
128 Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
129 const auto& output_shape_t = output_shape.vec<int64_t>();
130 auto expanded_shape_t = expanded_shape.vec<int64_t>();
131 expanded_shape_t(0) = 1;
132 std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
133
134 TensorShape expanded_tensor_shape(expanded_shape.vec<int64_t>());
135
136 indices.push_back(expanded_indices);
137 values.push_back(output_values);
138 if (i == 0) {
139 shape = expanded_tensor_shape;
140 } else {
141 OP_REQUIRES(
142 context, shape.dims() == expanded_tensor_shape.dims(),
143 errors::InvalidArgument(
144 "Inconsistent shape across SparseTensors: rank prior to "
145 "SparseTensor[",
146 i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
147 "] is: ", expanded_tensor_shape.dims() - 1));
148 for (int j = 1; j < shape.dims(); ++j) {
149 // NOTE(mrry): For compatibility with the implementations of
150 // DeserializeManySparse, and many ops that generate
151 // SparseTensors to batch that do not have a fixed
152 // dense_shape (e.g. `tf.parse_single_example()`), we
153 // compute the maximum in each dimension to find the
154 // smallest dense_shape that bounds all of the input
155 // SparseTensors.
156 shape.set_dim(j, std::max(shape.dim_size(j),
157 expanded_tensor_shape.dim_size(j)));
158 }
159 }
160 }
161
162 // Dimension 0 is the primary dimension.
163 int rank = shape.dims();
164 gtl::InlinedVector<int64_t, 8> std_order(rank);
165 std::iota(std_order.begin(), std_order.end(), 0);
166
167 std::vector<SparseTensor> tensors;
168 tensors.reserve(num_sparse_tensors);
169 for (int i = 0; i < num_sparse_tensors; ++i) {
170 SparseTensor tensor;
171 OP_REQUIRES_OK(context, SparseTensor::Create(indices[i], values[i], shape,
172 std_order, &tensor));
173 tensors.push_back(std::move(tensor));
174 }
175
176 gtl::optional<SparseTensor> maybe_output;
177#define HANDLE_TYPE(T) \
178 case DataTypeToEnum<T>::value: { \
179 maybe_output = SparseTensor::Concat<T>(tensors); \
180 break; \
181 }
182
183 switch (dtype_) {
184 TF_CALL_ALL_TYPES(HANDLE_TYPE);
185 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
186#undef HANDLE_TYPE
187 default:
188 OP_REQUIRES(context, false,
189 errors::Unimplemented(
190 "DeserializeSparse Unhandled data type: ", dtype_));
191 }
192 DCHECK(maybe_output);
193 SparseTensor& output = maybe_output.value();
194
195 // Compute the input shape for the reshape operation.
196 Tensor input_shape(DT_INT64, TensorShape({output.dims()}));
197 std::copy_n(output.shape().data(), output.dims(),
198 input_shape.vec<int64_t>().data());
199
200 // Compute the target shape for the reshape operation.
201 Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2}));
202 for (int i = 0; i < ndims - 1; ++i) {
203 target_shape.vec<int64_t>()(i) = serialized_sparse.shape().dim_size(i);
204 }
205 for (int i = 0; i < output.dims() - 1; ++i) {
206 target_shape.vec<int64_t>()(i + ndims - 1) = output.shape().data()[i + 1];
207 }
208
209 ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape,
210 target_shape, 0 /* output indices index */,
211 2 /* output shape index */);
212 context->set_output(1, output.values());
213 }
214
215 private:
216 Status Deserialize(const tstring& serialized, Tensor* result) {
217 TensorProto proto;
218 if (!ParseProtoUnlimited(&proto, serialized)) {
219 return errors::InvalidArgument("Could not parse serialized proto");
220 }
221 Tensor tensor;
222 if (!tensor.FromProto(proto)) {
223 return errors::InvalidArgument("Could not construct tensor from proto");
224 }
225 *result = tensor;
226 return OkStatus();
227 }
228
229 Status GetAndValidateSparseTensor(
230 const tstring& serialized_indices, const tstring& serialized_values,
231 const tstring& serialized_shape, DataType values_dtype, int index,
232 Tensor* output_indices, Tensor* output_values, Tensor* output_shape) {
233 // Deserialize and validate the indices.
234 TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices));
235 if (!TensorShapeUtils::IsMatrix(output_indices->shape())) {
236 return errors::InvalidArgument(
237 "Expected serialized_sparse[", index,
238 ", 0] to represent an index matrix but received shape ",
239 output_indices->shape().DebugString());
240 }
241 int64_t num_entries = output_indices->dim_size(0);
242 int rank = output_indices->dim_size(1);
243
244 // Deserialize and validate the values.
245 TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values));
246 if (!TensorShapeUtils::IsVector(output_values->shape())) {
247 return errors::InvalidArgument(
248 "Expected serialized_sparse[", index,
249 ", 1] to represent a values vector but received shape ",
250 output_values->shape().DebugString());
251 }
252 if (values_dtype != output_values->dtype()) {
253 return errors::InvalidArgument(
254 "Requested SparseTensor of type ", DataTypeString(values_dtype),
255 " but SparseTensor[", index,
256 "].values.dtype() == ", DataTypeString(output_values->dtype()));
257 }
258 if (num_entries != output_values->dim_size(0)) {
259 return errors::InvalidArgument(
260 "Expected row counts of SparseTensor[", index,
261 "].indices and SparseTensor[", index,
262 "].values to match but they do not: ", num_entries, " vs. ",
263 output_values->dim_size(0));
264 }
265
266 // Deserialize and validate the shape.
267 TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape));
268 if (!TensorShapeUtils::IsVector(output_shape->shape())) {
269 return errors::InvalidArgument(
270 "Expected serialized_sparse[", index,
271 ", 1] to be a shape vector but its shape is ",
272 output_shape->shape().DebugString());
273 }
274 if (rank != output_shape->dim_size(0)) {
275 return errors::InvalidArgument("Expected column counts of SparseTensor[",
276 index,
277 "].indices to match size of SparseTensor[",
278 index, "].shape but they do not: ", rank,
279 " vs. ", output_shape->dim_size(0));
280 }
281 return OkStatus();
282 }
283
284 DataType dtype_;
285};
286
287REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
288 .Device(DEVICE_CPU)
289 .TypeConstraint<tstring>("Tserialized"),
290 DeserializeSparseOp)
291
292REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
293 DeserializeSparseOp)
294
295} // namespace
296
297} // namespace tensorflow
298