1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Helpers for writing OpKernels for sparse tensors. |
17 | #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_ |
18 | #define TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_ |
19 | |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/tensor_types.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace sparse_utils { |
28 | |
29 | // Find the index i of the first element for which |
30 | // indices_mat(sparse_index_begin, 0) < indices_mat(i, 0). |
31 | // The search is conducted in the open interval |
32 | // [sparse_index_begin, indices_mat.dimension(0)) and when no such i is found, |
33 | // indices_mat.dimension(0) is returned. |
34 | // indices_mat(k, 0) should be non-decreasing over the interval |
35 | // [begin, indices_mat.dimension(0)). |
36 | // Requires 0 <= sparse_index_begin < indices_mat.dimension(0). |
37 | template <typename Tindices> |
38 | Tindices FindNextDenseRowStartIndex( |
39 | const Tindices sparse_index_begin, |
40 | const typename TTypes<Tindices>::ConstMatrix& indices_mat); |
41 | |
42 | // Returns the vector v of indices in indices_mat at which new dense matrix |
43 | // rows begin. |
44 | // v.front() = 0, v.back() = indices_mat.dimension(0), and for i > 0, |
45 | // v[i] - v[i-1] is the length of the ith dense row in indices_mat. |
46 | // *contains_empty_rows = true if and only if indices_mat contains empty rows |
47 | // (rows without values) between row 0 and the last row. |
48 | template <typename Tindices> |
49 | std::vector<Tindices> GetStartIndicesOfEachDenseRow( |
50 | const typename TTypes<Tindices>::ConstMatrix& indices_mat, |
51 | bool* contains_empty_rows); |
52 | |
53 | // Converts tensor.vec<Tindices> to an std::vector<Tindices> object, appends |
54 | // the value num_nonzero_entries_in_sparse_mat, and returns the result. |
55 | template <typename Tindices> |
56 | std::vector<Tindices> ParseRowStartIndices( |
57 | const tensorflow::Tensor& tensor, |
58 | const Tindices num_nonzero_entries_in_sparse_mat); |
59 | |
60 | // Returns true if and only if the sparse matrix indices_mat whose row start |
61 | // indices are represented by row_start_indices has empty dense rows |
62 | // (between its first and last dense rows). |
63 | // This function satisfies the identity row_start_indices == |
64 | // GetStartIndicesOfEachDenseRow(indices_mat, &return_value). |
65 | template <typename Tindices> |
66 | bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices); |
67 | |
68 | // Methods for validating sparse indices. |
69 | enum class IndexValidation { |
70 | kNone, // Indices are not used by the op, or are not directly accessible |
71 | // (e.g. on GPU). |
72 | kOrdered, // Indices must be unique, in lexicographical order, and within |
73 | // safe bounds. |
74 | kUnordered // Indices must be within safe bounds, but may repeat or appear |
75 | // out-of-order. |
76 | }; |
77 | |
78 | // Validates the three component tensors of a sparse tensor have the proper |
79 | // shapes. Also validates index values according to the method supplied. |
80 | template <typename Tindices> |
81 | Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, |
82 | const Tensor& shape, |
83 | IndexValidation index_validation); |
84 | |
85 | } // namespace sparse_utils |
86 | } // namespace tensorflow |
87 | |
88 | #endif // TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_ |
89 | |