1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// See docs in ../ops/data_flow_ops.cc.
16
17#include <limits.h>
18
19#include <unordered_map>
20#include <vector>
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/resource_mgr.h"
25#include "tensorflow/core/framework/resource_op_kernel.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/kernels/priority_queue.h"
30#include "tensorflow/core/kernels/queue_base.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/core/notification.h"
33#include "tensorflow/core/lib/gtl/map_util.h"
34#include "tensorflow/core/platform/logging.h"
35#include "tensorflow/core/platform/macros.h"
36#include "tensorflow/core/platform/mutex.h"
37#include "tensorflow/core/platform/thread_annotations.h"
38#include "tensorflow/core/platform/types.h"
39
40namespace tensorflow {
41
42namespace barrier {
43
44class Barrier : public ResourceBase {
45 public:
46 typedef std::vector<Tensor> Tuple;
47 typedef std::function<void()> DoneCallback;
48 typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)>
49 IndicesKeysValuesCallback;
50
51 Barrier(const DataTypeVector& value_component_types,
52 const std::vector<TensorShape>& value_component_shapes,
53 const string& name)
54 : closed_(false),
55 queue_closed_(false),
56 queue_cancelled_(false),
57 cancel_pending_enqueues_(false),
58 value_component_types_(value_component_types),
59 value_component_shapes_(value_component_shapes),
60 name_(name),
61 input_index_(std::numeric_limits<int64_t>::min()) {
62 DataTypeVector queue_component_types;
63 std::vector<TensorShape> queue_component_shapes;
64
65 // First queue component is for the input index;
66 // Second queue component is for the key;
67 // remaining queue components are for the value.
68 queue_component_types.push_back(DT_INT64);
69 queue_component_types.push_back(DT_STRING);
70 for (DataType dt : value_component_types) {
71 queue_component_types.push_back(dt);
72 }
73
74 // NOTE(mrry): PriorityQueue expects all shapes specified because
75 // we'll be issuing TakeMany.
76 queue_component_shapes.push_back(TensorShape({}));
77 queue_component_shapes.push_back(TensorShape({}));
78 queue_component_shapes.insert(queue_component_shapes.end(),
79 value_component_shapes.begin(),
80 value_component_shapes.end());
81
82 ready_queue_ = new PriorityQueue(
83 QueueBase::kUnbounded /* capacity */, queue_component_types,
84 queue_component_shapes, strings::StrCat(name_, "_queue"));
85 }
86
87 Status Initialize() { return ready_queue_->Initialize(); }
88
89 template <typename T>
90 void TryInsertMany(const Tensor& keys, int component_index,
91 const Tensor& values, OpKernelContext* ctx,
92 const DoneCallback& callback) {
93 TensorShape element_shape = values.shape();
94 OP_REQUIRES_ASYNC(
95 ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0,
96 errors::InvalidArgument("Tensors with no elements are not supported ",
97 name_, ": received shape ",
98 element_shape.DebugString()),
99 callback);
100 if (element_shape.dims() > 0) element_shape.RemoveDim(0);
101 const std::size_t num_inserted = keys.NumElements();
102
103 // For each key, update the corresponding incomplete tuple with the
104 // the corresponding given value at component_index.
105 // This will be passed to the final callback at the very end.
106 bool new_elements = false;
107
108 // Will be used for the final insert into the queue.
109 Tuple insert_tuple;
110
111 {
112 mutex_lock lock(mu_);
113 if (closed_) {
114 OP_REQUIRES_ASYNC(
115 ctx,
116 !cancel_pending_enqueues_ &&
117 (num_inserted == 0 || !incomplete_.empty()),
118 errors::Cancelled(
119 "Barrier ", name_, " is closed. Pending enqueues cancelled: ",
120 cancel_pending_enqueues_,
121 ". Number of new insertions: ", num_inserted,
122 ". Number of incomplete keys: ", incomplete_.size(), "."),
123 callback);
124 }
125
126 // Step 1: insert into the incomplete map and identify which
127 // entries are, in fact, complete and ready for enqueueing. Store
128 // them in a vector
129 std::vector<Tuple> ready_tuples;
130
131 for (int i = 0; i < num_inserted; ++i) {
132 OP_REQUIRES_OK_ASYNC(
133 ctx,
134 InsertOneLocked<T>(ctx, keys, values, element_shape,
135 component_index, i, &ready_tuples,
136 &new_elements),
137 callback);
138 }
139
140 if (new_elements) ++input_index_;
141
142 // This probably won't happen before the heat death of the
143 // universe, but who knows? Moore's law FTW.
144 OP_REQUIRES_ASYNC(
145 ctx, input_index_ != std::numeric_limits<int64_t>::max(),
146 errors::Internal(
147 "Barrier has had ", input_index_,
148 " insertions and can no longer keep track of new ones."),
149 callback);
150
151 if (ready_tuples.empty()) {
152 // Nothing to insert into the queue - so return early.
153 callback();
154 return;
155 }
156
157 // We have something to Enqueue. Convert the Tuples into a single
158 // tuple by slicing entries into new Tensors. This part is slow
159 // but seems the cleanest solution for now.
160 insert_tuple.reserve(2 + num_components()); // indices, keys, rest
161 int insertion_size = ready_tuples.size();
162 for (int i = 0; i < 2 + num_components(); ++i) {
163 TensorShape component_shape(ready_tuples[0][i].shape());
164 component_shape.InsertDim(0, insertion_size);
165 Tensor component(ready_tuples[0][i].dtype(), component_shape);
166 for (int b = 0; b < insertion_size; ++b) {
167 OP_REQUIRES_OK_ASYNC(
168 ctx,
169 batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]),
170 &component, b),
171 callback);
172 }
173 insert_tuple.push_back(component);
174 }
175 }
176
177 // Update the input index for the next batch.
178 ready_queue_->TryEnqueueMany(
179 insert_tuple, ctx,
180 // To avoid early closing of the queue, only close it if the
181 // SQSS is closed, nothing is left in the incomplete set,
182 // the queue is not already marked as closed, and (most
183 // importantly), the queue has entries in it.
184 [this, ctx, callback]() {
185 if (!ctx->status().ok()) {
186 callback();
187 return;
188 }
189 {
190 mutex_lock lock(mu_);
191 int32_t ready = ready_size();
192 if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) {
193 CloseQueueLocked(ctx, false, callback);
194 } else {
195 callback();
196 }
197 return;
198 }
199 });
200 }
201
202 void TryTakeMany(int num_elements, bool allow_small_batch, int64_t timeout,
203 OpKernelContext* ctx,
204 const IndicesKeysValuesCallback& callback) {
205 int num_elements_to_deliver = num_elements;
206 {
207 mutex_lock lock(mu_);
208 if (closed_) {
209 int available_elements = ready_size();
210 if (allow_small_batch) {
211 // We want to deliver a maximum of num_elements, if there are less
212 // elements available, we deliver at most the available_elements. If
213 // there are no
214 // elements available, a call to TryTakeMany should fail with
215 // OutOfRange. We trigger this error by setting the request here to 1.
216 num_elements_to_deliver = std::min(num_elements, available_elements);
217 } else {
218 // We're happy to wait for additional elements to be completed.
219 available_elements += incomplete_.size();
220 }
221 // If there are 0 available elements or less elements than the
222 // number we can deliver, then we are done.
223 if (available_elements < std::max(num_elements_to_deliver, 1)) {
224 ctx->SetStatus(errors::OutOfRange(
225 "Barrier '", name_, "' is closed and has ",
226 "insufficient elements (requested ", num_elements_to_deliver,
227 ", total size ", available_elements, ")"));
228 callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple());
229 return;
230 }
231 }
232 }
233
234 ready_queue_->TryDequeueMany(
235 num_elements_to_deliver, ctx, allow_small_batch,
236 [this, ctx, callback](const Tuple& t) {
237 Tensor indices(DT_INT64);
238 Tensor keys(DT_STRING);
239 Tuple values;
240
241 if (!ctx->status().ok()) {
242 callback(indices, keys, values);
243 return;
244 }
245
246 CHECK_EQ(t.size(), 2 + num_components());
247 indices = t[0];
248 keys = t[1];
249 values.insert(values.begin(), t.begin() + 2, t.end());
250 callback(indices, keys, values);
251 });
252 }
253
254 void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
255 const DoneCallback& callback) {
256 mutex_lock lock(mu_);
257 // We're allowed to close twice if the first close wasn't a
258 // cancel but the second one is.
259 if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) {
260 ctx->SetStatus(
261 errors::Cancelled("Barrier '", name_, "' is already closed."));
262 callback();
263 return;
264 }
265 cancel_pending_enqueues_ = cancel_pending_enqueues;
266 closed_ = true;
267 if (cancel_pending_enqueues_ || incomplete_.empty()) {
268 incomplete_.clear();
269 // CloseQueueLocked runs the callback
270 CloseQueueLocked(ctx, cancel_pending_enqueues_, callback);
271 return;
272 }
273 callback();
274 }
275
276 int32 ready_size() { return ready_queue_->size(); }
277
278 int32 incomplete_size() {
279 mutex_lock lock(mu_);
280 return incomplete_.size();
281 }
282
283 const string& name() const { return name_; }
284 int num_components() const { return value_component_types_.size(); }
285 DataType component_type(int i) const {
286 CHECK_GE(i, 0);
287 CHECK_LT(static_cast<size_t>(i), value_component_types_.size());
288 return value_component_types_[i];
289 }
290 const DataTypeVector component_types() const {
291 return value_component_types_;
292 }
293 const gtl::ArraySlice<TensorShape> component_shapes() const {
294 return value_component_shapes_;
295 }
296
297 ~Barrier() override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
298 mutex_lock lock(mu_);
299 incomplete_.clear();
300 ready_queue_->Unref();
301 }
302
303 string DebugString() const override { return "A barrier"; }
304
305 protected:
306 template <typename T>
307 Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys,
308 const Tensor& values, const TensorShape& element_shape,
309 int component_index, int i,
310 std::vector<Tuple>* ready_tuples, bool* new_elements)
311 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
312 auto keys_vec = keys.flat<tstring>();
313 auto values_matrix = values.flat_outer_dims<T>();
314
315 TensorTuple* element_ptr;
316 if (closed_) {
317 element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i));
318 if (element_ptr == nullptr) {
319 return errors::Cancelled(
320 "Barrier ", name_,
321 " is closed, but attempted to insert a brand new key: ",
322 keys_vec(i),
323 ". Pending enqueues cancelled: ", cancel_pending_enqueues_,
324 ". Insertion index: ", i,
325 ". Number of incomplete keys: ", incomplete_.size(), ".");
326 }
327 } else {
328 element_ptr =
329 &gtl::LookupOrInsert(&incomplete_, keys_vec(i), TensorTuple());
330 }
331 TensorTuple& element = *element_ptr;
332
333 if (element.empty()) { // Never seen before key
334 // Added a new element, for keeping track of the insertion index
335 *new_elements = true;
336
337 // Initialize the incomplete tuple for a new key.
338 element.reserve(1 + num_components());
339
340 // The first entry in element is the priority: the
341 // input_index_, so that tensors that entered the Barrier
342 // earlier have higher priority in the queue.
343 Tensor allocate_index_tensor;
344 TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({}),
345 &allocate_index_tensor));
346
347 Tensor index_tensor(DT_INT64, TensorShape({}));
348 allocate_index_tensor.scalar<int64_t>()() = input_index_;
349 element.push_back(allocate_index_tensor);
350
351 // The rest of the element stores uninitialized Tensors with
352 // the appropriate dtype.
353 for (int j = 0; j < num_components(); ++j) {
354 Tensor uninitialized(component_type(j));
355 element.push_back(Tensor(uninitialized));
356 }
357 }
358 const Tensor& component = element[1 + component_index];
359 if (component.IsInitialized() && component.NumElements() > 0) {
360 return errors::InvalidArgument("Key ", keys_vec(i),
361 " already has a value for component ",
362 component_index, " in barrier ", name());
363 }
364
365 // Extract the slice corresponding to the value from the value Tensor,
366 // and store it in the incomplete tuple at component_index.
367 Tensor next_element;
368 TF_RETURN_IF_ERROR(
369 ctx->allocate_temp(values.dtype(), element_shape, &next_element));
370 element[1 + component_index] = next_element;
371 next_element.flat<T>() = values_matrix.template chip<0>(i);
372
373 // Check the components of the tuple to see if it has become complete
374 // (i.e. all of its components are initialized). If so, add it to the
375 // ready queue.
376 bool is_complete = true;
377 for (int j = 0; is_complete && j < element.size(); ++j) {
378 is_complete = element[j].IsInitialized() && element[j].NumElements() > 0;
379 }
380 if (is_complete) {
381 // Add tuple to the ready queue. A queue tuple has the index
382 // as the first element and the key as the second element,
383 // followed by the value components.
384 Tuple ready_tuple;
385 ready_tuple.reserve(2 + num_components()); // index, key, rest
386 // Build a tensor for the key. TODO(mrry): Something more efficient.
387 Tensor key;
388 TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), &key));
389 ready_tuple.push_back(element[0]); // index
390 ready_tuple.push_back(key); // key
391 ready_tuple[1].scalar<tstring>()() = keys_vec(i); // set the key
392 for (int j = 1; j < num_components() + 1; ++j) {
393 ready_tuple.push_back(element[j]);
394 }
395 incomplete_.erase(incomplete_.find(keys_vec(i)));
396 TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple));
397 ready_tuples->push_back(ready_tuple);
398 }
399 return OkStatus();
400 }
401
402 void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues,
403 const DoneCallback& callback)
404 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
405 // CloseQueueLocked may only be called with mu_ held.
406 if (!cancel_pending_enqueues && queue_closed_) {
407 callback();
408 return;
409 }
410 if (cancel_pending_enqueues && queue_cancelled_) {
411 callback();
412 return;
413 }
414 queue_closed_ = true;
415 if (cancel_pending_enqueues) queue_cancelled_ = true;
416 if (!ready_queue_->is_closed()) {
417 ready_queue_->Close(ctx, cancel_pending_enqueues, callback);
418 }
419 }
420
421 private:
422 typedef std::vector<Tensor> TensorTuple;
423 mutex mu_;
424 bool closed_ TF_GUARDED_BY(mu_);
425 bool queue_closed_ TF_GUARDED_BY(mu_);
426 bool queue_cancelled_ TF_GUARDED_BY(mu_);
427 bool cancel_pending_enqueues_ TF_GUARDED_BY(mu_);
428 const DataTypeVector value_component_types_;
429 const std::vector<TensorShape>& value_component_shapes_;
430 const string name_;
431 int64_t input_index_ TF_GUARDED_BY(mu_);
432 std::unordered_map<string, TensorTuple> incomplete_ TF_GUARDED_BY(mu_);
433 PriorityQueue* ready_queue_;
434
435 TF_DISALLOW_COPY_AND_ASSIGN(Barrier);
436};
437
438class BarrierOp : public ResourceOpKernel<Barrier> {
439 public:
440 explicit BarrierOp(OpKernelConstruction* context)
441 : ResourceOpKernel(context) {
442 OP_REQUIRES_OK(
443 context, context->GetAttr("component_types", &value_component_types_));
444 OP_REQUIRES_OK(context,
445 context->GetAttr("shapes", &value_component_shapes_));
446 OP_REQUIRES(context,
447 value_component_shapes_.size() == value_component_types_.size(),
448 errors::InvalidArgument(
449 "All of the component shapes must be specified"));
450
451 int32_t value_capacity;
452 OP_REQUIRES_OK(context, context->GetAttr("capacity", &value_capacity));
453 OP_REQUIRES(context, value_capacity == -1,
454 errors::InvalidArgument(
455 "Barrier only accepts capacity=-1. Feed the "
456 "inputs to your Barrier through a queue to enforce a "
457 "limited capacity."));
458 }
459
460 private:
461 Status CreateResource(Barrier** barrier) override
462 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
463 *barrier = new Barrier(value_component_types_, value_component_shapes_,
464 cinfo_.name());
465 if (*barrier == nullptr) {
466 return errors::ResourceExhausted("Failed to allocate barrier");
467 }
468 return (*barrier)->Initialize();
469 }
470
471 Status VerifyResource(Barrier* barrier) override
472 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
473 if (barrier->component_types() != value_component_types_) {
474 return errors::InvalidArgument(
475 "Shared barrier '", cinfo_.name(), "' has component types ",
476 DataTypeSliceString(barrier->component_types()),
477 " but requested component types were ",
478 DataTypeSliceString(value_component_types_));
479 }
480 if (barrier->component_shapes() != value_component_shapes_) {
481 return errors::InvalidArgument(
482 "Shared barrier '", cinfo_.name(), "' has component shapes ",
483 TensorShapeUtils::ShapeListString(barrier->component_shapes()),
484 " but requested component shapes were ",
485 TensorShapeUtils::ShapeListString(value_component_shapes_));
486 }
487 return OkStatus();
488 }
489
490 DataTypeVector value_component_types_;
491 std::vector<TensorShape> value_component_shapes_;
492
493 TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp);
494};
495
496REGISTER_KERNEL_BUILDER(Name("Barrier").Device(DEVICE_CPU), BarrierOp);
497
498class BarrierOpKernel : public AsyncOpKernel {
499 public:
500 explicit BarrierOpKernel(OpKernelConstruction* context)
501 : AsyncOpKernel(context) {}
502
503 void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
504 Barrier* barrier = nullptr;
505 OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &barrier),
506 callback);
507 ComputeAsync(ctx, barrier, [callback, barrier]() {
508 barrier->Unref();
509 callback();
510 });
511 }
512
513 protected:
514 virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
515 DoneCallback callback) = 0;
516};
517
518template <typename T>
519class InsertManyOp : public BarrierOpKernel {
520 public:
521 explicit InsertManyOp(OpKernelConstruction* context)
522 : BarrierOpKernel(context) {
523 OP_REQUIRES_OK(context,
524 context->GetAttr("component_index", &component_index_));
525 }
526
527 protected:
528 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
529 DoneCallback callback) override {
530 OP_REQUIRES_ASYNC(
531 ctx, component_index_ < barrier->num_components(),
532 errors::InvalidArgument("The component ID is out of range ",
533 component_index_, " > num_components",
534 " (= ", barrier->num_components(), ")"),
535 callback);
536 OP_REQUIRES_OK_ASYNC(
537 ctx,
538 ctx->MatchSignature({DT_STRING_REF, DT_STRING,
539 barrier->component_type(component_index_)},
540 {}),
541 callback);
542
543 const Tensor* keys;
544 const Tensor* values;
545 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys", &keys), callback);
546 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values", &values), callback);
547 barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback);
548 }
549
550 private:
551 int component_index_;
552 TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp);
553};
554
555#define REGISTER_INSERTMANY(T) \
556 REGISTER_KERNEL_BUILDER( \
557 Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
558 InsertManyOp<T>);
559
560TF_CALL_ALL_TYPES(REGISTER_INSERTMANY);
561#undef REGISTER_INSERTMANY
562
563class TakeManyOp : public BarrierOpKernel {
564 public:
565 explicit TakeManyOp(OpKernelConstruction* context)
566 : BarrierOpKernel(context) {
567 OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
568 // TODO(keveman): Enable timeout.
569 OP_REQUIRES(context, timeout_ == -1,
570 errors::InvalidArgument("Timeout not supported yet."));
571
572 OP_REQUIRES_OK(context,
573 context->GetAttr("allow_small_batch", &allow_small_batch_));
574 }
575
576 protected:
577 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
578 DoneCallback callback) override {
579 const Tensor* Tnum_elements;
580 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements", &Tnum_elements),
581 callback);
582 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()),
583 errors::InvalidArgument("num_elements must be a scalar."),
584 callback);
585 const int32_t num_elements = Tnum_elements->scalar<int32>()();
586
587 DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32};
588 // The first output is the insertion index, the second output is the key.
589 DataTypeVector expected_outputs = {DT_INT64, DT_STRING};
590 for (DataType dt : barrier->component_types()) {
591 expected_outputs.push_back(dt);
592 }
593 OP_REQUIRES_OK_ASYNC(
594 ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback);
595
596 barrier->TryTakeMany(
597 num_elements, allow_small_batch_, timeout_, ctx,
598 [ctx, callback](const Tensor& indices, const Tensor& keys,
599 const Barrier::Tuple& values) {
600 if (!ctx->status().ok()) {
601 callback();
602 return;
603 }
604 // At this point, indices, keys, and values
605 // have all been written to successfully.
606 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices", indices),
607 callback);
608 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys", keys), callback);
609 OpOutputList values_output;
610 OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values", &values_output),
611 callback);
612 for (size_t i = 0; i < values.size(); ++i) {
613 values_output.set(i, values[i]);
614 }
615 callback();
616 });
617 }
618
619 private:
620 int64_t timeout_;
621 bool allow_small_batch_;
622 TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp);
623};
624
625REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany").Device(DEVICE_CPU), TakeManyOp);
626
627class BarrierCloseOp : public BarrierOpKernel {
628 public:
629 explicit BarrierCloseOp(OpKernelConstruction* context)
630 : BarrierOpKernel(context) {
631 OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
632 &cancel_pending_enqueues_));
633 }
634
635 protected:
636 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
637 DoneCallback callback) override {
638 barrier->Close(ctx, cancel_pending_enqueues_, callback);
639 }
640
641 private:
642 bool cancel_pending_enqueues_;
643 TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp);
644};
645
646REGISTER_KERNEL_BUILDER(Name("BarrierClose").Device(DEVICE_CPU),
647 BarrierCloseOp);
648
649class BarrierIncompleteSizeOp : public BarrierOpKernel {
650 public:
651 explicit BarrierIncompleteSizeOp(OpKernelConstruction* context)
652 : BarrierOpKernel(context) {}
653
654 protected:
655 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
656 DoneCallback callback) override {
657 Tensor* Tsize = nullptr;
658 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
659 callback);
660 Tsize->scalar<int32>().setConstant(barrier->incomplete_size());
661 callback();
662 }
663};
664
665REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize").Device(DEVICE_CPU),
666 BarrierIncompleteSizeOp);
667
668class BarrierReadySizeOp : public BarrierOpKernel {
669 public:
670 explicit BarrierReadySizeOp(OpKernelConstruction* context)
671 : BarrierOpKernel(context) {}
672
673 protected:
674 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
675 DoneCallback callback) override {
676 Tensor* Tsize = nullptr;
677 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
678 callback);
679 Tsize->scalar<int32>().setConstant(barrier->ready_size());
680 callback();
681 }
682};
683
684REGISTER_KERNEL_BUILDER(Name("BarrierReadySize").Device(DEVICE_CPU),
685 BarrierReadySizeOp);
686
687} // namespace barrier
688
689} // namespace tensorflow
690