1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Helpers for writing OpKernels for sparse tensors.
17#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_
18#define TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_
19
20#include <vector>
21
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/tensor_types.h"
24#include "tensorflow/core/platform/types.h"
25
26namespace tensorflow {
27namespace sparse_utils {
28
29// Find the index i of the first element for which
30// indices_mat(sparse_index_begin, 0) < indices_mat(i, 0).
31// The search is conducted in the open interval
32// [sparse_index_begin, indices_mat.dimension(0)) and when no such i is found,
33// indices_mat.dimension(0) is returned.
34// indices_mat(k, 0) should be non-decreasing over the interval
35// [begin, indices_mat.dimension(0)).
36// Requires 0 <= sparse_index_begin < indices_mat.dimension(0).
37template <typename Tindices>
38Tindices FindNextDenseRowStartIndex(
39 const Tindices sparse_index_begin,
40 const typename TTypes<Tindices>::ConstMatrix& indices_mat);
41
42// Returns the vector v of indices in indices_mat at which new dense matrix
43// rows begin.
44// v.front() = 0, v.back() = indices_mat.dimension(0), and for i > 0,
45// v[i] - v[i-1] is the length of the ith dense row in indices_mat.
46// *contains_empty_rows = true if and only if indices_mat contains empty rows
47// (rows without values) between row 0 and the last row.
48template <typename Tindices>
49std::vector<Tindices> GetStartIndicesOfEachDenseRow(
50 const typename TTypes<Tindices>::ConstMatrix& indices_mat,
51 bool* contains_empty_rows);
52
53// Converts tensor.vec<Tindices> to an std::vector<Tindices> object, appends
54// the value num_nonzero_entries_in_sparse_mat, and returns the result.
55template <typename Tindices>
56std::vector<Tindices> ParseRowStartIndices(
57 const tensorflow::Tensor& tensor,
58 const Tindices num_nonzero_entries_in_sparse_mat);
59
60// Returns true if and only if the sparse matrix indices_mat whose row start
61// indices are represented by row_start_indices has empty dense rows
62// (between its first and last dense rows).
63// This function satisfies the identity row_start_indices ==
64// GetStartIndicesOfEachDenseRow(indices_mat, &return_value).
65template <typename Tindices>
66bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices);
67
68// Methods for validating sparse indices.
69enum class IndexValidation {
70 kNone, // Indices are not used by the op, or are not directly accessible
71 // (e.g. on GPU).
72 kOrdered, // Indices must be unique, in lexicographical order, and within
73 // safe bounds.
74 kUnordered // Indices must be within safe bounds, but may repeat or appear
75 // out-of-order.
76};
77
78// Validates the three component tensors of a sparse tensor have the proper
79// shapes. Also validates index values according to the method supplied.
80template <typename Tindices>
81Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
82 const Tensor& shape,
83 IndexValidation index_validation);
84
85} // namespace sparse_utils
86} // namespace tensorflow
87
88#endif // TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_
89