1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/sparse_utils.h"
17
18#include <cstddef>
19#include <cstdint>
20
21#include "tensorflow/core/framework/tensor_shape.h"
22#include "tensorflow/core/platform/errors.h"
23#include "tensorflow/core/platform/macros.h"
24#include "tensorflow/core/platform/status.h"
25
26namespace tensorflow {
27namespace sparse_utils {
28
29template <typename Tindices>
30Tindices FindNextDenseRowStartIndex(
31 const Tindices sparse_index_begin,
32 const typename TTypes<Tindices>::ConstMatrix& indices_mat) {
33 // Search in the index range [begin, end) of indices_mat.
34 Tindices begin = sparse_index_begin;
35 Tindices end = indices_mat.dimension(0);
36 const Tindices orig_sparse_index_end = end;
37
38 // The first dense row we search.
39 const Tindices orig_dense_index_begin = indices_mat(begin, 0);
40 // Early exit if no next dense row index.
41 if (orig_dense_index_begin == static_cast<int64_t>(indices_mat(end - 1, 0))) {
42 return orig_sparse_index_end;
43 }
44
45 Tindices increment = 1;
46 while (begin + increment < end &&
47 indices_mat(begin + increment, 0) == orig_dense_index_begin) {
48 increment *= 2;
49 }
50 // Narrow the search space as an optimization.
51 if (begin + increment < end) {
52 end = begin + increment;
53 }
54 begin += increment / 2;
55
56 // Perform a binary search on the interval [begin, end) for
57 // dense_row_index_to_find.
58 const Tindices dense_row_index_to_find = orig_dense_index_begin;
59 while (begin < end) {
60 const Tindices m = begin + (end - begin) / 2;
61 const Tindices m_dense_row_index = static_cast<Tindices>(indices_mat(m, 0));
62 if (m_dense_row_index == dense_row_index_to_find &&
63 (m + 1 == orig_sparse_index_end ||
64 static_cast<Tindices>(indices_mat(m + 1, 0)) !=
65 dense_row_index_to_find)) {
66 return m + 1;
67 } else if (m_dense_row_index <= dense_row_index_to_find) {
68 begin = m + 1;
69 } else {
70 end = m;
71 }
72 }
73
74 // No next dense row index.
75 return orig_sparse_index_end;
76}
77
78template <typename Tindices>
79std::vector<Tindices> GetStartIndicesOfEachDenseRow(
80 const typename TTypes<Tindices>::ConstMatrix& indices_mat,
81 bool* contains_empty_rows) {
82 int64_t start_sparse_index_of_cur_dense_row = 0;
83 std::vector<Tindices> segment_indices;
84 const Tindices num_entries_in_sparse_tensor = indices_mat.dimension(0);
85 const Tindices num_dense_rows_in_sparse_tensor =
86 1 + indices_mat(num_entries_in_sparse_tensor - 1, 0);
87 // Reserve an extra slot for the 0 we store in the first entry by convention.
88 segment_indices.reserve(1 + num_dense_rows_in_sparse_tensor);
89 segment_indices.push_back(0);
90 for (Tindices i = 0; i < indices_mat(0, 0); ++i) {
91 segment_indices.push_back(0);
92 }
93 *contains_empty_rows = indices_mat(0, 0) > 0;
94 while (true) {
95 const Tindices start_sparse_index_of_next_dense_row =
96 FindNextDenseRowStartIndex<Tindices>(
97 start_sparse_index_of_cur_dense_row, indices_mat);
98 if (start_sparse_index_of_next_dense_row == num_entries_in_sparse_tensor) {
99 segment_indices.push_back(start_sparse_index_of_next_dense_row);
100 break;
101 }
102 // Encode the length of the current dense row as well as the lengths of all
103 // the empty rows until the next dense row,
104 for (Tindices i = 0;
105 i < indices_mat(start_sparse_index_of_next_dense_row, 0) -
106 indices_mat(start_sparse_index_of_cur_dense_row, 0);
107 ++i) {
108 segment_indices.push_back(start_sparse_index_of_next_dense_row);
109 }
110 // If there is more than one row between the current and next non-empty
111 // rows then those rows are empty.
112 *contains_empty_rows |=
113 indices_mat(start_sparse_index_of_next_dense_row, 0) -
114 indices_mat(start_sparse_index_of_cur_dense_row, 0) >
115 1;
116 start_sparse_index_of_cur_dense_row = start_sparse_index_of_next_dense_row;
117 }
118 return segment_indices;
119}
120
121template <typename Tindices>
122std::vector<Tindices> ParseRowStartIndices(
123 const tensorflow::Tensor& tensor,
124 const Tindices num_nonzero_entries_in_sparse_mat) {
125 std::vector<Tindices> out;
126 auto vec = tensor.vec<Tindices>();
127 out.reserve(vec.size() + 1);
128 for (size_t i = 0; i < vec.dimension(0); ++i) {
129 out.push_back(vec(i));
130 }
131 out.push_back(num_nonzero_entries_in_sparse_mat);
132 return out;
133}
134
135template <typename Tindices>
136bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
137 // Skip checking the length of the last dense row since it is
138 // always non-empty.
139 for (size_t i = 1; i < row_start_indices.size() - 1; ++i) {
140 if (row_start_indices.at(i) - row_start_indices.at(i - 1) == 0) {
141 return true;
142 }
143 }
144 return false;
145}
146
147namespace {
148
149// Ensures indices, values, shape are all of the proper ranks and are
150// compatible.
151Status ValidateSparseTensorShape(const Tensor& indices, const Tensor& values,
152 const Tensor& shape) {
153 // Indices must be a matrix, and values/shape must be a vector.
154 if (!TensorShapeUtils::IsMatrix(indices.shape())) {
155 return errors::InvalidArgument("Sparse indices must be rank 2 but is rank ",
156 indices.shape().dim_sizes().size());
157 }
158 if (!TensorShapeUtils::IsVector(values.shape())) {
159 return errors::InvalidArgument("Sparse values must be rank 1 but is rank ",
160 values.shape().dims());
161 }
162 if (!TensorShapeUtils::IsVector(shape.shape())) {
163 return errors::InvalidArgument("Sparse shape must be rank 1 but is rank ",
164 shape.shape().dims());
165 }
166 // Indices shape must be compatible with the values vector and dense shape.
167 int64_t nnz = indices.dim_size(0);
168 int64_t ndims = indices.dim_size(1);
169 if (values.dim_size(0) != nnz) {
170 return errors::InvalidArgument("Number of elements in indices (", nnz,
171 ") and values (", values.dim_size(0),
172 ") do not match");
173 }
174 if (shape.NumElements() != ndims) {
175 return errors::InvalidArgument("Index rank (", ndims, ") and shape rank (",
176 shape.NumElements(), ") do not match");
177 }
178
179 return OkStatus();
180}
181
182// Creates a debug string for the index tuple in indices(row, :).
183template <typename IndexTensor>
184string CreateIndexString(const IndexTensor& indices, int64_t row) {
185 const int64_t ndims = indices.dimension(1);
186 string index_str = strings::StrCat("indices[", row, ", :] = [");
187 for (int64_t dim = 0; dim < ndims; ++dim) {
188 strings::StrAppend(&index_str, indices(row, dim),
189 dim < ndims - 1 ? ", " : "]");
190 }
191 if (ndims == 0) {
192 strings::StrAppend(&index_str, "]");
193 }
194 return index_str;
195}
196
197// Ensures all sparse indices are within correct bounds.
198template <typename Tindices>
199Status ValidateSparseTensorIndicesUnordered(const Tensor& indices,
200 const Tensor& shape) {
201 // Ensure no index is out-of-bounds.
202 const auto indices_mat = indices.flat_inner_dims<Tindices>();
203 const auto shape_vec = shape.flat<Tindices>();
204 int64_t nnz = indices.dim_size(0);
205 int64_t ndims = indices.dim_size(1);
206
207 for (int64_t i = 0; i < nnz; ++i) {
208 for (int64_t dim = 0; dim < ndims; ++dim) {
209 const Tindices idx = indices_mat(i, dim);
210 if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
211 string index_str = CreateIndexString(indices_mat, i);
212 return errors::InvalidArgument("Sparse index tuple ", index_str,
213 " is out of bounds");
214 }
215 }
216 }
217
218 return OkStatus();
219}
220
221// Ensures all sparse indices are within correct bounds and are
222// lexicographically ordered.
223template <typename Tindices>
224Status ValidateSparseTensorIndicesOrdered(const Tensor& indices,
225 const Tensor& shape) {
226 const auto indices_mat = indices.flat_inner_dims<Tindices>();
227 const auto shape_vec = shape.flat<Tindices>();
228 int64_t nnz = indices.dim_size(0);
229 int64_t ndims = indices.dim_size(1);
230
231 if (nnz == 0) {
232 return OkStatus();
233 }
234
235 // First set of indices must be within range.
236 for (int64_t dim = 0; dim < ndims; ++dim) {
237 const Tindices idx = indices_mat(0, dim);
238 if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
239 string index_str = CreateIndexString(indices_mat, 0);
240 return errors::InvalidArgument("Sparse index tuple ", index_str,
241 " is out of bounds");
242 }
243 }
244
245 // Remaining set of indices must be within range and lexicographically
246 // larger than the previous.
247 for (int64_t i = 1; i < nnz; ++i) {
248 bool different = false;
249 for (int64_t dim = 0; dim < ndims; ++dim) {
250 const Tindices idx = indices_mat(i, dim);
251 const Tindices prev_idx = indices_mat(i - 1, dim);
252 // If indices are already different from previous i, the new index can
253 // be anything within the valid range.
254 if (TF_PREDICT_TRUE(different)) {
255 if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
256 string index_str = CreateIndexString(indices_mat, i);
257 return errors::InvalidArgument("Sparse index tuple ", index_str,
258 " is out of bounds");
259 }
260 } else {
261 // Otherwise, the new index must be >= previous and <= shape(dim).
262 if (TF_PREDICT_FALSE(idx < prev_idx || idx >= shape_vec(dim))) {
263 string index_str = CreateIndexString(indices_mat, i);
264 // Check if index is actually out of bounds.
265 if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
266 return errors::InvalidArgument("Sparse index tuple ", index_str,
267 " is out of bounds");
268 } else {
269 return errors::InvalidArgument("Sparse index tuple ", index_str,
270 " is out of order");
271 }
272 } else if (TF_PREDICT_TRUE(idx > prev_idx)) {
273 different = true;
274 }
275 } // if (different)
276 } // for dim in [0, ndims)
277
278 if (TF_PREDICT_FALSE(!different)) {
279 string index_str = CreateIndexString(indices_mat, i);
280 return errors::InvalidArgument("Sparse index tuple ", index_str,
281 " is repeated");
282 }
283 } // for i in [1, nnz)
284
285 return OkStatus();
286}
287
288} // namespace
289
290template <typename Tindices>
291Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
292 const Tensor& shape,
293 IndexValidation index_validation) {
294 TF_RETURN_IF_ERROR(ValidateSparseTensorShape(indices, values, shape));
295 switch (index_validation) {
296 case IndexValidation::kOrdered:
297 return ValidateSparseTensorIndicesOrdered<Tindices>(indices, shape);
298 case IndexValidation::kUnordered:
299 return ValidateSparseTensorIndicesUnordered<Tindices>(indices, shape);
300 case IndexValidation::kNone: {
301 }
302 }
303 return OkStatus();
304}
305
306#define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex) \
307 template TypeIndex FindNextDenseRowStartIndex<TypeIndex>( \
308 const TypeIndex sparse_index_begin, \
309 const TTypes<TypeIndex>::ConstMatrix& indices_mat); \
310 template std::vector<TypeIndex> GetStartIndicesOfEachDenseRow<TypeIndex>( \
311 const TTypes<TypeIndex>::ConstMatrix& indices_mat, \
312 bool* contains_empty_rows); \
313 template bool ContainsEmptyRows<TypeIndex>( \
314 const std::vector<TypeIndex>& row_start_indices); \
315 template std::vector<TypeIndex> ParseRowStartIndices<TypeIndex>( \
316 const tensorflow::Tensor& tensor, \
317 const TypeIndex num_nonzero_entries_in_sparse_mat); \
318 template Status ValidateSparseTensor<TypeIndex>( \
319 const Tensor& indices, const Tensor& values, const Tensor& shape, \
320 IndexValidation index_validation)
321
322REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
323REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
324REGISTER_SPARSE_UTIL_FUNCTIONS(uint8);
325REGISTER_SPARSE_UTIL_FUNCTIONS(uint16);
326REGISTER_SPARSE_UTIL_FUNCTIONS(uint32);
327REGISTER_SPARSE_UTIL_FUNCTIONS(uint64);
328
329} // namespace sparse_utils
330} // namespace tensorflow
331