1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstddef> |
17 | #include <functional> |
18 | #include <map> |
19 | #include <mutex> |
20 | #include <numeric> |
21 | #include <unordered_map> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/resource_mgr.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/lib/gtl/optional.h" |
29 | #include "tensorflow/core/lib/strings/strcat.h" |
30 | #include "tensorflow/core/platform/env.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/thread_annotations.h" |
33 | |
34 | namespace tensorflow { |
35 | namespace { |
36 | |
37 | // Partial Ordering Comparator for Tensor keys containing scalar int64's |
38 | struct KeyTensorLess { |
39 | bool operator()(const Tensor& lhs, const Tensor& rhs) const { |
40 | return std::less<int64_t>{}(lhs.scalar<int64_t>()(), |
41 | rhs.scalar<int64_t>()()); |
42 | } |
43 | }; |
44 | |
45 | // Key Equality operator for Tensor keys containing scalar int64's |
46 | struct KeyTensorEqual { |
47 | bool operator()(const Tensor& lhs, const Tensor& rhs) const { |
48 | return std::equal_to<int64_t>{}(lhs.scalar<int64_t>()(), |
49 | rhs.scalar<int64_t>()()); |
50 | } |
51 | }; |
52 | |
53 | // Hash for Tensor keys containing scalar int64's |
54 | struct KeyTensorHash { |
55 | std::size_t operator()(const Tensor& key) const { |
56 | return std::hash<int64_t>{}(key.scalar<int64_t>()()); |
57 | } |
58 | }; |
59 | |
60 | // Primary template. |
61 | template <bool Ordered, typename Data> |
62 | struct MapTraits; |
63 | |
64 | // Partial specialization for ordered. |
65 | template <typename Data> |
66 | struct MapTraits<true, Data> { |
67 | using KeyType = Tensor; |
68 | using DataType = Data; |
69 | using MapType = std::map<KeyType, Data, KeyTensorLess>; |
70 | }; |
71 | |
72 | // Partial specialization for unordered. |
73 | template <typename Data> |
74 | struct MapTraits<false, Data> { |
75 | using KeyType = Tensor; |
76 | using DataType = Data; |
77 | using MapType = |
78 | std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>; |
79 | }; |
80 | |
81 | // Wrapper around map/unordered_map. |
82 | template <bool Ordered> |
83 | class StagingMap : public ResourceBase { |
84 | public: |
85 | // Public typedefs |
86 | using Tuple = std::vector<Tensor>; |
87 | using OptionalTensor = gtl::optional<Tensor>; |
88 | using OptionalTuple = std::vector<OptionalTensor>; |
89 | |
90 | using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType; |
91 | using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType; |
92 | |
93 | using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType; |
94 | |
95 | private: |
96 | // Private variables |
97 | DataTypeVector dtypes_ TF_GUARDED_BY(mu_); |
98 | std::size_t capacity_ TF_GUARDED_BY(mu_); |
99 | std::size_t memory_limit_ TF_GUARDED_BY(mu_); |
100 | std::size_t current_bytes_ TF_GUARDED_BY(mu_); |
101 | tensorflow::mutex mu_; |
102 | tensorflow::condition_variable not_empty_; |
103 | tensorflow::condition_variable full_; |
104 | IncompleteType incomplete_ TF_GUARDED_BY(mu_); |
105 | MapType map_ TF_GUARDED_BY(mu_); |
106 | |
107 | private: |
108 | // private methods |
109 | |
110 | // If map is configured for bounded capacity, notify |
111 | // waiting inserters that space is now available |
112 | void notify_inserters_if_bounded() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
113 | if (has_capacity() || has_memory_limit()) { |
114 | // Notify all inserters. The removal of an element |
115 | // may make memory available for many inserters |
116 | // to insert new elements |
117 | full_.notify_all(); |
118 | } |
119 | } |
120 | |
121 | // Notify all removers waiting to extract values |
122 | // that data is now available |
123 | void notify_removers() { |
124 | // Notify all removers. This is because they are |
125 | // waiting for specific keys to appear in the map |
126 | // so we don't know which one to wake up. |
127 | not_empty_.notify_all(); |
128 | } |
129 | |
130 | bool has_capacity() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
131 | return capacity_ > 0; |
132 | } |
133 | |
134 | bool has_memory_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
135 | return memory_limit_ > 0; |
136 | } |
137 | |
138 | bool would_exceed_memory_limit(std::size_t bytes) const |
139 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
140 | return has_memory_limit() && bytes + current_bytes_ > memory_limit_; |
141 | } |
142 | |
143 | bool is_capacity_full() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
144 | return has_capacity() && map_.size() >= capacity_; |
145 | } |
146 | |
147 | // Get number of bytes in the tuple |
148 | std::size_t get_tuple_bytes(const Tuple& tuple) { |
149 | return std::accumulate(tuple.begin(), tuple.end(), |
150 | static_cast<std::size_t>(0), |
151 | [](const std::size_t& lhs, const Tensor& rhs) { |
152 | return lhs + rhs.TotalBytes(); |
153 | }); |
154 | } |
155 | |
156 | // Get number of bytes in the incomplete tuple |
157 | std::size_t get_tuple_bytes(const OptionalTuple& tuple) { |
158 | return std::accumulate( |
159 | tuple.begin(), tuple.end(), static_cast<std::size_t>(0), |
160 | [](const std::size_t& lhs, const OptionalTensor& rhs) { |
161 | return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0; |
162 | }); |
163 | } |
164 | |
165 | // Check that the index is within bounds |
166 | Status check_index(const Tensor& key, std::size_t index) |
167 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
168 | if (index >= dtypes_.size()) { |
169 | return Status(errors::InvalidArgument( |
170 | "Index '" , index, "' for key '" , key.scalar<int64_t>()(), |
171 | "' was out of bounds '" , dtypes_.size(), "'." )); |
172 | } |
173 | |
174 | return OkStatus(); |
175 | } |
176 | |
177 | Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key, |
178 | const Tensor& indices, Tuple* output, |
179 | bool copy = false) |
180 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
181 | auto findices = indices.flat<int>(); |
182 | |
183 | // Return values at specified indices |
184 | for (std::size_t i = 0; i < findices.dimension(0); ++i) { |
185 | std::size_t index = findices(i); |
186 | |
187 | TF_RETURN_IF_ERROR(check_index(key, index)); |
188 | |
189 | // Insist on a value present at the specified index |
190 | if (!(*map_tuple)[index].has_value()) { |
191 | return Status(errors::InvalidArgument( |
192 | "Tensor at index '" , index, "' for key '" , key.scalar<int64_t>()(), |
193 | "' has already been removed." )); |
194 | } |
195 | |
196 | // Copy the contained tensor and |
197 | // remove from the OptionalTuple |
198 | output->push_back((*map_tuple)[index].value()); |
199 | |
200 | // Clear out the entry if we're not copying (moving) |
201 | if (!copy) { |
202 | (*map_tuple)[index].reset(); |
203 | } |
204 | } |
205 | |
206 | return OkStatus(); |
207 | } |
208 | |
209 | // Check that the optional value at the specified index |
210 | // is uninitialized |
211 | Status check_index_uninitialized(const Tensor& key, std::size_t index, |
212 | const OptionalTuple& tuple) |
213 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
214 | if (tuple[index].has_value()) { |
215 | return errors::InvalidArgument("The tensor for index '" , index, |
216 | "' for key '" , key.scalar<int64_t>()(), |
217 | "' was already initialized '" , |
218 | dtypes_.size(), "'." ); |
219 | } |
220 | |
221 | return OkStatus(); |
222 | } |
223 | |
224 | // Check that the indices are strictly ordered |
225 | Status check_index_ordering(const Tensor& indices) { |
226 | if (indices.NumElements() == 0) { |
227 | return errors::InvalidArgument("Indices are empty" ); |
228 | } |
229 | |
230 | auto findices = indices.flat<int>(); |
231 | |
232 | for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) { |
233 | if (findices(i) < findices(i + 1)) { |
234 | continue; |
235 | } |
236 | |
237 | return errors::InvalidArgument("Indices are not strictly ordered" ); |
238 | } |
239 | |
240 | return OkStatus(); |
241 | } |
242 | |
243 | // Check bytes are within memory limits memory limits |
244 | Status check_memory_limit(std::size_t bytes) |
245 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
246 | if (has_memory_limit() && bytes > memory_limit_) { |
247 | return errors::ResourceExhausted( |
248 | "Attempted to insert tensors with combined size of '" , bytes, |
249 | "' bytes into Staging Area with a memory limit of '" , memory_limit_, |
250 | "'." ); |
251 | } |
252 | |
253 | return OkStatus(); |
254 | } |
255 | |
256 | // Insert incomplete data into the Barrier |
257 | Status put_incomplete(const KeyType& key, const Tensor& indices, |
258 | OptionalTuple* tuple, tensorflow::mutex_lock* lock) |
259 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
260 | auto findices = indices.flat<int>(); |
261 | |
262 | // Search for the key in our incomplete set |
263 | auto it = incomplete_.find(key); |
264 | |
265 | // Check that the tuple fits within the memory limit |
266 | std::size_t tuple_bytes = get_tuple_bytes(*tuple); |
267 | TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes)); |
268 | |
269 | // Wait until we don't exceed the memory limit |
270 | while (would_exceed_memory_limit(tuple_bytes)) { |
271 | full_.wait(*lock); |
272 | } |
273 | |
274 | // This key isn't present in the incomplete set |
275 | // Create OptionalTuple and insert |
276 | if (it == incomplete_.end()) { |
277 | OptionalTuple empty(dtypes_.size()); |
278 | |
279 | // Initialize empty tuple with given dta |
280 | for (std::size_t i = 0; i < findices.dimension(0); ++i) { |
281 | std::size_t index = findices(i); |
282 | TF_RETURN_IF_ERROR(check_index(key, index)); |
283 | |
284 | // Assign tuple at this index |
285 | empty[index] = std::move((*tuple)[i]); |
286 | } |
287 | |
288 | // Insert into incomplete map |
289 | incomplete_.insert({key, std::move(empty)}); |
290 | |
291 | // Increment size |
292 | current_bytes_ += tuple_bytes; |
293 | } |
294 | // Found an entry in the incomplete index |
295 | // Update with given data and insert complete entries |
296 | // into the main map |
297 | else { |
298 | // Reference existing incomplete tuple |
299 | OptionalTuple& present = it->second; |
300 | |
301 | // Assign given data |
302 | for (std::size_t i = 0; i < findices.dimension(0); ++i) { |
303 | std::size_t index = findices(i); |
304 | TF_RETURN_IF_ERROR(check_index(key, index)); |
305 | TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present)); |
306 | |
307 | // Assign tuple at this index |
308 | present[index] = std::move((*tuple)[i]); |
309 | } |
310 | |
311 | // Increment size |
312 | current_bytes_ += tuple_bytes; |
313 | |
314 | // Do we have values at all tuple elements? |
315 | bool complete = |
316 | std::all_of(present.begin(), present.end(), |
317 | [](const OptionalTensor& v) { return v.has_value(); }); |
318 | |
319 | // If so, put the tuple in the actual map |
320 | if (complete) { |
321 | OptionalTuple insert_tuple = std::move(it->second); |
322 | |
323 | // Remove from incomplete |
324 | incomplete_.erase(it); |
325 | |
326 | TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple)); |
327 | } |
328 | } |
329 | |
330 | return OkStatus(); |
331 | } |
332 | |
333 | // Does the insertion into the actual staging area |
334 | Status put_complete(const KeyType& key, OptionalTuple* tuple) |
335 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
336 | // Insert key and tuples into the map |
337 | map_.insert({key, std::move(*tuple)}); |
338 | |
339 | notify_removers(); |
340 | |
341 | return OkStatus(); |
342 | } |
343 | |
344 | public: |
345 | // public methods |
346 | explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity, |
347 | std::size_t memory_limit) |
348 | : dtypes_(dtypes), |
349 | capacity_(capacity), |
350 | memory_limit_(memory_limit), |
351 | current_bytes_(0) {} |
352 | |
353 | Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) { |
354 | tensorflow::mutex_lock lock(mu_); |
355 | |
356 | // Sanity check the indices |
357 | TF_RETURN_IF_ERROR(check_index_ordering(*indices)); |
358 | |
359 | // Handle incomplete inserts |
360 | if (indices->NumElements() != dtypes_.size()) { |
361 | return put_incomplete(*key, *indices, tuple, &lock); |
362 | } |
363 | |
364 | std::size_t tuple_bytes = get_tuple_bytes(*tuple); |
365 | // Check that tuple_bytes fits within the memory limit |
366 | TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes)); |
367 | |
368 | // Wait until there's space for insertion. |
369 | while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) { |
370 | full_.wait(lock); |
371 | } |
372 | |
373 | // Do the put operation |
374 | TF_RETURN_IF_ERROR(put_complete(*key, tuple)); |
375 | |
376 | // Update the current size |
377 | current_bytes_ += tuple_bytes; |
378 | |
379 | return OkStatus(); |
380 | } |
381 | |
382 | Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) { |
383 | tensorflow::mutex_lock lock(mu_); |
384 | |
385 | // Sanity check the indices |
386 | TF_RETURN_IF_ERROR(check_index_ordering(*indices)); |
387 | |
388 | typename MapType::iterator it; |
389 | |
390 | // Wait until the element with the requested key is present |
391 | while ((it = map_.find(*key)) == map_.end()) { |
392 | not_empty_.wait(lock); |
393 | } |
394 | |
395 | TF_RETURN_IF_ERROR( |
396 | copy_or_move_tensors(&it->second, *key, *indices, tuple, true)); |
397 | |
398 | // Update bytes in the Staging Area |
399 | current_bytes_ -= get_tuple_bytes(*tuple); |
400 | |
401 | return OkStatus(); |
402 | } |
403 | |
404 | Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) { |
405 | tensorflow::mutex_lock lock(mu_); |
406 | |
407 | // Sanity check the indices |
408 | TF_RETURN_IF_ERROR(check_index_ordering(*indices)); |
409 | |
410 | typename MapType::iterator it; |
411 | |
412 | // Wait until the element with the requested key is present |
413 | while ((it = map_.find(*key)) == map_.end()) { |
414 | not_empty_.wait(lock); |
415 | } |
416 | |
417 | TF_RETURN_IF_ERROR( |
418 | copy_or_move_tensors(&it->second, *key, *indices, tuple)); |
419 | |
420 | // Remove entry if all the values have been consumed |
421 | if (!std::any_of( |
422 | it->second.begin(), it->second.end(), |
423 | [](const OptionalTensor& tensor) { return tensor.has_value(); })) { |
424 | map_.erase(it); |
425 | } |
426 | |
427 | // Update bytes in the Staging Area |
428 | current_bytes_ -= get_tuple_bytes(*tuple); |
429 | |
430 | notify_inserters_if_bounded(); |
431 | |
432 | return OkStatus(); |
433 | } |
434 | |
435 | Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) { |
436 | tensorflow::mutex_lock lock(mu_); |
437 | |
438 | // Sanity check the indices |
439 | TF_RETURN_IF_ERROR(check_index_ordering(*indices)); |
440 | |
441 | // Wait until map is not empty |
442 | while (this->map_.empty()) { |
443 | not_empty_.wait(lock); |
444 | } |
445 | |
446 | // Move from the first element and erase it |
447 | |
448 | auto it = map_.begin(); |
449 | |
450 | TF_RETURN_IF_ERROR( |
451 | copy_or_move_tensors(&it->second, *key, *indices, tuple)); |
452 | |
453 | *key = it->first; |
454 | |
455 | // Remove entry if all the values have been consumed |
456 | if (!std::any_of( |
457 | it->second.begin(), it->second.end(), |
458 | [](const OptionalTensor& tensor) { return tensor.has_value(); })) { |
459 | map_.erase(it); |
460 | } |
461 | |
462 | // Update bytes in the Staging Area |
463 | current_bytes_ -= get_tuple_bytes(*tuple); |
464 | |
465 | notify_inserters_if_bounded(); |
466 | |
467 | return OkStatus(); |
468 | } |
469 | |
470 | Status clear() { |
471 | tensorflow::mutex_lock lock(mu_); |
472 | map_.clear(); |
473 | incomplete_.clear(); |
474 | current_bytes_ = 0; |
475 | |
476 | notify_inserters_if_bounded(); |
477 | |
478 | return OkStatus(); |
479 | } |
480 | |
481 | std::size_t incomplete_size() { |
482 | tensorflow::mutex_lock lock(mu_); |
483 | return incomplete_.size(); |
484 | } |
485 | |
486 | std::size_t size() { |
487 | tensorflow::mutex_lock lock(mu_); |
488 | return map_.size(); |
489 | } |
490 | |
491 | string DebugString() const override { return "StagingMap" ; } |
492 | }; |
493 | |
494 | template <bool Ordered> |
495 | Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef, |
496 | StagingMap<Ordered>** map) { |
497 | auto rm = ctx->resource_manager(); |
498 | ContainerInfo cinfo; |
499 | |
500 | // Lambda for creating the Staging Area |
501 | auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status { |
502 | DataTypeVector dtypes; |
503 | int64_t capacity; |
504 | int64_t memory_limit; |
505 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes" , &dtypes)); |
506 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity" , &capacity)); |
507 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit" , &memory_limit)); |
508 | *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit); |
509 | return OkStatus(); |
510 | }; |
511 | |
512 | TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */)); |
513 | TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>( |
514 | cinfo.container(), cinfo.name(), map, create_fn)); |
515 | return OkStatus(); |
516 | } |
517 | |
518 | template <bool Ordered> |
519 | class MapStageOp : public OpKernel { |
520 | public: |
521 | explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
522 | |
523 | void Compute(OpKernelContext* ctx) override { |
524 | StagingMap<Ordered>* map = nullptr; |
525 | OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map)); |
526 | core::ScopedUnref scope(map); |
527 | typename StagingMap<Ordered>::OptionalTuple tuple; |
528 | |
529 | const Tensor* key_tensor; |
530 | const Tensor* indices_tensor; |
531 | OpInputList values_tensor; |
532 | |
533 | OP_REQUIRES_OK(ctx, ctx->input("key" , &key_tensor)); |
534 | OP_REQUIRES_OK(ctx, ctx->input("indices" , &indices_tensor)); |
535 | OP_REQUIRES_OK(ctx, ctx->input_list("values" , &values_tensor)); |
536 | OP_REQUIRES(ctx, key_tensor->NumElements() > 0, |
537 | errors::InvalidArgument("key must not be empty" )); |
538 | |
539 | OP_REQUIRES(ctx, key_tensor->NumElements() == 1, |
540 | errors::InvalidArgument( |
541 | "key must be an int64 scalar, got tensor with shape: " , |
542 | key_tensor->shape())); |
543 | |
544 | // Create copy for insertion into Staging Area |
545 | Tensor key(*key_tensor); |
546 | |
547 | // Create the tuple to store |
548 | for (std::size_t i = 0; i < values_tensor.size(); ++i) { |
549 | tuple.push_back(values_tensor[i]); |
550 | } |
551 | |
552 | // Store the tuple in the map |
553 | OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple)); |
554 | } |
555 | }; |
556 | |
557 | REGISTER_KERNEL_BUILDER(Name("MapStage" ).Device(DEVICE_CPU), MapStageOp<false>); |
558 | REGISTER_KERNEL_BUILDER(Name("OrderedMapStage" ).Device(DEVICE_CPU), |
559 | MapStageOp<true>); |
560 | |
561 | REGISTER_KERNEL_BUILDER(Name("MapStage" ) |
562 | .HostMemory("key" ) |
563 | .HostMemory("indices" ) |
564 | .Device(DEVICE_DEFAULT), |
565 | MapStageOp<false>); |
566 | REGISTER_KERNEL_BUILDER(Name("OrderedMapStage" ) |
567 | .HostMemory("key" ) |
568 | .HostMemory("indices" ) |
569 | .Device(DEVICE_DEFAULT), |
570 | MapStageOp<true>); |
571 | |
572 | template <bool Ordered> |
573 | class MapUnstageOp : public OpKernel { |
574 | public: |
575 | explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
576 | |
577 | // Using this op in such a way that it blocks forever |
578 | // is an error. As such cancellation is not handled. |
579 | void Compute(OpKernelContext* ctx) override { |
580 | StagingMap<Ordered>* map = nullptr; |
581 | OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map)); |
582 | core::ScopedUnref scope(map); |
583 | typename StagingMap<Ordered>::Tuple tuple; |
584 | |
585 | const Tensor* key_tensor; |
586 | const Tensor* indices_tensor; |
587 | |
588 | OP_REQUIRES_OK(ctx, ctx->input("key" , &key_tensor)); |
589 | OP_REQUIRES_OK(ctx, ctx->input("indices" , &indices_tensor)); |
590 | OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple)); |
591 | |
592 | OP_REQUIRES( |
593 | ctx, tuple.size() == indices_tensor->NumElements(), |
594 | errors::InvalidArgument("output/indices size mismatch: " , tuple.size(), |
595 | " vs. " , indices_tensor->NumElements())); |
596 | |
597 | for (std::size_t i = 0; i < tuple.size(); ++i) { |
598 | ctx->set_output(i, tuple[i]); |
599 | } |
600 | } |
601 | }; |
602 | |
603 | REGISTER_KERNEL_BUILDER(Name("MapUnstage" ).Device(DEVICE_CPU), |
604 | MapUnstageOp<false>); |
605 | REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage" ).Device(DEVICE_CPU), |
606 | MapUnstageOp<true>); |
607 | |
608 | REGISTER_KERNEL_BUILDER(Name("MapUnstage" ) |
609 | .HostMemory("key" ) |
610 | .HostMemory("indices" ) |
611 | .Device(DEVICE_DEFAULT), |
612 | MapUnstageOp<false>); |
613 | REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage" ) |
614 | .HostMemory("key" ) |
615 | .HostMemory("indices" ) |
616 | .Device(DEVICE_DEFAULT), |
617 | MapUnstageOp<true>); |
618 | |
619 | template <bool Ordered> |
620 | class MapPeekOp : public OpKernel { |
621 | public: |
622 | explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
623 | |
624 | // Using this op in such a way that it blocks forever |
625 | // is an error. As such cancellation is not handled. |
626 | void Compute(OpKernelContext* ctx) override { |
627 | StagingMap<Ordered>* map = nullptr; |
628 | OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map)); |
629 | core::ScopedUnref scope(map); |
630 | typename StagingMap<Ordered>::Tuple tuple; |
631 | |
632 | const Tensor* key_tensor; |
633 | const Tensor* indices_tensor; |
634 | |
635 | OP_REQUIRES_OK(ctx, ctx->input("key" , &key_tensor)); |
636 | OP_REQUIRES_OK(ctx, ctx->input("indices" , &indices_tensor)); |
637 | OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple)); |
638 | |
639 | OP_REQUIRES( |
640 | ctx, tuple.size() == indices_tensor->NumElements(), |
641 | errors::InvalidArgument("output/indices size mismatch: " , tuple.size(), |
642 | " vs. " , indices_tensor->NumElements())); |
643 | |
644 | for (std::size_t i = 0; i < tuple.size(); ++i) { |
645 | ctx->set_output(i, tuple[i]); |
646 | } |
647 | } |
648 | }; |
649 | |
650 | REGISTER_KERNEL_BUILDER(Name("MapPeek" ).Device(DEVICE_CPU), MapPeekOp<false>); |
651 | REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek" ).Device(DEVICE_CPU), |
652 | MapPeekOp<true>); |
653 | |
654 | REGISTER_KERNEL_BUILDER( |
655 | Name("MapPeek" ).HostMemory("key" ).HostMemory("indices" ).Device( |
656 | DEVICE_DEFAULT), |
657 | MapPeekOp<false>); |
658 | REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek" ) |
659 | .HostMemory("key" ) |
660 | .HostMemory("indices" ) |
661 | .Device(DEVICE_DEFAULT), |
662 | MapPeekOp<true>); |
663 | |
664 | template <bool Ordered> |
665 | class MapUnstageNoKeyOp : public OpKernel { |
666 | public: |
667 | explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
668 | |
669 | // Using this op in such a way that it blocks forever |
670 | // is an error. As such cancellation is not handled. |
671 | void Compute(OpKernelContext* ctx) override { |
672 | StagingMap<Ordered>* map = nullptr; |
673 | OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map)); |
674 | core::ScopedUnref scope(map); |
675 | |
676 | // Pop a random (key, value) off the map |
677 | typename StagingMap<Ordered>::KeyType key; |
678 | typename StagingMap<Ordered>::Tuple tuple; |
679 | |
680 | const Tensor* indices_tensor; |
681 | |
682 | OP_REQUIRES_OK(ctx, ctx->input("indices" , &indices_tensor)); |
683 | OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple)); |
684 | |
685 | // Allocate a key tensor and assign the key as the first output |
686 | ctx->set_output(0, key); |
687 | |
688 | // Set the rest of the outputs to the tuple Tensors |
689 | OP_REQUIRES( |
690 | ctx, tuple.size() == indices_tensor->NumElements(), |
691 | errors::InvalidArgument("output/indices size mismatch: " , tuple.size(), |
692 | " vs. " , indices_tensor->NumElements())); |
693 | |
694 | for (std::size_t i = 0; i < tuple.size(); ++i) { |
695 | ctx->set_output(i + 1, tuple[i]); |
696 | } |
697 | } |
698 | }; |
699 | |
700 | REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey" ).Device(DEVICE_CPU), |
701 | MapUnstageNoKeyOp<false>); |
702 | REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey" ).Device(DEVICE_CPU), |
703 | MapUnstageNoKeyOp<true>); |
704 | |
705 | REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey" ) |
706 | .HostMemory("key" ) |
707 | .HostMemory("indices" ) |
708 | .Device(DEVICE_DEFAULT), |
709 | MapUnstageNoKeyOp<false>); |
710 | REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey" ) |
711 | .HostMemory("key" ) |
712 | .HostMemory("indices" ) |
713 | .Device(DEVICE_DEFAULT), |
714 | MapUnstageNoKeyOp<true>); |
715 | |
716 | template <bool Ordered> |
717 | class MapSizeOp : public OpKernel { |
718 | public: |
719 | explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
720 | |
721 | void Compute(OpKernelContext* ctx) override { |
722 | StagingMap<Ordered>* map = nullptr; |
723 | OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map)); |
724 | core::ScopedUnref scope(map); |
725 | |
726 | // Allocate size output tensor |
727 | Tensor* size = nullptr; |
728 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size)); |
729 | |
730 | // Set it to the actual size |
731 | size->scalar<int32>().setConstant(map->size()); |
732 | } |
733 | }; |
734 | |
735 | REGISTER_KERNEL_BUILDER(Name("MapSize" ).Device(DEVICE_CPU), MapSizeOp<false>); |
736 | REGISTER_KERNEL_BUILDER(Name("OrderedMapSize" ).Device(DEVICE_CPU), |
737 | MapSizeOp<true>); |
738 | |
739 | REGISTER_KERNEL_BUILDER( |
740 | Name("MapSize" ).Device(DEVICE_DEFAULT).HostMemory("size" ), |
741 | MapSizeOp<false>); |
742 | REGISTER_KERNEL_BUILDER( |
743 | Name("OrderedMapSize" ).Device(DEVICE_DEFAULT).HostMemory("size" ), |
744 | MapSizeOp<true>); |
745 | |
746 | template <bool Ordered> |
747 | class MapIncompleteSizeOp : public OpKernel { |
748 | public: |
749 | explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
750 | |
751 | void Compute(OpKernelContext* ctx) override { |
752 | StagingMap<Ordered>* map = nullptr; |
753 | OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map)); |
754 | core::ScopedUnref scope(map); |
755 | |
756 | // Allocate size output tensor |
757 | Tensor* size = nullptr; |
758 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size)); |
759 | |
760 | // Set it to the actual size |
761 | size->scalar<int32>().setConstant(map->incomplete_size()); |
762 | } |
763 | }; |
764 | |
765 | REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize" ).Device(DEVICE_CPU), |
766 | MapIncompleteSizeOp<false>); |
767 | REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize" ).Device(DEVICE_CPU), |
768 | MapIncompleteSizeOp<true>); |
769 | |
770 | REGISTER_KERNEL_BUILDER( |
771 | Name("MapIncompleteSize" ).Device(DEVICE_DEFAULT).HostMemory("size" ), |
772 | MapIncompleteSizeOp<false>); |
773 | REGISTER_KERNEL_BUILDER( |
774 | Name("OrderedMapIncompleteSize" ).Device(DEVICE_DEFAULT).HostMemory("size" ), |
775 | MapIncompleteSizeOp<true>); |
776 | |
777 | template <bool Ordered> |
778 | class MapClearOp : public OpKernel { |
779 | public: |
780 | explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
781 | |
782 | void Compute(OpKernelContext* ctx) override { |
783 | StagingMap<Ordered>* map = nullptr; |
784 | OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map)); |
785 | core::ScopedUnref scope(map); |
786 | |
787 | OP_REQUIRES_OK(ctx, map->clear()); |
788 | } |
789 | }; |
790 | |
791 | REGISTER_KERNEL_BUILDER(Name("MapClear" ).Device(DEVICE_CPU), MapClearOp<false>); |
792 | REGISTER_KERNEL_BUILDER(Name("OrderedMapClear" ).Device(DEVICE_CPU), |
793 | MapClearOp<true>); |
794 | |
795 | REGISTER_KERNEL_BUILDER(Name("MapClear" ).Device(DEVICE_DEFAULT), |
796 | MapClearOp<false>); |
797 | REGISTER_KERNEL_BUILDER(Name("OrderedMapClear" ).Device(DEVICE_DEFAULT), |
798 | MapClearOp<true>); |
799 | |
800 | } // namespace |
801 | } // namespace tensorflow |
802 | |