1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_kernel.h" |
17 | #include "tensorflow/core/framework/op_requires.h" |
18 | #include "tensorflow/core/framework/register_types.h" |
19 | #include "tensorflow/core/framework/tensor.h" |
20 | #include "tensorflow/core/framework/tensor_util.h" |
21 | #include "tensorflow/core/framework/types.h" |
22 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | template <typename T, typename Treal> |
27 | class SparseAddOp : public OpKernel { |
28 | public: |
29 | explicit SparseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} |
30 | |
31 | void Compute(OpKernelContext *ctx) override { |
32 | // (0) validations |
33 | const Tensor *a_indices, *b_indices, *a_values_t, *b_values_t, *a_shape, |
34 | *b_shape, *thresh_t; |
35 | |
36 | OP_REQUIRES_OK(ctx, ctx->input("a_indices" , &a_indices)); |
37 | OP_REQUIRES_OK(ctx, ctx->input("b_indices" , &b_indices)); |
38 | OP_REQUIRES(ctx, |
39 | TensorShapeUtils::IsMatrix(a_indices->shape()) && |
40 | TensorShapeUtils::IsMatrix(b_indices->shape()), |
41 | errors::InvalidArgument( |
42 | "Input indices should be matrices but received shapes: " , |
43 | a_indices->shape().DebugString(), " and " , |
44 | b_indices->shape().DebugString())); |
45 | const int64_t a_nnz = a_indices->dim_size(0); |
46 | const int64_t b_nnz = b_indices->dim_size(0); |
47 | const int num_dims = a_indices->dim_size(1); |
48 | OP_REQUIRES(ctx, b_indices->dim_size(1) == num_dims, |
49 | errors::InvalidArgument( |
50 | "Input indices must have the same dimension, got " , |
51 | num_dims, " and " , b_indices->dim_size(1))); |
52 | |
53 | OP_REQUIRES_OK(ctx, ctx->input("a_values" , &a_values_t)); |
54 | OP_REQUIRES_OK(ctx, ctx->input("b_values" , &b_values_t)); |
55 | |
56 | OP_REQUIRES(ctx, |
57 | TensorShapeUtils::IsVector(a_values_t->shape()) && |
58 | TensorShapeUtils::IsVector(b_values_t->shape()), |
59 | errors::InvalidArgument( |
60 | "Input values should be vectors but received shapes: " , |
61 | a_values_t->shape().DebugString(), " and " , |
62 | b_values_t->shape().DebugString())); |
63 | auto a_values = ctx->input(1).vec<T>(); |
64 | auto b_values = ctx->input(4).vec<T>(); |
65 | OP_REQUIRES( |
66 | ctx, a_values.size() == a_nnz && b_values.size() == b_nnz, |
67 | errors::InvalidArgument("Expected " , a_nnz, " and " , b_nnz, |
68 | " non-empty input values, got " , |
69 | a_values.size(), " and " , b_values.size())); |
70 | |
71 | OP_REQUIRES_OK(ctx, ctx->input("a_shape" , &a_shape)); |
72 | OP_REQUIRES_OK(ctx, ctx->input("b_shape" , &b_shape)); |
73 | OP_REQUIRES(ctx, |
74 | TensorShapeUtils::IsVector(a_shape->shape()) && |
75 | TensorShapeUtils::IsVector(b_shape->shape()), |
76 | errors::InvalidArgument( |
77 | "Input shapes should be a vector but received shapes " , |
78 | a_shape->shape().DebugString(), " and " , |
79 | b_shape->shape().DebugString())); |
80 | OP_REQUIRES( |
81 | ctx, a_shape->NumElements() == num_dims, |
82 | errors::InvalidArgument("Second dimension of a_indices and length of " |
83 | "a_shape must match, got " , |
84 | num_dims, " and " , a_shape->NumElements())); |
85 | OP_REQUIRES(ctx, num_dims > 0, |
86 | errors::InvalidArgument("Tesors must not be empty" )); |
87 | OP_REQUIRES( |
88 | ctx, a_shape->IsSameSize(*b_shape), |
89 | errors::InvalidArgument( |
90 | "Operands do not have the same ranks; got shapes: " , |
91 | a_shape->SummarizeValue(10), " and " , b_shape->SummarizeValue(10))); |
92 | const auto a_shape_flat = a_shape->flat<int64_t>(); |
93 | const auto b_shape_flat = b_shape->flat<int64_t>(); |
94 | for (int i = 0; i < a_shape->NumElements(); ++i) { |
95 | OP_REQUIRES(ctx, a_shape_flat(i) == b_shape_flat(i), |
96 | errors::InvalidArgument( |
97 | "Operands' shapes do not match: got " , a_shape_flat(i), |
98 | " and " , b_shape_flat(i), " for dimension " , i)); |
99 | } |
100 | |
101 | OP_REQUIRES_OK(ctx, ctx->input("thresh" , &thresh_t)); |
102 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(thresh_t->shape()), |
103 | errors::InvalidArgument( |
104 | "The magnitude threshold must be a scalar: got shape " , |
105 | thresh_t->shape().DebugString())); |
106 | // std::abs() so that it works for complex{64,128} values as well |
107 | const Treal thresh = thresh_t->scalar<Treal>()(); |
108 | |
109 | // (1) do a pass over inputs, and append values and indices to vectors |
110 | auto a_indices_mat = a_indices->matrix<int64_t>(); |
111 | auto b_indices_mat = b_indices->matrix<int64_t>(); |
112 | std::vector<std::pair<bool, int64>> entries_to_copy; // from_a?, idx |
113 | entries_to_copy.reserve(a_nnz + b_nnz); |
114 | std::vector<T> out_values; |
115 | |
116 | // The input and output sparse tensors are assumed to be ordered along |
117 | // increasing dimension number. |
118 | int64_t i = 0, j = 0; |
119 | T s; |
120 | while (i < a_nnz && j < b_nnz) { |
121 | switch (sparse::DimComparator::cmp(a_indices_mat, b_indices_mat, i, j, |
122 | num_dims)) { |
123 | case -1: |
124 | entries_to_copy.emplace_back(true, i); |
125 | out_values.push_back(a_values(i)); |
126 | ++i; |
127 | break; |
128 | case 0: |
129 | s = a_values(i) + b_values(j); |
130 | if (thresh <= std::abs(s)) { |
131 | entries_to_copy.emplace_back(true, i); |
132 | out_values.push_back(s); |
133 | } |
134 | ++i; |
135 | ++j; |
136 | break; |
137 | case 1: |
138 | entries_to_copy.emplace_back(false, j); |
139 | out_values.push_back(b_values(j)); |
140 | ++j; |
141 | break; |
142 | } |
143 | } |
144 | |
145 | #define HANDLE_LEFTOVERS(A_OR_B, IDX, IS_A) \ |
146 | while (IDX < A_OR_B##_nnz) { \ |
147 | entries_to_copy.emplace_back(IS_A, IDX); \ |
148 | out_values.push_back(A_OR_B##_values(IDX)); \ |
149 | ++IDX; \ |
150 | } |
151 | |
152 | // at most one of these calls appends new values |
153 | HANDLE_LEFTOVERS(a, i, true); |
154 | HANDLE_LEFTOVERS(b, j, false); |
155 | #undef HANDLE_LEFTOVERS |
156 | |
157 | // (2) allocate and fill output tensors |
158 | const int64_t sum_nnz = out_values.size(); |
159 | Tensor *out_indices_t, *out_values_t; |
160 | OP_REQUIRES_OK(ctx, |
161 | ctx->allocate_output(0, TensorShape({sum_nnz, num_dims}), |
162 | &out_indices_t)); |
163 | OP_REQUIRES_OK( |
164 | ctx, ctx->allocate_output(1, TensorShape({sum_nnz}), &out_values_t)); |
165 | auto out_indices_mat = out_indices_t->matrix<int64_t>(); |
166 | auto out_values_flat = out_values_t->vec<T>(); |
167 | |
168 | for (i = 0; i < sum_nnz; ++i) { |
169 | const bool from_a = entries_to_copy[i].first; |
170 | const int64_t idx = entries_to_copy[i].second; |
171 | out_indices_mat.chip<0>(i) = |
172 | from_a ? a_indices_mat.chip<0>(idx) : b_indices_mat.chip<0>(idx); |
173 | } |
174 | if (sum_nnz > 0) { |
175 | std::copy_n(out_values.begin(), sum_nnz, &out_values_flat(0)); |
176 | } |
177 | ctx->set_output(2, *a_shape); |
178 | } |
179 | }; |
180 | |
181 | #define REGISTER_KERNELS(type, thresh_type) \ |
182 | REGISTER_KERNEL_BUILDER( \ |
183 | Name("SparseAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
184 | SparseAddOp<type, thresh_type>) |
185 | |
186 | // The list below is equivalent to TF_CALL_REAL_NUMBER_TYPES, minus uint8. This |
187 | // is because std::abs() on uint8 does not compile. |
188 | REGISTER_KERNELS(float, float); |
189 | REGISTER_KERNELS(double, double); |
190 | REGISTER_KERNELS(int64_t, int64); |
191 | REGISTER_KERNELS(int32, int32); |
192 | REGISTER_KERNELS(int16, int16); |
193 | REGISTER_KERNELS(int8, int8); |
194 | REGISTER_KERNELS(complex64, float); |
195 | REGISTER_KERNELS(complex128, double); |
196 | #undef REGISTER_KERNELS |
197 | } // namespace tensorflow |
198 | |