1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Ops for operating with sets. They are not checked in |
17 | // to TensorFlow because we would first like to demonstrate successful |
18 | // end-to-end use of these ops in eval and polish the api a bit like taking two |
19 | // SparseTensor rather than on edense and one sparse. |
20 | |
21 | #define EIGEN_USE_THREADS |
22 | |
23 | #include <algorithm> |
24 | #include <numeric> |
25 | #include <string> |
26 | #include <utility> |
27 | #include <vector> |
28 | |
29 | #include "absl/container/btree_set.h" |
30 | #include "absl/container/flat_hash_set.h" |
31 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
32 | #include "tensorflow/core/framework/op_kernel.h" |
33 | #include "tensorflow/core/framework/register_types.h" |
34 | #include "tensorflow/core/framework/tensor.h" |
35 | #include "tensorflow/core/framework/tensor_util.h" |
36 | #include "tensorflow/core/framework/types.h" |
37 | #include "tensorflow/core/lib/core/status.h" |
38 | #include "tensorflow/core/platform/env.h" |
39 | #include "tensorflow/core/platform/errors.h" |
40 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
41 | |
42 | namespace tensorflow { |
43 | |
44 | using ShapeArray = sparse::SparseTensor::ShapeArray; |
45 | using VarDimArray = sparse::SparseTensor::VarDimArray; |
46 | |
47 | // Validate rank >= 2. |
48 | void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) { |
49 | const auto rank = shape.dims(); |
50 | OP_REQUIRES(ctx, rank >= 2, |
51 | errors::InvalidArgument("Invalid rank " , rank, "." )); |
52 | } |
53 | |
54 | // Return group shape, which is the 1st n-1 dimensions of shape. |
55 | Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) { |
56 | if (input_shape.size() < 2) { |
57 | // TODO(irving): Why can't 2 be 1 here? |
58 | return errors::InvalidArgument("Shape [" , absl::StrJoin(input_shape, "," ), |
59 | "] has rank " , input_shape.size(), " < 2" ); |
60 | } |
61 | // grouped_shape is input_shape[:-1] |
62 | *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1); |
63 | return OkStatus(); |
64 | } |
65 | |
66 | // Build `SparseTensor` from indices, values, and shape in inputs |
67 | // [base_index, base_index + 3), and validate its rank and indices. |
68 | Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index, |
69 | const bool validate_indices, |
70 | sparse::SparseTensor* tensor) { |
71 | // Assume row-major order. |
72 | TensorShape shape; |
73 | const Tensor& shape_tensor = ctx->input(base_index + 2); |
74 | if (shape_tensor.dims() != 1) { |
75 | return errors::InvalidArgument("Shape must be a 1D tensor." ); |
76 | } |
77 | TF_RETURN_IF_ERROR( |
78 | TensorShape::BuildTensorShape(shape_tensor.vec<int64_t>(), &shape)); |
79 | CheckRankAtLeast2(ctx, shape); |
80 | std::vector<int64_t> order(shape.dims()); |
81 | std::iota(order.begin(), order.end(), 0); |
82 | |
83 | Status status = sparse::SparseTensor::Create( |
84 | ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor); |
85 | |
86 | if (!validate_indices || !status.ok()) return status; |
87 | return tensor->IndicesValid(); |
88 | } |
89 | |
90 | // TODO(ptucker): CheckGroup is just a sanity check on the result of |
91 | // SparseTensor.group, consider removing. |
92 | // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group |
93 | // was created, and is used to sanity check the indices in `group'. |
94 | template <typename T> |
95 | void CheckGroup(OpKernelContext* ctx, const sparse::Group& group, |
96 | const VarDimArray& sparse_tensor_shape) { |
97 | const auto& indices = group.indices(); |
98 | const auto& values = group.values<T>(); |
99 | |
100 | // Sanity check: group is non-empty, and indices and values are same size. |
101 | const auto num_values = values.dimension(0); |
102 | OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group." )); |
103 | OP_REQUIRES( |
104 | ctx, indices.dimension(0) == num_values, |
105 | errors::Internal("shape[0] of group indices " , indices.dimension(0), |
106 | " != values " , num_values, "." )); |
107 | |
108 | // Sanity check: valid indices. |
109 | const auto group_rank = indices.dimension(1); |
110 | const auto expected_rank = sparse_tensor_shape.size(); |
111 | OP_REQUIRES(ctx, expected_rank == group_rank, |
112 | errors::Internal("Rank expected " , expected_rank, ", got " , |
113 | group_rank, "." )); |
114 | for (int32_t j = 0; j < expected_rank; ++j) { |
115 | const auto dim_size = sparse_tensor_shape[j]; |
116 | OP_REQUIRES( |
117 | ctx, dim_size > 0, |
118 | errors::Internal("Invalid dim_size[" , j, "] = " , dim_size, "." )); |
119 | for (int64_t i = 0; i < num_values; ++i) { |
120 | const auto index = indices(i, j); |
121 | OP_REQUIRES(ctx, dim_size > index, |
122 | errors::Internal("indices[" , i, ", " , j, "] expected < " , |
123 | dim_size, ", got " , index, "." )); |
124 | } |
125 | } |
126 | } |
127 | |
128 | // This lets us calculate the row-major index into flattened output. |
129 | const ShapeArray Strides(const VarDimArray& shape) { |
130 | ShapeArray result(shape.size()); |
131 | int64_t product = 1; |
132 | for (int i = shape.size() - 1; i >= 0; --i) { |
133 | result[i] = product; |
134 | product *= shape[i]; |
135 | } |
136 | return result; |
137 | } |
138 | |
139 | // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to |
140 | // eliminate the intermediate `values` data structure - iterate once to |
141 | // determine `num_values`, allocate output tensors, then write results directly |
142 | // to output tensors. |
143 | |
144 | // TODO(ptucker): Consider sharding work across multiple threads. See |
145 | // SparseCrossOp for an example. |
146 | |
147 | // Output `SparseTensor` of shape `output_shape`. `sets` contains pairs of |
148 | // group indices (i.e., values for all but the last dimension of `output_shape`) |
149 | // and set values, each of which will occupy the last dimension of |
150 | // `output_shape`. `sets` should be sorted in ascending order by group indices. |
151 | template <typename T> |
152 | void OutputSparseTensor( |
153 | OpKernelContext* ctx, const TensorShape& output_shape, |
154 | const int64_t num_values, |
155 | const std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>>& |
156 | sets) { |
157 | // Allocate 3 output tensors for sparse data. |
158 | Tensor *out_indices_t, *out_values_t, *out_shape_t; |
159 | OP_REQUIRES_OK(ctx, ctx->allocate_output( |
160 | 0, TensorShape({num_values, output_shape.dims()}), |
161 | &out_indices_t)); |
162 | OP_REQUIRES_OK( |
163 | ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t)); |
164 | OP_REQUIRES_OK(ctx, ctx->allocate_output( |
165 | 2, TensorShape({output_shape.dims()}), &out_shape_t)); |
166 | auto out_indices_mat = out_indices_t->matrix<int64_t>(); |
167 | auto out_values_flat = out_values_t->vec<T>(); |
168 | |
169 | // For each set, write its indices and values to output tensors. |
170 | int64_t value_index = 0; |
171 | for (auto it = sets.begin(); it != sets.end(); ++it) { |
172 | const auto& group_indices = it->first; |
173 | OP_REQUIRES( |
174 | ctx, group_indices.size() == output_shape.dims() - 1, |
175 | errors::Internal("Invalid number of indices " , group_indices.size(), |
176 | ", expected " , output_shape.dims() - 1, "." )); |
177 | const auto& set = it->second; |
178 | |
179 | // For each set item, write its indices and value to output tensors. |
180 | int64_t group_value_index = 0; |
181 | for (auto value = set.begin(); value != set.end(); |
182 | ++value, ++value_index, ++group_value_index) { |
183 | // First n-1 dimensions are the group, last dimension is the position in |
184 | // the set. |
185 | for (int32_t i = 0; i < group_indices.size(); ++i) { |
186 | out_indices_mat(value_index, i) = group_indices[i]; |
187 | } |
188 | out_indices_mat(value_index, group_indices.size()) = group_value_index; |
189 | |
190 | out_values_flat(value_index) = *value; |
191 | } |
192 | } |
193 | |
194 | // Write output shape. |
195 | auto out_shape_flat = out_shape_t->vec<int64_t>(); |
196 | for (int32_t i = 0; i < output_shape.dims(); ++i) { |
197 | out_shape_flat(i) = output_shape.dim_size(i); |
198 | } |
199 | } |
200 | |
201 | bool ValidateIndicesFromContext(OpKernelConstruction* ctx) { |
202 | bool result; |
203 | if (ctx->GetAttr("validate_indices" , &result).ok()) { |
204 | return result; |
205 | } |
206 | return true; |
207 | } |
208 | |
209 | // Populate `result` set from group in `tensor`. "Group" is defined by |
210 | // `group_indices`, which are values for the first n-1 dimensions of |
211 | // `input_tensor`. `input_strides` is provided to avoid recalculating it |
212 | // multiple times, and is used to calculate the flat index into `input_tensor` |
213 | // values. |
214 | template <typename T> |
215 | void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor, |
216 | const VarDimArray& input_strides, |
217 | const std::vector<int64_t>& group_indices, |
218 | absl::flat_hash_set<T>* result) { |
219 | OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1, |
220 | errors::Internal("group_indices.size " , group_indices.size(), |
221 | ", != input_strides.size-1 " , |
222 | input_strides.size() - 1, "." )); |
223 | result->clear(); |
224 | auto input_flat = input_tensor.flat<T>(); |
225 | const auto start = std::inner_product( |
226 | group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL); |
227 | const TensorShape& input_shape = input_tensor.shape(); |
228 | const auto end = start + input_shape.dim_size(input_shape.dims() - 1); |
229 | for (int64_t i = start; i < end; ++i) { |
230 | result->insert(input_flat(i)); |
231 | } |
232 | } |
233 | |
234 | // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the |
235 | // `SparseTensor` from which group was created, and is used to sanity check the |
236 | // indices in `group'. |
237 | template <typename T> |
238 | void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group, |
239 | const VarDimArray& sparse_tensor_shape, |
240 | absl::flat_hash_set<T>* result) { |
241 | CheckGroup<T>(ctx, group, sparse_tensor_shape); |
242 | result->clear(); |
243 | const auto& group_values = group.values<T>(); |
244 | for (int64_t i = 0; i < group_values.size(); ++i) { |
245 | result->insert(group_values(i)); |
246 | } |
247 | } |
248 | |
249 | template <typename T> |
250 | class SetSizeOp : public OpKernel { |
251 | public: |
252 | explicit SetSizeOp(OpKernelConstruction* ctx) |
253 | : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {} |
254 | |
255 | void Compute(OpKernelContext* ctx) override; |
256 | |
257 | private: |
258 | const bool validate_indices_; |
259 | }; |
260 | |
261 | template <typename T> |
262 | void SetSizeOp<T>::Compute(OpKernelContext* ctx) { |
263 | sparse::SparseTensor set_st; |
264 | OP_REQUIRES_OK(ctx, |
265 | SparseTensorFromContext(ctx, 0, validate_indices_, &set_st)); |
266 | |
267 | // Output shape is same as input except for last dimension, which reduces |
268 | // to the set size of values along that dimension. |
269 | ShapeArray output_shape; |
270 | OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape)); |
271 | const auto output_strides = Strides(output_shape); |
272 | |
273 | TensorShape output_shape_ts; |
274 | OP_REQUIRES_OK(ctx, |
275 | TensorShapeUtils::MakeShape(output_shape, &output_shape_ts)); |
276 | Tensor* out_t; |
277 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t)); |
278 | auto out = out_t->flat<int32>(); |
279 | out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0)); |
280 | |
281 | // Group by all but last dimension, create a set of group values, and add set |
282 | // size to output. |
283 | VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1); |
284 | absl::flat_hash_set<T> group_set; |
285 | for (const auto& group : set_st.group(group_ix)) { |
286 | PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set); |
287 | |
288 | const auto group_key = group.group(); |
289 | const auto output_index = std::inner_product( |
290 | group_key.begin(), group_key.end(), output_strides.begin(), 0LL); |
291 | out(output_index) = group_set.size(); |
292 | } |
293 | } |
294 | |
295 | #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T) \ |
296 | REGISTER_KERNEL_BUILDER( \ |
297 | Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
298 | SetSizeOp<T>); |
299 | _SET_SIZE_REGISTER_KERNEL_BUILDER(int8); |
300 | _SET_SIZE_REGISTER_KERNEL_BUILDER(int16); |
301 | _SET_SIZE_REGISTER_KERNEL_BUILDER(int32); |
302 | _SET_SIZE_REGISTER_KERNEL_BUILDER(int64_t); |
303 | _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8); |
304 | _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16); |
305 | _SET_SIZE_REGISTER_KERNEL_BUILDER(tstring); |
306 | #undef _SET_SIZE_REGISTER_KERNEL_BUILDER |
307 | |
308 | enum InputTypes { |
309 | DENSE_DENSE = 0, |
310 | DENSE_SPARSE = 1, |
311 | SPARSE_SPARSE = 2, |
312 | }; |
313 | |
314 | enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 }; |
315 | |
316 | SetOperation SetOperationFromContext(OpKernelConstruction* ctx) { |
317 | string set_operation_str; |
318 | if (!ctx->GetAttr("set_operation" , &set_operation_str).ok()) { |
319 | ctx->CtxFailure(errors::InvalidArgument("Missing set_operation." )); |
320 | } else { |
321 | std::transform(set_operation_str.begin(), set_operation_str.end(), |
322 | set_operation_str.begin(), ::tolower); |
323 | if ("a-b" == set_operation_str) { |
324 | return A_MINUS_B; |
325 | } |
326 | if ("b-a" == set_operation_str) { |
327 | return B_MINUS_A; |
328 | } |
329 | if ("intersection" == set_operation_str) { |
330 | return INTERSECTION; |
331 | } |
332 | if ("union" != set_operation_str) { |
333 | ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation " , |
334 | set_operation_str, "." )); |
335 | } |
336 | } |
337 | // NOTE: This is not the default, this function fails if no 'set_operation' |
338 | // attribute is provided. |
339 | return UNION; |
340 | } |
341 | |
342 | // Abstract base class for performing set operations across the last dimension |
343 | // of 2 input tensors. |
344 | template <typename T> |
345 | class SetOperationOp : public OpKernel { |
346 | public: |
347 | SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types) |
348 | : OpKernel(ctx), |
349 | set_operation_(SetOperationFromContext(ctx)), |
350 | validate_indices_(ValidateIndicesFromContext(ctx)), |
351 | input_types_(input_types) {} |
352 | |
353 | void Compute(OpKernelContext* ctx) override; |
354 | |
355 | private: |
356 | void ApplySetOperation(const absl::flat_hash_set<T>& set1, |
357 | const absl::flat_hash_set<T>& set2, |
358 | absl::btree_set<T>* result) const; |
359 | void ComputeDenseToDense(OpKernelContext* ctx) const; |
360 | void ComputeDenseToSparse(OpKernelContext* ctx) const; |
361 | void ComputeSparseToSparse(OpKernelContext* ctx) const; |
362 | const SetOperation set_operation_; |
363 | const bool validate_indices_; |
364 | const InputTypes input_types_; |
365 | }; |
366 | |
367 | template <typename T> |
368 | void SetDifference(const absl::flat_hash_set<T>& set1, |
369 | const absl::flat_hash_set<T>& set2, |
370 | absl::btree_set<T>* result) { |
371 | for (const T& elem : set1) { |
372 | if (!set2.contains(elem)) result->insert(elem); |
373 | } |
374 | } |
375 | |
376 | template <typename T> |
377 | void SetIntersection(const absl::flat_hash_set<T>& set1, |
378 | const absl::flat_hash_set<T>& set2, |
379 | absl::btree_set<T>* result) { |
380 | if (set1.size() <= set2.size()) { |
381 | for (const T& elem : set1) { |
382 | if (set2.contains(elem)) result->insert(elem); |
383 | } |
384 | } else { |
385 | for (const T& elem : set2) { |
386 | if (set1.contains(elem)) result->insert(elem); |
387 | } |
388 | } |
389 | } |
390 | |
391 | template <typename T> |
392 | void SetUnion(const absl::flat_hash_set<T>& set1, |
393 | const absl::flat_hash_set<T>& set2, absl::btree_set<T>* result) { |
394 | result->insert(set1.begin(), set1.end()); |
395 | result->insert(set2.begin(), set2.end()); |
396 | } |
397 | |
398 | template <typename T> |
399 | void SetOperationOp<T>::ApplySetOperation(const absl::flat_hash_set<T>& set1, |
400 | const absl::flat_hash_set<T>& set2, |
401 | absl::btree_set<T>* result) const { |
402 | switch (set_operation_) { |
403 | case A_MINUS_B: |
404 | SetDifference<T>(set1, set2, result); |
405 | break; |
406 | case B_MINUS_A: |
407 | SetDifference<T>(set2, set1, result); |
408 | break; |
409 | case INTERSECTION: |
410 | SetIntersection<T>(set1, set2, result); |
411 | break; |
412 | case UNION: |
413 | SetUnion<T>(set1, set2, result); |
414 | break; |
415 | } |
416 | } |
417 | |
418 | // Validate shapes have the same dimensions. |
419 | Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) { |
420 | if (shape1 != shape2) { |
421 | return errors::InvalidArgument("Mismatched shapes [" , |
422 | absl::StrJoin(shape1, "," ), "] vs [" , |
423 | absl::StrJoin(shape2, "," ), "]" ); |
424 | } |
425 | return OkStatus(); |
426 | } |
427 | |
428 | // Validate ranks are the same, and all but last dimension are the same. |
429 | // Return GroupShape. |
430 | Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2, |
431 | ShapeArray* group_shape) { |
432 | ShapeArray group_shape_1; |
433 | TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1)); |
434 | ShapeArray group_shape_2; |
435 | TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2)); |
436 | TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2)); |
437 | *group_shape = group_shape_1; |
438 | return OkStatus(); |
439 | } |
440 | |
441 | // Split `flat_group_index` into separate dimensions based on `group_shape`. |
442 | void PopulateGroupIndices(const int64_t flat_group_index, |
443 | VarDimArray group_shape, |
444 | std::vector<int64_t>* group_indices) { |
445 | group_indices->clear(); |
446 | int64_t running_flat_group_index = flat_group_index; |
447 | for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0; |
448 | --group_dim_index) { |
449 | const auto group_dim = group_shape[group_dim_index]; |
450 | group_indices->insert(group_indices->begin(), |
451 | running_flat_group_index % group_dim); |
452 | running_flat_group_index /= group_dim; |
453 | } |
454 | } |
455 | |
456 | ShapeArray TensorShapeToArray(const TensorShape& t) { |
457 | ShapeArray vec(t.dims()); |
458 | for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i); |
459 | return vec; |
460 | } |
461 | |
462 | // `ctx` contains set1 and set2 dense tensors. |
463 | // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, |
464 | // and outputting the result `SparseTensor`. A "group" is a collection of values |
465 | // with the same first n-1 dimensions in set1 and set2. |
466 | template <typename T> |
467 | void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const { |
468 | const Tensor& set1_t = ctx->input(0); |
469 | const Tensor& set2_t = ctx->input(1); |
470 | // The following should stay in sync with `_dense_to_dense_shape` shape |
471 | // assertions in python/ops/set_ops.py, and `SetShapeFn` for |
472 | // `DenseToDenseSetOperation` in ops/set_ops.cc. |
473 | ShapeArray group_shape; |
474 | const auto shape1 = TensorShapeToArray(set1_t.shape()); |
475 | const auto shape2 = TensorShapeToArray(set2_t.shape()); |
476 | OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape)); |
477 | |
478 | const auto set1_strides = Strides(shape1); |
479 | const auto set2_strides = Strides(shape2); |
480 | |
481 | std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets; |
482 | int64_t num_result_values = 0; |
483 | int64_t max_set_size = 0; |
484 | |
485 | absl::flat_hash_set<T> set1_group_set; |
486 | absl::flat_hash_set<T> set2_group_set; |
487 | std::vector<int64_t> group_indices; |
488 | int64_t num_elements; |
489 | OP_REQUIRES_OK(ctx, |
490 | TensorShapeUtils::NumElements(group_shape, &num_elements)); |
491 | for (int64_t flat_group_index = 0; flat_group_index < num_elements; |
492 | ++flat_group_index) { |
493 | PopulateGroupIndices(flat_group_index, group_shape, &group_indices); |
494 | PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices, |
495 | &set1_group_set); |
496 | PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices, |
497 | &set2_group_set); |
498 | |
499 | absl::btree_set<T> group_set; |
500 | ApplySetOperation(set1_group_set, set2_group_set, &group_set); |
501 | if (!group_set.empty()) { |
502 | const auto set_size = group_set.size(); |
503 | if (set_size > max_set_size) { |
504 | max_set_size = set_size; |
505 | } |
506 | num_result_values += set_size; |
507 | group_sets.push_back({group_indices, std::move(group_set)}); |
508 | } |
509 | } |
510 | |
511 | TensorShape output_shape; |
512 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape)); |
513 | output_shape.AddDim(max_set_size); |
514 | OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); |
515 | } |
516 | |
517 | // `ctx` contains dense set1 and sparse set2 tensors. |
518 | // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, |
519 | // and outputing the result `SparseTensor`. A "group" is a collection of values |
520 | // with the same first n-1 dimensions in set1 and set2. |
521 | template <typename T> |
522 | void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const { |
523 | const Tensor& set1_t = ctx->input(0); |
524 | sparse::SparseTensor set2_st; |
525 | OP_REQUIRES_OK(ctx, |
526 | SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st)); |
527 | // The following should stay in sync with `_dense_to_sparse_shape` shape |
528 | // assertions in python/ops/set_ops.py, and `SetShapeFn` for |
529 | // `DenseToSparseSetOperation` in ops/set_ops.cc. |
530 | ShapeArray group_shape; |
531 | OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()), |
532 | set2_st.shape(), &group_shape)); |
533 | |
534 | const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape())); |
535 | |
536 | std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets; |
537 | int64_t num_result_values = 0; |
538 | int64_t max_set_size = 0; |
539 | |
540 | absl::flat_hash_set<T> set1_group_set; |
541 | absl::flat_hash_set<T> set2_group_set; |
542 | auto set2_grouper = |
543 | set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1)); |
544 | auto set2_group_it = set2_grouper.begin(); |
545 | std::vector<int64_t> group_indices; |
546 | int64_t num_elements; |
547 | OP_REQUIRES_OK(ctx, |
548 | TensorShapeUtils::NumElements(group_shape, &num_elements)); |
549 | for (int64_t flat_group_index = 0; flat_group_index < num_elements; |
550 | ++flat_group_index) { |
551 | PopulateGroupIndices(flat_group_index, group_shape, &group_indices); |
552 | |
553 | // Get values from set1. |
554 | PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices, |
555 | &set1_group_set); |
556 | |
557 | // Get values from set2, if applicable. |
558 | set2_group_set.clear(); |
559 | if (set2_group_it != set2_grouper.end()) { |
560 | const auto& group = *set2_group_it; |
561 | const auto set2_group_indices = group.group(); |
562 | OP_REQUIRES( |
563 | ctx, set2_group_indices.size() == group_indices.size(), |
564 | errors::InvalidArgument("Invalid number of group indices " , |
565 | set2_group_indices.size(), ", expected " , |
566 | group_indices.size(), "." )); |
567 | bool group_match = true; |
568 | for (int32_t i = 0; group_match && (i < set2_group_indices.size()); ++i) { |
569 | if (set2_group_indices[i] != group_indices[i]) { |
570 | group_match = false; |
571 | } |
572 | } |
573 | if (group_match) { |
574 | PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(), |
575 | &set2_group_set); |
576 | ++set2_group_it; |
577 | } |
578 | } |
579 | |
580 | absl::btree_set<T> group_set; |
581 | ApplySetOperation(set1_group_set, set2_group_set, &group_set); |
582 | if (!group_set.empty()) { |
583 | const auto set_size = group_set.size(); |
584 | if (set_size > max_set_size) { |
585 | max_set_size = set_size; |
586 | } |
587 | num_result_values += set_size; |
588 | group_sets.push_back({group_indices, std::move(group_set)}); |
589 | } |
590 | } |
591 | |
592 | TensorShape output_shape; |
593 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape)); |
594 | output_shape.AddDim(max_set_size); |
595 | OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); |
596 | } |
597 | |
598 | // This is used to determine which group iterator is less than the other, based |
599 | // on row-major ordering of indices. |
600 | // An empty index list indicates end of iteration, which is interpreted as "max" |
601 | // for the purposes of comparison; i.e., non-empty < empty. |
602 | // Return 0 if both groups are empty, or both non-empty with the same values. |
603 | // Return <0 if set1 <= set2, or set2 is empty. |
604 | // Return >0 if set2 <= set1, or set1 is empty. |
605 | void CompareGroups(OpKernelContext* ctx, |
606 | const std::vector<int64_t>& set1_group_indices, |
607 | const std::vector<int64_t>& set2_group_indices, |
608 | int64_t* result) { |
609 | if (set1_group_indices.empty()) { |
610 | *result = set2_group_indices.empty() ? 0 : 1; |
611 | return; |
612 | } |
613 | if (set2_group_indices.empty()) { |
614 | *result = set1_group_indices.empty() ? 0 : -1; |
615 | return; |
616 | } |
617 | OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(), |
618 | errors::InvalidArgument("Mismatched group dims " , |
619 | set1_group_indices.size(), " vs " , |
620 | set2_group_indices.size(), "." )); |
621 | for (int32_t i = 0; i < set1_group_indices.size(); ++i) { |
622 | *result = set1_group_indices[i] - set2_group_indices[i]; |
623 | if (*result != 0) { |
624 | return; |
625 | } |
626 | } |
627 | } |
628 | |
629 | // `ctx` contains set1 and set2 sparse tensors. |
630 | // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, |
631 | // and outputing the result `SparseTensor`. A "group" is a collection of values |
632 | // with the same first n-1 dimensions in set1 and set2. |
633 | template <typename T> |
634 | void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const { |
635 | sparse::SparseTensor set1_st; |
636 | OP_REQUIRES_OK(ctx, |
637 | SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st)); |
638 | |
639 | sparse::SparseTensor set2_st; |
640 | OP_REQUIRES_OK(ctx, |
641 | SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st)); |
642 | |
643 | // The following should stay in sync with `_sparse_to_sparse_shape` shape |
644 | // assertions in python/ops/set_ops.py, and `SetShapeFn` for |
645 | // `SparseToSparseSetOperation` in ops/set_ops.cc. |
646 | ShapeArray group_shape; |
647 | OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(), |
648 | &group_shape)); |
649 | |
650 | std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets; |
651 | int64_t num_result_values = 0; |
652 | int64_t max_set_size = 0; |
653 | |
654 | absl::flat_hash_set<T> set1_group_set; |
655 | absl::flat_hash_set<T> set2_group_set; |
656 | auto set1_grouper = |
657 | set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1)); |
658 | auto set1_group_it = set1_grouper.begin(); |
659 | auto set2_grouper = |
660 | set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1)); |
661 | auto set2_group_it = set2_grouper.begin(); |
662 | |
663 | // Empty indices vector represents iteration end in `CompareGroups`. |
664 | const std::vector<int64_t> group_iter_end; |
665 | // Group by rows, and iterate over rows of both sets in parallel, creating a |
666 | // set for each row. |
667 | while ((set1_group_it != set1_grouper.end()) || |
668 | (set2_group_it != set2_grouper.end())) { |
669 | const std::vector<int64_t>& set1_group_indices = |
670 | (set1_group_it == set1_grouper.end()) ? group_iter_end |
671 | : (*set1_group_it).group(); |
672 | const std::vector<int64_t>& set2_group_indices = |
673 | (set2_group_it == set2_grouper.end()) ? group_iter_end |
674 | : (*set2_group_it).group(); |
675 | |
676 | int64_t compare_groups; |
677 | CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups); |
678 | const std::vector<int64_t>* group_indices = nullptr; |
679 | |
680 | // Get values from set1, if applicable. |
681 | set1_group_set.clear(); |
682 | if (compare_groups <= 0) { |
683 | PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(), |
684 | &set1_group_set); |
685 | ++set1_group_it; |
686 | group_indices = &set1_group_indices; |
687 | } |
688 | |
689 | // Get values from set2, if applicable. |
690 | set2_group_set.clear(); |
691 | if (compare_groups >= 0) { |
692 | PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(), |
693 | &set2_group_set); |
694 | ++set2_group_it; |
695 | group_indices = &set2_group_indices; |
696 | } |
697 | |
698 | absl::btree_set<T> group_set; |
699 | ApplySetOperation(set1_group_set, set2_group_set, &group_set); |
700 | if (!group_set.empty()) { |
701 | const auto set_size = group_set.size(); |
702 | if (set_size > max_set_size) { |
703 | max_set_size = set_size; |
704 | } |
705 | num_result_values += set_size; |
706 | group_sets.push_back({*group_indices, std::move(group_set)}); |
707 | } |
708 | } |
709 | |
710 | TensorShape output_shape; |
711 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape)); |
712 | output_shape.AddDim(max_set_size); |
713 | OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); |
714 | } |
715 | |
716 | // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result |
717 | // sparse tensor with [b, n3] values, where each row `i` contains the result of |
718 | // the set operation on elements from set1[i] and set2[i]. `n3` is the number |
719 | // of elements in that result row. |
720 | template <typename T> |
721 | void SetOperationOp<T>::Compute(OpKernelContext* ctx) { |
722 | switch (input_types_) { |
723 | case DENSE_DENSE: |
724 | ComputeDenseToDense(ctx); |
725 | break; |
726 | case DENSE_SPARSE: |
727 | ComputeDenseToSparse(ctx); |
728 | break; |
729 | case SPARSE_SPARSE: |
730 | ComputeSparseToSparse(ctx); |
731 | break; |
732 | } |
733 | } |
734 | |
735 | template <typename T> |
736 | class DenseToDenseSetOperationOp : public SetOperationOp<T> { |
737 | public: |
738 | explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx) |
739 | : SetOperationOp<T>(ctx, DENSE_DENSE) {} |
740 | }; |
741 | |
742 | #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ |
743 | REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation") \ |
744 | .Device(DEVICE_CPU) \ |
745 | .TypeConstraint<T>("T"), \ |
746 | DenseToDenseSetOperationOp<T>); |
747 | _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); |
748 | _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); |
749 | _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); |
750 | _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t); |
751 | _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); |
752 | _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); |
753 | _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring); |
754 | #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER |
755 | |
756 | template <typename T> |
757 | class DenseToSparseSetOperationOp : public SetOperationOp<T> { |
758 | public: |
759 | explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx) |
760 | : SetOperationOp<T>(ctx, DENSE_SPARSE) {} |
761 | }; |
762 | |
763 | #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ |
764 | REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation") \ |
765 | .Device(DEVICE_CPU) \ |
766 | .TypeConstraint<T>("T"), \ |
767 | DenseToSparseSetOperationOp<T>); |
768 | _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); |
769 | _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); |
770 | _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); |
771 | _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t); |
772 | _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); |
773 | _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); |
774 | _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring); |
775 | #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER |
776 | |
777 | template <typename T> |
778 | class SparseToSparseSetOperationOp : public SetOperationOp<T> { |
779 | public: |
780 | explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx) |
781 | : SetOperationOp<T>(ctx, SPARSE_SPARSE) {} |
782 | }; |
783 | |
784 | #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ |
785 | REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation") \ |
786 | .Device(DEVICE_CPU) \ |
787 | .TypeConstraint<T>("T"), \ |
788 | SparseToSparseSetOperationOp<T>); |
789 | _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); |
790 | _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); |
791 | _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); |
792 | _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t); |
793 | _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); |
794 | _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); |
795 | _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring); |
796 | #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER |
797 | |
798 | } // namespace tensorflow |
799 | |