1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See core/ops/sparse_ops.cc for documentation.
17//
18// NOTE: the operations in this file only are suitable for execution
19// on CPUs.
20
21#define EIGEN_USE_THREADS
22
23#include <numeric>
24#include <sstream>
25#include <string>
26#include <unordered_map>
27#include <utility>
28
29#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
30#include "tensorflow/core/framework/op_kernel.h"
31#include "tensorflow/core/framework/register_types.h"
32#include "tensorflow/core/framework/tensor.h"
33#include "tensorflow/core/framework/types.h"
34#include "tensorflow/core/lib/core/status.h"
35#include "tensorflow/core/lib/gtl/inlined_vector.h"
36#include "tensorflow/core/lib/strings/stringprintf.h"
37#include "tensorflow/core/util/ptr_util.h"
38#include "tensorflow/core/util/sparse/sparse_tensor.h"
39
40#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
41#include "tensorflow/core/kernels/gpu_utils.h"
42#include "tensorflow/core/kernels/sparse_to_dense_op_gpu.h"
43#include "tensorflow/core/platform/stream_executor.h"
44#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
45
46namespace tensorflow {
47
48namespace {
49
50Status CheckSparseToDenseShapes(const Tensor& indices,
51 const Tensor& output_shape,
52 const Tensor& sparse_values,
53 const Tensor& default_value) {
54 // sparse_indices
55 if (indices.dims() > 2) {
56 return errors::InvalidArgument(
57 "sparse_indices should be a scalar, vector, or matrix, "
58 "got shape ",
59 indices.shape().DebugString());
60 }
61 const int64_t num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1;
62 const int64_t num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1;
63
64 // output_shape
65 if (!TensorShapeUtils::IsVector(output_shape.shape())) {
66 return errors::InvalidArgument("output_shape must be rank 1, got shape ",
67 output_shape.shape().DebugString());
68 }
69
70 if (output_shape.NumElements() != num_dims) {
71 return errors::InvalidArgument(
72 "output_shape has incorrect number of elements: ",
73 output_shape.NumElements(), " should be: ", num_dims);
74 }
75
76 // sparse_values
77 const int64_t num_values = sparse_values.NumElements();
78 if (sparse_values.dims() != 0 &&
79 (sparse_values.dims() != 1 || num_values != num_elems)) {
80 return errors::InvalidArgument("sparse_values has incorrect shape ",
81 sparse_values.shape().DebugString(),
82 ", should be [] or [", num_elems, "]");
83 }
84
85 // default_value
86 if (!TensorShapeUtils::IsScalar(default_value.shape())) {
87 return errors::InvalidArgument("default_value should be a scalar.");
88 }
89 return OkStatus();
90}
91
92} // end namespace
93
94// Operator to convert sparse representations to dense.
95template <typename T, typename Index>
96class SparseToDense : public OpKernel {
97 public:
98 explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {
99 OP_REQUIRES_OK(context,
100 context->GetAttr("validate_indices", &validate_indices_));
101 }
102
103 void Compute(OpKernelContext* c) override {
104 const Tensor& indices = c->input(0);
105 const Tensor& output_shape = c->input(1);
106 const Tensor& sparse_values = c->input(2);
107 const Tensor& default_value = c->input(3);
108 OP_REQUIRES_OK(c, CheckSparseToDenseShapes(indices, output_shape,
109 sparse_values, default_value));
110
111 const int64_t num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1;
112 const int64_t num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1;
113
114 auto output_shape_vec = output_shape.flat<Index>();
115 TensorShape output_tensor_shape;
116 OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(output_shape_vec.data(),
117 output_shape_vec.size(),
118 &output_tensor_shape));
119 Tensor* output = nullptr;
120 OP_REQUIRES_OK(c, c->allocate_output(0, output_tensor_shape, &output));
121
122 const Tensor* indices_shaped;
123 std::unique_ptr<Tensor> indices_shaped_holder;
124 if (indices.dtype() == DT_INT64 && indices.dims() == 2) {
125 indices_shaped = &indices;
126 } else {
127 TensorShape ix_shape({num_elems, num_dims});
128 indices_shaped_holder = MakeUnique<Tensor>(DT_INT64, ix_shape);
129 indices_shaped = indices_shaped_holder.get();
130 if (indices.dtype() == DT_INT64) {
131 CHECK(indices_shaped_holder->CopyFrom(indices, ix_shape));
132 } else {
133 indices_shaped_holder->matrix<int64_t>() =
134 indices.shaped<Index, 2>(ix_shape.dim_sizes())
135 .template cast<int64_t>();
136 }
137 }
138
139 // If we received a scalar, we'll need to create a new
140 // tensor with copies of the values as a vec.
141 const Tensor* sparse_values_b;
142 std::unique_ptr<Tensor> sparse_values_b_holder;
143
144 if (TensorShapeUtils::IsScalar(sparse_values.shape())) {
145 sparse_values_b_holder = MakeUnique<Tensor>(DataTypeToEnum<T>::value,
146 TensorShape({num_elems}));
147 sparse_values_b = sparse_values_b_holder.get();
148 sparse_values_b_holder->vec<T>().setConstant(sparse_values.scalar<T>()());
149 } else {
150 sparse_values_b = &sparse_values;
151 }
152
153 // Assume SparseTensor is lexicographically sorted.
154 gtl::InlinedVector<int64_t, 8> order(output->shape().dims());
155 std::iota(order.begin(), order.end(), 0);
156 sparse::SparseTensor st;
157 OP_REQUIRES_OK(
158 c, sparse::SparseTensor::Create(*indices_shaped, *sparse_values_b,
159 output->shape(), order, &st));
160
161 if (validate_indices_) {
162 OP_REQUIRES_OK(c, st.IndicesValid());
163 }
164
165 output->flat<T>().setConstant(default_value.scalar<T>()());
166 OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */),
167 errors::InvalidArgument(
168 "Indices are not valid (out of bounds). Shape: ",
169 output->shape().DebugString()));
170 }
171
172 private:
173 bool validate_indices_;
174};
175
176#define REGISTER_KERNELS(type, index_type) \
177 REGISTER_KERNEL_BUILDER(Name("SparseToDense") \
178 .Device(DEVICE_CPU) \
179 .TypeConstraint<type>("T") \
180 .TypeConstraint<index_type>("Tindices"), \
181 SparseToDense<type, index_type>);
182
183#define REGISTER_KERNELS_ALL(type) \
184 REGISTER_KERNELS(type, int32); \
185 REGISTER_KERNELS(type, int64_t);
186
187TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL);
188REGISTER_KERNELS_ALL(bool);
189REGISTER_KERNELS_ALL(tstring);
190REGISTER_KERNELS_ALL(complex64);
191REGISTER_KERNELS_ALL(complex128);
192
193#undef REGISTER_KERNELS_ALL
194#undef REGISTER_KERNELS
195
196#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
197template <typename T, typename Index>
198class SparseToDenseGPU : public AsyncOpKernel {
199 public:
200 explicit SparseToDenseGPU(OpKernelConstruction* context)
201 : AsyncOpKernel(context) {
202 OP_REQUIRES_OK(context,
203 context->GetAttr("validate_indices", &validate_indices_));
204 }
205
206 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
207 auto* stream = c->op_device_context()->stream();
208 OP_REQUIRES_ASYNC(c, stream, errors::Internal("No GPU stream available."),
209 done);
210
211 const Tensor& indices = c->input(0);
212 const Tensor& output_shape = c->input(1);
213 const Tensor& sparse_values = c->input(2);
214 const Tensor& default_value = c->input(3);
215 OP_REQUIRES_OK_ASYNC(c,
216 CheckSparseToDenseShapes(indices, output_shape,
217 sparse_values, default_value),
218 done);
219
220 auto output_shape_vec = output_shape.flat<Index>();
221 TensorShape output_tensor_shape;
222 OP_REQUIRES_OK_ASYNC(c,
223 TensorShapeUtils::MakeShape(output_shape_vec.data(),
224 output_shape_vec.size(),
225 &output_tensor_shape),
226 done);
227 Tensor* output = nullptr;
228 OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, output_tensor_shape, &output),
229 done);
230
231 Tensor output_shape_tensor;
232 OP_REQUIRES_OK_ASYNC(
233 c,
234 c->allocate_temp(DataTypeToEnum<Index>::value,
235 {output_shape_vec.size()}, &output_shape_tensor),
236 done);
237 auto output_shape_data =
238 AsDeviceMemory(output_shape_tensor.template flat<Index>().data(),
239 output_shape_tensor.template flat<Index>().size());
240 OP_REQUIRES_ASYNC(
241 c,
242 stream
243 ->ThenMemcpy(&output_shape_data, output_shape_vec.data(),
244 output_shape_tensor.NumElements() * sizeof(Index))
245 .ok(),
246 errors::InvalidArgument(
247 "failed to copy output_shape vector from host to "
248 "device in SparseToDenseOp"),
249 done);
250
251 functor::LaunchSparseToDense<T, Index>()(
252 c, done, this, validate_indices_, indices, sparse_values,
253 output_shape_tensor, default_value.scalar<T>()(), output);
254 }
255
256 private:
257 bool validate_indices_;
258};
259
260// TODO(b/184077412): SparseToDense causes an illegal access error.
261
262#define REGISTER_GPU_KERNELS(type, index_type) \
263 REGISTER_KERNEL_BUILDER(Name("SparseToDense") \
264 .Device(DEVICE_GPU) \
265 .HostMemory("default_value") \
266 .HostMemory("output_shape") \
267 .TypeConstraint<type>("T") \
268 .TypeConstraint<index_type>("Tindices"), \
269 SparseToDenseGPU<type, index_type>);
270
271#define REGISTER_GPU_KERNELS_ALL(type) \
272 REGISTER_GPU_KERNELS(type, int32); \
273 REGISTER_GPU_KERNELS(type, int64_t);
274
275TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS_ALL);
276TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS_ALL)
277REGISTER_GPU_KERNELS_ALL(bool)
278
279#undef REGISTER_GPU_KERNELS_ALL
280#undef REGISTER_GPU_KERNELS
281
282#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
283
284} // namespace tensorflow
285