1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstddef>
17#include <functional>
18#include <map>
19#include <mutex>
20#include <numeric>
21#include <unordered_map>
22#include <vector>
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/resource_mgr.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/lib/gtl/optional.h"
29#include "tensorflow/core/lib/strings/strcat.h"
30#include "tensorflow/core/platform/env.h"
31#include "tensorflow/core/platform/mutex.h"
32#include "tensorflow/core/platform/thread_annotations.h"
33
34namespace tensorflow {
35namespace {
36
37// Partial Ordering Comparator for Tensor keys containing scalar int64's
38struct KeyTensorLess {
39 bool operator()(const Tensor& lhs, const Tensor& rhs) const {
40 return std::less<int64_t>{}(lhs.scalar<int64_t>()(),
41 rhs.scalar<int64_t>()());
42 }
43};
44
45// Key Equality operator for Tensor keys containing scalar int64's
46struct KeyTensorEqual {
47 bool operator()(const Tensor& lhs, const Tensor& rhs) const {
48 return std::equal_to<int64_t>{}(lhs.scalar<int64_t>()(),
49 rhs.scalar<int64_t>()());
50 }
51};
52
53// Hash for Tensor keys containing scalar int64's
54struct KeyTensorHash {
55 std::size_t operator()(const Tensor& key) const {
56 return std::hash<int64_t>{}(key.scalar<int64_t>()());
57 }
58};
59
60// Primary template.
61template <bool Ordered, typename Data>
62struct MapTraits;
63
64// Partial specialization for ordered.
65template <typename Data>
66struct MapTraits<true, Data> {
67 using KeyType = Tensor;
68 using DataType = Data;
69 using MapType = std::map<KeyType, Data, KeyTensorLess>;
70};
71
72// Partial specialization for unordered.
73template <typename Data>
74struct MapTraits<false, Data> {
75 using KeyType = Tensor;
76 using DataType = Data;
77 using MapType =
78 std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>;
79};
80
81// Wrapper around map/unordered_map.
82template <bool Ordered>
83class StagingMap : public ResourceBase {
84 public:
85 // Public typedefs
86 using Tuple = std::vector<Tensor>;
87 using OptionalTensor = gtl::optional<Tensor>;
88 using OptionalTuple = std::vector<OptionalTensor>;
89
90 using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType;
91 using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType;
92
93 using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType;
94
95 private:
96 // Private variables
97 DataTypeVector dtypes_ TF_GUARDED_BY(mu_);
98 std::size_t capacity_ TF_GUARDED_BY(mu_);
99 std::size_t memory_limit_ TF_GUARDED_BY(mu_);
100 std::size_t current_bytes_ TF_GUARDED_BY(mu_);
101 tensorflow::mutex mu_;
102 tensorflow::condition_variable not_empty_;
103 tensorflow::condition_variable full_;
104 IncompleteType incomplete_ TF_GUARDED_BY(mu_);
105 MapType map_ TF_GUARDED_BY(mu_);
106
107 private:
108 // private methods
109
110 // If map is configured for bounded capacity, notify
111 // waiting inserters that space is now available
112 void notify_inserters_if_bounded() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
113 if (has_capacity() || has_memory_limit()) {
114 // Notify all inserters. The removal of an element
115 // may make memory available for many inserters
116 // to insert new elements
117 full_.notify_all();
118 }
119 }
120
121 // Notify all removers waiting to extract values
122 // that data is now available
123 void notify_removers() {
124 // Notify all removers. This is because they are
125 // waiting for specific keys to appear in the map
126 // so we don't know which one to wake up.
127 not_empty_.notify_all();
128 }
129
130 bool has_capacity() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
131 return capacity_ > 0;
132 }
133
134 bool has_memory_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
135 return memory_limit_ > 0;
136 }
137
138 bool would_exceed_memory_limit(std::size_t bytes) const
139 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
140 return has_memory_limit() && bytes + current_bytes_ > memory_limit_;
141 }
142
143 bool is_capacity_full() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
144 return has_capacity() && map_.size() >= capacity_;
145 }
146
147 // Get number of bytes in the tuple
148 std::size_t get_tuple_bytes(const Tuple& tuple) {
149 return std::accumulate(tuple.begin(), tuple.end(),
150 static_cast<std::size_t>(0),
151 [](const std::size_t& lhs, const Tensor& rhs) {
152 return lhs + rhs.TotalBytes();
153 });
154 }
155
156 // Get number of bytes in the incomplete tuple
157 std::size_t get_tuple_bytes(const OptionalTuple& tuple) {
158 return std::accumulate(
159 tuple.begin(), tuple.end(), static_cast<std::size_t>(0),
160 [](const std::size_t& lhs, const OptionalTensor& rhs) {
161 return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0;
162 });
163 }
164
165 // Check that the index is within bounds
166 Status check_index(const Tensor& key, std::size_t index)
167 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
168 if (index >= dtypes_.size()) {
169 return Status(errors::InvalidArgument(
170 "Index '", index, "' for key '", key.scalar<int64_t>()(),
171 "' was out of bounds '", dtypes_.size(), "'."));
172 }
173
174 return OkStatus();
175 }
176
177 Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key,
178 const Tensor& indices, Tuple* output,
179 bool copy = false)
180 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
181 auto findices = indices.flat<int>();
182
183 // Return values at specified indices
184 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
185 std::size_t index = findices(i);
186
187 TF_RETURN_IF_ERROR(check_index(key, index));
188
189 // Insist on a value present at the specified index
190 if (!(*map_tuple)[index].has_value()) {
191 return Status(errors::InvalidArgument(
192 "Tensor at index '", index, "' for key '", key.scalar<int64_t>()(),
193 "' has already been removed."));
194 }
195
196 // Copy the contained tensor and
197 // remove from the OptionalTuple
198 output->push_back((*map_tuple)[index].value());
199
200 // Clear out the entry if we're not copying (moving)
201 if (!copy) {
202 (*map_tuple)[index].reset();
203 }
204 }
205
206 return OkStatus();
207 }
208
209 // Check that the optional value at the specified index
210 // is uninitialized
211 Status check_index_uninitialized(const Tensor& key, std::size_t index,
212 const OptionalTuple& tuple)
213 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
214 if (tuple[index].has_value()) {
215 return errors::InvalidArgument("The tensor for index '", index,
216 "' for key '", key.scalar<int64_t>()(),
217 "' was already initialized '",
218 dtypes_.size(), "'.");
219 }
220
221 return OkStatus();
222 }
223
224 // Check that the indices are strictly ordered
225 Status check_index_ordering(const Tensor& indices) {
226 if (indices.NumElements() == 0) {
227 return errors::InvalidArgument("Indices are empty");
228 }
229
230 auto findices = indices.flat<int>();
231
232 for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) {
233 if (findices(i) < findices(i + 1)) {
234 continue;
235 }
236
237 return errors::InvalidArgument("Indices are not strictly ordered");
238 }
239
240 return OkStatus();
241 }
242
243 // Check bytes are within memory limits memory limits
244 Status check_memory_limit(std::size_t bytes)
245 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
246 if (has_memory_limit() && bytes > memory_limit_) {
247 return errors::ResourceExhausted(
248 "Attempted to insert tensors with combined size of '", bytes,
249 "' bytes into Staging Area with a memory limit of '", memory_limit_,
250 "'.");
251 }
252
253 return OkStatus();
254 }
255
256 // Insert incomplete data into the Barrier
257 Status put_incomplete(const KeyType& key, const Tensor& indices,
258 OptionalTuple* tuple, tensorflow::mutex_lock* lock)
259 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
260 auto findices = indices.flat<int>();
261
262 // Search for the key in our incomplete set
263 auto it = incomplete_.find(key);
264
265 // Check that the tuple fits within the memory limit
266 std::size_t tuple_bytes = get_tuple_bytes(*tuple);
267 TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
268
269 // Wait until we don't exceed the memory limit
270 while (would_exceed_memory_limit(tuple_bytes)) {
271 full_.wait(*lock);
272 }
273
274 // This key isn't present in the incomplete set
275 // Create OptionalTuple and insert
276 if (it == incomplete_.end()) {
277 OptionalTuple empty(dtypes_.size());
278
279 // Initialize empty tuple with given dta
280 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
281 std::size_t index = findices(i);
282 TF_RETURN_IF_ERROR(check_index(key, index));
283
284 // Assign tuple at this index
285 empty[index] = std::move((*tuple)[i]);
286 }
287
288 // Insert into incomplete map
289 incomplete_.insert({key, std::move(empty)});
290
291 // Increment size
292 current_bytes_ += tuple_bytes;
293 }
294 // Found an entry in the incomplete index
295 // Update with given data and insert complete entries
296 // into the main map
297 else {
298 // Reference existing incomplete tuple
299 OptionalTuple& present = it->second;
300
301 // Assign given data
302 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
303 std::size_t index = findices(i);
304 TF_RETURN_IF_ERROR(check_index(key, index));
305 TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present));
306
307 // Assign tuple at this index
308 present[index] = std::move((*tuple)[i]);
309 }
310
311 // Increment size
312 current_bytes_ += tuple_bytes;
313
314 // Do we have values at all tuple elements?
315 bool complete =
316 std::all_of(present.begin(), present.end(),
317 [](const OptionalTensor& v) { return v.has_value(); });
318
319 // If so, put the tuple in the actual map
320 if (complete) {
321 OptionalTuple insert_tuple = std::move(it->second);
322
323 // Remove from incomplete
324 incomplete_.erase(it);
325
326 TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple));
327 }
328 }
329
330 return OkStatus();
331 }
332
333 // Does the insertion into the actual staging area
334 Status put_complete(const KeyType& key, OptionalTuple* tuple)
335 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
336 // Insert key and tuples into the map
337 map_.insert({key, std::move(*tuple)});
338
339 notify_removers();
340
341 return OkStatus();
342 }
343
344 public:
345 // public methods
346 explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity,
347 std::size_t memory_limit)
348 : dtypes_(dtypes),
349 capacity_(capacity),
350 memory_limit_(memory_limit),
351 current_bytes_(0) {}
352
353 Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) {
354 tensorflow::mutex_lock lock(mu_);
355
356 // Sanity check the indices
357 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
358
359 // Handle incomplete inserts
360 if (indices->NumElements() != dtypes_.size()) {
361 return put_incomplete(*key, *indices, tuple, &lock);
362 }
363
364 std::size_t tuple_bytes = get_tuple_bytes(*tuple);
365 // Check that tuple_bytes fits within the memory limit
366 TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
367
368 // Wait until there's space for insertion.
369 while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) {
370 full_.wait(lock);
371 }
372
373 // Do the put operation
374 TF_RETURN_IF_ERROR(put_complete(*key, tuple));
375
376 // Update the current size
377 current_bytes_ += tuple_bytes;
378
379 return OkStatus();
380 }
381
382 Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) {
383 tensorflow::mutex_lock lock(mu_);
384
385 // Sanity check the indices
386 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
387
388 typename MapType::iterator it;
389
390 // Wait until the element with the requested key is present
391 while ((it = map_.find(*key)) == map_.end()) {
392 not_empty_.wait(lock);
393 }
394
395 TF_RETURN_IF_ERROR(
396 copy_or_move_tensors(&it->second, *key, *indices, tuple, true));
397
398 // Update bytes in the Staging Area
399 current_bytes_ -= get_tuple_bytes(*tuple);
400
401 return OkStatus();
402 }
403
404 Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) {
405 tensorflow::mutex_lock lock(mu_);
406
407 // Sanity check the indices
408 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
409
410 typename MapType::iterator it;
411
412 // Wait until the element with the requested key is present
413 while ((it = map_.find(*key)) == map_.end()) {
414 not_empty_.wait(lock);
415 }
416
417 TF_RETURN_IF_ERROR(
418 copy_or_move_tensors(&it->second, *key, *indices, tuple));
419
420 // Remove entry if all the values have been consumed
421 if (!std::any_of(
422 it->second.begin(), it->second.end(),
423 [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
424 map_.erase(it);
425 }
426
427 // Update bytes in the Staging Area
428 current_bytes_ -= get_tuple_bytes(*tuple);
429
430 notify_inserters_if_bounded();
431
432 return OkStatus();
433 }
434
435 Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) {
436 tensorflow::mutex_lock lock(mu_);
437
438 // Sanity check the indices
439 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
440
441 // Wait until map is not empty
442 while (this->map_.empty()) {
443 not_empty_.wait(lock);
444 }
445
446 // Move from the first element and erase it
447
448 auto it = map_.begin();
449
450 TF_RETURN_IF_ERROR(
451 copy_or_move_tensors(&it->second, *key, *indices, tuple));
452
453 *key = it->first;
454
455 // Remove entry if all the values have been consumed
456 if (!std::any_of(
457 it->second.begin(), it->second.end(),
458 [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
459 map_.erase(it);
460 }
461
462 // Update bytes in the Staging Area
463 current_bytes_ -= get_tuple_bytes(*tuple);
464
465 notify_inserters_if_bounded();
466
467 return OkStatus();
468 }
469
470 Status clear() {
471 tensorflow::mutex_lock lock(mu_);
472 map_.clear();
473 incomplete_.clear();
474 current_bytes_ = 0;
475
476 notify_inserters_if_bounded();
477
478 return OkStatus();
479 }
480
481 std::size_t incomplete_size() {
482 tensorflow::mutex_lock lock(mu_);
483 return incomplete_.size();
484 }
485
486 std::size_t size() {
487 tensorflow::mutex_lock lock(mu_);
488 return map_.size();
489 }
490
491 string DebugString() const override { return "StagingMap"; }
492};
493
494template <bool Ordered>
495Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef,
496 StagingMap<Ordered>** map) {
497 auto rm = ctx->resource_manager();
498 ContainerInfo cinfo;
499
500 // Lambda for creating the Staging Area
501 auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status {
502 DataTypeVector dtypes;
503 int64_t capacity;
504 int64_t memory_limit;
505 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes", &dtypes));
506 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
507 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
508 *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit);
509 return OkStatus();
510 };
511
512 TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
513 TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>(
514 cinfo.container(), cinfo.name(), map, create_fn));
515 return OkStatus();
516}
517
518template <bool Ordered>
519class MapStageOp : public OpKernel {
520 public:
521 explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
522
523 void Compute(OpKernelContext* ctx) override {
524 StagingMap<Ordered>* map = nullptr;
525 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
526 core::ScopedUnref scope(map);
527 typename StagingMap<Ordered>::OptionalTuple tuple;
528
529 const Tensor* key_tensor;
530 const Tensor* indices_tensor;
531 OpInputList values_tensor;
532
533 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
534 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
535 OP_REQUIRES_OK(ctx, ctx->input_list("values", &values_tensor));
536 OP_REQUIRES(ctx, key_tensor->NumElements() > 0,
537 errors::InvalidArgument("key must not be empty"));
538
539 OP_REQUIRES(ctx, key_tensor->NumElements() == 1,
540 errors::InvalidArgument(
541 "key must be an int64 scalar, got tensor with shape: ",
542 key_tensor->shape()));
543
544 // Create copy for insertion into Staging Area
545 Tensor key(*key_tensor);
546
547 // Create the tuple to store
548 for (std::size_t i = 0; i < values_tensor.size(); ++i) {
549 tuple.push_back(values_tensor[i]);
550 }
551
552 // Store the tuple in the map
553 OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple));
554 }
555};
556
557REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp<false>);
558REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU),
559 MapStageOp<true>);
560
561REGISTER_KERNEL_BUILDER(Name("MapStage")
562 .HostMemory("key")
563 .HostMemory("indices")
564 .Device(DEVICE_DEFAULT),
565 MapStageOp<false>);
566REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
567 .HostMemory("key")
568 .HostMemory("indices")
569 .Device(DEVICE_DEFAULT),
570 MapStageOp<true>);
571
572template <bool Ordered>
573class MapUnstageOp : public OpKernel {
574 public:
575 explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
576
577 // Using this op in such a way that it blocks forever
578 // is an error. As such cancellation is not handled.
579 void Compute(OpKernelContext* ctx) override {
580 StagingMap<Ordered>* map = nullptr;
581 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
582 core::ScopedUnref scope(map);
583 typename StagingMap<Ordered>::Tuple tuple;
584
585 const Tensor* key_tensor;
586 const Tensor* indices_tensor;
587
588 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
589 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
590 OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
591
592 OP_REQUIRES(
593 ctx, tuple.size() == indices_tensor->NumElements(),
594 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
595 " vs. ", indices_tensor->NumElements()));
596
597 for (std::size_t i = 0; i < tuple.size(); ++i) {
598 ctx->set_output(i, tuple[i]);
599 }
600 }
601};
602
603REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU),
604 MapUnstageOp<false>);
605REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
606 MapUnstageOp<true>);
607
608REGISTER_KERNEL_BUILDER(Name("MapUnstage")
609 .HostMemory("key")
610 .HostMemory("indices")
611 .Device(DEVICE_DEFAULT),
612 MapUnstageOp<false>);
613REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
614 .HostMemory("key")
615 .HostMemory("indices")
616 .Device(DEVICE_DEFAULT),
617 MapUnstageOp<true>);
618
619template <bool Ordered>
620class MapPeekOp : public OpKernel {
621 public:
622 explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
623
624 // Using this op in such a way that it blocks forever
625 // is an error. As such cancellation is not handled.
626 void Compute(OpKernelContext* ctx) override {
627 StagingMap<Ordered>* map = nullptr;
628 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
629 core::ScopedUnref scope(map);
630 typename StagingMap<Ordered>::Tuple tuple;
631
632 const Tensor* key_tensor;
633 const Tensor* indices_tensor;
634
635 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
636 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
637 OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
638
639 OP_REQUIRES(
640 ctx, tuple.size() == indices_tensor->NumElements(),
641 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
642 " vs. ", indices_tensor->NumElements()));
643
644 for (std::size_t i = 0; i < tuple.size(); ++i) {
645 ctx->set_output(i, tuple[i]);
646 }
647 }
648};
649
650REGISTER_KERNEL_BUILDER(Name("MapPeek").Device(DEVICE_CPU), MapPeekOp<false>);
651REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
652 MapPeekOp<true>);
653
654REGISTER_KERNEL_BUILDER(
655 Name("MapPeek").HostMemory("key").HostMemory("indices").Device(
656 DEVICE_DEFAULT),
657 MapPeekOp<false>);
658REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
659 .HostMemory("key")
660 .HostMemory("indices")
661 .Device(DEVICE_DEFAULT),
662 MapPeekOp<true>);
663
664template <bool Ordered>
665class MapUnstageNoKeyOp : public OpKernel {
666 public:
667 explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
668
669 // Using this op in such a way that it blocks forever
670 // is an error. As such cancellation is not handled.
671 void Compute(OpKernelContext* ctx) override {
672 StagingMap<Ordered>* map = nullptr;
673 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
674 core::ScopedUnref scope(map);
675
676 // Pop a random (key, value) off the map
677 typename StagingMap<Ordered>::KeyType key;
678 typename StagingMap<Ordered>::Tuple tuple;
679
680 const Tensor* indices_tensor;
681
682 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
683 OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
684
685 // Allocate a key tensor and assign the key as the first output
686 ctx->set_output(0, key);
687
688 // Set the rest of the outputs to the tuple Tensors
689 OP_REQUIRES(
690 ctx, tuple.size() == indices_tensor->NumElements(),
691 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
692 " vs. ", indices_tensor->NumElements()));
693
694 for (std::size_t i = 0; i < tuple.size(); ++i) {
695 ctx->set_output(i + 1, tuple[i]);
696 }
697 }
698};
699
700REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").Device(DEVICE_CPU),
701 MapUnstageNoKeyOp<false>);
702REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
703 MapUnstageNoKeyOp<true>);
704
705REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
706 .HostMemory("key")
707 .HostMemory("indices")
708 .Device(DEVICE_DEFAULT),
709 MapUnstageNoKeyOp<false>);
710REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
711 .HostMemory("key")
712 .HostMemory("indices")
713 .Device(DEVICE_DEFAULT),
714 MapUnstageNoKeyOp<true>);
715
716template <bool Ordered>
717class MapSizeOp : public OpKernel {
718 public:
719 explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
720
721 void Compute(OpKernelContext* ctx) override {
722 StagingMap<Ordered>* map = nullptr;
723 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
724 core::ScopedUnref scope(map);
725
726 // Allocate size output tensor
727 Tensor* size = nullptr;
728 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
729
730 // Set it to the actual size
731 size->scalar<int32>().setConstant(map->size());
732 }
733};
734
735REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp<false>);
736REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU),
737 MapSizeOp<true>);
738
739REGISTER_KERNEL_BUILDER(
740 Name("MapSize").Device(DEVICE_DEFAULT).HostMemory("size"),
741 MapSizeOp<false>);
742REGISTER_KERNEL_BUILDER(
743 Name("OrderedMapSize").Device(DEVICE_DEFAULT).HostMemory("size"),
744 MapSizeOp<true>);
745
746template <bool Ordered>
747class MapIncompleteSizeOp : public OpKernel {
748 public:
749 explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
750
751 void Compute(OpKernelContext* ctx) override {
752 StagingMap<Ordered>* map = nullptr;
753 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
754 core::ScopedUnref scope(map);
755
756 // Allocate size output tensor
757 Tensor* size = nullptr;
758 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
759
760 // Set it to the actual size
761 size->scalar<int32>().setConstant(map->incomplete_size());
762 }
763};
764
765REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_CPU),
766 MapIncompleteSizeOp<false>);
767REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU),
768 MapIncompleteSizeOp<true>);
769
770REGISTER_KERNEL_BUILDER(
771 Name("MapIncompleteSize").Device(DEVICE_DEFAULT).HostMemory("size"),
772 MapIncompleteSizeOp<false>);
773REGISTER_KERNEL_BUILDER(
774 Name("OrderedMapIncompleteSize").Device(DEVICE_DEFAULT).HostMemory("size"),
775 MapIncompleteSizeOp<true>);
776
777template <bool Ordered>
778class MapClearOp : public OpKernel {
779 public:
780 explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
781
782 void Compute(OpKernelContext* ctx) override {
783 StagingMap<Ordered>* map = nullptr;
784 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
785 core::ScopedUnref scope(map);
786
787 OP_REQUIRES_OK(ctx, map->clear());
788 }
789};
790
791REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp<false>);
792REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU),
793 MapClearOp<true>);
794
795REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_DEFAULT),
796 MapClearOp<false>);
797REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_DEFAULT),
798 MapClearOp<true>);
799
800} // namespace
801} // namespace tensorflow
802