1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See core/ops/sparse_ops.cc for documentation. |
17 | // |
18 | // NOTE: the operations in this file only are suitable for execution |
19 | // on CPUs. |
20 | |
21 | #define EIGEN_USE_THREADS |
22 | |
23 | #include <numeric> |
24 | #include <sstream> |
25 | #include <string> |
26 | #include <unordered_map> |
27 | #include <utility> |
28 | |
29 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
30 | #include "tensorflow/core/framework/op_kernel.h" |
31 | #include "tensorflow/core/framework/register_types.h" |
32 | #include "tensorflow/core/framework/tensor.h" |
33 | #include "tensorflow/core/framework/types.h" |
34 | #include "tensorflow/core/lib/core/status.h" |
35 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
36 | #include "tensorflow/core/lib/strings/stringprintf.h" |
37 | #include "tensorflow/core/util/ptr_util.h" |
38 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
39 | |
40 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
41 | #include "tensorflow/core/kernels/gpu_utils.h" |
42 | #include "tensorflow/core/kernels/sparse_to_dense_op_gpu.h" |
43 | #include "tensorflow/core/platform/stream_executor.h" |
44 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
45 | |
46 | namespace tensorflow { |
47 | |
48 | namespace { |
49 | |
50 | Status CheckSparseToDenseShapes(const Tensor& indices, |
51 | const Tensor& output_shape, |
52 | const Tensor& sparse_values, |
53 | const Tensor& default_value) { |
54 | // sparse_indices |
55 | if (indices.dims() > 2) { |
56 | return errors::InvalidArgument( |
57 | "sparse_indices should be a scalar, vector, or matrix, " |
58 | "got shape " , |
59 | indices.shape().DebugString()); |
60 | } |
61 | const int64_t num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1; |
62 | const int64_t num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1; |
63 | |
64 | // output_shape |
65 | if (!TensorShapeUtils::IsVector(output_shape.shape())) { |
66 | return errors::InvalidArgument("output_shape must be rank 1, got shape " , |
67 | output_shape.shape().DebugString()); |
68 | } |
69 | |
70 | if (output_shape.NumElements() != num_dims) { |
71 | return errors::InvalidArgument( |
72 | "output_shape has incorrect number of elements: " , |
73 | output_shape.NumElements(), " should be: " , num_dims); |
74 | } |
75 | |
76 | // sparse_values |
77 | const int64_t num_values = sparse_values.NumElements(); |
78 | if (sparse_values.dims() != 0 && |
79 | (sparse_values.dims() != 1 || num_values != num_elems)) { |
80 | return errors::InvalidArgument("sparse_values has incorrect shape " , |
81 | sparse_values.shape().DebugString(), |
82 | ", should be [] or [" , num_elems, "]" ); |
83 | } |
84 | |
85 | // default_value |
86 | if (!TensorShapeUtils::IsScalar(default_value.shape())) { |
87 | return errors::InvalidArgument("default_value should be a scalar." ); |
88 | } |
89 | return OkStatus(); |
90 | } |
91 | |
92 | } // end namespace |
93 | |
94 | // Operator to convert sparse representations to dense. |
95 | template <typename T, typename Index> |
96 | class SparseToDense : public OpKernel { |
97 | public: |
98 | explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) { |
99 | OP_REQUIRES_OK(context, |
100 | context->GetAttr("validate_indices" , &validate_indices_)); |
101 | } |
102 | |
103 | void Compute(OpKernelContext* c) override { |
104 | const Tensor& indices = c->input(0); |
105 | const Tensor& output_shape = c->input(1); |
106 | const Tensor& sparse_values = c->input(2); |
107 | const Tensor& default_value = c->input(3); |
108 | OP_REQUIRES_OK(c, CheckSparseToDenseShapes(indices, output_shape, |
109 | sparse_values, default_value)); |
110 | |
111 | const int64_t num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1; |
112 | const int64_t num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1; |
113 | |
114 | auto output_shape_vec = output_shape.flat<Index>(); |
115 | TensorShape output_tensor_shape; |
116 | OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(output_shape_vec.data(), |
117 | output_shape_vec.size(), |
118 | &output_tensor_shape)); |
119 | Tensor* output = nullptr; |
120 | OP_REQUIRES_OK(c, c->allocate_output(0, output_tensor_shape, &output)); |
121 | |
122 | const Tensor* indices_shaped; |
123 | std::unique_ptr<Tensor> indices_shaped_holder; |
124 | if (indices.dtype() == DT_INT64 && indices.dims() == 2) { |
125 | indices_shaped = &indices; |
126 | } else { |
127 | TensorShape ix_shape({num_elems, num_dims}); |
128 | indices_shaped_holder = MakeUnique<Tensor>(DT_INT64, ix_shape); |
129 | indices_shaped = indices_shaped_holder.get(); |
130 | if (indices.dtype() == DT_INT64) { |
131 | CHECK(indices_shaped_holder->CopyFrom(indices, ix_shape)); |
132 | } else { |
133 | indices_shaped_holder->matrix<int64_t>() = |
134 | indices.shaped<Index, 2>(ix_shape.dim_sizes()) |
135 | .template cast<int64_t>(); |
136 | } |
137 | } |
138 | |
139 | // If we received a scalar, we'll need to create a new |
140 | // tensor with copies of the values as a vec. |
141 | const Tensor* sparse_values_b; |
142 | std::unique_ptr<Tensor> sparse_values_b_holder; |
143 | |
144 | if (TensorShapeUtils::IsScalar(sparse_values.shape())) { |
145 | sparse_values_b_holder = MakeUnique<Tensor>(DataTypeToEnum<T>::value, |
146 | TensorShape({num_elems})); |
147 | sparse_values_b = sparse_values_b_holder.get(); |
148 | sparse_values_b_holder->vec<T>().setConstant(sparse_values.scalar<T>()()); |
149 | } else { |
150 | sparse_values_b = &sparse_values; |
151 | } |
152 | |
153 | // Assume SparseTensor is lexicographically sorted. |
154 | gtl::InlinedVector<int64_t, 8> order(output->shape().dims()); |
155 | std::iota(order.begin(), order.end(), 0); |
156 | sparse::SparseTensor st; |
157 | OP_REQUIRES_OK( |
158 | c, sparse::SparseTensor::Create(*indices_shaped, *sparse_values_b, |
159 | output->shape(), order, &st)); |
160 | |
161 | if (validate_indices_) { |
162 | OP_REQUIRES_OK(c, st.IndicesValid()); |
163 | } |
164 | |
165 | output->flat<T>().setConstant(default_value.scalar<T>()()); |
166 | OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */), |
167 | errors::InvalidArgument( |
168 | "Indices are not valid (out of bounds). Shape: " , |
169 | output->shape().DebugString())); |
170 | } |
171 | |
172 | private: |
173 | bool validate_indices_; |
174 | }; |
175 | |
176 | #define REGISTER_KERNELS(type, index_type) \ |
177 | REGISTER_KERNEL_BUILDER(Name("SparseToDense") \ |
178 | .Device(DEVICE_CPU) \ |
179 | .TypeConstraint<type>("T") \ |
180 | .TypeConstraint<index_type>("Tindices"), \ |
181 | SparseToDense<type, index_type>); |
182 | |
183 | #define REGISTER_KERNELS_ALL(type) \ |
184 | REGISTER_KERNELS(type, int32); \ |
185 | REGISTER_KERNELS(type, int64_t); |
186 | |
187 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL); |
188 | REGISTER_KERNELS_ALL(bool); |
189 | REGISTER_KERNELS_ALL(tstring); |
190 | REGISTER_KERNELS_ALL(complex64); |
191 | REGISTER_KERNELS_ALL(complex128); |
192 | |
193 | #undef REGISTER_KERNELS_ALL |
194 | #undef REGISTER_KERNELS |
195 | |
196 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
197 | template <typename T, typename Index> |
198 | class SparseToDenseGPU : public AsyncOpKernel { |
199 | public: |
200 | explicit SparseToDenseGPU(OpKernelConstruction* context) |
201 | : AsyncOpKernel(context) { |
202 | OP_REQUIRES_OK(context, |
203 | context->GetAttr("validate_indices" , &validate_indices_)); |
204 | } |
205 | |
206 | void ComputeAsync(OpKernelContext* c, DoneCallback done) final { |
207 | auto* stream = c->op_device_context()->stream(); |
208 | OP_REQUIRES_ASYNC(c, stream, errors::Internal("No GPU stream available." ), |
209 | done); |
210 | |
211 | const Tensor& indices = c->input(0); |
212 | const Tensor& output_shape = c->input(1); |
213 | const Tensor& sparse_values = c->input(2); |
214 | const Tensor& default_value = c->input(3); |
215 | OP_REQUIRES_OK_ASYNC(c, |
216 | CheckSparseToDenseShapes(indices, output_shape, |
217 | sparse_values, default_value), |
218 | done); |
219 | |
220 | auto output_shape_vec = output_shape.flat<Index>(); |
221 | TensorShape output_tensor_shape; |
222 | OP_REQUIRES_OK_ASYNC(c, |
223 | TensorShapeUtils::MakeShape(output_shape_vec.data(), |
224 | output_shape_vec.size(), |
225 | &output_tensor_shape), |
226 | done); |
227 | Tensor* output = nullptr; |
228 | OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, output_tensor_shape, &output), |
229 | done); |
230 | |
231 | Tensor output_shape_tensor; |
232 | OP_REQUIRES_OK_ASYNC( |
233 | c, |
234 | c->allocate_temp(DataTypeToEnum<Index>::value, |
235 | {output_shape_vec.size()}, &output_shape_tensor), |
236 | done); |
237 | auto output_shape_data = |
238 | AsDeviceMemory(output_shape_tensor.template flat<Index>().data(), |
239 | output_shape_tensor.template flat<Index>().size()); |
240 | OP_REQUIRES_ASYNC( |
241 | c, |
242 | stream |
243 | ->ThenMemcpy(&output_shape_data, output_shape_vec.data(), |
244 | output_shape_tensor.NumElements() * sizeof(Index)) |
245 | .ok(), |
246 | errors::InvalidArgument( |
247 | "failed to copy output_shape vector from host to " |
248 | "device in SparseToDenseOp" ), |
249 | done); |
250 | |
251 | functor::LaunchSparseToDense<T, Index>()( |
252 | c, done, this, validate_indices_, indices, sparse_values, |
253 | output_shape_tensor, default_value.scalar<T>()(), output); |
254 | } |
255 | |
256 | private: |
257 | bool validate_indices_; |
258 | }; |
259 | |
260 | // TODO(b/184077412): SparseToDense causes an illegal access error. |
261 | |
262 | #define REGISTER_GPU_KERNELS(type, index_type) \ |
263 | REGISTER_KERNEL_BUILDER(Name("SparseToDense") \ |
264 | .Device(DEVICE_GPU) \ |
265 | .HostMemory("default_value") \ |
266 | .HostMemory("output_shape") \ |
267 | .TypeConstraint<type>("T") \ |
268 | .TypeConstraint<index_type>("Tindices"), \ |
269 | SparseToDenseGPU<type, index_type>); |
270 | |
271 | #define REGISTER_GPU_KERNELS_ALL(type) \ |
272 | REGISTER_GPU_KERNELS(type, int32); \ |
273 | REGISTER_GPU_KERNELS(type, int64_t); |
274 | |
275 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS_ALL); |
276 | TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS_ALL) |
277 | REGISTER_GPU_KERNELS_ALL(bool) |
278 | |
279 | #undef REGISTER_GPU_KERNELS_ALL |
280 | #undef REGISTER_GPU_KERNELS |
281 | |
282 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
283 | |
284 | } // namespace tensorflow |
285 | |