1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/sparse_slice_grad_op.h"
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/register_types.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/framework/tensor_util.h"
22#include "tensorflow/core/framework/types.h"
23
24namespace tensorflow {
25
26typedef Eigen::ThreadPoolDevice CPUDevice;
27
28namespace functor {
29
30template <typename T>
31struct SparseSliceGradFunctor<CPUDevice, T> {
32 void operator()(OpKernelContext *ctx,
33 typename TTypes<T>::ConstFlat backprop_val_grad,
34 typename TTypes<int64_t>::ConstMatrix input_indices_mat,
35 typename TTypes<int64_t>::ConstFlat input_start_flat,
36 typename TTypes<int64_t>::ConstMatrix output_indices_mat,
37 typename TTypes<T>::Flat val_grad) const {
38 const int64_t input_nnz = input_indices_mat.dimension(0);
39 const int num_dims = input_indices_mat.dimension(1);
40
41 T *val_grad_flat = val_grad.data();
42 const T *backprop_val_grad_flat = backprop_val_grad.data();
43 memset(val_grad_flat, 0, sizeof(T) * input_nnz);
44
45 // Fill gradients for position where indices of input and output are same.
46 int64_t j = 0;
47 for (int64_t i = 0; i < input_nnz && j < backprop_val_grad.dimension(0);
48 ++i) {
49 bool is_same = true;
50 for (int d = 0; d < num_dims; ++d) {
51 const int64_t a = input_indices_mat(i, d);
52 const int64_t b = output_indices_mat(j, d);
53 const int64_t offset = input_start_flat(d);
54 if (a != b + offset) {
55 is_same = false;
56 break;
57 }
58 }
59 if (is_same) {
60 val_grad_flat[i] = backprop_val_grad_flat[j];
61 ++j;
62 }
63 }
64 OP_REQUIRES(
65 ctx, backprop_val_grad.dimension(0) == j,
66 errors::Internal("Elements of backprop_val_grad aren't all propagated. "
67 "Num elements:",
68 backprop_val_grad.dimension(0), ", used: ", j));
69 }
70};
71
72} // namespace functor
73
74template <typename Device, typename T>
75class SparseSliceGradOp : public OpKernel {
76 public:
77 explicit SparseSliceGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}
78
79 void Compute(OpKernelContext *ctx) override {
80 const Tensor *backprop_val_grad, *input_indices, *output_indices, *input_start;
81 OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad", &backprop_val_grad));
82 OP_REQUIRES_OK(ctx, ctx->input("input_indices", &input_indices));
83 OP_REQUIRES_OK(ctx, ctx->input("input_start", &input_start));
84 OP_REQUIRES_OK(ctx, ctx->input("output_indices", &output_indices));
85
86 OP_REQUIRES(ctx,
87 TensorShapeUtils::IsMatrix(input_indices->shape()) &&
88 TensorShapeUtils::IsMatrix(output_indices->shape()),
89 errors::InvalidArgument(
90 "Input and output indices should be matrices "
91 "but received shapes: ",
92 input_indices->shape().DebugString(), " and ",
93 output_indices->shape().DebugString()));
94 OP_REQUIRES(
95 ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()),
96 errors::InvalidArgument(
97 "Input backprop_val_grad should be a vector but received shape: ",
98 backprop_val_grad->shape().DebugString()));
99 OP_REQUIRES(
100 ctx,
101 input_indices->dim_size(1) == output_indices->dim_size(1),
102 errors::InvalidArgument("The input and output should have the same "
103 "ndims: got: ", input_indices->dim_size(1), " and ",
104 output_indices->dim_size(1)));
105 OP_REQUIRES(
106 ctx, output_indices->dim_size(0) <= input_indices->dim_size(0),
107 errors::InvalidArgument("# rows of output_indices should be not greater "
108 "than of input_indices, got ",
109 output_indices->dim_size(0), " and ",
110 input_indices->dim_size(0)));
111 OP_REQUIRES(
112 ctx, backprop_val_grad->NumElements() == output_indices->dim_size(0),
113 errors::InvalidArgument("# elements of backprop_val_grad and # rows of "
114 "output_indices should match (#nnz of sum): got ",
115 backprop_val_grad->NumElements(), " and ",
116 output_indices->dim_size(0)));
117 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_start->shape()),
118 errors::InvalidArgument(
119 "The input_start should be a vector but received shape ",
120 input_start->shape().DebugString()));
121
122 const int num_dims = input_indices->dim_size(1);
123 OP_REQUIRES(ctx, num_dims == input_start->NumElements(),
124 errors::InvalidArgument(
125 "Expected input_start to be a vector of length ", num_dims,
126 " but got length ", input_start->NumElements()));
127
128 const int64_t input_nnz = input_indices->dim_size(0);
129
130 Tensor *val_grad;
131 OP_REQUIRES_OK(ctx,
132 ctx->allocate_output(0, TensorShape({input_nnz}), &val_grad));
133
134 if (input_nnz == 0) return;
135
136 functor::SparseSliceGradFunctor<Device, T>()(
137 ctx, backprop_val_grad->flat<T>(), input_indices->matrix<int64_t>(),
138 input_start->flat<int64_t>(), output_indices->matrix<int64_t>(),
139 val_grad->flat<T>());
140 }
141};
142
143#define REGISTER_KERNELS(type) \
144 REGISTER_KERNEL_BUILDER( \
145 Name("SparseSliceGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
146 SparseSliceGradOp<CPUDevice, type>)
147
148TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
149#undef REGISTER_KERNELS
150
151#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
152
153typedef Eigen::GpuDevice GPUDevice;
154
155#define REGISTER_KERNELS(type) \
156 REGISTER_KERNEL_BUILDER( \
157 Name("SparseSliceGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
158 SparseSliceGradOp<GPUDevice, type>)
159TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
160#undef REGISTER_KERNELS
161
162#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
163
164} // namespace tensorflow
165