1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include <algorithm>
19#include <numeric>
20#include <unordered_map>
21#include <utility>
22#include <vector>
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/resource_mgr.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_util.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/lib/gtl/inlined_vector.h"
31#include "tensorflow/core/util/overflow.h"
32#include "tensorflow/core/util/sparse/sparse_tensor.h"
33
34namespace tensorflow {
35
36typedef Eigen::ThreadPoolDevice CPUDevice;
37
38using sparse::SparseTensor;
39
40class SparseTensorsMap : public ResourceBase {
41 public:
42 explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {}
43
44 string DebugString() const override { return "A SparseTensorsMap"; }
45
46 typedef struct {
47 Tensor indices;
48 Tensor values;
49 gtl::InlinedVector<int64_t, 8> shape;
50 } PersistentSparseTensor;
51
52 Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp,
53 int64_t* handle) {
54 Tensor ix;
55 TF_RETURN_IF_ERROR(
56 ctx->allocate_temp(sp.indices().dtype(), sp.indices().shape(), &ix));
57 ix = sp.indices();
58
59 Tensor values;
60 TF_RETURN_IF_ERROR(ctx->allocate_temp(sp.indices().dtype(),
61 sp.indices().shape(), &values));
62 values = sp.values();
63 {
64 mutex_lock l(mu_);
65 int64_t unique_st_handle = counter_++; // increment is guarded on purpose
66 sp_tensors_[unique_st_handle] = PersistentSparseTensor{
67 ix, values,
68 gtl::InlinedVector<int64_t, 8>(sp.shape().begin(), sp.shape().end())};
69 *handle = unique_st_handle;
70 }
71 return OkStatus();
72 }
73
74 Status RetrieveAndClearSparseTensors(
75 OpKernelContext* ctx, const TTypes<int64_t>::ConstVec& handles,
76 std::vector<SparseTensor>* sparse_tensors) {
77 sparse_tensors->clear();
78 sparse_tensors->reserve(handles.size());
79 {
80 mutex_lock l(mu_);
81 for (size_t i = 0; i < handles.size(); ++i) {
82 const int64_t handle = handles(i);
83 auto sp_iter = sp_tensors_.find(handle);
84 if (sp_iter == sp_tensors_.end()) {
85 return errors::InvalidArgument(
86 "Unable to find SparseTensor: ", handle, " in map: ", name_);
87 }
88 const Tensor* ix = &sp_iter->second.indices;
89 const Tensor* values = &sp_iter->second.values;
90 const auto& shape = sp_iter->second.shape;
91 SparseTensor tensor;
92 TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor));
93 sparse_tensors->push_back(std::move(tensor));
94 sp_tensors_.erase(sp_iter);
95 }
96 }
97
98 return OkStatus();
99 }
100
101 protected:
102 ~SparseTensorsMap() override {}
103
104 private:
105 string name_;
106
107 mutex mu_;
108 int64_t counter_ TF_GUARDED_BY(mu_);
109 std::unordered_map<int64_t, PersistentSparseTensor> sp_tensors_
110 TF_GUARDED_BY(mu_);
111};
112
113class SparseTensorAccessingOp : public OpKernel {
114 public:
115 typedef std::function<Status(SparseTensorsMap**)> CreatorCallback;
116
117 explicit SparseTensorAccessingOp(OpKernelConstruction* context)
118 : OpKernel(context), sparse_tensors_map_(nullptr) {}
119
120 protected:
121 ~SparseTensorAccessingOp() override {
122 if (sparse_tensors_map_) sparse_tensors_map_->Unref();
123 }
124
125 Status GetMap(OpKernelContext* ctx, bool is_writing,
126 SparseTensorsMap** sparse_tensors_map) {
127 mutex_lock l(mu_);
128
129 if (sparse_tensors_map_) {
130 *sparse_tensors_map = sparse_tensors_map_;
131 return OkStatus();
132 }
133
134 TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(),
135 is_writing /* use_node_name_as_default */));
136
137 CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) {
138 SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name());
139 *c = map;
140 return OkStatus();
141 };
142
143 TF_RETURN_IF_ERROR(
144 cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>(
145 cinfo_.container(), cinfo_.name(), &sparse_tensors_map_,
146 sparse_tensors_map_creator));
147
148 *sparse_tensors_map = sparse_tensors_map_;
149 return OkStatus();
150 }
151
152 private:
153 ContainerInfo cinfo_;
154
155 mutex mu_;
156 SparseTensorsMap* sparse_tensors_map_ TF_PT_GUARDED_BY(mu_);
157};
158
159class AddSparseToTensorsMapOp : public SparseTensorAccessingOp {
160 public:
161 explicit AddSparseToTensorsMapOp(OpKernelConstruction* context)
162 : SparseTensorAccessingOp(context) {}
163
164 void Compute(OpKernelContext* context) override {
165 const Tensor* input_indices;
166 const Tensor* input_values;
167 const Tensor* input_shape;
168 SparseTensorsMap* map;
169
170 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
171 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
172 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
173 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
174
175 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
176 errors::InvalidArgument(
177 "Input indices should be a matrix but received shape ",
178 input_indices->shape().DebugString()));
179
180 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
181 errors::InvalidArgument(
182 "Input values should be a vector but received shape ",
183 input_values->shape().DebugString()));
184
185 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
186 errors::InvalidArgument(
187 "Input shape should be a vector but received shape ",
188 input_shape->shape().DebugString()));
189
190 TensorShape input_shape_object;
191 OP_REQUIRES_OK(
192 context, TensorShapeUtils::MakeShape(input_shape->vec<int64_t>().data(),
193 input_shape->NumElements(),
194 &input_shape_object));
195 SparseTensor st;
196 OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
197 input_shape_object, &st));
198 int64_t handle;
199 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle));
200
201 Tensor sparse_handle(DT_INT64, TensorShape({}));
202 auto sparse_handle_t = sparse_handle.scalar<int64_t>();
203
204 sparse_handle_t() = handle;
205
206 context->set_output(0, sparse_handle);
207 }
208};
209
210REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU),
211 AddSparseToTensorsMapOp);
212
213template <typename T>
214class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
215 public:
216 explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context)
217 : SparseTensorAccessingOp(context) {}
218
219 void Compute(OpKernelContext* context) override {
220 const Tensor* input_indices;
221 const Tensor* input_values;
222 const Tensor* input_shape;
223 SparseTensorsMap* map;
224
225 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
226 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
227 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
228 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
229
230 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
231 errors::InvalidArgument(
232 "Input indices should be a matrix but received shape ",
233 input_indices->shape().DebugString()));
234 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
235 errors::InvalidArgument(
236 "Input values should be a vector but received shape ",
237 input_values->shape().DebugString()));
238 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
239 errors::InvalidArgument(
240 "Input shape should be a vector but received shape ",
241 input_shape->shape().DebugString()));
242 OP_REQUIRES(
243 context,
244 input_values->shape().dim_size(0) == input_indices->shape().dim_size(0),
245 errors::InvalidArgument(
246 "Number of values must match first dimension of indices. ", "Got ",
247 input_values->shape().dim_size(0),
248 " values, indices shape: ", input_indices->shape().DebugString()));
249 OP_REQUIRES(
250 context,
251 input_shape->shape().dim_size(0) == input_indices->shape().dim_size(1),
252 errors::InvalidArgument(
253 "Number of dimensions must match second dimension of indices. ",
254 "Got ", input_shape->shape().dim_size(0),
255 " dimensions, indices shape: ",
256 input_indices->shape().DebugString()));
257
258 int rank = input_shape->NumElements();
259
260 OP_REQUIRES(
261 context, rank > 1,
262 errors::InvalidArgument(
263 "Rank of input SparseTensor should be > 1, but saw rank: ", rank));
264
265 auto input_shape_vec = input_shape->vec<int64_t>();
266
267 TensorShape tensor_input_shape;
268 OP_REQUIRES_OK(context, TensorShape::BuildTensorShape(input_shape_vec,
269 &tensor_input_shape));
270 gtl::InlinedVector<int64_t, 8> std_order(rank);
271 std::iota(std_order.begin(), std_order.end(), 0);
272 SparseTensor input_st;
273 OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
274 tensor_input_shape, std_order,
275 &input_st));
276
277 const int64_t N = input_shape_vec(0);
278
279 Tensor sparse_handles(DT_INT64, TensorShape({N}));
280 auto sparse_handles_t = sparse_handles.vec<int64_t>();
281
282 OP_REQUIRES_OK(context, input_st.IndicesValid());
283
284 // We can generate the output shape proto string now, for all
285 // minibatch entries.
286 TensorShape output_shape;
287 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
288 input_shape_vec.data() + 1,
289 input_shape->NumElements() - 1, &output_shape));
290
291 // Get groups by minibatch dimension
292 std::unordered_set<int64_t> visited;
293 sparse::GroupIterable minibatch = input_st.group({0});
294 for (const auto& subset : minibatch) {
295 const int64_t b = subset.group()[0];
296 visited.insert(b);
297 OP_REQUIRES(
298 context, b > -1 && b < N,
299 errors::InvalidArgument(
300 "Received unexpected column 0 value in input SparseTensor: ", b,
301 " < 0 or >= N (= ", N, ")"));
302
303 const auto indices = subset.indices();
304 const auto values = subset.values<T>();
305 const int64_t num_entries = values.size();
306
307 Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
308 Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
309
310 auto output_indices_t = output_indices.matrix<int64_t>();
311 auto output_values_t = output_values.vec<T>();
312
313 for (int i = 0; i < num_entries; ++i) {
314 for (int d = 1; d < rank; ++d) {
315 output_indices_t(i, d - 1) = indices(i, d);
316 }
317 output_values_t(i) = values(i);
318 }
319
320 SparseTensor st_i;
321 OP_REQUIRES_OK(context,
322 SparseTensor::Create(output_indices, output_values,
323 output_shape, &st_i));
324 int64_t handle;
325 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle));
326 sparse_handles_t(b) = handle;
327 }
328
329 // Fill in any gaps; we must provide an empty ST for batch entries
330 // the grouper didn't find.
331 if (visited.size() < N) {
332 Tensor empty_indices(DT_INT64, {0, rank - 1});
333 Tensor empty_values(DataTypeToEnum<T>::value, {0});
334 SparseTensor empty_st;
335 OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values,
336 output_shape, &empty_st));
337
338 for (int64_t b = 0; b < N; ++b) {
339 // We skipped this batch entry.
340 if (visited.find(b) == visited.end()) {
341 int64_t handle;
342 OP_REQUIRES_OK(context,
343 map->AddSparseTensor(context, empty_st, &handle));
344 sparse_handles_t(b) = handle;
345 }
346 }
347 }
348
349 context->set_output(0, sparse_handles);
350 }
351};
352
353#define REGISTER_KERNELS(type) \
354 REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \
355 .Device(DEVICE_CPU) \
356 .TypeConstraint<type>("T"), \
357 AddManySparseToTensorsMapOp<type>)
358
359TF_CALL_ALL_TYPES(REGISTER_KERNELS);
360#undef REGISTER_KERNELS
361
362template <typename T>
363class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
364 public:
365 explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context)
366 : SparseTensorAccessingOp(context) {}
367
368 void Compute(OpKernelContext* context) override {
369 SparseTensorsMap* map = nullptr;
370 OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map));
371
372 const Tensor& sparse_handles = context->input(0);
373
374 OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()),
375 errors::InvalidArgument(
376 "sparse_handles should be a vector but received shape ",
377 sparse_handles.shape().DebugString()));
378
379 int64_t N = sparse_handles.shape().dim_size(0);
380
381 OP_REQUIRES(
382 context, N > 0,
383 errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
384 "but input matrix has 0 rows"));
385
386 std::vector<Tensor> indices_to_concat;
387 std::vector<Tensor> values_to_concat;
388 std::vector<TensorShape> shapes_to_concat;
389
390 const auto& sparse_handles_t = sparse_handles.vec<int64_t>();
391
392 std::vector<SparseTensor> sparse_tensors;
393
394 OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors(
395 context, sparse_handles_t, &sparse_tensors));
396
397 for (int64_t i = 0; i < N; ++i) {
398 const SparseTensor& st = sparse_tensors[i];
399 const Tensor& output_indices = st.indices();
400 const Tensor& output_values = st.values();
401 const auto output_shape = st.shape();
402
403 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
404 errors::InvalidArgument(
405 "Expected sparse_handles[", i,
406 "] to represent an index matrix but received shape ",
407 output_indices.shape().DebugString()));
408 OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
409 errors::InvalidArgument(
410 "Expected sparse_handles[", i,
411 "] to represent a values vector but received shape ",
412 output_values.shape().DebugString()));
413 OP_REQUIRES(
414 context, DataTypeToEnum<T>::value == output_values.dtype(),
415 errors::InvalidArgument(
416 "Requested SparseTensor of type ",
417 DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
418 "].values.dtype() == ", DataTypeString(output_values.dtype())));
419
420 int64_t num_entries = output_indices.dim_size(0);
421 OP_REQUIRES(context, num_entries == output_values.dim_size(0),
422 errors::InvalidArgument(
423 "Expected row counts of SparseTensor[", i,
424 "].indices and SparseTensor[", i,
425 "].values to match but they do not: ", num_entries,
426 " vs. ", output_values.dim_size(0)));
427 int rank = output_indices.dim_size(1);
428 OP_REQUIRES(
429 context, rank == output_shape.size(),
430 errors::InvalidArgument("Expected column counts of SparseTensor[", i,
431 "].indices to match size of SparseTensor[", i,
432 "].shape "
433 "but they do not: ",
434 rank, " vs. ", output_shape.size()));
435
436 // Now we expand each SparseTensors' indices and shape by
437 // prefixing a dimension
438 Tensor expanded_indices(
439 DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
440 Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
441 const auto& output_indices_t = output_indices.matrix<int64_t>();
442 auto expanded_indices_t = expanded_indices.matrix<int64_t>();
443 auto expanded_shape_t = expanded_shape.vec<int64_t>();
444 expanded_indices_t.chip<1>(0).setZero();
445 Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
446 Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
447 expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
448 expanded_shape_t(0) = 1;
449 // TODO: copy shape from TensorShape to &expanded_shape_t(1)
450 // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
451 for (int i = 0; i < rank; ++i) {
452 expanded_shape_t(i + 1) = output_shape[i];
453 }
454 TensorShape expanded_tensor_shape(expanded_shape_t);
455
456 indices_to_concat.push_back(std::move(expanded_indices));
457 values_to_concat.push_back(output_values);
458 shapes_to_concat.push_back(std::move(expanded_tensor_shape));
459 }
460
461 int rank = -1;
462 for (int i = 0; i < N; ++i) {
463 if (rank < 0) rank = shapes_to_concat[i].dims();
464 OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
465 errors::InvalidArgument(
466 "Inconsistent rank across SparseTensors: rank prior to "
467 "SparseTensor[",
468 i, "] was: ", rank, " but rank of SparseTensor[", i,
469 "] is: ", shapes_to_concat[i].dims()));
470 }
471
472 // SparseTensor::Concat requires consistent shape for all but the
473 // primary order dimension (dimension 0 in this case). So we get
474 // the maximum value across all the input SparseTensors for each
475 // dimension and use that.
476 TensorShape preconcat_shape(shapes_to_concat[0]);
477 for (int i = 0; i < N; ++i) {
478 for (int d = 0; d < rank; ++d) {
479 preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
480 shapes_to_concat[i].dim_size(d)));
481 }
482 }
483
484 // Dimension 0 is the primary dimension.
485 gtl::InlinedVector<int64_t, 8> std_order(rank);
486 std::iota(std_order.begin(), std_order.end(), 0);
487
488 std::vector<SparseTensor> tensors_to_concat;
489 tensors_to_concat.reserve(N);
490 for (int i = 0; i < N; ++i) {
491 SparseTensor tensor;
492 OP_REQUIRES_OK(context,
493 SparseTensor::Create(std::move(indices_to_concat[i]),
494 std::move(values_to_concat[i]),
495 preconcat_shape, std_order, &tensor));
496 tensors_to_concat.push_back(std::move(tensor));
497 }
498
499 auto output = SparseTensor::Concat<T>(tensors_to_concat);
500 Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
501
502 std::copy_n(output.shape().data(), output.dims(),
503 final_output_shape.vec<int64_t>().data());
504
505 context->set_output(0, output.indices());
506 context->set_output(1, output.values());
507 context->set_output(2, final_output_shape);
508 }
509};
510
511#define REGISTER_KERNELS(type) \
512 REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \
513 .Device(DEVICE_CPU) \
514 .TypeConstraint<type>("dtype"), \
515 TakeManySparseFromTensorsMapOp<type>)
516
517TF_CALL_ALL_TYPES(REGISTER_KERNELS);
518#undef REGISTER_KERNELS
519
520} // namespace tensorflow
521