1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Ops for operating with sets. They are not checked in
17// to TensorFlow because we would first like to demonstrate successful
18// end-to-end use of these ops in eval and polish the api a bit like taking two
19// SparseTensor rather than on edense and one sparse.
20
21#define EIGEN_USE_THREADS
22
23#include <algorithm>
24#include <numeric>
25#include <string>
26#include <utility>
27#include <vector>
28
29#include "absl/container/btree_set.h"
30#include "absl/container/flat_hash_set.h"
31#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/register_types.h"
34#include "tensorflow/core/framework/tensor.h"
35#include "tensorflow/core/framework/tensor_util.h"
36#include "tensorflow/core/framework/types.h"
37#include "tensorflow/core/lib/core/status.h"
38#include "tensorflow/core/platform/env.h"
39#include "tensorflow/core/platform/errors.h"
40#include "tensorflow/core/util/sparse/sparse_tensor.h"
41
42namespace tensorflow {
43
44using ShapeArray = sparse::SparseTensor::ShapeArray;
45using VarDimArray = sparse::SparseTensor::VarDimArray;
46
47// Validate rank >= 2.
48void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) {
49 const auto rank = shape.dims();
50 OP_REQUIRES(ctx, rank >= 2,
51 errors::InvalidArgument("Invalid rank ", rank, "."));
52}
53
54// Return group shape, which is the 1st n-1 dimensions of shape.
55Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
56 if (input_shape.size() < 2) {
57 // TODO(irving): Why can't 2 be 1 here?
58 return errors::InvalidArgument("Shape [", absl::StrJoin(input_shape, ","),
59 "] has rank ", input_shape.size(), " < 2");
60 }
61 // grouped_shape is input_shape[:-1]
62 *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1);
63 return OkStatus();
64}
65
66// Build `SparseTensor` from indices, values, and shape in inputs
67// [base_index, base_index + 3), and validate its rank and indices.
68Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index,
69 const bool validate_indices,
70 sparse::SparseTensor* tensor) {
71 // Assume row-major order.
72 TensorShape shape;
73 const Tensor& shape_tensor = ctx->input(base_index + 2);
74 if (shape_tensor.dims() != 1) {
75 return errors::InvalidArgument("Shape must be a 1D tensor.");
76 }
77 TF_RETURN_IF_ERROR(
78 TensorShape::BuildTensorShape(shape_tensor.vec<int64_t>(), &shape));
79 CheckRankAtLeast2(ctx, shape);
80 std::vector<int64_t> order(shape.dims());
81 std::iota(order.begin(), order.end(), 0);
82
83 Status status = sparse::SparseTensor::Create(
84 ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
85
86 if (!validate_indices || !status.ok()) return status;
87 return tensor->IndicesValid();
88}
89
90// TODO(ptucker): CheckGroup is just a sanity check on the result of
91// SparseTensor.group, consider removing.
92// `sparse_tensor_shape` is the shape of the `SparseTensor` from which group
93// was created, and is used to sanity check the indices in `group'.
94template <typename T>
95void CheckGroup(OpKernelContext* ctx, const sparse::Group& group,
96 const VarDimArray& sparse_tensor_shape) {
97 const auto& indices = group.indices();
98 const auto& values = group.values<T>();
99
100 // Sanity check: group is non-empty, and indices and values are same size.
101 const auto num_values = values.dimension(0);
102 OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group."));
103 OP_REQUIRES(
104 ctx, indices.dimension(0) == num_values,
105 errors::Internal("shape[0] of group indices ", indices.dimension(0),
106 " != values ", num_values, "."));
107
108 // Sanity check: valid indices.
109 const auto group_rank = indices.dimension(1);
110 const auto expected_rank = sparse_tensor_shape.size();
111 OP_REQUIRES(ctx, expected_rank == group_rank,
112 errors::Internal("Rank expected ", expected_rank, ", got ",
113 group_rank, "."));
114 for (int32_t j = 0; j < expected_rank; ++j) {
115 const auto dim_size = sparse_tensor_shape[j];
116 OP_REQUIRES(
117 ctx, dim_size > 0,
118 errors::Internal("Invalid dim_size[", j, "] = ", dim_size, "."));
119 for (int64_t i = 0; i < num_values; ++i) {
120 const auto index = indices(i, j);
121 OP_REQUIRES(ctx, dim_size > index,
122 errors::Internal("indices[", i, ", ", j, "] expected < ",
123 dim_size, ", got ", index, "."));
124 }
125 }
126}
127
128// This lets us calculate the row-major index into flattened output.
129const ShapeArray Strides(const VarDimArray& shape) {
130 ShapeArray result(shape.size());
131 int64_t product = 1;
132 for (int i = shape.size() - 1; i >= 0; --i) {
133 result[i] = product;
134 product *= shape[i];
135 }
136 return result;
137}
138
139// TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to
140// eliminate the intermediate `values` data structure - iterate once to
141// determine `num_values`, allocate output tensors, then write results directly
142// to output tensors.
143
144// TODO(ptucker): Consider sharding work across multiple threads. See
145// SparseCrossOp for an example.
146
147// Output `SparseTensor` of shape `output_shape`. `sets` contains pairs of
148// group indices (i.e., values for all but the last dimension of `output_shape`)
149// and set values, each of which will occupy the last dimension of
150// `output_shape`. `sets` should be sorted in ascending order by group indices.
151template <typename T>
152void OutputSparseTensor(
153 OpKernelContext* ctx, const TensorShape& output_shape,
154 const int64_t num_values,
155 const std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>>&
156 sets) {
157 // Allocate 3 output tensors for sparse data.
158 Tensor *out_indices_t, *out_values_t, *out_shape_t;
159 OP_REQUIRES_OK(ctx, ctx->allocate_output(
160 0, TensorShape({num_values, output_shape.dims()}),
161 &out_indices_t));
162 OP_REQUIRES_OK(
163 ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t));
164 OP_REQUIRES_OK(ctx, ctx->allocate_output(
165 2, TensorShape({output_shape.dims()}), &out_shape_t));
166 auto out_indices_mat = out_indices_t->matrix<int64_t>();
167 auto out_values_flat = out_values_t->vec<T>();
168
169 // For each set, write its indices and values to output tensors.
170 int64_t value_index = 0;
171 for (auto it = sets.begin(); it != sets.end(); ++it) {
172 const auto& group_indices = it->first;
173 OP_REQUIRES(
174 ctx, group_indices.size() == output_shape.dims() - 1,
175 errors::Internal("Invalid number of indices ", group_indices.size(),
176 ", expected ", output_shape.dims() - 1, "."));
177 const auto& set = it->second;
178
179 // For each set item, write its indices and value to output tensors.
180 int64_t group_value_index = 0;
181 for (auto value = set.begin(); value != set.end();
182 ++value, ++value_index, ++group_value_index) {
183 // First n-1 dimensions are the group, last dimension is the position in
184 // the set.
185 for (int32_t i = 0; i < group_indices.size(); ++i) {
186 out_indices_mat(value_index, i) = group_indices[i];
187 }
188 out_indices_mat(value_index, group_indices.size()) = group_value_index;
189
190 out_values_flat(value_index) = *value;
191 }
192 }
193
194 // Write output shape.
195 auto out_shape_flat = out_shape_t->vec<int64_t>();
196 for (int32_t i = 0; i < output_shape.dims(); ++i) {
197 out_shape_flat(i) = output_shape.dim_size(i);
198 }
199}
200
201bool ValidateIndicesFromContext(OpKernelConstruction* ctx) {
202 bool result;
203 if (ctx->GetAttr("validate_indices", &result).ok()) {
204 return result;
205 }
206 return true;
207}
208
209// Populate `result` set from group in `tensor`. "Group" is defined by
210// `group_indices`, which are values for the first n-1 dimensions of
211// `input_tensor`. `input_strides` is provided to avoid recalculating it
212// multiple times, and is used to calculate the flat index into `input_tensor`
213// values.
214template <typename T>
215void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
216 const VarDimArray& input_strides,
217 const std::vector<int64_t>& group_indices,
218 absl::flat_hash_set<T>* result) {
219 OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1,
220 errors::Internal("group_indices.size ", group_indices.size(),
221 ", != input_strides.size-1 ",
222 input_strides.size() - 1, "."));
223 result->clear();
224 auto input_flat = input_tensor.flat<T>();
225 const auto start = std::inner_product(
226 group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
227 const TensorShape& input_shape = input_tensor.shape();
228 const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
229 for (int64_t i = start; i < end; ++i) {
230 result->insert(input_flat(i));
231 }
232}
233
234// Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the
235// `SparseTensor` from which group was created, and is used to sanity check the
236// indices in `group'.
237template <typename T>
238void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group,
239 const VarDimArray& sparse_tensor_shape,
240 absl::flat_hash_set<T>* result) {
241 CheckGroup<T>(ctx, group, sparse_tensor_shape);
242 result->clear();
243 const auto& group_values = group.values<T>();
244 for (int64_t i = 0; i < group_values.size(); ++i) {
245 result->insert(group_values(i));
246 }
247}
248
249template <typename T>
250class SetSizeOp : public OpKernel {
251 public:
252 explicit SetSizeOp(OpKernelConstruction* ctx)
253 : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {}
254
255 void Compute(OpKernelContext* ctx) override;
256
257 private:
258 const bool validate_indices_;
259};
260
261template <typename T>
262void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
263 sparse::SparseTensor set_st;
264 OP_REQUIRES_OK(ctx,
265 SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
266
267 // Output shape is same as input except for last dimension, which reduces
268 // to the set size of values along that dimension.
269 ShapeArray output_shape;
270 OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
271 const auto output_strides = Strides(output_shape);
272
273 TensorShape output_shape_ts;
274 OP_REQUIRES_OK(ctx,
275 TensorShapeUtils::MakeShape(output_shape, &output_shape_ts));
276 Tensor* out_t;
277 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t));
278 auto out = out_t->flat<int32>();
279 out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0));
280
281 // Group by all but last dimension, create a set of group values, and add set
282 // size to output.
283 VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
284 absl::flat_hash_set<T> group_set;
285 for (const auto& group : set_st.group(group_ix)) {
286 PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
287
288 const auto group_key = group.group();
289 const auto output_index = std::inner_product(
290 group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
291 out(output_index) = group_set.size();
292 }
293}
294
295#define _SET_SIZE_REGISTER_KERNEL_BUILDER(T) \
296 REGISTER_KERNEL_BUILDER( \
297 Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
298 SetSizeOp<T>);
299_SET_SIZE_REGISTER_KERNEL_BUILDER(int8);
300_SET_SIZE_REGISTER_KERNEL_BUILDER(int16);
301_SET_SIZE_REGISTER_KERNEL_BUILDER(int32);
302_SET_SIZE_REGISTER_KERNEL_BUILDER(int64_t);
303_SET_SIZE_REGISTER_KERNEL_BUILDER(uint8);
304_SET_SIZE_REGISTER_KERNEL_BUILDER(uint16);
305_SET_SIZE_REGISTER_KERNEL_BUILDER(tstring);
306#undef _SET_SIZE_REGISTER_KERNEL_BUILDER
307
308enum InputTypes {
309 DENSE_DENSE = 0,
310 DENSE_SPARSE = 1,
311 SPARSE_SPARSE = 2,
312};
313
314enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
315
316SetOperation SetOperationFromContext(OpKernelConstruction* ctx) {
317 string set_operation_str;
318 if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) {
319 ctx->CtxFailure(errors::InvalidArgument("Missing set_operation."));
320 } else {
321 std::transform(set_operation_str.begin(), set_operation_str.end(),
322 set_operation_str.begin(), ::tolower);
323 if ("a-b" == set_operation_str) {
324 return A_MINUS_B;
325 }
326 if ("b-a" == set_operation_str) {
327 return B_MINUS_A;
328 }
329 if ("intersection" == set_operation_str) {
330 return INTERSECTION;
331 }
332 if ("union" != set_operation_str) {
333 ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ",
334 set_operation_str, "."));
335 }
336 }
337 // NOTE: This is not the default, this function fails if no 'set_operation'
338 // attribute is provided.
339 return UNION;
340}
341
342// Abstract base class for performing set operations across the last dimension
343// of 2 input tensors.
344template <typename T>
345class SetOperationOp : public OpKernel {
346 public:
347 SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types)
348 : OpKernel(ctx),
349 set_operation_(SetOperationFromContext(ctx)),
350 validate_indices_(ValidateIndicesFromContext(ctx)),
351 input_types_(input_types) {}
352
353 void Compute(OpKernelContext* ctx) override;
354
355 private:
356 void ApplySetOperation(const absl::flat_hash_set<T>& set1,
357 const absl::flat_hash_set<T>& set2,
358 absl::btree_set<T>* result) const;
359 void ComputeDenseToDense(OpKernelContext* ctx) const;
360 void ComputeDenseToSparse(OpKernelContext* ctx) const;
361 void ComputeSparseToSparse(OpKernelContext* ctx) const;
362 const SetOperation set_operation_;
363 const bool validate_indices_;
364 const InputTypes input_types_;
365};
366
367template <typename T>
368void SetDifference(const absl::flat_hash_set<T>& set1,
369 const absl::flat_hash_set<T>& set2,
370 absl::btree_set<T>* result) {
371 for (const T& elem : set1) {
372 if (!set2.contains(elem)) result->insert(elem);
373 }
374}
375
376template <typename T>
377void SetIntersection(const absl::flat_hash_set<T>& set1,
378 const absl::flat_hash_set<T>& set2,
379 absl::btree_set<T>* result) {
380 if (set1.size() <= set2.size()) {
381 for (const T& elem : set1) {
382 if (set2.contains(elem)) result->insert(elem);
383 }
384 } else {
385 for (const T& elem : set2) {
386 if (set1.contains(elem)) result->insert(elem);
387 }
388 }
389}
390
391template <typename T>
392void SetUnion(const absl::flat_hash_set<T>& set1,
393 const absl::flat_hash_set<T>& set2, absl::btree_set<T>* result) {
394 result->insert(set1.begin(), set1.end());
395 result->insert(set2.begin(), set2.end());
396}
397
398template <typename T>
399void SetOperationOp<T>::ApplySetOperation(const absl::flat_hash_set<T>& set1,
400 const absl::flat_hash_set<T>& set2,
401 absl::btree_set<T>* result) const {
402 switch (set_operation_) {
403 case A_MINUS_B:
404 SetDifference<T>(set1, set2, result);
405 break;
406 case B_MINUS_A:
407 SetDifference<T>(set2, set1, result);
408 break;
409 case INTERSECTION:
410 SetIntersection<T>(set1, set2, result);
411 break;
412 case UNION:
413 SetUnion<T>(set1, set2, result);
414 break;
415 }
416}
417
418// Validate shapes have the same dimensions.
419Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) {
420 if (shape1 != shape2) {
421 return errors::InvalidArgument("Mismatched shapes [",
422 absl::StrJoin(shape1, ","), "] vs [",
423 absl::StrJoin(shape2, ","), "]");
424 }
425 return OkStatus();
426}
427
428// Validate ranks are the same, and all but last dimension are the same.
429// Return GroupShape.
430Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2,
431 ShapeArray* group_shape) {
432 ShapeArray group_shape_1;
433 TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1));
434 ShapeArray group_shape_2;
435 TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2));
436 TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2));
437 *group_shape = group_shape_1;
438 return OkStatus();
439}
440
441// Split `flat_group_index` into separate dimensions based on `group_shape`.
442void PopulateGroupIndices(const int64_t flat_group_index,
443 VarDimArray group_shape,
444 std::vector<int64_t>* group_indices) {
445 group_indices->clear();
446 int64_t running_flat_group_index = flat_group_index;
447 for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0;
448 --group_dim_index) {
449 const auto group_dim = group_shape[group_dim_index];
450 group_indices->insert(group_indices->begin(),
451 running_flat_group_index % group_dim);
452 running_flat_group_index /= group_dim;
453 }
454}
455
456ShapeArray TensorShapeToArray(const TensorShape& t) {
457 ShapeArray vec(t.dims());
458 for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i);
459 return vec;
460}
461
462// `ctx` contains set1 and set2 dense tensors.
463// Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
464// and outputting the result `SparseTensor`. A "group" is a collection of values
465// with the same first n-1 dimensions in set1 and set2.
466template <typename T>
467void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
468 const Tensor& set1_t = ctx->input(0);
469 const Tensor& set2_t = ctx->input(1);
470 // The following should stay in sync with `_dense_to_dense_shape` shape
471 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
472 // `DenseToDenseSetOperation` in ops/set_ops.cc.
473 ShapeArray group_shape;
474 const auto shape1 = TensorShapeToArray(set1_t.shape());
475 const auto shape2 = TensorShapeToArray(set2_t.shape());
476 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape));
477
478 const auto set1_strides = Strides(shape1);
479 const auto set2_strides = Strides(shape2);
480
481 std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets;
482 int64_t num_result_values = 0;
483 int64_t max_set_size = 0;
484
485 absl::flat_hash_set<T> set1_group_set;
486 absl::flat_hash_set<T> set2_group_set;
487 std::vector<int64_t> group_indices;
488 int64_t num_elements;
489 OP_REQUIRES_OK(ctx,
490 TensorShapeUtils::NumElements(group_shape, &num_elements));
491 for (int64_t flat_group_index = 0; flat_group_index < num_elements;
492 ++flat_group_index) {
493 PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
494 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
495 &set1_group_set);
496 PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices,
497 &set2_group_set);
498
499 absl::btree_set<T> group_set;
500 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
501 if (!group_set.empty()) {
502 const auto set_size = group_set.size();
503 if (set_size > max_set_size) {
504 max_set_size = set_size;
505 }
506 num_result_values += set_size;
507 group_sets.push_back({group_indices, std::move(group_set)});
508 }
509 }
510
511 TensorShape output_shape;
512 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
513 output_shape.AddDim(max_set_size);
514 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
515}
516
517// `ctx` contains dense set1 and sparse set2 tensors.
518// Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
519// and outputing the result `SparseTensor`. A "group" is a collection of values
520// with the same first n-1 dimensions in set1 and set2.
521template <typename T>
522void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
523 const Tensor& set1_t = ctx->input(0);
524 sparse::SparseTensor set2_st;
525 OP_REQUIRES_OK(ctx,
526 SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
527 // The following should stay in sync with `_dense_to_sparse_shape` shape
528 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
529 // `DenseToSparseSetOperation` in ops/set_ops.cc.
530 ShapeArray group_shape;
531 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()),
532 set2_st.shape(), &group_shape));
533
534 const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape()));
535
536 std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets;
537 int64_t num_result_values = 0;
538 int64_t max_set_size = 0;
539
540 absl::flat_hash_set<T> set1_group_set;
541 absl::flat_hash_set<T> set2_group_set;
542 auto set2_grouper =
543 set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
544 auto set2_group_it = set2_grouper.begin();
545 std::vector<int64_t> group_indices;
546 int64_t num_elements;
547 OP_REQUIRES_OK(ctx,
548 TensorShapeUtils::NumElements(group_shape, &num_elements));
549 for (int64_t flat_group_index = 0; flat_group_index < num_elements;
550 ++flat_group_index) {
551 PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
552
553 // Get values from set1.
554 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
555 &set1_group_set);
556
557 // Get values from set2, if applicable.
558 set2_group_set.clear();
559 if (set2_group_it != set2_grouper.end()) {
560 const auto& group = *set2_group_it;
561 const auto set2_group_indices = group.group();
562 OP_REQUIRES(
563 ctx, set2_group_indices.size() == group_indices.size(),
564 errors::InvalidArgument("Invalid number of group indices ",
565 set2_group_indices.size(), ", expected ",
566 group_indices.size(), "."));
567 bool group_match = true;
568 for (int32_t i = 0; group_match && (i < set2_group_indices.size()); ++i) {
569 if (set2_group_indices[i] != group_indices[i]) {
570 group_match = false;
571 }
572 }
573 if (group_match) {
574 PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(),
575 &set2_group_set);
576 ++set2_group_it;
577 }
578 }
579
580 absl::btree_set<T> group_set;
581 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
582 if (!group_set.empty()) {
583 const auto set_size = group_set.size();
584 if (set_size > max_set_size) {
585 max_set_size = set_size;
586 }
587 num_result_values += set_size;
588 group_sets.push_back({group_indices, std::move(group_set)});
589 }
590 }
591
592 TensorShape output_shape;
593 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
594 output_shape.AddDim(max_set_size);
595 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
596}
597
598// This is used to determine which group iterator is less than the other, based
599// on row-major ordering of indices.
600// An empty index list indicates end of iteration, which is interpreted as "max"
601// for the purposes of comparison; i.e., non-empty < empty.
602// Return 0 if both groups are empty, or both non-empty with the same values.
603// Return <0 if set1 <= set2, or set2 is empty.
604// Return >0 if set2 <= set1, or set1 is empty.
605void CompareGroups(OpKernelContext* ctx,
606 const std::vector<int64_t>& set1_group_indices,
607 const std::vector<int64_t>& set2_group_indices,
608 int64_t* result) {
609 if (set1_group_indices.empty()) {
610 *result = set2_group_indices.empty() ? 0 : 1;
611 return;
612 }
613 if (set2_group_indices.empty()) {
614 *result = set1_group_indices.empty() ? 0 : -1;
615 return;
616 }
617 OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(),
618 errors::InvalidArgument("Mismatched group dims ",
619 set1_group_indices.size(), " vs ",
620 set2_group_indices.size(), "."));
621 for (int32_t i = 0; i < set1_group_indices.size(); ++i) {
622 *result = set1_group_indices[i] - set2_group_indices[i];
623 if (*result != 0) {
624 return;
625 }
626 }
627}
628
629// `ctx` contains set1 and set2 sparse tensors.
630// Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
631// and outputing the result `SparseTensor`. A "group" is a collection of values
632// with the same first n-1 dimensions in set1 and set2.
633template <typename T>
634void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
635 sparse::SparseTensor set1_st;
636 OP_REQUIRES_OK(ctx,
637 SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
638
639 sparse::SparseTensor set2_st;
640 OP_REQUIRES_OK(ctx,
641 SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
642
643 // The following should stay in sync with `_sparse_to_sparse_shape` shape
644 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
645 // `SparseToSparseSetOperation` in ops/set_ops.cc.
646 ShapeArray group_shape;
647 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(),
648 &group_shape));
649
650 std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets;
651 int64_t num_result_values = 0;
652 int64_t max_set_size = 0;
653
654 absl::flat_hash_set<T> set1_group_set;
655 absl::flat_hash_set<T> set2_group_set;
656 auto set1_grouper =
657 set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
658 auto set1_group_it = set1_grouper.begin();
659 auto set2_grouper =
660 set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
661 auto set2_group_it = set2_grouper.begin();
662
663 // Empty indices vector represents iteration end in `CompareGroups`.
664 const std::vector<int64_t> group_iter_end;
665 // Group by rows, and iterate over rows of both sets in parallel, creating a
666 // set for each row.
667 while ((set1_group_it != set1_grouper.end()) ||
668 (set2_group_it != set2_grouper.end())) {
669 const std::vector<int64_t>& set1_group_indices =
670 (set1_group_it == set1_grouper.end()) ? group_iter_end
671 : (*set1_group_it).group();
672 const std::vector<int64_t>& set2_group_indices =
673 (set2_group_it == set2_grouper.end()) ? group_iter_end
674 : (*set2_group_it).group();
675
676 int64_t compare_groups;
677 CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups);
678 const std::vector<int64_t>* group_indices = nullptr;
679
680 // Get values from set1, if applicable.
681 set1_group_set.clear();
682 if (compare_groups <= 0) {
683 PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(),
684 &set1_group_set);
685 ++set1_group_it;
686 group_indices = &set1_group_indices;
687 }
688
689 // Get values from set2, if applicable.
690 set2_group_set.clear();
691 if (compare_groups >= 0) {
692 PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(),
693 &set2_group_set);
694 ++set2_group_it;
695 group_indices = &set2_group_indices;
696 }
697
698 absl::btree_set<T> group_set;
699 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
700 if (!group_set.empty()) {
701 const auto set_size = group_set.size();
702 if (set_size > max_set_size) {
703 max_set_size = set_size;
704 }
705 num_result_values += set_size;
706 group_sets.push_back({*group_indices, std::move(group_set)});
707 }
708 }
709
710 TensorShape output_shape;
711 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
712 output_shape.AddDim(max_set_size);
713 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
714}
715
716// Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result
717// sparse tensor with [b, n3] values, where each row `i` contains the result of
718// the set operation on elements from set1[i] and set2[i]. `n3` is the number
719// of elements in that result row.
720template <typename T>
721void SetOperationOp<T>::Compute(OpKernelContext* ctx) {
722 switch (input_types_) {
723 case DENSE_DENSE:
724 ComputeDenseToDense(ctx);
725 break;
726 case DENSE_SPARSE:
727 ComputeDenseToSparse(ctx);
728 break;
729 case SPARSE_SPARSE:
730 ComputeSparseToSparse(ctx);
731 break;
732 }
733}
734
735template <typename T>
736class DenseToDenseSetOperationOp : public SetOperationOp<T> {
737 public:
738 explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx)
739 : SetOperationOp<T>(ctx, DENSE_DENSE) {}
740};
741
742#define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
743 REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation") \
744 .Device(DEVICE_CPU) \
745 .TypeConstraint<T>("T"), \
746 DenseToDenseSetOperationOp<T>);
747_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
748_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
749_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
750_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t);
751_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
752_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
753_DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
754#undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
755
756template <typename T>
757class DenseToSparseSetOperationOp : public SetOperationOp<T> {
758 public:
759 explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx)
760 : SetOperationOp<T>(ctx, DENSE_SPARSE) {}
761};
762
763#define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
764 REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation") \
765 .Device(DEVICE_CPU) \
766 .TypeConstraint<T>("T"), \
767 DenseToSparseSetOperationOp<T>);
768_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
769_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
770_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
771_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t);
772_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
773_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
774_DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
775#undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
776
777template <typename T>
778class SparseToSparseSetOperationOp : public SetOperationOp<T> {
779 public:
780 explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx)
781 : SetOperationOp<T>(ctx, SPARSE_SPARSE) {}
782};
783
784#define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
785 REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation") \
786 .Device(DEVICE_CPU) \
787 .TypeConstraint<T>("T"), \
788 SparseToSparseSetOperationOp<T>);
789_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
790_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
791_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
792_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t);
793_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
794_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
795_SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
796#undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
797
798} // namespace tensorflow
799