1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op_kernel.h"
17#include "tensorflow/core/framework/op_requires.h"
18#include "tensorflow/core/framework/register_types.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/framework/tensor_util.h"
21#include "tensorflow/core/framework/types.h"
22#include "tensorflow/core/util/sparse/sparse_tensor.h"
23
24namespace tensorflow {
25
26template <typename T, typename Treal>
27class SparseAddOp : public OpKernel {
28 public:
29 explicit SparseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}
30
31 void Compute(OpKernelContext *ctx) override {
32 // (0) validations
33 const Tensor *a_indices, *b_indices, *a_values_t, *b_values_t, *a_shape,
34 *b_shape, *thresh_t;
35
36 OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices));
37 OP_REQUIRES_OK(ctx, ctx->input("b_indices", &b_indices));
38 OP_REQUIRES(ctx,
39 TensorShapeUtils::IsMatrix(a_indices->shape()) &&
40 TensorShapeUtils::IsMatrix(b_indices->shape()),
41 errors::InvalidArgument(
42 "Input indices should be matrices but received shapes: ",
43 a_indices->shape().DebugString(), " and ",
44 b_indices->shape().DebugString()));
45 const int64_t a_nnz = a_indices->dim_size(0);
46 const int64_t b_nnz = b_indices->dim_size(0);
47 const int num_dims = a_indices->dim_size(1);
48 OP_REQUIRES(ctx, b_indices->dim_size(1) == num_dims,
49 errors::InvalidArgument(
50 "Input indices must have the same dimension, got ",
51 num_dims, " and ", b_indices->dim_size(1)));
52
53 OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
54 OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t));
55
56 OP_REQUIRES(ctx,
57 TensorShapeUtils::IsVector(a_values_t->shape()) &&
58 TensorShapeUtils::IsVector(b_values_t->shape()),
59 errors::InvalidArgument(
60 "Input values should be vectors but received shapes: ",
61 a_values_t->shape().DebugString(), " and ",
62 b_values_t->shape().DebugString()));
63 auto a_values = ctx->input(1).vec<T>();
64 auto b_values = ctx->input(4).vec<T>();
65 OP_REQUIRES(
66 ctx, a_values.size() == a_nnz && b_values.size() == b_nnz,
67 errors::InvalidArgument("Expected ", a_nnz, " and ", b_nnz,
68 " non-empty input values, got ",
69 a_values.size(), " and ", b_values.size()));
70
71 OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
72 OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape));
73 OP_REQUIRES(ctx,
74 TensorShapeUtils::IsVector(a_shape->shape()) &&
75 TensorShapeUtils::IsVector(b_shape->shape()),
76 errors::InvalidArgument(
77 "Input shapes should be a vector but received shapes ",
78 a_shape->shape().DebugString(), " and ",
79 b_shape->shape().DebugString()));
80 OP_REQUIRES(
81 ctx, a_shape->NumElements() == num_dims,
82 errors::InvalidArgument("Second dimension of a_indices and length of "
83 "a_shape must match, got ",
84 num_dims, " and ", a_shape->NumElements()));
85 OP_REQUIRES(ctx, num_dims > 0,
86 errors::InvalidArgument("Tesors must not be empty"));
87 OP_REQUIRES(
88 ctx, a_shape->IsSameSize(*b_shape),
89 errors::InvalidArgument(
90 "Operands do not have the same ranks; got shapes: ",
91 a_shape->SummarizeValue(10), " and ", b_shape->SummarizeValue(10)));
92 const auto a_shape_flat = a_shape->flat<int64_t>();
93 const auto b_shape_flat = b_shape->flat<int64_t>();
94 for (int i = 0; i < a_shape->NumElements(); ++i) {
95 OP_REQUIRES(ctx, a_shape_flat(i) == b_shape_flat(i),
96 errors::InvalidArgument(
97 "Operands' shapes do not match: got ", a_shape_flat(i),
98 " and ", b_shape_flat(i), " for dimension ", i));
99 }
100
101 OP_REQUIRES_OK(ctx, ctx->input("thresh", &thresh_t));
102 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(thresh_t->shape()),
103 errors::InvalidArgument(
104 "The magnitude threshold must be a scalar: got shape ",
105 thresh_t->shape().DebugString()));
106 // std::abs() so that it works for complex{64,128} values as well
107 const Treal thresh = thresh_t->scalar<Treal>()();
108
109 // (1) do a pass over inputs, and append values and indices to vectors
110 auto a_indices_mat = a_indices->matrix<int64_t>();
111 auto b_indices_mat = b_indices->matrix<int64_t>();
112 std::vector<std::pair<bool, int64>> entries_to_copy; // from_a?, idx
113 entries_to_copy.reserve(a_nnz + b_nnz);
114 std::vector<T> out_values;
115
116 // The input and output sparse tensors are assumed to be ordered along
117 // increasing dimension number.
118 int64_t i = 0, j = 0;
119 T s;
120 while (i < a_nnz && j < b_nnz) {
121 switch (sparse::DimComparator::cmp(a_indices_mat, b_indices_mat, i, j,
122 num_dims)) {
123 case -1:
124 entries_to_copy.emplace_back(true, i);
125 out_values.push_back(a_values(i));
126 ++i;
127 break;
128 case 0:
129 s = a_values(i) + b_values(j);
130 if (thresh <= std::abs(s)) {
131 entries_to_copy.emplace_back(true, i);
132 out_values.push_back(s);
133 }
134 ++i;
135 ++j;
136 break;
137 case 1:
138 entries_to_copy.emplace_back(false, j);
139 out_values.push_back(b_values(j));
140 ++j;
141 break;
142 }
143 }
144
145#define HANDLE_LEFTOVERS(A_OR_B, IDX, IS_A) \
146 while (IDX < A_OR_B##_nnz) { \
147 entries_to_copy.emplace_back(IS_A, IDX); \
148 out_values.push_back(A_OR_B##_values(IDX)); \
149 ++IDX; \
150 }
151
152 // at most one of these calls appends new values
153 HANDLE_LEFTOVERS(a, i, true);
154 HANDLE_LEFTOVERS(b, j, false);
155#undef HANDLE_LEFTOVERS
156
157 // (2) allocate and fill output tensors
158 const int64_t sum_nnz = out_values.size();
159 Tensor *out_indices_t, *out_values_t;
160 OP_REQUIRES_OK(ctx,
161 ctx->allocate_output(0, TensorShape({sum_nnz, num_dims}),
162 &out_indices_t));
163 OP_REQUIRES_OK(
164 ctx, ctx->allocate_output(1, TensorShape({sum_nnz}), &out_values_t));
165 auto out_indices_mat = out_indices_t->matrix<int64_t>();
166 auto out_values_flat = out_values_t->vec<T>();
167
168 for (i = 0; i < sum_nnz; ++i) {
169 const bool from_a = entries_to_copy[i].first;
170 const int64_t idx = entries_to_copy[i].second;
171 out_indices_mat.chip<0>(i) =
172 from_a ? a_indices_mat.chip<0>(idx) : b_indices_mat.chip<0>(idx);
173 }
174 if (sum_nnz > 0) {
175 std::copy_n(out_values.begin(), sum_nnz, &out_values_flat(0));
176 }
177 ctx->set_output(2, *a_shape);
178 }
179};
180
181#define REGISTER_KERNELS(type, thresh_type) \
182 REGISTER_KERNEL_BUILDER( \
183 Name("SparseAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
184 SparseAddOp<type, thresh_type>)
185
186// The list below is equivalent to TF_CALL_REAL_NUMBER_TYPES, minus uint8. This
187// is because std::abs() on uint8 does not compile.
188REGISTER_KERNELS(float, float);
189REGISTER_KERNELS(double, double);
190REGISTER_KERNELS(int64_t, int64);
191REGISTER_KERNELS(int32, int32);
192REGISTER_KERNELS(int16, int16);
193REGISTER_KERNELS(int8, int8);
194REGISTER_KERNELS(complex64, float);
195REGISTER_KERNELS(complex128, double);
196#undef REGISTER_KERNELS
197} // namespace tensorflow
198