1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include <algorithm>
19#include <numeric>
20#include <unordered_map>
21#include <utility>
22#include <vector>
23
24#include "tensorflow/core/common_runtime/dma_helper.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/op_requires.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor.pb.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/framework/tensor_util.h"
32#include "tensorflow/core/framework/types.h"
33#include "tensorflow/core/framework/variant.h"
34#include "tensorflow/core/framework/variant_encode_decode.h"
35#include "tensorflow/core/kernels/reshape_util.h"
36#include "tensorflow/core/lib/gtl/inlined_vector.h"
37#include "tensorflow/core/lib/gtl/optional.h"
38#include "tensorflow/core/util/sparse/group_iterator.h"
39#include "tensorflow/core/util/sparse/sparse_tensor.h"
40
41namespace tensorflow {
42
43namespace {
44
45using sparse::SparseTensor;
46
47template <typename T>
48class SerializeSparseOp : public OpKernel {
49 public:
50 explicit SerializeSparseOp(OpKernelConstruction* context)
51 : OpKernel(context) {}
52
53 bool IsExpensive() override;
54
55 Status Initialize(Tensor* result);
56 Status Serialize(const Tensor& input, T* result);
57
58 void Compute(OpKernelContext* context) override {
59 const Tensor* input_indices;
60 const Tensor* input_values;
61 const Tensor* input_shape;
62
63 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
64 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
65 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
66 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
67 errors::InvalidArgument(
68 "Input indices should be a matrix but received shape ",
69 input_indices->shape().DebugString()));
70
71 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
72 errors::InvalidArgument(
73 "Input values should be a vector but received shape ",
74 input_values->shape().DebugString()));
75
76 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
77 errors::InvalidArgument(
78 "Input shape should be a vector but received shape ",
79 input_shape->shape().DebugString()));
80
81 Tensor serialized_sparse;
82 OP_REQUIRES_OK(context, Initialize(&serialized_sparse));
83
84 auto serialized_sparse_t = serialized_sparse.vec<T>();
85 OP_REQUIRES_OK(context, Serialize(*input_indices, &serialized_sparse_t(0)));
86 OP_REQUIRES_OK(context, Serialize(*input_values, &serialized_sparse_t(1)));
87 OP_REQUIRES_OK(context, Serialize(*input_shape, &serialized_sparse_t(2)));
88
89 context->set_output(0, serialized_sparse);
90 }
91};
92
93// NOTE(mrry): We specialize the IsExpensive() method differently for
94// the string and variant cases, because (i) the string version
95// actually performs memory copies as part of its serialization (and
96// is hence potentially expensive), and (ii) the variant version
97// performs O(1) shallow copies (and hence is much cheaper than
98// dispatching to another thread would be).
99template <>
100bool SerializeSparseOp<tstring>::IsExpensive() {
101 return true;
102}
103template <>
104bool SerializeSparseOp<Variant>::IsExpensive() {
105 return false;
106}
107
108template <>
109Status SerializeSparseOp<tstring>::Initialize(Tensor* result) {
110 *result = Tensor(DT_STRING, TensorShape({3}));
111 return OkStatus();
112}
113
114template <>
115Status SerializeSparseOp<tstring>::Serialize(const Tensor& input,
116 tstring* result) {
117 TensorProto proto;
118 input.AsProtoTensorContent(&proto);
119 *result = proto.SerializeAsString();
120 return OkStatus();
121}
122
123REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
124 .Device(DEVICE_CPU)
125 .TypeConstraint<tstring>("out_type"),
126 SerializeSparseOp<tstring>);
127
128template <>
129Status SerializeSparseOp<Variant>::Initialize(Tensor* result) {
130 *result = Tensor(DT_VARIANT, TensorShape({3}));
131 return OkStatus();
132}
133
134template <>
135Status SerializeSparseOp<Variant>::Serialize(const Tensor& input,
136 Variant* result) {
137 *result = input;
138 return OkStatus();
139}
140
141REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
142 .Device(DEVICE_CPU)
143 .TypeConstraint<Variant>("out_type"),
144 SerializeSparseOp<Variant>);
145
146template <typename T, typename U>
147struct SerializeGroups {};
148
149template <typename T>
150struct SerializeGroups<T, tstring> {
151 Status operator()(sparse::GroupIterable* minibatch,
152 const Tensor& output_shape, int64_t N, int rank,
153 Tensor* serialized_sparse) {
154 auto serialized_sparse_t = serialized_sparse->matrix<tstring>();
155
156 int64_t last_nonempty_group = -1;
157
158 auto serialize = [](const Tensor& input, tstring* result) {
159 TensorProto proto;
160 input.AsProtoTensorContent(&proto);
161 *result = proto.SerializeAsString();
162 };
163
164 tstring serialized_shape;
165 serialize(output_shape, &serialized_shape);
166
167 auto serialize_empty_element = [&](int64_t b) {
168 serialize(Tensor(DT_INT64, {0, rank - 1}), &serialized_sparse_t(b, 0));
169 serialize(Tensor(DataTypeToEnum<T>::value, {0}),
170 &serialized_sparse_t(b, 1));
171 serialized_sparse_t(b, 2) = serialized_shape;
172 };
173
174 for (const auto& subset : *minibatch) {
175 const int64_t b = subset.group_at(0);
176 if (b < 0 || b >= N) {
177 return errors::InvalidArgument(
178 "Received unexpected column 0 value in input SparseTensor: ", b,
179 " < 0 or >= N (= ", N, ")");
180 }
181
182 // GroupIterable generates only the non-empty groups of rows, so we must
183 // generate empty outputs for any empty rows since the last non-empty
184 // group that was generated.
185 for (int64_t empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) {
186 serialize_empty_element(empty_b);
187 }
188
189 last_nonempty_group = b;
190
191 const auto indices = subset.indices();
192 const auto values = subset.values<T>();
193 const int64_t num_entries = values.size();
194
195 Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
196 Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
197
198 auto output_indices_t = output_indices.matrix<int64_t>();
199 auto output_values_t = output_values.vec<T>();
200
201 for (int i = 0; i < num_entries; ++i) {
202 for (int d = 1; d < rank; ++d) {
203 output_indices_t(i, d - 1) = indices(i, d);
204 }
205 output_values_t(i) = values(i);
206 }
207
208 serialize(output_indices, &serialized_sparse_t(b, 0));
209 serialize(output_values, &serialized_sparse_t(b, 1));
210 serialized_sparse_t(b, 2) = serialized_shape;
211 }
212
213 for (int64_t empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) {
214 serialize_empty_element(empty_b);
215 }
216
217 return OkStatus();
218 }
219};
220
221template <typename T>
222void CopyValues(const T* src, T* dest, int64_t num_values) {
223 static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
224 memcpy(dest, src, num_values * sizeof(T));
225}
226
227template <>
228void CopyValues<tstring>(const tstring* src, tstring* dest,
229 int64_t num_values) {
230 std::copy_n(src, num_values, dest);
231}
232
233template <>
234void CopyValues<Variant>(const Variant* src, Variant* dest,
235 int64_t num_values) {
236 std::copy_n(src, num_values, dest);
237}
238
239template <>
240void CopyValues<ResourceHandle>(const ResourceHandle* src, ResourceHandle* dest,
241 int64_t num_values) {
242 std::copy_n(src, num_values, dest);
243}
244
245template <>
246void CopyValues<Eigen::half>(const Eigen::half* src, Eigen::half* dest,
247 int64_t num_values) {
248 return CopyValues(reinterpret_cast<const char*>(src),
249 reinterpret_cast<char*>(dest),
250 num_values * sizeof(Eigen::half));
251}
252
253template <typename T>
254struct SerializeGroups<T, Variant> {
255 Status operator()(sparse::GroupIterable* minibatch,
256 const Tensor& output_shape, int64_t N, int rank,
257 Tensor* serialized_sparse) {
258 auto serialized_sparse_t = serialized_sparse->template matrix<Variant>();
259
260 int64_t last_nonempty_group = -1;
261
262 // The "DataTypeToEnum<T>::value" member is static and defined but not
263 // declared. This leads to linker errors when a "DataTypeToEnum<T>::value"
264 // reference is passed to a routine. Creating a local variable here to
265 // workaround the linker errors.
266 DataType T_type = DataTypeToEnum<T>::value;
267
268 auto serialize_empty_element = [&](int64_t b) {
269 serialized_sparse_t(b, 0).emplace<Tensor>(DT_INT64,
270 TensorShape({0, rank - 1}));
271 serialized_sparse_t(b, 1).emplace<Tensor>(T_type, TensorShape({0}));
272 serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
273 };
274
275 for (const auto& subset : *minibatch) {
276 const int64_t b = subset.group_at(0);
277 if (b < 0 || b >= N) {
278 return errors::InvalidArgument(
279 "Received unexpected column 0 value in input SparseTensor: ", b,
280 " < 0 or >= N (= ", N, ")");
281 }
282
283 // GroupIterable generates only the non-empty groups of rows, so we must
284 // generate empty outputs for any empty rows since the last non-empty
285 // group that was generated.
286 for (int64_t empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) {
287 serialize_empty_element(empty_b);
288 }
289
290 last_nonempty_group = b;
291
292 const auto indices = subset.indices();
293 const auto values = subset.values<T>();
294 const int64_t num_entries = values.size();
295
296 Tensor& output_indices = serialized_sparse_t(b, 0).emplace<Tensor>(
297 DT_INT64, TensorShape({num_entries, rank - 1}));
298 Tensor& output_values = serialized_sparse_t(b, 1).emplace<Tensor>(
299 T_type, TensorShape({num_entries}));
300
301 int64_t* output_indices_ptr =
302 static_cast<int64_t*>(DMAHelper::base(&output_indices));
303 const int64_t* indices_ptr = indices.data();
304
305 T* output_values_ptr = static_cast<T*>(DMAHelper::base(&output_values));
306 const T* values_ptr = values.data();
307
308 // TODO(mrry): Consider adding a template-based specialization for higher
309 // ranks.
310 if (rank == 2) {
311 for (int i = 0; i < num_entries; ++i) {
312 output_indices_ptr[i] = indices_ptr[(2 * i) + 1];
313 }
314 } else {
315 for (int i = 0; i < num_entries; ++i) {
316 // Skip the first index in each row.
317 ++indices_ptr;
318 for (int d = 1; d < rank; ++d) {
319 *output_indices_ptr++ = *indices_ptr++;
320 }
321 }
322 }
323
324 CopyValues(values_ptr, output_values_ptr, num_entries);
325 serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
326 }
327
328 for (int64_t empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) {
329 serialize_empty_element(empty_b);
330 }
331
332 return OkStatus();
333 }
334};
335
336template <typename T, typename U>
337class SerializeManySparseOp : public OpKernel {
338 public:
339 explicit SerializeManySparseOp(OpKernelConstruction* context)
340 : OpKernel(context) {}
341
342 void Compute(OpKernelContext* context) override {
343 const Tensor* input_indices;
344 const Tensor* input_values;
345 const Tensor* input_shape;
346 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
347 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
348 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
349 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
350 errors::InvalidArgument(
351 "Input indices should be a matrix but received shape ",
352 input_indices->shape().DebugString()));
353
354 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
355 errors::InvalidArgument(
356 "Input values should be a vector but received shape ",
357 input_values->shape().DebugString()));
358
359 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
360 errors::InvalidArgument(
361 "Input shape should be a vector but received shape ",
362 input_shape->shape().DebugString()));
363
364 int rank = input_shape->NumElements();
365
366 OP_REQUIRES(
367 context, rank > 1,
368 errors::InvalidArgument(
369 "Rank of input SparseTensor should be > 1, but saw rank: ", rank));
370
371 TensorShape tensor_input_shape;
372 OP_REQUIRES_OK(context,
373 TensorShape::BuildTensorShape(input_shape->vec<int64_t>(),
374 &tensor_input_shape));
375 gtl::InlinedVector<int64_t, 8> std_order(rank);
376 std::iota(std_order.begin(), std_order.end(), 0);
377 SparseTensor input_st;
378 OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
379 tensor_input_shape, std_order,
380 &input_st));
381
382 auto input_shape_t = input_shape->vec<int64_t>();
383 const int64_t N = input_shape_t(0);
384
385 Tensor* serialized_sparse;
386 OP_REQUIRES_OK(context,
387 context->allocate_output(0, {N, 3}, &serialized_sparse));
388
389 OP_REQUIRES_OK(context, input_st.IndicesValid());
390
391 Tensor output_shape(DT_INT64, {rank - 1});
392 auto output_shape_t = output_shape.vec<int64_t>();
393 for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
394
395 // Get groups by minibatch dimension
396 sparse::GroupIterable minibatch = input_st.group({0});
397
398 OP_REQUIRES_OK(context, SerializeGroups<T, U>()(&minibatch, output_shape, N,
399 rank, serialized_sparse));
400 }
401};
402
403#define REGISTER_KERNELS(type) \
404 REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
405 .Device(DEVICE_CPU) \
406 .TypeConstraint<type>("T") \
407 .TypeConstraint<tstring>("out_type"), \
408 SerializeManySparseOp<type, tstring>)
409
410TF_CALL_ALL_TYPES(REGISTER_KERNELS);
411#undef REGISTER_KERNELS
412
413
414#define REGISTER_KERNELS(type) \
415 REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
416 .Device(DEVICE_CPU) \
417 .TypeConstraint<type>("T") \
418 .TypeConstraint<Variant>("out_type"), \
419 SerializeManySparseOp<type, Variant>)
420
421TF_CALL_ALL_TYPES(REGISTER_KERNELS);
422#undef REGISTER_KERNELS
423
424} // namespace
425
426} // namespace tensorflow
427