1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include <algorithm> |
19 | #include <numeric> |
20 | #include <unordered_map> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/common_runtime/dma_helper.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/op_requires.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/tensor.pb.h" |
30 | #include "tensorflow/core/framework/tensor_shape.h" |
31 | #include "tensorflow/core/framework/tensor_util.h" |
32 | #include "tensorflow/core/framework/types.h" |
33 | #include "tensorflow/core/framework/variant.h" |
34 | #include "tensorflow/core/framework/variant_encode_decode.h" |
35 | #include "tensorflow/core/kernels/reshape_util.h" |
36 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
37 | #include "tensorflow/core/lib/gtl/optional.h" |
38 | #include "tensorflow/core/util/sparse/group_iterator.h" |
39 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | namespace { |
44 | |
45 | using sparse::SparseTensor; |
46 | |
47 | template <typename T> |
48 | class SerializeSparseOp : public OpKernel { |
49 | public: |
50 | explicit SerializeSparseOp(OpKernelConstruction* context) |
51 | : OpKernel(context) {} |
52 | |
53 | bool IsExpensive() override; |
54 | |
55 | Status Initialize(Tensor* result); |
56 | Status Serialize(const Tensor& input, T* result); |
57 | |
58 | void Compute(OpKernelContext* context) override { |
59 | const Tensor* input_indices; |
60 | const Tensor* input_values; |
61 | const Tensor* input_shape; |
62 | |
63 | OP_REQUIRES_OK(context, context->input("sparse_indices" , &input_indices)); |
64 | OP_REQUIRES_OK(context, context->input("sparse_values" , &input_values)); |
65 | OP_REQUIRES_OK(context, context->input("sparse_shape" , &input_shape)); |
66 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), |
67 | errors::InvalidArgument( |
68 | "Input indices should be a matrix but received shape " , |
69 | input_indices->shape().DebugString())); |
70 | |
71 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), |
72 | errors::InvalidArgument( |
73 | "Input values should be a vector but received shape " , |
74 | input_values->shape().DebugString())); |
75 | |
76 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), |
77 | errors::InvalidArgument( |
78 | "Input shape should be a vector but received shape " , |
79 | input_shape->shape().DebugString())); |
80 | |
81 | Tensor serialized_sparse; |
82 | OP_REQUIRES_OK(context, Initialize(&serialized_sparse)); |
83 | |
84 | auto serialized_sparse_t = serialized_sparse.vec<T>(); |
85 | OP_REQUIRES_OK(context, Serialize(*input_indices, &serialized_sparse_t(0))); |
86 | OP_REQUIRES_OK(context, Serialize(*input_values, &serialized_sparse_t(1))); |
87 | OP_REQUIRES_OK(context, Serialize(*input_shape, &serialized_sparse_t(2))); |
88 | |
89 | context->set_output(0, serialized_sparse); |
90 | } |
91 | }; |
92 | |
93 | // NOTE(mrry): We specialize the IsExpensive() method differently for |
94 | // the string and variant cases, because (i) the string version |
95 | // actually performs memory copies as part of its serialization (and |
96 | // is hence potentially expensive), and (ii) the variant version |
97 | // performs O(1) shallow copies (and hence is much cheaper than |
98 | // dispatching to another thread would be). |
99 | template <> |
100 | bool SerializeSparseOp<tstring>::IsExpensive() { |
101 | return true; |
102 | } |
103 | template <> |
104 | bool SerializeSparseOp<Variant>::IsExpensive() { |
105 | return false; |
106 | } |
107 | |
108 | template <> |
109 | Status SerializeSparseOp<tstring>::Initialize(Tensor* result) { |
110 | *result = Tensor(DT_STRING, TensorShape({3})); |
111 | return OkStatus(); |
112 | } |
113 | |
114 | template <> |
115 | Status SerializeSparseOp<tstring>::Serialize(const Tensor& input, |
116 | tstring* result) { |
117 | TensorProto proto; |
118 | input.AsProtoTensorContent(&proto); |
119 | *result = proto.SerializeAsString(); |
120 | return OkStatus(); |
121 | } |
122 | |
123 | REGISTER_KERNEL_BUILDER(Name("SerializeSparse" ) |
124 | .Device(DEVICE_CPU) |
125 | .TypeConstraint<tstring>("out_type" ), |
126 | SerializeSparseOp<tstring>); |
127 | |
128 | template <> |
129 | Status SerializeSparseOp<Variant>::Initialize(Tensor* result) { |
130 | *result = Tensor(DT_VARIANT, TensorShape({3})); |
131 | return OkStatus(); |
132 | } |
133 | |
134 | template <> |
135 | Status SerializeSparseOp<Variant>::Serialize(const Tensor& input, |
136 | Variant* result) { |
137 | *result = input; |
138 | return OkStatus(); |
139 | } |
140 | |
141 | REGISTER_KERNEL_BUILDER(Name("SerializeSparse" ) |
142 | .Device(DEVICE_CPU) |
143 | .TypeConstraint<Variant>("out_type" ), |
144 | SerializeSparseOp<Variant>); |
145 | |
146 | template <typename T, typename U> |
147 | struct SerializeGroups {}; |
148 | |
149 | template <typename T> |
150 | struct SerializeGroups<T, tstring> { |
151 | Status operator()(sparse::GroupIterable* minibatch, |
152 | const Tensor& output_shape, int64_t N, int rank, |
153 | Tensor* serialized_sparse) { |
154 | auto serialized_sparse_t = serialized_sparse->matrix<tstring>(); |
155 | |
156 | int64_t last_nonempty_group = -1; |
157 | |
158 | auto serialize = [](const Tensor& input, tstring* result) { |
159 | TensorProto proto; |
160 | input.AsProtoTensorContent(&proto); |
161 | *result = proto.SerializeAsString(); |
162 | }; |
163 | |
164 | tstring serialized_shape; |
165 | serialize(output_shape, &serialized_shape); |
166 | |
167 | auto serialize_empty_element = [&](int64_t b) { |
168 | serialize(Tensor(DT_INT64, {0, rank - 1}), &serialized_sparse_t(b, 0)); |
169 | serialize(Tensor(DataTypeToEnum<T>::value, {0}), |
170 | &serialized_sparse_t(b, 1)); |
171 | serialized_sparse_t(b, 2) = serialized_shape; |
172 | }; |
173 | |
174 | for (const auto& subset : *minibatch) { |
175 | const int64_t b = subset.group_at(0); |
176 | if (b < 0 || b >= N) { |
177 | return errors::InvalidArgument( |
178 | "Received unexpected column 0 value in input SparseTensor: " , b, |
179 | " < 0 or >= N (= " , N, ")" ); |
180 | } |
181 | |
182 | // GroupIterable generates only the non-empty groups of rows, so we must |
183 | // generate empty outputs for any empty rows since the last non-empty |
184 | // group that was generated. |
185 | for (int64_t empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) { |
186 | serialize_empty_element(empty_b); |
187 | } |
188 | |
189 | last_nonempty_group = b; |
190 | |
191 | const auto indices = subset.indices(); |
192 | const auto values = subset.values<T>(); |
193 | const int64_t num_entries = values.size(); |
194 | |
195 | Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1}); |
196 | Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries}); |
197 | |
198 | auto output_indices_t = output_indices.matrix<int64_t>(); |
199 | auto output_values_t = output_values.vec<T>(); |
200 | |
201 | for (int i = 0; i < num_entries; ++i) { |
202 | for (int d = 1; d < rank; ++d) { |
203 | output_indices_t(i, d - 1) = indices(i, d); |
204 | } |
205 | output_values_t(i) = values(i); |
206 | } |
207 | |
208 | serialize(output_indices, &serialized_sparse_t(b, 0)); |
209 | serialize(output_values, &serialized_sparse_t(b, 1)); |
210 | serialized_sparse_t(b, 2) = serialized_shape; |
211 | } |
212 | |
213 | for (int64_t empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) { |
214 | serialize_empty_element(empty_b); |
215 | } |
216 | |
217 | return OkStatus(); |
218 | } |
219 | }; |
220 | |
221 | template <typename T> |
222 | void CopyValues(const T* src, T* dest, int64_t num_values) { |
223 | static_assert(is_simple_type<T>::value, "Memcpy requires a simple type." ); |
224 | memcpy(dest, src, num_values * sizeof(T)); |
225 | } |
226 | |
227 | template <> |
228 | void CopyValues<tstring>(const tstring* src, tstring* dest, |
229 | int64_t num_values) { |
230 | std::copy_n(src, num_values, dest); |
231 | } |
232 | |
233 | template <> |
234 | void CopyValues<Variant>(const Variant* src, Variant* dest, |
235 | int64_t num_values) { |
236 | std::copy_n(src, num_values, dest); |
237 | } |
238 | |
239 | template <> |
240 | void CopyValues<ResourceHandle>(const ResourceHandle* src, ResourceHandle* dest, |
241 | int64_t num_values) { |
242 | std::copy_n(src, num_values, dest); |
243 | } |
244 | |
245 | template <> |
246 | void CopyValues<Eigen::half>(const Eigen::half* src, Eigen::half* dest, |
247 | int64_t num_values) { |
248 | return CopyValues(reinterpret_cast<const char*>(src), |
249 | reinterpret_cast<char*>(dest), |
250 | num_values * sizeof(Eigen::half)); |
251 | } |
252 | |
253 | template <typename T> |
254 | struct SerializeGroups<T, Variant> { |
255 | Status operator()(sparse::GroupIterable* minibatch, |
256 | const Tensor& output_shape, int64_t N, int rank, |
257 | Tensor* serialized_sparse) { |
258 | auto serialized_sparse_t = serialized_sparse->template matrix<Variant>(); |
259 | |
260 | int64_t last_nonempty_group = -1; |
261 | |
262 | // The "DataTypeToEnum<T>::value" member is static and defined but not |
263 | // declared. This leads to linker errors when a "DataTypeToEnum<T>::value" |
264 | // reference is passed to a routine. Creating a local variable here to |
265 | // workaround the linker errors. |
266 | DataType T_type = DataTypeToEnum<T>::value; |
267 | |
268 | auto serialize_empty_element = [&](int64_t b) { |
269 | serialized_sparse_t(b, 0).emplace<Tensor>(DT_INT64, |
270 | TensorShape({0, rank - 1})); |
271 | serialized_sparse_t(b, 1).emplace<Tensor>(T_type, TensorShape({0})); |
272 | serialized_sparse_t(b, 2).emplace<Tensor>(output_shape); |
273 | }; |
274 | |
275 | for (const auto& subset : *minibatch) { |
276 | const int64_t b = subset.group_at(0); |
277 | if (b < 0 || b >= N) { |
278 | return errors::InvalidArgument( |
279 | "Received unexpected column 0 value in input SparseTensor: " , b, |
280 | " < 0 or >= N (= " , N, ")" ); |
281 | } |
282 | |
283 | // GroupIterable generates only the non-empty groups of rows, so we must |
284 | // generate empty outputs for any empty rows since the last non-empty |
285 | // group that was generated. |
286 | for (int64_t empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) { |
287 | serialize_empty_element(empty_b); |
288 | } |
289 | |
290 | last_nonempty_group = b; |
291 | |
292 | const auto indices = subset.indices(); |
293 | const auto values = subset.values<T>(); |
294 | const int64_t num_entries = values.size(); |
295 | |
296 | Tensor& output_indices = serialized_sparse_t(b, 0).emplace<Tensor>( |
297 | DT_INT64, TensorShape({num_entries, rank - 1})); |
298 | Tensor& output_values = serialized_sparse_t(b, 1).emplace<Tensor>( |
299 | T_type, TensorShape({num_entries})); |
300 | |
301 | int64_t* output_indices_ptr = |
302 | static_cast<int64_t*>(DMAHelper::base(&output_indices)); |
303 | const int64_t* indices_ptr = indices.data(); |
304 | |
305 | T* output_values_ptr = static_cast<T*>(DMAHelper::base(&output_values)); |
306 | const T* values_ptr = values.data(); |
307 | |
308 | // TODO(mrry): Consider adding a template-based specialization for higher |
309 | // ranks. |
310 | if (rank == 2) { |
311 | for (int i = 0; i < num_entries; ++i) { |
312 | output_indices_ptr[i] = indices_ptr[(2 * i) + 1]; |
313 | } |
314 | } else { |
315 | for (int i = 0; i < num_entries; ++i) { |
316 | // Skip the first index in each row. |
317 | ++indices_ptr; |
318 | for (int d = 1; d < rank; ++d) { |
319 | *output_indices_ptr++ = *indices_ptr++; |
320 | } |
321 | } |
322 | } |
323 | |
324 | CopyValues(values_ptr, output_values_ptr, num_entries); |
325 | serialized_sparse_t(b, 2).emplace<Tensor>(output_shape); |
326 | } |
327 | |
328 | for (int64_t empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) { |
329 | serialize_empty_element(empty_b); |
330 | } |
331 | |
332 | return OkStatus(); |
333 | } |
334 | }; |
335 | |
336 | template <typename T, typename U> |
337 | class SerializeManySparseOp : public OpKernel { |
338 | public: |
339 | explicit SerializeManySparseOp(OpKernelConstruction* context) |
340 | : OpKernel(context) {} |
341 | |
342 | void Compute(OpKernelContext* context) override { |
343 | const Tensor* input_indices; |
344 | const Tensor* input_values; |
345 | const Tensor* input_shape; |
346 | OP_REQUIRES_OK(context, context->input("sparse_indices" , &input_indices)); |
347 | OP_REQUIRES_OK(context, context->input("sparse_values" , &input_values)); |
348 | OP_REQUIRES_OK(context, context->input("sparse_shape" , &input_shape)); |
349 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), |
350 | errors::InvalidArgument( |
351 | "Input indices should be a matrix but received shape " , |
352 | input_indices->shape().DebugString())); |
353 | |
354 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), |
355 | errors::InvalidArgument( |
356 | "Input values should be a vector but received shape " , |
357 | input_values->shape().DebugString())); |
358 | |
359 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), |
360 | errors::InvalidArgument( |
361 | "Input shape should be a vector but received shape " , |
362 | input_shape->shape().DebugString())); |
363 | |
364 | int rank = input_shape->NumElements(); |
365 | |
366 | OP_REQUIRES( |
367 | context, rank > 1, |
368 | errors::InvalidArgument( |
369 | "Rank of input SparseTensor should be > 1, but saw rank: " , rank)); |
370 | |
371 | TensorShape tensor_input_shape; |
372 | OP_REQUIRES_OK(context, |
373 | TensorShape::BuildTensorShape(input_shape->vec<int64_t>(), |
374 | &tensor_input_shape)); |
375 | gtl::InlinedVector<int64_t, 8> std_order(rank); |
376 | std::iota(std_order.begin(), std_order.end(), 0); |
377 | SparseTensor input_st; |
378 | OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values, |
379 | tensor_input_shape, std_order, |
380 | &input_st)); |
381 | |
382 | auto input_shape_t = input_shape->vec<int64_t>(); |
383 | const int64_t N = input_shape_t(0); |
384 | |
385 | Tensor* serialized_sparse; |
386 | OP_REQUIRES_OK(context, |
387 | context->allocate_output(0, {N, 3}, &serialized_sparse)); |
388 | |
389 | OP_REQUIRES_OK(context, input_st.IndicesValid()); |
390 | |
391 | Tensor output_shape(DT_INT64, {rank - 1}); |
392 | auto output_shape_t = output_shape.vec<int64_t>(); |
393 | for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d); |
394 | |
395 | // Get groups by minibatch dimension |
396 | sparse::GroupIterable minibatch = input_st.group({0}); |
397 | |
398 | OP_REQUIRES_OK(context, SerializeGroups<T, U>()(&minibatch, output_shape, N, |
399 | rank, serialized_sparse)); |
400 | } |
401 | }; |
402 | |
403 | #define REGISTER_KERNELS(type) \ |
404 | REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \ |
405 | .Device(DEVICE_CPU) \ |
406 | .TypeConstraint<type>("T") \ |
407 | .TypeConstraint<tstring>("out_type"), \ |
408 | SerializeManySparseOp<type, tstring>) |
409 | |
410 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
411 | #undef REGISTER_KERNELS |
412 | |
413 | |
414 | #define REGISTER_KERNELS(type) \ |
415 | REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \ |
416 | .Device(DEVICE_CPU) \ |
417 | .TypeConstraint<type>("T") \ |
418 | .TypeConstraint<Variant>("out_type"), \ |
419 | SerializeManySparseOp<type, Variant>) |
420 | |
421 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
422 | #undef REGISTER_KERNELS |
423 | |
424 | } // namespace |
425 | |
426 | } // namespace tensorflow |
427 | |