1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/sparse_utils.h" |
17 | |
18 | #include <cstddef> |
19 | #include <cstdint> |
20 | |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/platform/errors.h" |
23 | #include "tensorflow/core/platform/macros.h" |
24 | #include "tensorflow/core/platform/status.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace sparse_utils { |
28 | |
29 | template <typename Tindices> |
30 | Tindices FindNextDenseRowStartIndex( |
31 | const Tindices sparse_index_begin, |
32 | const typename TTypes<Tindices>::ConstMatrix& indices_mat) { |
33 | // Search in the index range [begin, end) of indices_mat. |
34 | Tindices begin = sparse_index_begin; |
35 | Tindices end = indices_mat.dimension(0); |
36 | const Tindices orig_sparse_index_end = end; |
37 | |
38 | // The first dense row we search. |
39 | const Tindices orig_dense_index_begin = indices_mat(begin, 0); |
40 | // Early exit if no next dense row index. |
41 | if (orig_dense_index_begin == static_cast<int64_t>(indices_mat(end - 1, 0))) { |
42 | return orig_sparse_index_end; |
43 | } |
44 | |
45 | Tindices increment = 1; |
46 | while (begin + increment < end && |
47 | indices_mat(begin + increment, 0) == orig_dense_index_begin) { |
48 | increment *= 2; |
49 | } |
50 | // Narrow the search space as an optimization. |
51 | if (begin + increment < end) { |
52 | end = begin + increment; |
53 | } |
54 | begin += increment / 2; |
55 | |
56 | // Perform a binary search on the interval [begin, end) for |
57 | // dense_row_index_to_find. |
58 | const Tindices dense_row_index_to_find = orig_dense_index_begin; |
59 | while (begin < end) { |
60 | const Tindices m = begin + (end - begin) / 2; |
61 | const Tindices m_dense_row_index = static_cast<Tindices>(indices_mat(m, 0)); |
62 | if (m_dense_row_index == dense_row_index_to_find && |
63 | (m + 1 == orig_sparse_index_end || |
64 | static_cast<Tindices>(indices_mat(m + 1, 0)) != |
65 | dense_row_index_to_find)) { |
66 | return m + 1; |
67 | } else if (m_dense_row_index <= dense_row_index_to_find) { |
68 | begin = m + 1; |
69 | } else { |
70 | end = m; |
71 | } |
72 | } |
73 | |
74 | // No next dense row index. |
75 | return orig_sparse_index_end; |
76 | } |
77 | |
78 | template <typename Tindices> |
79 | std::vector<Tindices> GetStartIndicesOfEachDenseRow( |
80 | const typename TTypes<Tindices>::ConstMatrix& indices_mat, |
81 | bool* contains_empty_rows) { |
82 | int64_t start_sparse_index_of_cur_dense_row = 0; |
83 | std::vector<Tindices> segment_indices; |
84 | const Tindices num_entries_in_sparse_tensor = indices_mat.dimension(0); |
85 | const Tindices num_dense_rows_in_sparse_tensor = |
86 | 1 + indices_mat(num_entries_in_sparse_tensor - 1, 0); |
87 | // Reserve an extra slot for the 0 we store in the first entry by convention. |
88 | segment_indices.reserve(1 + num_dense_rows_in_sparse_tensor); |
89 | segment_indices.push_back(0); |
90 | for (Tindices i = 0; i < indices_mat(0, 0); ++i) { |
91 | segment_indices.push_back(0); |
92 | } |
93 | *contains_empty_rows = indices_mat(0, 0) > 0; |
94 | while (true) { |
95 | const Tindices start_sparse_index_of_next_dense_row = |
96 | FindNextDenseRowStartIndex<Tindices>( |
97 | start_sparse_index_of_cur_dense_row, indices_mat); |
98 | if (start_sparse_index_of_next_dense_row == num_entries_in_sparse_tensor) { |
99 | segment_indices.push_back(start_sparse_index_of_next_dense_row); |
100 | break; |
101 | } |
102 | // Encode the length of the current dense row as well as the lengths of all |
103 | // the empty rows until the next dense row, |
104 | for (Tindices i = 0; |
105 | i < indices_mat(start_sparse_index_of_next_dense_row, 0) - |
106 | indices_mat(start_sparse_index_of_cur_dense_row, 0); |
107 | ++i) { |
108 | segment_indices.push_back(start_sparse_index_of_next_dense_row); |
109 | } |
110 | // If there is more than one row between the current and next non-empty |
111 | // rows then those rows are empty. |
112 | *contains_empty_rows |= |
113 | indices_mat(start_sparse_index_of_next_dense_row, 0) - |
114 | indices_mat(start_sparse_index_of_cur_dense_row, 0) > |
115 | 1; |
116 | start_sparse_index_of_cur_dense_row = start_sparse_index_of_next_dense_row; |
117 | } |
118 | return segment_indices; |
119 | } |
120 | |
121 | template <typename Tindices> |
122 | std::vector<Tindices> ParseRowStartIndices( |
123 | const tensorflow::Tensor& tensor, |
124 | const Tindices num_nonzero_entries_in_sparse_mat) { |
125 | std::vector<Tindices> out; |
126 | auto vec = tensor.vec<Tindices>(); |
127 | out.reserve(vec.size() + 1); |
128 | for (size_t i = 0; i < vec.dimension(0); ++i) { |
129 | out.push_back(vec(i)); |
130 | } |
131 | out.push_back(num_nonzero_entries_in_sparse_mat); |
132 | return out; |
133 | } |
134 | |
135 | template <typename Tindices> |
136 | bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) { |
137 | // Skip checking the length of the last dense row since it is |
138 | // always non-empty. |
139 | for (size_t i = 1; i < row_start_indices.size() - 1; ++i) { |
140 | if (row_start_indices.at(i) - row_start_indices.at(i - 1) == 0) { |
141 | return true; |
142 | } |
143 | } |
144 | return false; |
145 | } |
146 | |
147 | namespace { |
148 | |
149 | // Ensures indices, values, shape are all of the proper ranks and are |
150 | // compatible. |
151 | Status ValidateSparseTensorShape(const Tensor& indices, const Tensor& values, |
152 | const Tensor& shape) { |
153 | // Indices must be a matrix, and values/shape must be a vector. |
154 | if (!TensorShapeUtils::IsMatrix(indices.shape())) { |
155 | return errors::InvalidArgument("Sparse indices must be rank 2 but is rank " , |
156 | indices.shape().dim_sizes().size()); |
157 | } |
158 | if (!TensorShapeUtils::IsVector(values.shape())) { |
159 | return errors::InvalidArgument("Sparse values must be rank 1 but is rank " , |
160 | values.shape().dims()); |
161 | } |
162 | if (!TensorShapeUtils::IsVector(shape.shape())) { |
163 | return errors::InvalidArgument("Sparse shape must be rank 1 but is rank " , |
164 | shape.shape().dims()); |
165 | } |
166 | // Indices shape must be compatible with the values vector and dense shape. |
167 | int64_t nnz = indices.dim_size(0); |
168 | int64_t ndims = indices.dim_size(1); |
169 | if (values.dim_size(0) != nnz) { |
170 | return errors::InvalidArgument("Number of elements in indices (" , nnz, |
171 | ") and values (" , values.dim_size(0), |
172 | ") do not match" ); |
173 | } |
174 | if (shape.NumElements() != ndims) { |
175 | return errors::InvalidArgument("Index rank (" , ndims, ") and shape rank (" , |
176 | shape.NumElements(), ") do not match" ); |
177 | } |
178 | |
179 | return OkStatus(); |
180 | } |
181 | |
182 | // Creates a debug string for the index tuple in indices(row, :). |
183 | template <typename IndexTensor> |
184 | string CreateIndexString(const IndexTensor& indices, int64_t row) { |
185 | const int64_t ndims = indices.dimension(1); |
186 | string index_str = strings::StrCat("indices[" , row, ", :] = [" ); |
187 | for (int64_t dim = 0; dim < ndims; ++dim) { |
188 | strings::StrAppend(&index_str, indices(row, dim), |
189 | dim < ndims - 1 ? ", " : "]" ); |
190 | } |
191 | if (ndims == 0) { |
192 | strings::StrAppend(&index_str, "]" ); |
193 | } |
194 | return index_str; |
195 | } |
196 | |
197 | // Ensures all sparse indices are within correct bounds. |
198 | template <typename Tindices> |
199 | Status ValidateSparseTensorIndicesUnordered(const Tensor& indices, |
200 | const Tensor& shape) { |
201 | // Ensure no index is out-of-bounds. |
202 | const auto indices_mat = indices.flat_inner_dims<Tindices>(); |
203 | const auto shape_vec = shape.flat<Tindices>(); |
204 | int64_t nnz = indices.dim_size(0); |
205 | int64_t ndims = indices.dim_size(1); |
206 | |
207 | for (int64_t i = 0; i < nnz; ++i) { |
208 | for (int64_t dim = 0; dim < ndims; ++dim) { |
209 | const Tindices idx = indices_mat(i, dim); |
210 | if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) { |
211 | string index_str = CreateIndexString(indices_mat, i); |
212 | return errors::InvalidArgument("Sparse index tuple " , index_str, |
213 | " is out of bounds" ); |
214 | } |
215 | } |
216 | } |
217 | |
218 | return OkStatus(); |
219 | } |
220 | |
221 | // Ensures all sparse indices are within correct bounds and are |
222 | // lexicographically ordered. |
223 | template <typename Tindices> |
224 | Status ValidateSparseTensorIndicesOrdered(const Tensor& indices, |
225 | const Tensor& shape) { |
226 | const auto indices_mat = indices.flat_inner_dims<Tindices>(); |
227 | const auto shape_vec = shape.flat<Tindices>(); |
228 | int64_t nnz = indices.dim_size(0); |
229 | int64_t ndims = indices.dim_size(1); |
230 | |
231 | if (nnz == 0) { |
232 | return OkStatus(); |
233 | } |
234 | |
235 | // First set of indices must be within range. |
236 | for (int64_t dim = 0; dim < ndims; ++dim) { |
237 | const Tindices idx = indices_mat(0, dim); |
238 | if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) { |
239 | string index_str = CreateIndexString(indices_mat, 0); |
240 | return errors::InvalidArgument("Sparse index tuple " , index_str, |
241 | " is out of bounds" ); |
242 | } |
243 | } |
244 | |
245 | // Remaining set of indices must be within range and lexicographically |
246 | // larger than the previous. |
247 | for (int64_t i = 1; i < nnz; ++i) { |
248 | bool different = false; |
249 | for (int64_t dim = 0; dim < ndims; ++dim) { |
250 | const Tindices idx = indices_mat(i, dim); |
251 | const Tindices prev_idx = indices_mat(i - 1, dim); |
252 | // If indices are already different from previous i, the new index can |
253 | // be anything within the valid range. |
254 | if (TF_PREDICT_TRUE(different)) { |
255 | if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) { |
256 | string index_str = CreateIndexString(indices_mat, i); |
257 | return errors::InvalidArgument("Sparse index tuple " , index_str, |
258 | " is out of bounds" ); |
259 | } |
260 | } else { |
261 | // Otherwise, the new index must be >= previous and <= shape(dim). |
262 | if (TF_PREDICT_FALSE(idx < prev_idx || idx >= shape_vec(dim))) { |
263 | string index_str = CreateIndexString(indices_mat, i); |
264 | // Check if index is actually out of bounds. |
265 | if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) { |
266 | return errors::InvalidArgument("Sparse index tuple " , index_str, |
267 | " is out of bounds" ); |
268 | } else { |
269 | return errors::InvalidArgument("Sparse index tuple " , index_str, |
270 | " is out of order" ); |
271 | } |
272 | } else if (TF_PREDICT_TRUE(idx > prev_idx)) { |
273 | different = true; |
274 | } |
275 | } // if (different) |
276 | } // for dim in [0, ndims) |
277 | |
278 | if (TF_PREDICT_FALSE(!different)) { |
279 | string index_str = CreateIndexString(indices_mat, i); |
280 | return errors::InvalidArgument("Sparse index tuple " , index_str, |
281 | " is repeated" ); |
282 | } |
283 | } // for i in [1, nnz) |
284 | |
285 | return OkStatus(); |
286 | } |
287 | |
288 | } // namespace |
289 | |
290 | template <typename Tindices> |
291 | Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, |
292 | const Tensor& shape, |
293 | IndexValidation index_validation) { |
294 | TF_RETURN_IF_ERROR(ValidateSparseTensorShape(indices, values, shape)); |
295 | switch (index_validation) { |
296 | case IndexValidation::kOrdered: |
297 | return ValidateSparseTensorIndicesOrdered<Tindices>(indices, shape); |
298 | case IndexValidation::kUnordered: |
299 | return ValidateSparseTensorIndicesUnordered<Tindices>(indices, shape); |
300 | case IndexValidation::kNone: { |
301 | } |
302 | } |
303 | return OkStatus(); |
304 | } |
305 | |
306 | #define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex) \ |
307 | template TypeIndex FindNextDenseRowStartIndex<TypeIndex>( \ |
308 | const TypeIndex sparse_index_begin, \ |
309 | const TTypes<TypeIndex>::ConstMatrix& indices_mat); \ |
310 | template std::vector<TypeIndex> GetStartIndicesOfEachDenseRow<TypeIndex>( \ |
311 | const TTypes<TypeIndex>::ConstMatrix& indices_mat, \ |
312 | bool* contains_empty_rows); \ |
313 | template bool ContainsEmptyRows<TypeIndex>( \ |
314 | const std::vector<TypeIndex>& row_start_indices); \ |
315 | template std::vector<TypeIndex> ParseRowStartIndices<TypeIndex>( \ |
316 | const tensorflow::Tensor& tensor, \ |
317 | const TypeIndex num_nonzero_entries_in_sparse_mat); \ |
318 | template Status ValidateSparseTensor<TypeIndex>( \ |
319 | const Tensor& indices, const Tensor& values, const Tensor& shape, \ |
320 | IndexValidation index_validation) |
321 | |
322 | REGISTER_SPARSE_UTIL_FUNCTIONS(int32); |
323 | REGISTER_SPARSE_UTIL_FUNCTIONS(int64); |
324 | REGISTER_SPARSE_UTIL_FUNCTIONS(uint8); |
325 | REGISTER_SPARSE_UTIL_FUNCTIONS(uint16); |
326 | REGISTER_SPARSE_UTIL_FUNCTIONS(uint32); |
327 | REGISTER_SPARSE_UTIL_FUNCTIONS(uint64); |
328 | |
329 | } // namespace sparse_utils |
330 | } // namespace tensorflow |
331 | |