1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/sparse_tensor_dense_add_op.h"
19
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/op_requires.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_util.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/util/sparse/sparse_tensor.h"
27
28namespace tensorflow {
29
30typedef Eigen::ThreadPoolDevice CPUDevice;
31// NOTE: does not support GPU yet.
32
33namespace {
34
35template <typename Index>
36Status ValidateInputs(const Tensor *a_indices, const Tensor *a_values,
37 const Tensor *a_shape, const Tensor *b) {
38 if (!TensorShapeUtils::IsMatrix(a_indices->shape())) {
39 return errors::InvalidArgument(
40 "Input a_indices should be a matrix but received shape: ",
41 a_indices->shape().DebugString());
42 }
43 if (!TensorShapeUtils::IsVector(a_values->shape()) ||
44 !TensorShapeUtils::IsVector(a_shape->shape())) {
45 return errors::InvalidArgument(
46 "Inputs a_values and a_shape should be vectors "
47 "but received shapes: ",
48 a_values->shape().DebugString(), " and ",
49 a_shape->shape().DebugString());
50 }
51 int64_t nnz = a_indices->dim_size(0);
52 int64_t ndims = a_indices->dim_size(1);
53 if (a_values->dim_size(0) != nnz) {
54 return errors::InvalidArgument("Dimensions ", nnz, " and ",
55 a_values->dim_size(0),
56 " are not compatible");
57 }
58 if (a_shape->dim_size(0) != ndims) {
59 return errors::InvalidArgument("Dimensions ", ndims, " and ",
60 a_shape->dim_size(0), " are not compatible");
61 }
62 if (a_shape->NumElements() != b->dims()) {
63 return errors::InvalidArgument(
64 "Two operands have different ranks; received: ", a_shape->NumElements(),
65 " and ", b->dims());
66 }
67 const auto a_shape_flat = a_shape->flat<Index>();
68 for (int i = 0; i < b->dims(); ++i) {
69 if (a_shape_flat(i) != b->dim_size(i)) {
70 return errors::InvalidArgument(
71 "Dimension ", i,
72 " does not equal (no broadcasting is supported): sparse side ",
73 a_shape_flat(i), " vs dense side ", b->dim_size(i));
74 }
75 }
76
77 // Check for invalid indices.
78 const auto a_indices_mat = a_indices->flat_inner_dims<Index>();
79
80 for (int64_t zidx = 0; zidx < nnz; ++zidx) {
81 for (int64_t didx = 0; didx < ndims; ++didx) {
82 const Index idx = a_indices_mat(zidx, didx);
83 if (idx < 0 || idx >= a_shape_flat(didx)) {
84 return errors::InvalidArgument(
85 "Sparse tensor has an invalid index on dimension ", didx,
86 ": "
87 "a_indices(",
88 zidx, ",", didx, ") = ", idx,
89 ", dense tensor shape: ", a_shape_flat);
90 }
91 }
92 }
93
94 return OkStatus();
95}
96
97} // namespace
98
99template <typename Device, typename T, typename Index>
100class SparseTensorDenseAddOp : public OpKernel {
101 public:
102 explicit SparseTensorDenseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}
103
104 void Compute(OpKernelContext *ctx) override {
105 const Tensor *a_indices_t, *a_values_t, *a_shape_t, *b;
106 OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices_t));
107 OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
108 OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape_t));
109 OP_REQUIRES_OK(ctx, ctx->input("b", &b));
110 OP_REQUIRES_OK(
111 ctx, ValidateInputs<Index>(a_indices_t, a_values_t, a_shape_t, b));
112
113 Tensor *out_t;
114 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, b->shape(), &out_t));
115
116 const int ndims = static_cast<int>(a_indices_t->dim_size(1));
117 const auto a_indices_mat = a_indices_t->flat_inner_dims<Index>();
118 const auto a_values_flat = a_values_t->flat<T>();
119
120 switch (ndims) {
121#define NDIMS_CASE(N) \
122 case N: { \
123 auto out_tensor = out_t->tensor<T, N>(); \
124 out_tensor.device(ctx->eigen_device<Device>()) = b->tensor<T, N>(); \
125 const Index result = \
126 functor::ScatterNdFunctor<Device, T, Index, N, \
127 scatter_op::UpdateOp::ADD>()( \
128 ctx->eigen_device<Device>(), a_indices_mat, a_values_flat, \
129 out_tensor); \
130 OP_REQUIRES( \
131 ctx, result == -1, \
132 errors::InvalidArgument( \
133 "Sparse tensor has some invalid index on dimension ", result, \
134 "; dense tensor shape: ", b->shape().DebugString())); \
135 } break;
136
137 NDIMS_CASE(1);
138 NDIMS_CASE(2);
139 NDIMS_CASE(3);
140 NDIMS_CASE(4);
141 NDIMS_CASE(5);
142 default:
143 OP_REQUIRES(
144 ctx, false,
145 errors::InvalidArgument("Only tensors with ranks between 1 and 5 "
146 "are currently supported. Tensor rank: ",
147 ndims));
148#undef NDIMS_CASE
149 }
150 }
151};
152
153namespace functor {
154template <typename T, typename Index, int NDIMS>
155struct ScatterNdFunctor<CPUDevice, T, Index, NDIMS, scatter_op::UpdateOp::ADD> {
156 Index operator()(const CPUDevice &d,
157 typename TTypes<Index>::ConstMatrix indices,
158 typename TTypes<T>::ConstFlat updates,
159 typename TTypes<T, NDIMS>::Tensor out) {
160 Eigen::array<Eigen::DenseIndex, NDIMS> idx;
161 const int num_nnz = static_cast<int>(indices.dimension(0));
162 for (int i = 0; i < num_nnz; ++i) {
163 for (int d = 0; d < NDIMS; ++d) {
164 idx[d] = internal::SubtleMustCopy(indices(i, d));
165 if (!FastBoundsCheck(idx[d], out.dimension(d))) {
166 return d; // on failure: d nonnegative
167 }
168 }
169 out(idx) += updates(i);
170 }
171 return -1; // on success
172 }
173};
174} // namespace functor
175
176#define REGISTER_KERNELS_CPU(TypeT, TypeIndex) \
177 REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseAdd") \
178 .Device(DEVICE_CPU) \
179 .TypeConstraint<TypeT>("T") \
180 .TypeConstraint<TypeIndex>("Tindices"), \
181 SparseTensorDenseAddOp<CPUDevice, TypeT, TypeIndex>)
182
183#define REGISTER_KERNELS(T) \
184 REGISTER_KERNELS_CPU(T, int64_t); \
185 REGISTER_KERNELS_CPU(T, int32)
186
187TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
188#undef REGISTER_KERNELS
189#undef REGISTER_KERNELS_CPU
190} // namespace tensorflow
191