1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/sparse_slice_grad_op.h" |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/register_types.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/tensor_util.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | typedef Eigen::ThreadPoolDevice CPUDevice; |
27 | |
28 | namespace functor { |
29 | |
30 | template <typename T> |
31 | struct SparseSliceGradFunctor<CPUDevice, T> { |
32 | void operator()(OpKernelContext *ctx, |
33 | typename TTypes<T>::ConstFlat backprop_val_grad, |
34 | typename TTypes<int64_t>::ConstMatrix input_indices_mat, |
35 | typename TTypes<int64_t>::ConstFlat input_start_flat, |
36 | typename TTypes<int64_t>::ConstMatrix output_indices_mat, |
37 | typename TTypes<T>::Flat val_grad) const { |
38 | const int64_t input_nnz = input_indices_mat.dimension(0); |
39 | const int num_dims = input_indices_mat.dimension(1); |
40 | |
41 | T *val_grad_flat = val_grad.data(); |
42 | const T *backprop_val_grad_flat = backprop_val_grad.data(); |
43 | memset(val_grad_flat, 0, sizeof(T) * input_nnz); |
44 | |
45 | // Fill gradients for position where indices of input and output are same. |
46 | int64_t j = 0; |
47 | for (int64_t i = 0; i < input_nnz && j < backprop_val_grad.dimension(0); |
48 | ++i) { |
49 | bool is_same = true; |
50 | for (int d = 0; d < num_dims; ++d) { |
51 | const int64_t a = input_indices_mat(i, d); |
52 | const int64_t b = output_indices_mat(j, d); |
53 | const int64_t offset = input_start_flat(d); |
54 | if (a != b + offset) { |
55 | is_same = false; |
56 | break; |
57 | } |
58 | } |
59 | if (is_same) { |
60 | val_grad_flat[i] = backprop_val_grad_flat[j]; |
61 | ++j; |
62 | } |
63 | } |
64 | OP_REQUIRES( |
65 | ctx, backprop_val_grad.dimension(0) == j, |
66 | errors::Internal("Elements of backprop_val_grad aren't all propagated. " |
67 | "Num elements:" , |
68 | backprop_val_grad.dimension(0), ", used: " , j)); |
69 | } |
70 | }; |
71 | |
72 | } // namespace functor |
73 | |
74 | template <typename Device, typename T> |
75 | class SparseSliceGradOp : public OpKernel { |
76 | public: |
77 | explicit SparseSliceGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} |
78 | |
79 | void Compute(OpKernelContext *ctx) override { |
80 | const Tensor *backprop_val_grad, *input_indices, *output_indices, *input_start; |
81 | OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad" , &backprop_val_grad)); |
82 | OP_REQUIRES_OK(ctx, ctx->input("input_indices" , &input_indices)); |
83 | OP_REQUIRES_OK(ctx, ctx->input("input_start" , &input_start)); |
84 | OP_REQUIRES_OK(ctx, ctx->input("output_indices" , &output_indices)); |
85 | |
86 | OP_REQUIRES(ctx, |
87 | TensorShapeUtils::IsMatrix(input_indices->shape()) && |
88 | TensorShapeUtils::IsMatrix(output_indices->shape()), |
89 | errors::InvalidArgument( |
90 | "Input and output indices should be matrices " |
91 | "but received shapes: " , |
92 | input_indices->shape().DebugString(), " and " , |
93 | output_indices->shape().DebugString())); |
94 | OP_REQUIRES( |
95 | ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()), |
96 | errors::InvalidArgument( |
97 | "Input backprop_val_grad should be a vector but received shape: " , |
98 | backprop_val_grad->shape().DebugString())); |
99 | OP_REQUIRES( |
100 | ctx, |
101 | input_indices->dim_size(1) == output_indices->dim_size(1), |
102 | errors::InvalidArgument("The input and output should have the same " |
103 | "ndims: got: " , input_indices->dim_size(1), " and " , |
104 | output_indices->dim_size(1))); |
105 | OP_REQUIRES( |
106 | ctx, output_indices->dim_size(0) <= input_indices->dim_size(0), |
107 | errors::InvalidArgument("# rows of output_indices should be not greater " |
108 | "than of input_indices, got " , |
109 | output_indices->dim_size(0), " and " , |
110 | input_indices->dim_size(0))); |
111 | OP_REQUIRES( |
112 | ctx, backprop_val_grad->NumElements() == output_indices->dim_size(0), |
113 | errors::InvalidArgument("# elements of backprop_val_grad and # rows of " |
114 | "output_indices should match (#nnz of sum): got " , |
115 | backprop_val_grad->NumElements(), " and " , |
116 | output_indices->dim_size(0))); |
117 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_start->shape()), |
118 | errors::InvalidArgument( |
119 | "The input_start should be a vector but received shape " , |
120 | input_start->shape().DebugString())); |
121 | |
122 | const int num_dims = input_indices->dim_size(1); |
123 | OP_REQUIRES(ctx, num_dims == input_start->NumElements(), |
124 | errors::InvalidArgument( |
125 | "Expected input_start to be a vector of length " , num_dims, |
126 | " but got length " , input_start->NumElements())); |
127 | |
128 | const int64_t input_nnz = input_indices->dim_size(0); |
129 | |
130 | Tensor *val_grad; |
131 | OP_REQUIRES_OK(ctx, |
132 | ctx->allocate_output(0, TensorShape({input_nnz}), &val_grad)); |
133 | |
134 | if (input_nnz == 0) return; |
135 | |
136 | functor::SparseSliceGradFunctor<Device, T>()( |
137 | ctx, backprop_val_grad->flat<T>(), input_indices->matrix<int64_t>(), |
138 | input_start->flat<int64_t>(), output_indices->matrix<int64_t>(), |
139 | val_grad->flat<T>()); |
140 | } |
141 | }; |
142 | |
143 | #define REGISTER_KERNELS(type) \ |
144 | REGISTER_KERNEL_BUILDER( \ |
145 | Name("SparseSliceGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
146 | SparseSliceGradOp<CPUDevice, type>) |
147 | |
148 | TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); |
149 | #undef REGISTER_KERNELS |
150 | |
151 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
152 | |
153 | typedef Eigen::GpuDevice GPUDevice; |
154 | |
155 | #define REGISTER_KERNELS(type) \ |
156 | REGISTER_KERNEL_BUILDER( \ |
157 | Name("SparseSliceGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
158 | SparseSliceGradOp<GPUDevice, type>) |
159 | TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); |
160 | #undef REGISTER_KERNELS |
161 | |
162 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
163 | |
164 | } // namespace tensorflow |
165 | |