1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include <algorithm> |
19 | #include <numeric> |
20 | #include <unordered_map> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/register_types.h" |
26 | #include "tensorflow/core/framework/resource_mgr.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/tensor_util.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
31 | #include "tensorflow/core/util/overflow.h" |
32 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | typedef Eigen::ThreadPoolDevice CPUDevice; |
37 | |
38 | using sparse::SparseTensor; |
39 | |
40 | class SparseTensorsMap : public ResourceBase { |
41 | public: |
42 | explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {} |
43 | |
44 | string DebugString() const override { return "A SparseTensorsMap" ; } |
45 | |
46 | typedef struct { |
47 | Tensor indices; |
48 | Tensor values; |
49 | gtl::InlinedVector<int64_t, 8> shape; |
50 | } PersistentSparseTensor; |
51 | |
52 | Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp, |
53 | int64_t* handle) { |
54 | Tensor ix; |
55 | TF_RETURN_IF_ERROR( |
56 | ctx->allocate_temp(sp.indices().dtype(), sp.indices().shape(), &ix)); |
57 | ix = sp.indices(); |
58 | |
59 | Tensor values; |
60 | TF_RETURN_IF_ERROR(ctx->allocate_temp(sp.indices().dtype(), |
61 | sp.indices().shape(), &values)); |
62 | values = sp.values(); |
63 | { |
64 | mutex_lock l(mu_); |
65 | int64_t unique_st_handle = counter_++; // increment is guarded on purpose |
66 | sp_tensors_[unique_st_handle] = PersistentSparseTensor{ |
67 | ix, values, |
68 | gtl::InlinedVector<int64_t, 8>(sp.shape().begin(), sp.shape().end())}; |
69 | *handle = unique_st_handle; |
70 | } |
71 | return OkStatus(); |
72 | } |
73 | |
74 | Status RetrieveAndClearSparseTensors( |
75 | OpKernelContext* ctx, const TTypes<int64_t>::ConstVec& handles, |
76 | std::vector<SparseTensor>* sparse_tensors) { |
77 | sparse_tensors->clear(); |
78 | sparse_tensors->reserve(handles.size()); |
79 | { |
80 | mutex_lock l(mu_); |
81 | for (size_t i = 0; i < handles.size(); ++i) { |
82 | const int64_t handle = handles(i); |
83 | auto sp_iter = sp_tensors_.find(handle); |
84 | if (sp_iter == sp_tensors_.end()) { |
85 | return errors::InvalidArgument( |
86 | "Unable to find SparseTensor: " , handle, " in map: " , name_); |
87 | } |
88 | const Tensor* ix = &sp_iter->second.indices; |
89 | const Tensor* values = &sp_iter->second.values; |
90 | const auto& shape = sp_iter->second.shape; |
91 | SparseTensor tensor; |
92 | TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor)); |
93 | sparse_tensors->push_back(std::move(tensor)); |
94 | sp_tensors_.erase(sp_iter); |
95 | } |
96 | } |
97 | |
98 | return OkStatus(); |
99 | } |
100 | |
101 | protected: |
102 | ~SparseTensorsMap() override {} |
103 | |
104 | private: |
105 | string name_; |
106 | |
107 | mutex mu_; |
108 | int64_t counter_ TF_GUARDED_BY(mu_); |
109 | std::unordered_map<int64_t, PersistentSparseTensor> sp_tensors_ |
110 | TF_GUARDED_BY(mu_); |
111 | }; |
112 | |
113 | class SparseTensorAccessingOp : public OpKernel { |
114 | public: |
115 | typedef std::function<Status(SparseTensorsMap**)> CreatorCallback; |
116 | |
117 | explicit SparseTensorAccessingOp(OpKernelConstruction* context) |
118 | : OpKernel(context), sparse_tensors_map_(nullptr) {} |
119 | |
120 | protected: |
121 | ~SparseTensorAccessingOp() override { |
122 | if (sparse_tensors_map_) sparse_tensors_map_->Unref(); |
123 | } |
124 | |
125 | Status GetMap(OpKernelContext* ctx, bool is_writing, |
126 | SparseTensorsMap** sparse_tensors_map) { |
127 | mutex_lock l(mu_); |
128 | |
129 | if (sparse_tensors_map_) { |
130 | *sparse_tensors_map = sparse_tensors_map_; |
131 | return OkStatus(); |
132 | } |
133 | |
134 | TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(), |
135 | is_writing /* use_node_name_as_default */)); |
136 | |
137 | CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) { |
138 | SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name()); |
139 | *c = map; |
140 | return OkStatus(); |
141 | }; |
142 | |
143 | TF_RETURN_IF_ERROR( |
144 | cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>( |
145 | cinfo_.container(), cinfo_.name(), &sparse_tensors_map_, |
146 | sparse_tensors_map_creator)); |
147 | |
148 | *sparse_tensors_map = sparse_tensors_map_; |
149 | return OkStatus(); |
150 | } |
151 | |
152 | private: |
153 | ContainerInfo cinfo_; |
154 | |
155 | mutex mu_; |
156 | SparseTensorsMap* sparse_tensors_map_ TF_PT_GUARDED_BY(mu_); |
157 | }; |
158 | |
159 | class AddSparseToTensorsMapOp : public SparseTensorAccessingOp { |
160 | public: |
161 | explicit AddSparseToTensorsMapOp(OpKernelConstruction* context) |
162 | : SparseTensorAccessingOp(context) {} |
163 | |
164 | void Compute(OpKernelContext* context) override { |
165 | const Tensor* input_indices; |
166 | const Tensor* input_values; |
167 | const Tensor* input_shape; |
168 | SparseTensorsMap* map; |
169 | |
170 | OP_REQUIRES_OK(context, context->input("sparse_indices" , &input_indices)); |
171 | OP_REQUIRES_OK(context, context->input("sparse_values" , &input_values)); |
172 | OP_REQUIRES_OK(context, context->input("sparse_shape" , &input_shape)); |
173 | OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); |
174 | |
175 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), |
176 | errors::InvalidArgument( |
177 | "Input indices should be a matrix but received shape " , |
178 | input_indices->shape().DebugString())); |
179 | |
180 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), |
181 | errors::InvalidArgument( |
182 | "Input values should be a vector but received shape " , |
183 | input_values->shape().DebugString())); |
184 | |
185 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), |
186 | errors::InvalidArgument( |
187 | "Input shape should be a vector but received shape " , |
188 | input_shape->shape().DebugString())); |
189 | |
190 | TensorShape input_shape_object; |
191 | OP_REQUIRES_OK( |
192 | context, TensorShapeUtils::MakeShape(input_shape->vec<int64_t>().data(), |
193 | input_shape->NumElements(), |
194 | &input_shape_object)); |
195 | SparseTensor st; |
196 | OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values, |
197 | input_shape_object, &st)); |
198 | int64_t handle; |
199 | OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle)); |
200 | |
201 | Tensor sparse_handle(DT_INT64, TensorShape({})); |
202 | auto sparse_handle_t = sparse_handle.scalar<int64_t>(); |
203 | |
204 | sparse_handle_t() = handle; |
205 | |
206 | context->set_output(0, sparse_handle); |
207 | } |
208 | }; |
209 | |
210 | REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap" ).Device(DEVICE_CPU), |
211 | AddSparseToTensorsMapOp); |
212 | |
213 | template <typename T> |
214 | class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp { |
215 | public: |
216 | explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context) |
217 | : SparseTensorAccessingOp(context) {} |
218 | |
219 | void Compute(OpKernelContext* context) override { |
220 | const Tensor* input_indices; |
221 | const Tensor* input_values; |
222 | const Tensor* input_shape; |
223 | SparseTensorsMap* map; |
224 | |
225 | OP_REQUIRES_OK(context, context->input("sparse_indices" , &input_indices)); |
226 | OP_REQUIRES_OK(context, context->input("sparse_values" , &input_values)); |
227 | OP_REQUIRES_OK(context, context->input("sparse_shape" , &input_shape)); |
228 | OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); |
229 | |
230 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), |
231 | errors::InvalidArgument( |
232 | "Input indices should be a matrix but received shape " , |
233 | input_indices->shape().DebugString())); |
234 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), |
235 | errors::InvalidArgument( |
236 | "Input values should be a vector but received shape " , |
237 | input_values->shape().DebugString())); |
238 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), |
239 | errors::InvalidArgument( |
240 | "Input shape should be a vector but received shape " , |
241 | input_shape->shape().DebugString())); |
242 | OP_REQUIRES( |
243 | context, |
244 | input_values->shape().dim_size(0) == input_indices->shape().dim_size(0), |
245 | errors::InvalidArgument( |
246 | "Number of values must match first dimension of indices. " , "Got " , |
247 | input_values->shape().dim_size(0), |
248 | " values, indices shape: " , input_indices->shape().DebugString())); |
249 | OP_REQUIRES( |
250 | context, |
251 | input_shape->shape().dim_size(0) == input_indices->shape().dim_size(1), |
252 | errors::InvalidArgument( |
253 | "Number of dimensions must match second dimension of indices. " , |
254 | "Got " , input_shape->shape().dim_size(0), |
255 | " dimensions, indices shape: " , |
256 | input_indices->shape().DebugString())); |
257 | |
258 | int rank = input_shape->NumElements(); |
259 | |
260 | OP_REQUIRES( |
261 | context, rank > 1, |
262 | errors::InvalidArgument( |
263 | "Rank of input SparseTensor should be > 1, but saw rank: " , rank)); |
264 | |
265 | auto input_shape_vec = input_shape->vec<int64_t>(); |
266 | |
267 | TensorShape tensor_input_shape; |
268 | OP_REQUIRES_OK(context, TensorShape::BuildTensorShape(input_shape_vec, |
269 | &tensor_input_shape)); |
270 | gtl::InlinedVector<int64_t, 8> std_order(rank); |
271 | std::iota(std_order.begin(), std_order.end(), 0); |
272 | SparseTensor input_st; |
273 | OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values, |
274 | tensor_input_shape, std_order, |
275 | &input_st)); |
276 | |
277 | const int64_t N = input_shape_vec(0); |
278 | |
279 | Tensor sparse_handles(DT_INT64, TensorShape({N})); |
280 | auto sparse_handles_t = sparse_handles.vec<int64_t>(); |
281 | |
282 | OP_REQUIRES_OK(context, input_st.IndicesValid()); |
283 | |
284 | // We can generate the output shape proto string now, for all |
285 | // minibatch entries. |
286 | TensorShape output_shape; |
287 | OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( |
288 | input_shape_vec.data() + 1, |
289 | input_shape->NumElements() - 1, &output_shape)); |
290 | |
291 | // Get groups by minibatch dimension |
292 | std::unordered_set<int64_t> visited; |
293 | sparse::GroupIterable minibatch = input_st.group({0}); |
294 | for (const auto& subset : minibatch) { |
295 | const int64_t b = subset.group()[0]; |
296 | visited.insert(b); |
297 | OP_REQUIRES( |
298 | context, b > -1 && b < N, |
299 | errors::InvalidArgument( |
300 | "Received unexpected column 0 value in input SparseTensor: " , b, |
301 | " < 0 or >= N (= " , N, ")" )); |
302 | |
303 | const auto indices = subset.indices(); |
304 | const auto values = subset.values<T>(); |
305 | const int64_t num_entries = values.size(); |
306 | |
307 | Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1}); |
308 | Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries}); |
309 | |
310 | auto output_indices_t = output_indices.matrix<int64_t>(); |
311 | auto output_values_t = output_values.vec<T>(); |
312 | |
313 | for (int i = 0; i < num_entries; ++i) { |
314 | for (int d = 1; d < rank; ++d) { |
315 | output_indices_t(i, d - 1) = indices(i, d); |
316 | } |
317 | output_values_t(i) = values(i); |
318 | } |
319 | |
320 | SparseTensor st_i; |
321 | OP_REQUIRES_OK(context, |
322 | SparseTensor::Create(output_indices, output_values, |
323 | output_shape, &st_i)); |
324 | int64_t handle; |
325 | OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle)); |
326 | sparse_handles_t(b) = handle; |
327 | } |
328 | |
329 | // Fill in any gaps; we must provide an empty ST for batch entries |
330 | // the grouper didn't find. |
331 | if (visited.size() < N) { |
332 | Tensor empty_indices(DT_INT64, {0, rank - 1}); |
333 | Tensor empty_values(DataTypeToEnum<T>::value, {0}); |
334 | SparseTensor empty_st; |
335 | OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values, |
336 | output_shape, &empty_st)); |
337 | |
338 | for (int64_t b = 0; b < N; ++b) { |
339 | // We skipped this batch entry. |
340 | if (visited.find(b) == visited.end()) { |
341 | int64_t handle; |
342 | OP_REQUIRES_OK(context, |
343 | map->AddSparseTensor(context, empty_st, &handle)); |
344 | sparse_handles_t(b) = handle; |
345 | } |
346 | } |
347 | } |
348 | |
349 | context->set_output(0, sparse_handles); |
350 | } |
351 | }; |
352 | |
353 | #define REGISTER_KERNELS(type) \ |
354 | REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \ |
355 | .Device(DEVICE_CPU) \ |
356 | .TypeConstraint<type>("T"), \ |
357 | AddManySparseToTensorsMapOp<type>) |
358 | |
359 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
360 | #undef REGISTER_KERNELS |
361 | |
362 | template <typename T> |
363 | class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp { |
364 | public: |
365 | explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context) |
366 | : SparseTensorAccessingOp(context) {} |
367 | |
368 | void Compute(OpKernelContext* context) override { |
369 | SparseTensorsMap* map = nullptr; |
370 | OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map)); |
371 | |
372 | const Tensor& sparse_handles = context->input(0); |
373 | |
374 | OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()), |
375 | errors::InvalidArgument( |
376 | "sparse_handles should be a vector but received shape " , |
377 | sparse_handles.shape().DebugString())); |
378 | |
379 | int64_t N = sparse_handles.shape().dim_size(0); |
380 | |
381 | OP_REQUIRES( |
382 | context, N > 0, |
383 | errors::InvalidArgument("Must have at least 1 serialized SparseTensor, " |
384 | "but input matrix has 0 rows" )); |
385 | |
386 | std::vector<Tensor> indices_to_concat; |
387 | std::vector<Tensor> values_to_concat; |
388 | std::vector<TensorShape> shapes_to_concat; |
389 | |
390 | const auto& sparse_handles_t = sparse_handles.vec<int64_t>(); |
391 | |
392 | std::vector<SparseTensor> sparse_tensors; |
393 | |
394 | OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors( |
395 | context, sparse_handles_t, &sparse_tensors)); |
396 | |
397 | for (int64_t i = 0; i < N; ++i) { |
398 | const SparseTensor& st = sparse_tensors[i]; |
399 | const Tensor& output_indices = st.indices(); |
400 | const Tensor& output_values = st.values(); |
401 | const auto output_shape = st.shape(); |
402 | |
403 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), |
404 | errors::InvalidArgument( |
405 | "Expected sparse_handles[" , i, |
406 | "] to represent an index matrix but received shape " , |
407 | output_indices.shape().DebugString())); |
408 | OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()), |
409 | errors::InvalidArgument( |
410 | "Expected sparse_handles[" , i, |
411 | "] to represent a values vector but received shape " , |
412 | output_values.shape().DebugString())); |
413 | OP_REQUIRES( |
414 | context, DataTypeToEnum<T>::value == output_values.dtype(), |
415 | errors::InvalidArgument( |
416 | "Requested SparseTensor of type " , |
417 | DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[" , i, |
418 | "].values.dtype() == " , DataTypeString(output_values.dtype()))); |
419 | |
420 | int64_t num_entries = output_indices.dim_size(0); |
421 | OP_REQUIRES(context, num_entries == output_values.dim_size(0), |
422 | errors::InvalidArgument( |
423 | "Expected row counts of SparseTensor[" , i, |
424 | "].indices and SparseTensor[" , i, |
425 | "].values to match but they do not: " , num_entries, |
426 | " vs. " , output_values.dim_size(0))); |
427 | int rank = output_indices.dim_size(1); |
428 | OP_REQUIRES( |
429 | context, rank == output_shape.size(), |
430 | errors::InvalidArgument("Expected column counts of SparseTensor[" , i, |
431 | "].indices to match size of SparseTensor[" , i, |
432 | "].shape " |
433 | "but they do not: " , |
434 | rank, " vs. " , output_shape.size())); |
435 | |
436 | // Now we expand each SparseTensors' indices and shape by |
437 | // prefixing a dimension |
438 | Tensor expanded_indices( |
439 | DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)})); |
440 | Tensor expanded_shape(DT_INT64, TensorShape({1 + rank})); |
441 | const auto& output_indices_t = output_indices.matrix<int64_t>(); |
442 | auto expanded_indices_t = expanded_indices.matrix<int64_t>(); |
443 | auto expanded_shape_t = expanded_shape.vec<int64_t>(); |
444 | expanded_indices_t.chip<1>(0).setZero(); |
445 | Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1); |
446 | Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank); |
447 | expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t; |
448 | expanded_shape_t(0) = 1; |
449 | // TODO: copy shape from TensorShape to &expanded_shape_t(1) |
450 | // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); |
451 | for (int i = 0; i < rank; ++i) { |
452 | expanded_shape_t(i + 1) = output_shape[i]; |
453 | } |
454 | TensorShape expanded_tensor_shape(expanded_shape_t); |
455 | |
456 | indices_to_concat.push_back(std::move(expanded_indices)); |
457 | values_to_concat.push_back(output_values); |
458 | shapes_to_concat.push_back(std::move(expanded_tensor_shape)); |
459 | } |
460 | |
461 | int rank = -1; |
462 | for (int i = 0; i < N; ++i) { |
463 | if (rank < 0) rank = shapes_to_concat[i].dims(); |
464 | OP_REQUIRES(context, rank == shapes_to_concat[i].dims(), |
465 | errors::InvalidArgument( |
466 | "Inconsistent rank across SparseTensors: rank prior to " |
467 | "SparseTensor[" , |
468 | i, "] was: " , rank, " but rank of SparseTensor[" , i, |
469 | "] is: " , shapes_to_concat[i].dims())); |
470 | } |
471 | |
472 | // SparseTensor::Concat requires consistent shape for all but the |
473 | // primary order dimension (dimension 0 in this case). So we get |
474 | // the maximum value across all the input SparseTensors for each |
475 | // dimension and use that. |
476 | TensorShape preconcat_shape(shapes_to_concat[0]); |
477 | for (int i = 0; i < N; ++i) { |
478 | for (int d = 0; d < rank; ++d) { |
479 | preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d), |
480 | shapes_to_concat[i].dim_size(d))); |
481 | } |
482 | } |
483 | |
484 | // Dimension 0 is the primary dimension. |
485 | gtl::InlinedVector<int64_t, 8> std_order(rank); |
486 | std::iota(std_order.begin(), std_order.end(), 0); |
487 | |
488 | std::vector<SparseTensor> tensors_to_concat; |
489 | tensors_to_concat.reserve(N); |
490 | for (int i = 0; i < N; ++i) { |
491 | SparseTensor tensor; |
492 | OP_REQUIRES_OK(context, |
493 | SparseTensor::Create(std::move(indices_to_concat[i]), |
494 | std::move(values_to_concat[i]), |
495 | preconcat_shape, std_order, &tensor)); |
496 | tensors_to_concat.push_back(std::move(tensor)); |
497 | } |
498 | |
499 | auto output = SparseTensor::Concat<T>(tensors_to_concat); |
500 | Tensor final_output_shape(DT_INT64, TensorShape({output.dims()})); |
501 | |
502 | std::copy_n(output.shape().data(), output.dims(), |
503 | final_output_shape.vec<int64_t>().data()); |
504 | |
505 | context->set_output(0, output.indices()); |
506 | context->set_output(1, output.values()); |
507 | context->set_output(2, final_output_shape); |
508 | } |
509 | }; |
510 | |
511 | #define REGISTER_KERNELS(type) \ |
512 | REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \ |
513 | .Device(DEVICE_CPU) \ |
514 | .TypeConstraint<type>("dtype"), \ |
515 | TakeManySparseFromTensorsMapOp<type>) |
516 | |
517 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
518 | #undef REGISTER_KERNELS |
519 | |
520 | } // namespace tensorflow |
521 | |