1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/data_flow_ops.cc.
17
18#include <cstddef>
19#include <deque>
20#include <vector>
21
22#include "tensorflow/core/framework/node_def.pb.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/resource_mgr.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/kernels/queue_op.h"
29#include "tensorflow/core/kernels/typed_queue.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/random/philox_random.h"
32#include "tensorflow/core/lib/random/random.h"
33#include "tensorflow/core/lib/random/random_distributions.h"
34#include "tensorflow/core/platform/logging.h"
35#include "tensorflow/core/platform/macros.h"
36#include "tensorflow/core/platform/mutex.h"
37#include "tensorflow/core/platform/thread_annotations.h"
38#include "tensorflow/core/platform/types.h"
39#include "tensorflow/core/util/batch_util.h"
40
41namespace tensorflow {
42
43class RandomShuffleQueue : public TypedQueue<std::vector<Tensor> > {
44 public:
45 RandomShuffleQueue(int32_t capacity, int32_t min_after_dequeue, int64_t seed,
46 int64_t seed2, const DataTypeVector& component_dtypes,
47 const std::vector<TensorShape>& component_shapes,
48 const string& name);
49
50 Status Initialize() override; // Must be called before any other method.
51
52 // Implementations of QueueInterface methods --------------------------------
53 void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
54 DoneCallback callback) override;
55 void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
56 DoneCallback callback) override;
57 void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
58 void TryDequeueMany(int num_elements, OpKernelContext* ctx,
59 bool allow_small_batch,
60 CallbackWithTuple callback) override;
61 Status MatchesNodeDef(const NodeDef& node_def) override;
62
63 int32 size() const override {
64 mutex_lock lock(mu_);
65 return queues_[0].size();
66 }
67
68 private:
69 ~RandomShuffleQueue() override {}
70
71 // Helper for dequeuing a single random element from queues_.
72 void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
73 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
74
75 static Status GetElementComponentFromBatch(const Tuple& tuple, int64_t index,
76 int component,
77 OpKernelContext* ctx,
78 Tensor* out_tensor);
79
80 const int32 min_after_dequeue_;
81 const int64_t original_seed_;
82 const int64_t original_seed2_;
83
84 random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
85 random::SingleSampleAdapter<random::PhiloxRandom> generator_
86 TF_GUARDED_BY(mu_);
87
88 TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue);
89};
90
91RandomShuffleQueue::RandomShuffleQueue(
92 int32_t capacity, int32_t min_after_dequeue, int64_t seed, int64_t seed2,
93 const DataTypeVector& component_dtypes,
94 const std::vector<TensorShape>& component_shapes, const string& name)
95 : TypedQueue(capacity, component_dtypes, component_shapes, name),
96 min_after_dequeue_(min_after_dequeue),
97 original_seed_(seed),
98 original_seed2_(seed2),
99 generator_(&parent_generator_) {
100 if (seed == 0 && seed2 == 0) {
101 // If both seeds are unspecified, use completely random seeds.
102 seed = random::New64();
103 seed2 = random::New64();
104 }
105 parent_generator_ = random::PhiloxRandom(seed, seed2);
106}
107
108Status RandomShuffleQueue::Initialize() {
109 TF_RETURN_IF_ERROR(TypedQueue::Initialize());
110
111 mutex_lock lock(mu_);
112 for (int i = 0; i < num_components(); ++i) {
113 queues_[i].reserve(min_after_dequeue_);
114 }
115 return OkStatus();
116}
117
118void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
119 DCHECK_GT(queues_[0].size(), size_t{0});
120 int64_t index = generator_() % queues_[0].size();
121 (*tuple).reserve(num_components());
122 for (int i = 0; i < num_components(); ++i) {
123 (*tuple).push_back(queues_[i][index]);
124 queues_[i][index] = queues_[i].back();
125 queues_[i].pop_back();
126 }
127}
128
129void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
130 DoneCallback callback) {
131 CancellationManager* cm = ctx->cancellation_manager();
132 CancellationToken token = cm->get_cancellation_token();
133 bool already_cancelled;
134 {
135 mutex_lock l(mu_);
136 already_cancelled = !cm->RegisterCallback(
137 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
138 if (!already_cancelled) {
139 enqueue_attempts_.emplace_back(
140 1, callback, ctx, cm, token,
141 [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
142 if (closed_) {
143 attempt->context->SetStatus(errors::Cancelled(
144 "RandomShuffleQueue '", name_, "' is closed."));
145 return kComplete;
146 }
147 if (queues_[0].size() < static_cast<size_t>(capacity_)) {
148 for (int i = 0; i < num_components(); ++i) {
149 queues_[i].push_back(tuple[i]);
150 }
151 return kComplete;
152 } else {
153 return kNoProgress;
154 }
155 });
156 }
157 }
158 if (!already_cancelled) {
159 FlushUnlocked();
160 } else {
161 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
162 callback();
163 }
164}
165
166/* static */
167Status RandomShuffleQueue::GetElementComponentFromBatch(const Tuple& tuple,
168 int64_t index,
169 int component,
170 OpKernelContext* ctx,
171 Tensor* out_tensor) {
172 TensorShape element_shape(tuple[component].shape());
173 element_shape.RemoveDim(0);
174 TF_RETURN_IF_ERROR(
175 ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor));
176 TF_RETURN_IF_ERROR(
177 batch_util::CopySliceToElement(tuple[component], out_tensor, index));
178 return OkStatus();
179}
180
181void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple,
182 OpKernelContext* ctx,
183 DoneCallback callback) {
184 const int64_t batch_size = tuple[0].dim_size(0);
185 if (batch_size == 0) {
186 callback();
187 return;
188 }
189
190 CancellationManager* cm = ctx->cancellation_manager();
191 CancellationToken token = cm->get_cancellation_token();
192 bool already_cancelled;
193 {
194 mutex_lock l(mu_);
195 already_cancelled = !cm->RegisterCallback(
196 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
197 if (!already_cancelled) {
198 enqueue_attempts_.emplace_back(
199 batch_size, callback, ctx, cm, token,
200 [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
201 if (closed_) {
202 attempt->context->SetStatus(errors::Cancelled(
203 "RandomShuffleQueue '", name_, "' is closed."));
204 return kComplete;
205 }
206 RunResult result = kNoProgress;
207 while (queues_[0].size() < static_cast<size_t>(capacity_)) {
208 result = kProgress;
209 const int index =
210 tuple[0].dim_size(0) - attempt->elements_requested;
211 for (int i = 0; i < num_components(); ++i) {
212 Tensor element;
213 attempt->context->SetStatus(GetElementComponentFromBatch(
214 tuple, index, i, attempt->context, &element));
215 if (!attempt->context->status().ok()) return kComplete;
216 queues_[i].push_back(element);
217 }
218 --attempt->elements_requested;
219 if (attempt->elements_requested == 0) {
220 return kComplete;
221 }
222 }
223 return result;
224 });
225 }
226 }
227 if (!already_cancelled) {
228 FlushUnlocked();
229 } else {
230 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
231 callback();
232 }
233}
234
235void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx,
236 CallbackWithTuple callback) {
237 CancellationManager* cm = ctx->cancellation_manager();
238 CancellationToken token = cm->get_cancellation_token();
239 bool already_cancelled;
240 {
241 mutex_lock l(mu_);
242 already_cancelled = !cm->RegisterCallback(
243 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
244 if (!already_cancelled) {
245 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
246 dequeue_attempts_.emplace_back(
247 1, [callback]() { callback(Tuple()); }, ctx, cm, token,
248 [callback, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
249 int32_t queue_size = queues_[0].size();
250 if (closed_ && queue_size == 0) {
251 attempt->context->SetStatus(errors::OutOfRange(
252 "RandomShuffleQueue '", name_, "' is closed and has ",
253 "insufficient elements (requested ", 1, ", current size ",
254 queue_size, ")"));
255 return kComplete;
256 }
257 if (!closed_) queue_size -= min_after_dequeue_;
258 if (queue_size > 0) {
259 Tuple tuple;
260 DequeueLocked(attempt->context, &tuple);
261 attempt->done_callback = [callback, tuple]() { callback(tuple); };
262 return kComplete;
263 } else {
264 return kNoProgress;
265 }
266 });
267 }
268 }
269 if (!already_cancelled) {
270 FlushUnlocked();
271 } else {
272 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
273 callback(Tuple());
274 }
275}
276
277void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
278 bool allow_small_batch,
279 CallbackWithTuple callback) {
280 if (!specified_shapes()) {
281 ctx->SetStatus(errors::InvalidArgument(
282 "RandomShuffleQueue's DequeueMany and DequeueUpTo require the "
283 "components to have specified shapes."));
284 callback(Tuple());
285 return;
286 }
287 if (num_elements == 0) {
288 Tuple tuple;
289 tuple.reserve(num_components());
290 for (int i = 0; i < num_components(); ++i) {
291 // TODO(josh11b,misard): Switch to allocate_output(). Problem is
292 // this breaks the abstraction boundary since we don't *really*
293 // know if and how the Tensors in the tuple we pass to callback
294 // correspond to the outputs of *ctx. For example, the
295 // ReaderRead Op uses TryDequeue() to get a filename out of a
296 // queue that is used internally by the reader and is not
297 // associated with any output of the ReaderRead.
298 // mrry@ adds:
299 // Maybe we need to pass a std::function<Tensor*(...)> (or
300 // better signature) that calls the appropriate allocator
301 // function in addition to ctx? (Or support a shim Allocator
302 // that has an internal OpKernelContext*, and dispatches to the
303 // appropriate method?)
304 // misard@ adds:
305 // I don't see that a std::function would help. The problem is
306 // that at this point (allocation time) the system doesn't know
307 // what is going to happen to the element read out of the
308 // queue. As long as we keep the generality that TensorFlow Ops
309 // do their own dynamic allocation in arbitrary C++ code, we
310 // need to preserve robustness to allocating output Tensors with
311 // the 'wrong' attributes, and fixing up with a copy. The only
312 // improvement I can see here in the future would be to support
313 // an optimized case where the queue 'knows' what attributes to
314 // use, and plumbs them through here.
315 Tensor element;
316 Status s = ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0),
317 &element);
318 if (!s.ok()) {
319 ctx->SetStatus(s);
320 callback(Tuple());
321 return;
322 }
323 tuple.emplace_back(element);
324 }
325 callback(tuple);
326 return;
327 }
328
329 CancellationManager* cm = ctx->cancellation_manager();
330 CancellationToken token = cm->get_cancellation_token();
331 bool already_cancelled;
332 {
333 mutex_lock l(mu_);
334 already_cancelled = !cm->RegisterCallback(
335 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
336 if (!already_cancelled) {
337 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
338 dequeue_attempts_.emplace_back(
339 num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
340 [callback, allow_small_batch,
341 this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
342 int32_t queue_size = queues_[0].size();
343 if (closed_ && queue_size < attempt->elements_requested) {
344 // If we don't have enough for a full dequeue, we have
345 // to reset the attempt tuple.
346 if (!attempt->tuple.empty()) {
347 // Restore already-dequeued elements to the queue.
348 for (int64_t i = attempt->tuple[0].dim_size(0) -
349 attempt->elements_requested - 1;
350 i >= 0; --i) {
351 for (int j = 0; j < num_components(); ++j) {
352 Tensor element;
353 Status s = GetElementComponentFromBatch(
354 attempt->tuple, i, j, attempt->context, &element);
355 if (!s.ok()) {
356 attempt->context->SetStatus(
357 errors::DataLoss("Failed to restore element from "
358 "partially-dequeued batch "
359 "to RandomShuffleQueue: ",
360 s.error_message()));
361 }
362 queues_[j].push_back(element);
363 }
364 }
365 }
366 if (allow_small_batch && !queues_[0].empty()) {
367 // Request all remaining elements in the queue.
368 queue_size = queues_[0].size();
369 attempt->tuple.clear();
370 attempt->elements_requested = queue_size;
371 } else {
372 if (allow_small_batch) {
373 // There may be some other attempts containing
374 // values. If so, we'll yield and wait for them
375 // to add elements to the queue.
376 if (!enqueue_attempts_.empty()) return kProgress;
377 }
378 if (attempt->context->status().ok()) {
379 attempt->context->SetStatus(errors::OutOfRange(
380 "RandomShuffleQueue '", name_, "' is closed and has ",
381 "insufficient elements (requested ",
382 attempt->elements_requested, ", current size ",
383 queue_size, ")"));
384 }
385 return kComplete;
386 }
387 }
388
389 RunResult result = kNoProgress;
390 if (!closed_) queue_size -= min_after_dequeue_;
391 for (; queue_size > 0; --queue_size) {
392 if (attempt->tuple.empty()) {
393 // Only allocate tuple when we have something to dequeue
394 // so we don't use excessive memory when there are many
395 // blocked dequeue attempts waiting.
396 attempt->tuple.reserve(num_components());
397 for (int i = 0; i < num_components(); ++i) {
398 const TensorShape shape =
399 ManyOutShape(i, attempt->elements_requested);
400 Tensor element;
401 attempt->context->SetStatus(attempt->context->allocate_temp(
402 component_dtypes_[i], shape, &element));
403 if (!attempt->context->status().ok()) return kComplete;
404 attempt->tuple.emplace_back(element);
405 }
406 }
407 result = kProgress;
408 Tuple tuple;
409 DequeueLocked(attempt->context, &tuple);
410 const int index =
411 attempt->tuple[0].dim_size(0) - attempt->elements_requested;
412 for (int i = 0; i < num_components(); ++i) {
413 attempt->context->SetStatus(batch_util::CopyElementToSlice(
414 std::move(tuple[i]), &attempt->tuple[i], index));
415 if (!attempt->context->status().ok()) return kComplete;
416 }
417 tuple.clear();
418 --attempt->elements_requested;
419 if (attempt->elements_requested == 0) {
420 tuple = attempt->tuple;
421 attempt->done_callback = [callback, tuple]() {
422 callback(tuple);
423 };
424 return kComplete;
425 }
426 }
427 return result;
428 });
429 }
430 }
431 if (!already_cancelled) {
432 FlushUnlocked();
433 } else {
434 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
435 callback(Tuple());
436 }
437}
438
439Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
440 if (!MatchesNodeDefOp(node_def, "RandomShuffleQueue").ok() &&
441 !MatchesNodeDefOp(node_def, "RandomShuffleQueueV2").ok()) {
442 return errors::InvalidArgument("Expected RandomShuffleQueue, found ",
443 node_def.op());
444 }
445 TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
446
447 int32_t min_after_dequeue = -1;
448 TF_RETURN_IF_ERROR(
449 GetNodeAttr(node_def, "min_after_dequeue", &min_after_dequeue));
450 if (min_after_dequeue != min_after_dequeue_) {
451 return errors::InvalidArgument(
452 "Shared queue '", name_, "' has min_after_dequeue ", min_after_dequeue_,
453 " but requested min_after_dequeue was ", min_after_dequeue, ".");
454 }
455
456 int64_t seed = -1;
457 int64_t seed2 = -1;
458 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed", &seed));
459 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed2", &seed2));
460 if ((seed != 0 || seed2 != 0) &&
461 (seed != original_seed_ || seed2 != original_seed2_)) {
462 return errors::InvalidArgument(
463 "Shared queue '", name_, "' has random seeds (", original_seed_, ", ",
464 original_seed2_, ") but requested seeds are (", seed, ", ", seed2,
465 ").");
466 }
467
468 TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
469 TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def));
470
471 return OkStatus();
472}
473
474// Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one
475// backed by RandomShuffleQueue) that persists across different graph
476// executions, and sessions. Running this op produces a single-element
477// tensor of handles to Queues in the corresponding device.
478class RandomShuffleQueueOp : public TypedQueueOp {
479 public:
480 explicit RandomShuffleQueueOp(OpKernelConstruction* context)
481 : TypedQueueOp(context) {
482 OP_REQUIRES_OK(context,
483 context->GetAttr("min_after_dequeue", &min_after_dequeue_));
484 OP_REQUIRES(context, min_after_dequeue_ >= 0,
485 errors::InvalidArgument("min_after_dequeue ",
486 min_after_dequeue_, " must be >= 0"));
487 OP_REQUIRES(
488 context, min_after_dequeue_ < capacity_,
489 errors::InvalidArgument("min_after_dequeue ", min_after_dequeue_,
490 " must be < capacity ", capacity_));
491 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
492 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
493
494 OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
495 }
496
497 private:
498 Status CreateResource(QueueInterface** ret) override
499 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
500 RandomShuffleQueue* queue = new RandomShuffleQueue(
501 capacity_, min_after_dequeue_, seed_, seed2_, component_types_,
502 component_shapes_, cinfo_.name());
503 return CreateTypedQueue(queue, ret);
504 }
505
506 int32 min_after_dequeue_;
507 int64_t seed_;
508 int64_t seed2_;
509 std::vector<TensorShape> component_shapes_;
510
511 TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp);
512};
513
514REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU),
515 RandomShuffleQueueOp);
516REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2").Device(DEVICE_CPU),
517 RandomShuffleQueueOp);
518
519} // namespace tensorflow
520