1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/sparse_tensor_dense_add_op.h" |
19 | |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/op_requires.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_util.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | typedef Eigen::ThreadPoolDevice CPUDevice; |
31 | // NOTE: does not support GPU yet. |
32 | |
33 | namespace { |
34 | |
35 | template <typename Index> |
36 | Status ValidateInputs(const Tensor *a_indices, const Tensor *a_values, |
37 | const Tensor *a_shape, const Tensor *b) { |
38 | if (!TensorShapeUtils::IsMatrix(a_indices->shape())) { |
39 | return errors::InvalidArgument( |
40 | "Input a_indices should be a matrix but received shape: " , |
41 | a_indices->shape().DebugString()); |
42 | } |
43 | if (!TensorShapeUtils::IsVector(a_values->shape()) || |
44 | !TensorShapeUtils::IsVector(a_shape->shape())) { |
45 | return errors::InvalidArgument( |
46 | "Inputs a_values and a_shape should be vectors " |
47 | "but received shapes: " , |
48 | a_values->shape().DebugString(), " and " , |
49 | a_shape->shape().DebugString()); |
50 | } |
51 | int64_t nnz = a_indices->dim_size(0); |
52 | int64_t ndims = a_indices->dim_size(1); |
53 | if (a_values->dim_size(0) != nnz) { |
54 | return errors::InvalidArgument("Dimensions " , nnz, " and " , |
55 | a_values->dim_size(0), |
56 | " are not compatible" ); |
57 | } |
58 | if (a_shape->dim_size(0) != ndims) { |
59 | return errors::InvalidArgument("Dimensions " , ndims, " and " , |
60 | a_shape->dim_size(0), " are not compatible" ); |
61 | } |
62 | if (a_shape->NumElements() != b->dims()) { |
63 | return errors::InvalidArgument( |
64 | "Two operands have different ranks; received: " , a_shape->NumElements(), |
65 | " and " , b->dims()); |
66 | } |
67 | const auto a_shape_flat = a_shape->flat<Index>(); |
68 | for (int i = 0; i < b->dims(); ++i) { |
69 | if (a_shape_flat(i) != b->dim_size(i)) { |
70 | return errors::InvalidArgument( |
71 | "Dimension " , i, |
72 | " does not equal (no broadcasting is supported): sparse side " , |
73 | a_shape_flat(i), " vs dense side " , b->dim_size(i)); |
74 | } |
75 | } |
76 | |
77 | // Check for invalid indices. |
78 | const auto a_indices_mat = a_indices->flat_inner_dims<Index>(); |
79 | |
80 | for (int64_t zidx = 0; zidx < nnz; ++zidx) { |
81 | for (int64_t didx = 0; didx < ndims; ++didx) { |
82 | const Index idx = a_indices_mat(zidx, didx); |
83 | if (idx < 0 || idx >= a_shape_flat(didx)) { |
84 | return errors::InvalidArgument( |
85 | "Sparse tensor has an invalid index on dimension " , didx, |
86 | ": " |
87 | "a_indices(" , |
88 | zidx, "," , didx, ") = " , idx, |
89 | ", dense tensor shape: " , a_shape_flat); |
90 | } |
91 | } |
92 | } |
93 | |
94 | return OkStatus(); |
95 | } |
96 | |
97 | } // namespace |
98 | |
99 | template <typename Device, typename T, typename Index> |
100 | class SparseTensorDenseAddOp : public OpKernel { |
101 | public: |
102 | explicit SparseTensorDenseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} |
103 | |
104 | void Compute(OpKernelContext *ctx) override { |
105 | const Tensor *a_indices_t, *a_values_t, *a_shape_t, *b; |
106 | OP_REQUIRES_OK(ctx, ctx->input("a_indices" , &a_indices_t)); |
107 | OP_REQUIRES_OK(ctx, ctx->input("a_values" , &a_values_t)); |
108 | OP_REQUIRES_OK(ctx, ctx->input("a_shape" , &a_shape_t)); |
109 | OP_REQUIRES_OK(ctx, ctx->input("b" , &b)); |
110 | OP_REQUIRES_OK( |
111 | ctx, ValidateInputs<Index>(a_indices_t, a_values_t, a_shape_t, b)); |
112 | |
113 | Tensor *out_t; |
114 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, b->shape(), &out_t)); |
115 | |
116 | const int ndims = static_cast<int>(a_indices_t->dim_size(1)); |
117 | const auto a_indices_mat = a_indices_t->flat_inner_dims<Index>(); |
118 | const auto a_values_flat = a_values_t->flat<T>(); |
119 | |
120 | switch (ndims) { |
121 | #define NDIMS_CASE(N) \ |
122 | case N: { \ |
123 | auto out_tensor = out_t->tensor<T, N>(); \ |
124 | out_tensor.device(ctx->eigen_device<Device>()) = b->tensor<T, N>(); \ |
125 | const Index result = \ |
126 | functor::ScatterNdFunctor<Device, T, Index, N, \ |
127 | scatter_op::UpdateOp::ADD>()( \ |
128 | ctx->eigen_device<Device>(), a_indices_mat, a_values_flat, \ |
129 | out_tensor); \ |
130 | OP_REQUIRES( \ |
131 | ctx, result == -1, \ |
132 | errors::InvalidArgument( \ |
133 | "Sparse tensor has some invalid index on dimension ", result, \ |
134 | "; dense tensor shape: ", b->shape().DebugString())); \ |
135 | } break; |
136 | |
137 | NDIMS_CASE(1); |
138 | NDIMS_CASE(2); |
139 | NDIMS_CASE(3); |
140 | NDIMS_CASE(4); |
141 | NDIMS_CASE(5); |
142 | default: |
143 | OP_REQUIRES( |
144 | ctx, false, |
145 | errors::InvalidArgument("Only tensors with ranks between 1 and 5 " |
146 | "are currently supported. Tensor rank: " , |
147 | ndims)); |
148 | #undef NDIMS_CASE |
149 | } |
150 | } |
151 | }; |
152 | |
153 | namespace functor { |
154 | template <typename T, typename Index, int NDIMS> |
155 | struct ScatterNdFunctor<CPUDevice, T, Index, NDIMS, scatter_op::UpdateOp::ADD> { |
156 | Index operator()(const CPUDevice &d, |
157 | typename TTypes<Index>::ConstMatrix indices, |
158 | typename TTypes<T>::ConstFlat updates, |
159 | typename TTypes<T, NDIMS>::Tensor out) { |
160 | Eigen::array<Eigen::DenseIndex, NDIMS> idx; |
161 | const int num_nnz = static_cast<int>(indices.dimension(0)); |
162 | for (int i = 0; i < num_nnz; ++i) { |
163 | for (int d = 0; d < NDIMS; ++d) { |
164 | idx[d] = internal::SubtleMustCopy(indices(i, d)); |
165 | if (!FastBoundsCheck(idx[d], out.dimension(d))) { |
166 | return d; // on failure: d nonnegative |
167 | } |
168 | } |
169 | out(idx) += updates(i); |
170 | } |
171 | return -1; // on success |
172 | } |
173 | }; |
174 | } // namespace functor |
175 | |
176 | #define REGISTER_KERNELS_CPU(TypeT, TypeIndex) \ |
177 | REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseAdd") \ |
178 | .Device(DEVICE_CPU) \ |
179 | .TypeConstraint<TypeT>("T") \ |
180 | .TypeConstraint<TypeIndex>("Tindices"), \ |
181 | SparseTensorDenseAddOp<CPUDevice, TypeT, TypeIndex>) |
182 | |
183 | #define REGISTER_KERNELS(T) \ |
184 | REGISTER_KERNELS_CPU(T, int64_t); \ |
185 | REGISTER_KERNELS_CPU(T, int32) |
186 | |
187 | TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); |
188 | #undef REGISTER_KERNELS |
189 | #undef REGISTER_KERNELS_CPU |
190 | } // namespace tensorflow |
191 | |