1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/sparse_fill_empty_rows_op.h"
19
20#include <algorithm>
21#include <numeric>
22#include <unordered_map>
23#include <utility>
24#include <vector>
25
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/op_requires.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor_util.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/lib/gtl/inlined_vector.h"
33#include "tensorflow/core/platform/errors.h"
34#include "tensorflow/core/util/sparse/sparse_tensor.h"
35
36namespace tensorflow {
37
38using CPUDevice = Eigen::ThreadPoolDevice;
39using GPUDevice = Eigen::GpuDevice;
40
41namespace functor {
42
43template <typename T, typename Tindex>
44struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
45 Status operator()(OpKernelContext* context, const Tensor& default_value_t,
46 const Tensor& indices_t, const Tensor& values_t,
47 const Tensor& dense_shape_t,
48 typename AsyncOpKernel::DoneCallback done) {
49 (void)done; // Unused (only used in GPU implementation)
50 const int kOutputIndicesOutput = 0;
51 const int kOutputValuesOutput = 1;
52 const int kEmptyRowIndicatorOutput = 2;
53 const int kReverseIndexMapOutput = 3;
54
55 const T& default_value = default_value_t.scalar<T>()();
56 const auto indices = indices_t.matrix<Tindex>();
57 const auto values = values_t.vec<T>();
58 const auto dense_shape = dense_shape_t.vec<Tindex>();
59
60 const Tindex N = indices_t.shape().dim_size(0);
61 const Tindex dense_rows = dense_shape(0);
62
63 bool* empty_row_indicator = nullptr;
64 if (context->output_required(kEmptyRowIndicatorOutput)) {
65 Tensor* empty_row_indicator_t = nullptr;
66 TensorShape output_shape;
67 TF_RETURN_IF_ERROR(
68 TensorShape::BuildTensorShape({dense_rows}, &output_shape));
69 TF_RETURN_IF_ERROR(context->allocate_output(
70 kEmptyRowIndicatorOutput, output_shape, &empty_row_indicator_t));
71 empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
72 }
73 Tindex* reverse_index_map = nullptr;
74 if (context->output_required(kReverseIndexMapOutput)) {
75 Tensor* reverse_index_map_t = nullptr;
76 TensorShape output_shape;
77 TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape({N}, &output_shape));
78 TF_RETURN_IF_ERROR(context->allocate_output(
79 kReverseIndexMapOutput, output_shape, &reverse_index_map_t));
80 reverse_index_map = reverse_index_map_t->vec<Tindex>().data();
81 }
82
83 int rank = indices_t.shape().dim_size(1);
84
85 if (dense_rows == 0) {
86 if (N != 0) {
87 return errors::InvalidArgument(
88 "Received SparseTensor with dense_shape[0] = 0 but "
89 "indices.shape[0] = ",
90 N);
91 }
92 Tensor* output_indices_t;
93 TensorShape output_indices_shape;
94 TF_RETURN_IF_ERROR(
95 TensorShape::BuildTensorShape({0, rank}, &output_indices_shape));
96 TF_RETURN_IF_ERROR(context->allocate_output(
97 kOutputIndicesOutput, output_indices_shape, &output_indices_t));
98 Tensor* output_values_t;
99 TF_RETURN_IF_ERROR(context->allocate_output(
100 kOutputValuesOutput, TensorShape({0}), &output_values_t));
101
102 // Exit early, nothing more to do.
103 return OkStatus();
104 }
105
106 bool rows_are_ordered = true;
107 Tindex last_indices_row = 0;
108 std::vector<Tindex> csr_offset(dense_rows, 0);
109 for (int i = 0; i < N; ++i) {
110 const Tindex row = indices(i, 0);
111 if (row < 0 || row >= dense_rows) {
112 return errors::InvalidArgument("indices(", i, ", 0) is invalid: ", row,
113 " >= ", dense_rows);
114 }
115 ++csr_offset[row];
116 rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
117 last_indices_row = row;
118 }
119 bool all_rows_full = true;
120 for (int row = 0; row < dense_rows; ++row) {
121 // csr_offset here describes the number of elements in this dense row
122 bool row_empty = (csr_offset[row] == 0);
123 if (empty_row_indicator) {
124 empty_row_indicator[row] = row_empty;
125 }
126 all_rows_full = all_rows_full & !row_empty;
127 // In filled version, each row has at least one element.
128 csr_offset[row] = std::max(csr_offset[row], Tindex{1});
129 // Update csr_offset to represent the number of elements up to and
130 // including dense_row + 1:
131 // csr_offset(0) == #{elements of row 0}
132 // csr_offset(1) == #{elements of row 1} + #{elements of row 0}
133 // ..
134 // csr_offset(i) == starting index for elements in row i + 1.
135 if (row > 0) {
136 csr_offset[row] += csr_offset[row - 1];
137 }
138 }
139
140 if (all_rows_full && rows_are_ordered) {
141 context->set_output(kOutputIndicesOutput, indices_t);
142 context->set_output(kOutputValuesOutput, values_t);
143 if (reverse_index_map) {
144 for (Tindex i = 0; i < N; ++i) {
145 reverse_index_map[i] = i;
146 }
147 }
148 } else {
149 Tensor* output_indices_t;
150 const Tindex N_full = csr_offset[dense_rows - 1];
151 TensorShape output_indices_shape;
152 TF_RETURN_IF_ERROR(
153 TensorShape::BuildTensorShape({N_full, rank}, &output_indices_shape));
154 TF_RETURN_IF_ERROR(context->allocate_output(
155 kOutputIndicesOutput, output_indices_shape, &output_indices_t));
156 auto output_indices = output_indices_t->matrix<Tindex>();
157
158 Tensor* output_values_t;
159 TF_RETURN_IF_ERROR(context->allocate_output(
160 kOutputValuesOutput, TensorShape({N_full}), &output_values_t));
161 auto output_values = output_values_t->vec<T>();
162
163 std::vector<Tindex> filled_count(dense_rows, 0);
164
165 // Fill in values for rows that are not missing
166 for (Tindex i = 0; i < N; ++i) {
167 const Tindex row = indices(i, 0);
168 Tindex& offset = filled_count[row];
169 const Tindex output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset;
170 offset++; // Increment the filled count for this row.
171 std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
172 output_values(output_i) = values(i);
173 // We'll need this reverse index map to backprop correctly.
174 if (reverse_index_map) {
175 reverse_index_map[i] = output_i;
176 }
177 }
178
179 // Fill in values for rows that are missing
180 for (Tindex row = 0; row < dense_rows; ++row) {
181 const Tindex row_count = filled_count[row];
182 if (row_count == 0) { // We haven't filled this row
183 const Tindex starting_index = (row == 0) ? 0 : csr_offset[row - 1];
184 // Remaining index values were set to zero already.
185 // Just need to set the row index in the right location.
186 output_indices(starting_index, 0) = row;
187 for (Tindex col = 1; col < rank; ++col) {
188 output_indices(starting_index, col) = 0;
189 }
190 output_values(starting_index) = default_value;
191 }
192 }
193 }
194
195 return OkStatus();
196 }
197};
198
199} // namespace functor
200
201namespace {
202
203template <typename Device, typename T, typename Tindex>
204void SparseFillEmptyRowsOpImpl(OpKernelContext* context,
205 AsyncOpKernel::DoneCallback done = nullptr) {
206 // Note that setting this empty lambda as the default parameter value directly
207 // can cause strange compiler/linker errors, so we do it like this instead.
208 if (!done) {
209 done = [] {};
210 }
211
212 const int kIndicesInput = 0;
213 const int kValuesInput = 1;
214 const int kDenseShapeInput = 2;
215 const int kDefaultValueInput = 3;
216
217 const Tensor& indices_t = context->input(kIndicesInput);
218 const Tensor& values_t = context->input(kValuesInput);
219 const Tensor& dense_shape_t = context->input(kDenseShapeInput);
220 const Tensor& default_value_t = context->input(kDefaultValueInput);
221
222 OP_REQUIRES_ASYNC(
223 context, TensorShapeUtils::IsVector(dense_shape_t.shape()),
224 errors::InvalidArgument("dense_shape must be a vector, saw: ",
225 dense_shape_t.shape().DebugString()),
226 done);
227 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(indices_t.shape()),
228 errors::InvalidArgument("indices must be a matrix, saw: ",
229 indices_t.shape().DebugString()),
230 done);
231 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(values_t.shape()),
232 errors::InvalidArgument("values must be a vector, saw: ",
233 values_t.shape().DebugString()),
234 done);
235 OP_REQUIRES_ASYNC(
236 context, indices_t.dim_size(0) == values_t.dim_size(0),
237 errors::InvalidArgument("The length of `values` (", values_t.dim_size(0),
238 ") must match the first dimension of `indices` (",
239 indices_t.dim_size(0), ")."),
240 done);
241 OP_REQUIRES_ASYNC(
242 context, TensorShapeUtils::IsScalar(default_value_t.shape()),
243 errors::InvalidArgument("default_value must be a scalar, saw: ",
244 default_value_t.shape().DebugString()),
245 done);
246 // TODO(ebrevdo): add shape checks between values, indices,
247 // Also add check that dense rank > 0.
248 OP_REQUIRES_ASYNC(context, dense_shape_t.NumElements() != 0,
249 errors::InvalidArgument("Dense shape cannot be empty."),
250 done);
251
252 using FunctorType = functor::SparseFillEmptyRows<Device, T, Tindex>;
253 OP_REQUIRES_OK_ASYNC(context,
254 FunctorType()(context, default_value_t, indices_t,
255 values_t, dense_shape_t, done),
256 done);
257}
258
259} // namespace
260
261template <typename Device, typename T, typename Tindex>
262class SparseFillEmptyRowsOp : public OpKernel {
263 public:
264 explicit SparseFillEmptyRowsOp(OpKernelConstruction* context)
265 : OpKernel(context) {}
266
267 void Compute(OpKernelContext* context) override {
268 SparseFillEmptyRowsOpImpl<Device, T, Tindex>(context);
269 }
270};
271
272#define REGISTER_KERNELS(D, T, Tindex) \
273 REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows") \
274 .Device(DEVICE_##D) \
275 .HostMemory("dense_shape") \
276 .TypeConstraint<T>("T"), \
277 SparseFillEmptyRowsOp<D##Device, T, Tindex>)
278
279#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
280TF_CALL_ALL_TYPES(REGISTER_CPU_KERNELS);
281#undef REGISTER_CPU_KERNELS
282
283#undef REGISTER_KERNELS
284
285#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
286
287// The GPU implementation is async because it requires waiting for a
288// host->device memcpy before the output is allocated (similar to
289// SegmentSumGPUOp).
290template <typename T, typename Tindex>
291class SparseFillEmptyRowsGPUOp : public AsyncOpKernel {
292 public:
293 explicit SparseFillEmptyRowsGPUOp(OpKernelConstruction* context)
294 : AsyncOpKernel(context) {}
295
296 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
297 SparseFillEmptyRowsOpImpl<GPUDevice, T, Tindex>(context, done);
298 }
299};
300
301#define REGISTER_KERNELS(T, Tindex) \
302 REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows") \
303 .Device(DEVICE_GPU) \
304 .HostMemory("dense_shape") \
305 .TypeConstraint<T>("T"), \
306 SparseFillEmptyRowsGPUOp<T, Tindex>)
307
308// Forward declarations of the functor specializations for GPU.
309namespace functor {
310#define DECLARE_GPU_SPEC(T, Tindex) \
311 template <> \
312 Status SparseFillEmptyRows<GPUDevice, T, Tindex>::operator()( \
313 OpKernelContext* context, const Tensor& default_value_t, \
314 const Tensor& indices_t, const Tensor& values_t, \
315 const Tensor& dense_shape_t, typename AsyncOpKernel::DoneCallback done); \
316 extern template struct SparseFillEmptyRows<GPUDevice, T, Tindex>;
317#define DECLARE_GPU_SPEC_INT64(T) DECLARE_GPU_SPEC(T, int64_t)
318TF_CALL_POD_TYPES(DECLARE_GPU_SPEC_INT64)
319#undef DECLARE_GPU_SPEC_INT64
320#undef DECLARE_GPU_SPEC
321} // namespace functor
322
323#define REGISTER_KERNELS_TINDEX(T) REGISTER_KERNELS(T, int64)
324TF_CALL_POD_TYPES(REGISTER_KERNELS_TINDEX)
325#undef REGISTER_KERNELS_TINDEX
326
327#undef REGISTER_KERNELS
328
329#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
330
331namespace functor {
332
333template <typename T, typename Tindex>
334struct SparseFillEmptyRowsGrad<CPUDevice, T, Tindex> {
335 Status operator()(OpKernelContext* context,
336 typename TTypes<Tindex>::ConstVec reverse_index_map,
337 typename TTypes<T>::ConstVec grad_values,
338 typename TTypes<T>::Vec d_values,
339 typename TTypes<T>::Scalar d_default_value) {
340 const CPUDevice& device = context->eigen_device<CPUDevice>();
341 const Tindex N = reverse_index_map.dimension(0);
342 const Tindex N_full = grad_values.dimension(0);
343
344 T& d_default_value_scalar = d_default_value();
345 d_default_value_scalar = T();
346
347 Tensor visited_t;
348 TF_RETURN_IF_ERROR(
349 context->allocate_temp(DT_BOOL, TensorShape({N_full}), &visited_t));
350 auto visited = visited_t.vec<bool>();
351 visited.device(device) = visited.constant(false);
352
353 for (int i = 0; i < N; ++i) {
354 // Locate the index of the output of the forward prop associated
355 // with this location in the input of the forward prop. Copy
356 // the gradient into it. Mark it as visited.
357 int64_t reverse_index = reverse_index_map(i);
358 if (reverse_index < 0 || reverse_index >= N_full) {
359 return errors::InvalidArgument(
360 "Elements in reverse index must be in [0, ", N_full, ") but got ",
361 reverse_index);
362 }
363 d_values(i) = grad_values(reverse_index);
364 visited(reverse_index) = true;
365 }
366 for (int j = 0; j < N_full; ++j) {
367 // The default value gradient gets the accumulated remainder of
368 // the backprop values (since the default value was used to fill
369 // in these slots in the forward calculation).
370 if (!visited(j)) {
371 d_default_value_scalar += grad_values(j);
372 }
373 }
374 return OkStatus();
375 }
376};
377
378} // namespace functor
379
380template <typename Device, typename T, typename Tindex>
381class SparseFillEmptyRowsGradOp : public OpKernel {
382 public:
383 explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context)
384 : OpKernel(context) {}
385
386 void Compute(OpKernelContext* context) override {
387 const Tensor* reverse_index_map_t;
388 const Tensor* grad_values_t;
389 OP_REQUIRES_OK(context,
390 context->input("reverse_index_map", &reverse_index_map_t));
391 OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t));
392
393 OP_REQUIRES(
394 context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
395 errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
396 reverse_index_map_t->shape().DebugString()));
397 OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()),
398 errors::InvalidArgument("grad_values must be a vector, saw: ",
399 grad_values_t->shape().DebugString()));
400
401 const auto reverse_index_map = reverse_index_map_t->vec<Tindex>();
402 const auto grad_values = grad_values_t->vec<T>();
403
404 const Tindex N = reverse_index_map_t->shape().dim_size(0);
405
406 Tensor* d_values_t;
407 OP_REQUIRES_OK(context, context->allocate_output(
408 "d_values", TensorShape({N}), &d_values_t));
409 auto d_values = d_values_t->vec<T>();
410 Tensor* d_default_value_t;
411 OP_REQUIRES_OK(context,
412 context->allocate_output("d_default_value", TensorShape({}),
413 &d_default_value_t));
414 auto d_default_value = d_default_value_t->scalar<T>();
415
416 OP_REQUIRES_OK(context,
417 functor::SparseFillEmptyRowsGrad<Device, T, Tindex>()(
418 context, reverse_index_map, grad_values, d_values,
419 d_default_value));
420 }
421};
422
423#define REGISTER_KERNELS(D, T, Tindex) \
424 REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \
425 .Device(DEVICE_##D) \
426 .TypeConstraint<T>("T"), \
427 SparseFillEmptyRowsGradOp<D##Device, T, Tindex>)
428
429#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
430TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
431#undef REGISTER_CPU_KERNELS
432
433#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
434
435// Forward declarations of the functor specializations for GPU.
436namespace functor {
437#define DECLARE_GPU_SPEC(T, Tindex) \
438 template <> \
439 Status SparseFillEmptyRowsGrad<GPUDevice, T, Tindex>::operator()( \
440 OpKernelContext* context, \
441 typename TTypes<Tindex>::ConstVec reverse_index_map, \
442 typename TTypes<T>::ConstVec grad_values, \
443 typename TTypes<T>::Vec d_values, \
444 typename TTypes<T>::Scalar d_default_value); \
445 extern template struct SparseFillEmptyRowsGrad<GPUDevice, T, Tindex>;
446#define DECLARE_GPU_SPEC_INT64(T) DECLARE_GPU_SPEC(T, int64_t)
447TF_CALL_REAL_NUMBER_TYPES(DECLARE_GPU_SPEC_INT64);
448#undef DECLARE_GPU_SPEC_INT64
449#undef DECLARE_GPU_SPEC
450} // namespace functor
451
452#define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T, int64)
453TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_KERNELS);
454#undef REGISTER_GPU_KERNELS
455
456#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
457
458#undef REGISTER_KERNELS
459} // namespace tensorflow
460