1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | // See docs in ../ops/data_flow_ops.cc. |
16 | |
17 | #include <limits.h> |
18 | |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/resource_mgr.h" |
25 | #include "tensorflow/core/framework/resource_op_kernel.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/kernels/priority_queue.h" |
30 | #include "tensorflow/core/kernels/queue_base.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/core/notification.h" |
33 | #include "tensorflow/core/lib/gtl/map_util.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/platform/macros.h" |
36 | #include "tensorflow/core/platform/mutex.h" |
37 | #include "tensorflow/core/platform/thread_annotations.h" |
38 | #include "tensorflow/core/platform/types.h" |
39 | |
40 | namespace tensorflow { |
41 | |
42 | namespace barrier { |
43 | |
44 | class Barrier : public ResourceBase { |
45 | public: |
46 | typedef std::vector<Tensor> Tuple; |
47 | typedef std::function<void()> DoneCallback; |
48 | typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)> |
49 | IndicesKeysValuesCallback; |
50 | |
51 | Barrier(const DataTypeVector& value_component_types, |
52 | const std::vector<TensorShape>& value_component_shapes, |
53 | const string& name) |
54 | : closed_(false), |
55 | queue_closed_(false), |
56 | queue_cancelled_(false), |
57 | cancel_pending_enqueues_(false), |
58 | value_component_types_(value_component_types), |
59 | value_component_shapes_(value_component_shapes), |
60 | name_(name), |
61 | input_index_(std::numeric_limits<int64_t>::min()) { |
62 | DataTypeVector queue_component_types; |
63 | std::vector<TensorShape> queue_component_shapes; |
64 | |
65 | // First queue component is for the input index; |
66 | // Second queue component is for the key; |
67 | // remaining queue components are for the value. |
68 | queue_component_types.push_back(DT_INT64); |
69 | queue_component_types.push_back(DT_STRING); |
70 | for (DataType dt : value_component_types) { |
71 | queue_component_types.push_back(dt); |
72 | } |
73 | |
74 | // NOTE(mrry): PriorityQueue expects all shapes specified because |
75 | // we'll be issuing TakeMany. |
76 | queue_component_shapes.push_back(TensorShape({})); |
77 | queue_component_shapes.push_back(TensorShape({})); |
78 | queue_component_shapes.insert(queue_component_shapes.end(), |
79 | value_component_shapes.begin(), |
80 | value_component_shapes.end()); |
81 | |
82 | ready_queue_ = new PriorityQueue( |
83 | QueueBase::kUnbounded /* capacity */, queue_component_types, |
84 | queue_component_shapes, strings::StrCat(name_, "_queue" )); |
85 | } |
86 | |
87 | Status Initialize() { return ready_queue_->Initialize(); } |
88 | |
89 | template <typename T> |
90 | void TryInsertMany(const Tensor& keys, int component_index, |
91 | const Tensor& values, OpKernelContext* ctx, |
92 | const DoneCallback& callback) { |
93 | TensorShape element_shape = values.shape(); |
94 | OP_REQUIRES_ASYNC( |
95 | ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0, |
96 | errors::InvalidArgument("Tensors with no elements are not supported " , |
97 | name_, ": received shape " , |
98 | element_shape.DebugString()), |
99 | callback); |
100 | if (element_shape.dims() > 0) element_shape.RemoveDim(0); |
101 | const std::size_t num_inserted = keys.NumElements(); |
102 | |
103 | // For each key, update the corresponding incomplete tuple with the |
104 | // the corresponding given value at component_index. |
105 | // This will be passed to the final callback at the very end. |
106 | bool new_elements = false; |
107 | |
108 | // Will be used for the final insert into the queue. |
109 | Tuple insert_tuple; |
110 | |
111 | { |
112 | mutex_lock lock(mu_); |
113 | if (closed_) { |
114 | OP_REQUIRES_ASYNC( |
115 | ctx, |
116 | !cancel_pending_enqueues_ && |
117 | (num_inserted == 0 || !incomplete_.empty()), |
118 | errors::Cancelled( |
119 | "Barrier " , name_, " is closed. Pending enqueues cancelled: " , |
120 | cancel_pending_enqueues_, |
121 | ". Number of new insertions: " , num_inserted, |
122 | ". Number of incomplete keys: " , incomplete_.size(), "." ), |
123 | callback); |
124 | } |
125 | |
126 | // Step 1: insert into the incomplete map and identify which |
127 | // entries are, in fact, complete and ready for enqueueing. Store |
128 | // them in a vector |
129 | std::vector<Tuple> ready_tuples; |
130 | |
131 | for (int i = 0; i < num_inserted; ++i) { |
132 | OP_REQUIRES_OK_ASYNC( |
133 | ctx, |
134 | InsertOneLocked<T>(ctx, keys, values, element_shape, |
135 | component_index, i, &ready_tuples, |
136 | &new_elements), |
137 | callback); |
138 | } |
139 | |
140 | if (new_elements) ++input_index_; |
141 | |
142 | // This probably won't happen before the heat death of the |
143 | // universe, but who knows? Moore's law FTW. |
144 | OP_REQUIRES_ASYNC( |
145 | ctx, input_index_ != std::numeric_limits<int64_t>::max(), |
146 | errors::Internal( |
147 | "Barrier has had " , input_index_, |
148 | " insertions and can no longer keep track of new ones." ), |
149 | callback); |
150 | |
151 | if (ready_tuples.empty()) { |
152 | // Nothing to insert into the queue - so return early. |
153 | callback(); |
154 | return; |
155 | } |
156 | |
157 | // We have something to Enqueue. Convert the Tuples into a single |
158 | // tuple by slicing entries into new Tensors. This part is slow |
159 | // but seems the cleanest solution for now. |
160 | insert_tuple.reserve(2 + num_components()); // indices, keys, rest |
161 | int insertion_size = ready_tuples.size(); |
162 | for (int i = 0; i < 2 + num_components(); ++i) { |
163 | TensorShape component_shape(ready_tuples[0][i].shape()); |
164 | component_shape.InsertDim(0, insertion_size); |
165 | Tensor component(ready_tuples[0][i].dtype(), component_shape); |
166 | for (int b = 0; b < insertion_size; ++b) { |
167 | OP_REQUIRES_OK_ASYNC( |
168 | ctx, |
169 | batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]), |
170 | &component, b), |
171 | callback); |
172 | } |
173 | insert_tuple.push_back(component); |
174 | } |
175 | } |
176 | |
177 | // Update the input index for the next batch. |
178 | ready_queue_->TryEnqueueMany( |
179 | insert_tuple, ctx, |
180 | // To avoid early closing of the queue, only close it if the |
181 | // SQSS is closed, nothing is left in the incomplete set, |
182 | // the queue is not already marked as closed, and (most |
183 | // importantly), the queue has entries in it. |
184 | [this, ctx, callback]() { |
185 | if (!ctx->status().ok()) { |
186 | callback(); |
187 | return; |
188 | } |
189 | { |
190 | mutex_lock lock(mu_); |
191 | int32_t ready = ready_size(); |
192 | if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) { |
193 | CloseQueueLocked(ctx, false, callback); |
194 | } else { |
195 | callback(); |
196 | } |
197 | return; |
198 | } |
199 | }); |
200 | } |
201 | |
202 | void TryTakeMany(int num_elements, bool allow_small_batch, int64_t timeout, |
203 | OpKernelContext* ctx, |
204 | const IndicesKeysValuesCallback& callback) { |
205 | int num_elements_to_deliver = num_elements; |
206 | { |
207 | mutex_lock lock(mu_); |
208 | if (closed_) { |
209 | int available_elements = ready_size(); |
210 | if (allow_small_batch) { |
211 | // We want to deliver a maximum of num_elements, if there are less |
212 | // elements available, we deliver at most the available_elements. If |
213 | // there are no |
214 | // elements available, a call to TryTakeMany should fail with |
215 | // OutOfRange. We trigger this error by setting the request here to 1. |
216 | num_elements_to_deliver = std::min(num_elements, available_elements); |
217 | } else { |
218 | // We're happy to wait for additional elements to be completed. |
219 | available_elements += incomplete_.size(); |
220 | } |
221 | // If there are 0 available elements or less elements than the |
222 | // number we can deliver, then we are done. |
223 | if (available_elements < std::max(num_elements_to_deliver, 1)) { |
224 | ctx->SetStatus(errors::OutOfRange( |
225 | "Barrier '" , name_, "' is closed and has " , |
226 | "insufficient elements (requested " , num_elements_to_deliver, |
227 | ", total size " , available_elements, ")" )); |
228 | callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple()); |
229 | return; |
230 | } |
231 | } |
232 | } |
233 | |
234 | ready_queue_->TryDequeueMany( |
235 | num_elements_to_deliver, ctx, allow_small_batch, |
236 | [this, ctx, callback](const Tuple& t) { |
237 | Tensor indices(DT_INT64); |
238 | Tensor keys(DT_STRING); |
239 | Tuple values; |
240 | |
241 | if (!ctx->status().ok()) { |
242 | callback(indices, keys, values); |
243 | return; |
244 | } |
245 | |
246 | CHECK_EQ(t.size(), 2 + num_components()); |
247 | indices = t[0]; |
248 | keys = t[1]; |
249 | values.insert(values.begin(), t.begin() + 2, t.end()); |
250 | callback(indices, keys, values); |
251 | }); |
252 | } |
253 | |
254 | void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, |
255 | const DoneCallback& callback) { |
256 | mutex_lock lock(mu_); |
257 | // We're allowed to close twice if the first close wasn't a |
258 | // cancel but the second one is. |
259 | if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) { |
260 | ctx->SetStatus( |
261 | errors::Cancelled("Barrier '" , name_, "' is already closed." )); |
262 | callback(); |
263 | return; |
264 | } |
265 | cancel_pending_enqueues_ = cancel_pending_enqueues; |
266 | closed_ = true; |
267 | if (cancel_pending_enqueues_ || incomplete_.empty()) { |
268 | incomplete_.clear(); |
269 | // CloseQueueLocked runs the callback |
270 | CloseQueueLocked(ctx, cancel_pending_enqueues_, callback); |
271 | return; |
272 | } |
273 | callback(); |
274 | } |
275 | |
276 | int32 ready_size() { return ready_queue_->size(); } |
277 | |
278 | int32 incomplete_size() { |
279 | mutex_lock lock(mu_); |
280 | return incomplete_.size(); |
281 | } |
282 | |
283 | const string& name() const { return name_; } |
284 | int num_components() const { return value_component_types_.size(); } |
285 | DataType component_type(int i) const { |
286 | CHECK_GE(i, 0); |
287 | CHECK_LT(static_cast<size_t>(i), value_component_types_.size()); |
288 | return value_component_types_[i]; |
289 | } |
290 | const DataTypeVector component_types() const { |
291 | return value_component_types_; |
292 | } |
293 | const gtl::ArraySlice<TensorShape> component_shapes() const { |
294 | return value_component_shapes_; |
295 | } |
296 | |
297 | ~Barrier() override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
298 | mutex_lock lock(mu_); |
299 | incomplete_.clear(); |
300 | ready_queue_->Unref(); |
301 | } |
302 | |
303 | string DebugString() const override { return "A barrier" ; } |
304 | |
305 | protected: |
306 | template <typename T> |
307 | Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys, |
308 | const Tensor& values, const TensorShape& element_shape, |
309 | int component_index, int i, |
310 | std::vector<Tuple>* ready_tuples, bool* new_elements) |
311 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
312 | auto keys_vec = keys.flat<tstring>(); |
313 | auto values_matrix = values.flat_outer_dims<T>(); |
314 | |
315 | TensorTuple* element_ptr; |
316 | if (closed_) { |
317 | element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i)); |
318 | if (element_ptr == nullptr) { |
319 | return errors::Cancelled( |
320 | "Barrier " , name_, |
321 | " is closed, but attempted to insert a brand new key: " , |
322 | keys_vec(i), |
323 | ". Pending enqueues cancelled: " , cancel_pending_enqueues_, |
324 | ". Insertion index: " , i, |
325 | ". Number of incomplete keys: " , incomplete_.size(), "." ); |
326 | } |
327 | } else { |
328 | element_ptr = |
329 | >l::LookupOrInsert(&incomplete_, keys_vec(i), TensorTuple()); |
330 | } |
331 | TensorTuple& element = *element_ptr; |
332 | |
333 | if (element.empty()) { // Never seen before key |
334 | // Added a new element, for keeping track of the insertion index |
335 | *new_elements = true; |
336 | |
337 | // Initialize the incomplete tuple for a new key. |
338 | element.reserve(1 + num_components()); |
339 | |
340 | // The first entry in element is the priority: the |
341 | // input_index_, so that tensors that entered the Barrier |
342 | // earlier have higher priority in the queue. |
343 | Tensor allocate_index_tensor; |
344 | TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({}), |
345 | &allocate_index_tensor)); |
346 | |
347 | Tensor index_tensor(DT_INT64, TensorShape({})); |
348 | allocate_index_tensor.scalar<int64_t>()() = input_index_; |
349 | element.push_back(allocate_index_tensor); |
350 | |
351 | // The rest of the element stores uninitialized Tensors with |
352 | // the appropriate dtype. |
353 | for (int j = 0; j < num_components(); ++j) { |
354 | Tensor uninitialized(component_type(j)); |
355 | element.push_back(Tensor(uninitialized)); |
356 | } |
357 | } |
358 | const Tensor& component = element[1 + component_index]; |
359 | if (component.IsInitialized() && component.NumElements() > 0) { |
360 | return errors::InvalidArgument("Key " , keys_vec(i), |
361 | " already has a value for component " , |
362 | component_index, " in barrier " , name()); |
363 | } |
364 | |
365 | // Extract the slice corresponding to the value from the value Tensor, |
366 | // and store it in the incomplete tuple at component_index. |
367 | Tensor next_element; |
368 | TF_RETURN_IF_ERROR( |
369 | ctx->allocate_temp(values.dtype(), element_shape, &next_element)); |
370 | element[1 + component_index] = next_element; |
371 | next_element.flat<T>() = values_matrix.template chip<0>(i); |
372 | |
373 | // Check the components of the tuple to see if it has become complete |
374 | // (i.e. all of its components are initialized). If so, add it to the |
375 | // ready queue. |
376 | bool is_complete = true; |
377 | for (int j = 0; is_complete && j < element.size(); ++j) { |
378 | is_complete = element[j].IsInitialized() && element[j].NumElements() > 0; |
379 | } |
380 | if (is_complete) { |
381 | // Add tuple to the ready queue. A queue tuple has the index |
382 | // as the first element and the key as the second element, |
383 | // followed by the value components. |
384 | Tuple ready_tuple; |
385 | ready_tuple.reserve(2 + num_components()); // index, key, rest |
386 | // Build a tensor for the key. TODO(mrry): Something more efficient. |
387 | Tensor key; |
388 | TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), &key)); |
389 | ready_tuple.push_back(element[0]); // index |
390 | ready_tuple.push_back(key); // key |
391 | ready_tuple[1].scalar<tstring>()() = keys_vec(i); // set the key |
392 | for (int j = 1; j < num_components() + 1; ++j) { |
393 | ready_tuple.push_back(element[j]); |
394 | } |
395 | incomplete_.erase(incomplete_.find(keys_vec(i))); |
396 | TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple)); |
397 | ready_tuples->push_back(ready_tuple); |
398 | } |
399 | return OkStatus(); |
400 | } |
401 | |
402 | void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues, |
403 | const DoneCallback& callback) |
404 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
405 | // CloseQueueLocked may only be called with mu_ held. |
406 | if (!cancel_pending_enqueues && queue_closed_) { |
407 | callback(); |
408 | return; |
409 | } |
410 | if (cancel_pending_enqueues && queue_cancelled_) { |
411 | callback(); |
412 | return; |
413 | } |
414 | queue_closed_ = true; |
415 | if (cancel_pending_enqueues) queue_cancelled_ = true; |
416 | if (!ready_queue_->is_closed()) { |
417 | ready_queue_->Close(ctx, cancel_pending_enqueues, callback); |
418 | } |
419 | } |
420 | |
421 | private: |
422 | typedef std::vector<Tensor> TensorTuple; |
423 | mutex mu_; |
424 | bool closed_ TF_GUARDED_BY(mu_); |
425 | bool queue_closed_ TF_GUARDED_BY(mu_); |
426 | bool queue_cancelled_ TF_GUARDED_BY(mu_); |
427 | bool cancel_pending_enqueues_ TF_GUARDED_BY(mu_); |
428 | const DataTypeVector value_component_types_; |
429 | const std::vector<TensorShape>& value_component_shapes_; |
430 | const string name_; |
431 | int64_t input_index_ TF_GUARDED_BY(mu_); |
432 | std::unordered_map<string, TensorTuple> incomplete_ TF_GUARDED_BY(mu_); |
433 | PriorityQueue* ready_queue_; |
434 | |
435 | TF_DISALLOW_COPY_AND_ASSIGN(Barrier); |
436 | }; |
437 | |
438 | class BarrierOp : public ResourceOpKernel<Barrier> { |
439 | public: |
440 | explicit BarrierOp(OpKernelConstruction* context) |
441 | : ResourceOpKernel(context) { |
442 | OP_REQUIRES_OK( |
443 | context, context->GetAttr("component_types" , &value_component_types_)); |
444 | OP_REQUIRES_OK(context, |
445 | context->GetAttr("shapes" , &value_component_shapes_)); |
446 | OP_REQUIRES(context, |
447 | value_component_shapes_.size() == value_component_types_.size(), |
448 | errors::InvalidArgument( |
449 | "All of the component shapes must be specified" )); |
450 | |
451 | int32_t value_capacity; |
452 | OP_REQUIRES_OK(context, context->GetAttr("capacity" , &value_capacity)); |
453 | OP_REQUIRES(context, value_capacity == -1, |
454 | errors::InvalidArgument( |
455 | "Barrier only accepts capacity=-1. Feed the " |
456 | "inputs to your Barrier through a queue to enforce a " |
457 | "limited capacity." )); |
458 | } |
459 | |
460 | private: |
461 | Status CreateResource(Barrier** barrier) override |
462 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
463 | *barrier = new Barrier(value_component_types_, value_component_shapes_, |
464 | cinfo_.name()); |
465 | if (*barrier == nullptr) { |
466 | return errors::ResourceExhausted("Failed to allocate barrier" ); |
467 | } |
468 | return (*barrier)->Initialize(); |
469 | } |
470 | |
471 | Status VerifyResource(Barrier* barrier) override |
472 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
473 | if (barrier->component_types() != value_component_types_) { |
474 | return errors::InvalidArgument( |
475 | "Shared barrier '" , cinfo_.name(), "' has component types " , |
476 | DataTypeSliceString(barrier->component_types()), |
477 | " but requested component types were " , |
478 | DataTypeSliceString(value_component_types_)); |
479 | } |
480 | if (barrier->component_shapes() != value_component_shapes_) { |
481 | return errors::InvalidArgument( |
482 | "Shared barrier '" , cinfo_.name(), "' has component shapes " , |
483 | TensorShapeUtils::ShapeListString(barrier->component_shapes()), |
484 | " but requested component shapes were " , |
485 | TensorShapeUtils::ShapeListString(value_component_shapes_)); |
486 | } |
487 | return OkStatus(); |
488 | } |
489 | |
490 | DataTypeVector value_component_types_; |
491 | std::vector<TensorShape> value_component_shapes_; |
492 | |
493 | TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp); |
494 | }; |
495 | |
496 | REGISTER_KERNEL_BUILDER(Name("Barrier" ).Device(DEVICE_CPU), BarrierOp); |
497 | |
498 | class BarrierOpKernel : public AsyncOpKernel { |
499 | public: |
500 | explicit BarrierOpKernel(OpKernelConstruction* context) |
501 | : AsyncOpKernel(context) {} |
502 | |
503 | void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { |
504 | Barrier* barrier = nullptr; |
505 | OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle" , &barrier), |
506 | callback); |
507 | ComputeAsync(ctx, barrier, [callback, barrier]() { |
508 | barrier->Unref(); |
509 | callback(); |
510 | }); |
511 | } |
512 | |
513 | protected: |
514 | virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, |
515 | DoneCallback callback) = 0; |
516 | }; |
517 | |
518 | template <typename T> |
519 | class InsertManyOp : public BarrierOpKernel { |
520 | public: |
521 | explicit InsertManyOp(OpKernelConstruction* context) |
522 | : BarrierOpKernel(context) { |
523 | OP_REQUIRES_OK(context, |
524 | context->GetAttr("component_index" , &component_index_)); |
525 | } |
526 | |
527 | protected: |
528 | void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, |
529 | DoneCallback callback) override { |
530 | OP_REQUIRES_ASYNC( |
531 | ctx, component_index_ < barrier->num_components(), |
532 | errors::InvalidArgument("The component ID is out of range " , |
533 | component_index_, " > num_components" , |
534 | " (= " , barrier->num_components(), ")" ), |
535 | callback); |
536 | OP_REQUIRES_OK_ASYNC( |
537 | ctx, |
538 | ctx->MatchSignature({DT_STRING_REF, DT_STRING, |
539 | barrier->component_type(component_index_)}, |
540 | {}), |
541 | callback); |
542 | |
543 | const Tensor* keys; |
544 | const Tensor* values; |
545 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys" , &keys), callback); |
546 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values" , &values), callback); |
547 | barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback); |
548 | } |
549 | |
550 | private: |
551 | int component_index_; |
552 | TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp); |
553 | }; |
554 | |
555 | #define REGISTER_INSERTMANY(T) \ |
556 | REGISTER_KERNEL_BUILDER( \ |
557 | Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
558 | InsertManyOp<T>); |
559 | |
560 | TF_CALL_ALL_TYPES(REGISTER_INSERTMANY); |
561 | #undef REGISTER_INSERTMANY |
562 | |
563 | class TakeManyOp : public BarrierOpKernel { |
564 | public: |
565 | explicit TakeManyOp(OpKernelConstruction* context) |
566 | : BarrierOpKernel(context) { |
567 | OP_REQUIRES_OK(context, context->GetAttr("timeout_ms" , &timeout_)); |
568 | // TODO(keveman): Enable timeout. |
569 | OP_REQUIRES(context, timeout_ == -1, |
570 | errors::InvalidArgument("Timeout not supported yet." )); |
571 | |
572 | OP_REQUIRES_OK(context, |
573 | context->GetAttr("allow_small_batch" , &allow_small_batch_)); |
574 | } |
575 | |
576 | protected: |
577 | void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, |
578 | DoneCallback callback) override { |
579 | const Tensor* Tnum_elements; |
580 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements" , &Tnum_elements), |
581 | callback); |
582 | OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()), |
583 | errors::InvalidArgument("num_elements must be a scalar." ), |
584 | callback); |
585 | const int32_t num_elements = Tnum_elements->scalar<int32>()(); |
586 | |
587 | DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32}; |
588 | // The first output is the insertion index, the second output is the key. |
589 | DataTypeVector expected_outputs = {DT_INT64, DT_STRING}; |
590 | for (DataType dt : barrier->component_types()) { |
591 | expected_outputs.push_back(dt); |
592 | } |
593 | OP_REQUIRES_OK_ASYNC( |
594 | ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback); |
595 | |
596 | barrier->TryTakeMany( |
597 | num_elements, allow_small_batch_, timeout_, ctx, |
598 | [ctx, callback](const Tensor& indices, const Tensor& keys, |
599 | const Barrier::Tuple& values) { |
600 | if (!ctx->status().ok()) { |
601 | callback(); |
602 | return; |
603 | } |
604 | // At this point, indices, keys, and values |
605 | // have all been written to successfully. |
606 | OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices" , indices), |
607 | callback); |
608 | OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys" , keys), callback); |
609 | OpOutputList values_output; |
610 | OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values" , &values_output), |
611 | callback); |
612 | for (size_t i = 0; i < values.size(); ++i) { |
613 | values_output.set(i, values[i]); |
614 | } |
615 | callback(); |
616 | }); |
617 | } |
618 | |
619 | private: |
620 | int64_t timeout_; |
621 | bool allow_small_batch_; |
622 | TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp); |
623 | }; |
624 | |
625 | REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany" ).Device(DEVICE_CPU), TakeManyOp); |
626 | |
627 | class BarrierCloseOp : public BarrierOpKernel { |
628 | public: |
629 | explicit BarrierCloseOp(OpKernelConstruction* context) |
630 | : BarrierOpKernel(context) { |
631 | OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues" , |
632 | &cancel_pending_enqueues_)); |
633 | } |
634 | |
635 | protected: |
636 | void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, |
637 | DoneCallback callback) override { |
638 | barrier->Close(ctx, cancel_pending_enqueues_, callback); |
639 | } |
640 | |
641 | private: |
642 | bool cancel_pending_enqueues_; |
643 | TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp); |
644 | }; |
645 | |
646 | REGISTER_KERNEL_BUILDER(Name("BarrierClose" ).Device(DEVICE_CPU), |
647 | BarrierCloseOp); |
648 | |
649 | class BarrierIncompleteSizeOp : public BarrierOpKernel { |
650 | public: |
651 | explicit BarrierIncompleteSizeOp(OpKernelConstruction* context) |
652 | : BarrierOpKernel(context) {} |
653 | |
654 | protected: |
655 | void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, |
656 | DoneCallback callback) override { |
657 | Tensor* Tsize = nullptr; |
658 | OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), |
659 | callback); |
660 | Tsize->scalar<int32>().setConstant(barrier->incomplete_size()); |
661 | callback(); |
662 | } |
663 | }; |
664 | |
665 | REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize" ).Device(DEVICE_CPU), |
666 | BarrierIncompleteSizeOp); |
667 | |
668 | class BarrierReadySizeOp : public BarrierOpKernel { |
669 | public: |
670 | explicit BarrierReadySizeOp(OpKernelConstruction* context) |
671 | : BarrierOpKernel(context) {} |
672 | |
673 | protected: |
674 | void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, |
675 | DoneCallback callback) override { |
676 | Tensor* Tsize = nullptr; |
677 | OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), |
678 | callback); |
679 | Tsize->scalar<int32>().setConstant(barrier->ready_size()); |
680 | callback(); |
681 | } |
682 | }; |
683 | |
684 | REGISTER_KERNEL_BUILDER(Name("BarrierReadySize" ).Device(DEVICE_CPU), |
685 | BarrierReadySizeOp); |
686 | |
687 | } // namespace barrier |
688 | |
689 | } // namespace tensorflow |
690 | |