1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_kernel.h" |
17 | #include "tensorflow/core/framework/register_types.h" |
18 | #include "tensorflow/core/framework/tensor.h" |
19 | #include "tensorflow/core/framework/tensor_util.h" |
20 | #include "tensorflow/core/framework/types.h" |
21 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | template <typename T> |
26 | class SparseAddGradOp : public OpKernel { |
27 | public: |
28 | explicit SparseAddGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} |
29 | |
30 | void Compute(OpKernelContext *ctx) override { |
31 | // Gradient for op: SparseAdd(a, b) == sum. |
32 | const Tensor *backprop_val_grad, *a_indices, *b_indices, *sum_indices; |
33 | OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad" , &backprop_val_grad)); |
34 | OP_REQUIRES_OK(ctx, ctx->input("a_indices" , &a_indices)); |
35 | OP_REQUIRES_OK(ctx, ctx->input("b_indices" , &b_indices)); |
36 | OP_REQUIRES_OK(ctx, ctx->input("sum_indices" , &sum_indices)); |
37 | |
38 | OP_REQUIRES(ctx, |
39 | TensorShapeUtils::IsMatrix(a_indices->shape()) && |
40 | TensorShapeUtils::IsMatrix(b_indices->shape()) && |
41 | TensorShapeUtils::IsMatrix(sum_indices->shape()), |
42 | errors::InvalidArgument( |
43 | "Input indices should be matrices but received shapes: " , |
44 | a_indices->shape().DebugString(), " and " , |
45 | b_indices->shape().DebugString(), " and " , |
46 | sum_indices->shape().DebugString())); |
47 | OP_REQUIRES( |
48 | ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()), |
49 | errors::InvalidArgument( |
50 | "Input backprop_val_grad should be a vector but received shape: " , |
51 | backprop_val_grad->shape().DebugString())); |
52 | OP_REQUIRES( |
53 | ctx, |
54 | a_indices->dim_size(1) == b_indices->dim_size(1) && |
55 | b_indices->dim_size(1) == sum_indices->dim_size(1), |
56 | errors::InvalidArgument("The densified operands should have the same " |
57 | "ndims; for A, B, sum got: " , |
58 | a_indices->dim_size(1), b_indices->dim_size(1), |
59 | sum_indices->dim_size(1))); |
60 | OP_REQUIRES( |
61 | ctx, backprop_val_grad->NumElements() == sum_indices->dim_size(0), |
62 | errors::InvalidArgument("# elements of backprop_val_grad and # rows of " |
63 | "sum_indices should match (#nnz of sum): got " , |
64 | backprop_val_grad->NumElements(), " and " , |
65 | sum_indices->dim_size(0))); |
66 | |
67 | const int num_dims = a_indices->dim_size(1); |
68 | const int64_t a_nnz = a_indices->dim_size(0); |
69 | const int64_t b_nnz = b_indices->dim_size(0); |
70 | const int64_t sum_nnz = backprop_val_grad->NumElements(); |
71 | |
72 | const auto a_indices_mat = a_indices->matrix<int64_t>(); |
73 | const auto b_indices_mat = b_indices->matrix<int64_t>(); |
74 | const auto sum_indices_mat = sum_indices->matrix<int64_t>(); |
75 | |
76 | Tensor *a_val_grad, *b_val_grad; |
77 | OP_REQUIRES_OK(ctx, |
78 | ctx->allocate_output(0, TensorShape({a_nnz}), &a_val_grad)); |
79 | OP_REQUIRES_OK(ctx, |
80 | ctx->allocate_output(1, TensorShape({b_nnz}), &b_val_grad)); |
81 | |
82 | T *a_val_grad_flat = a_val_grad->flat<T>().data(); |
83 | T *b_val_grad_flat = b_val_grad->flat<T>().data(); |
84 | const T *backprop_val_grad_flat = backprop_val_grad->flat<T>().data(); |
85 | memset(a_val_grad_flat, 0, sizeof(T) * a_nnz); |
86 | memset(b_val_grad_flat, 0, sizeof(T) * b_nnz); |
87 | |
88 | #define COMPARE(a_or_b, idx) \ |
89 | switch (sparse::DimComparator::cmp(a_or_b##_indices_mat, sum_indices_mat, \ |
90 | idx, k, num_dims)) { \ |
91 | case 0: \ |
92 | a_or_b##_val_grad_flat[idx] = backprop_val_grad_flat[k]; \ |
93 | ++idx; \ |
94 | break; \ |
95 | case -1: \ |
96 | ++idx; \ |
97 | a_or_b##_idx_geq = false; \ |
98 | break; \ |
99 | case 1: \ |
100 | break; \ |
101 | } |
102 | |
103 | // Set-intersect the indices; fill in grads for positions in the |
104 | // intersection. |
105 | int64_t i = 0, j = 0, k = 0; |
106 | bool a_idx_geq, b_idx_geq; |
107 | while (i < a_nnz && j < b_nnz && k < sum_nnz) { |
108 | a_idx_geq = b_idx_geq = true; |
109 | COMPARE(a, i); |
110 | COMPARE(b, j); |
111 | // increment pointer into sum_indices iff both the current A, B indices >= |
112 | // the current sum index. |
113 | if (a_idx_geq && b_idx_geq) ++k; |
114 | } |
115 | |
116 | // at most one loop below will run |
117 | while (i < a_nnz && k < sum_nnz) { |
118 | a_idx_geq = true; |
119 | COMPARE(a, i); |
120 | if (a_idx_geq) ++k; |
121 | } |
122 | while (j < b_nnz && k < sum_nnz) { |
123 | b_idx_geq = true; |
124 | COMPARE(b, j); |
125 | if (b_idx_geq) ++k; |
126 | } |
127 | #undef COMPARE |
128 | } |
129 | }; |
130 | |
131 | #define REGISTER_KERNELS(type) \ |
132 | REGISTER_KERNEL_BUILDER( \ |
133 | Name("SparseAddGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
134 | SparseAddGradOp<type>) |
135 | |
136 | // This op should work for any T that SparseAdd is registered with. |
137 | REGISTER_KERNELS(float); |
138 | REGISTER_KERNELS(double); |
139 | REGISTER_KERNELS(int64_t); |
140 | REGISTER_KERNELS(int32); |
141 | REGISTER_KERNELS(int16); |
142 | REGISTER_KERNELS(int8); |
143 | REGISTER_KERNELS(complex64); |
144 | REGISTER_KERNELS(complex128); |
145 | #undef REGISTER_KERNELS |
146 | } // namespace tensorflow |
147 | |