1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/data_flow_ops.cc. |
17 | |
18 | #include <cstddef> |
19 | #include <deque> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/node_def.pb.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/resource_mgr.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/kernels/queue_op.h" |
29 | #include "tensorflow/core/kernels/typed_queue.h" |
30 | #include "tensorflow/core/lib/core/errors.h" |
31 | #include "tensorflow/core/lib/random/philox_random.h" |
32 | #include "tensorflow/core/lib/random/random.h" |
33 | #include "tensorflow/core/lib/random/random_distributions.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/platform/macros.h" |
36 | #include "tensorflow/core/platform/mutex.h" |
37 | #include "tensorflow/core/platform/thread_annotations.h" |
38 | #include "tensorflow/core/platform/types.h" |
39 | #include "tensorflow/core/util/batch_util.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | class RandomShuffleQueue : public TypedQueue<std::vector<Tensor> > { |
44 | public: |
45 | RandomShuffleQueue(int32_t capacity, int32_t min_after_dequeue, int64_t seed, |
46 | int64_t seed2, const DataTypeVector& component_dtypes, |
47 | const std::vector<TensorShape>& component_shapes, |
48 | const string& name); |
49 | |
50 | Status Initialize() override; // Must be called before any other method. |
51 | |
52 | // Implementations of QueueInterface methods -------------------------------- |
53 | void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, |
54 | DoneCallback callback) override; |
55 | void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, |
56 | DoneCallback callback) override; |
57 | void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; |
58 | void TryDequeueMany(int num_elements, OpKernelContext* ctx, |
59 | bool allow_small_batch, |
60 | CallbackWithTuple callback) override; |
61 | Status MatchesNodeDef(const NodeDef& node_def) override; |
62 | |
63 | int32 size() const override { |
64 | mutex_lock lock(mu_); |
65 | return queues_[0].size(); |
66 | } |
67 | |
68 | private: |
69 | ~RandomShuffleQueue() override {} |
70 | |
71 | // Helper for dequeuing a single random element from queues_. |
72 | void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) |
73 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
74 | |
75 | static Status GetElementComponentFromBatch(const Tuple& tuple, int64_t index, |
76 | int component, |
77 | OpKernelContext* ctx, |
78 | Tensor* out_tensor); |
79 | |
80 | const int32 min_after_dequeue_; |
81 | const int64_t original_seed_; |
82 | const int64_t original_seed2_; |
83 | |
84 | random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_); |
85 | random::SingleSampleAdapter<random::PhiloxRandom> generator_ |
86 | TF_GUARDED_BY(mu_); |
87 | |
88 | TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue); |
89 | }; |
90 | |
91 | RandomShuffleQueue::RandomShuffleQueue( |
92 | int32_t capacity, int32_t min_after_dequeue, int64_t seed, int64_t seed2, |
93 | const DataTypeVector& component_dtypes, |
94 | const std::vector<TensorShape>& component_shapes, const string& name) |
95 | : TypedQueue(capacity, component_dtypes, component_shapes, name), |
96 | min_after_dequeue_(min_after_dequeue), |
97 | original_seed_(seed), |
98 | original_seed2_(seed2), |
99 | generator_(&parent_generator_) { |
100 | if (seed == 0 && seed2 == 0) { |
101 | // If both seeds are unspecified, use completely random seeds. |
102 | seed = random::New64(); |
103 | seed2 = random::New64(); |
104 | } |
105 | parent_generator_ = random::PhiloxRandom(seed, seed2); |
106 | } |
107 | |
108 | Status RandomShuffleQueue::Initialize() { |
109 | TF_RETURN_IF_ERROR(TypedQueue::Initialize()); |
110 | |
111 | mutex_lock lock(mu_); |
112 | for (int i = 0; i < num_components(); ++i) { |
113 | queues_[i].reserve(min_after_dequeue_); |
114 | } |
115 | return OkStatus(); |
116 | } |
117 | |
118 | void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) { |
119 | DCHECK_GT(queues_[0].size(), size_t{0}); |
120 | int64_t index = generator_() % queues_[0].size(); |
121 | (*tuple).reserve(num_components()); |
122 | for (int i = 0; i < num_components(); ++i) { |
123 | (*tuple).push_back(queues_[i][index]); |
124 | queues_[i][index] = queues_[i].back(); |
125 | queues_[i].pop_back(); |
126 | } |
127 | } |
128 | |
129 | void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, |
130 | DoneCallback callback) { |
131 | CancellationManager* cm = ctx->cancellation_manager(); |
132 | CancellationToken token = cm->get_cancellation_token(); |
133 | bool already_cancelled; |
134 | { |
135 | mutex_lock l(mu_); |
136 | already_cancelled = !cm->RegisterCallback( |
137 | token, [this, cm, token]() { Cancel(kEnqueue, cm, token); }); |
138 | if (!already_cancelled) { |
139 | enqueue_attempts_.emplace_back( |
140 | 1, callback, ctx, cm, token, |
141 | [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
142 | if (closed_) { |
143 | attempt->context->SetStatus(errors::Cancelled( |
144 | "RandomShuffleQueue '" , name_, "' is closed." )); |
145 | return kComplete; |
146 | } |
147 | if (queues_[0].size() < static_cast<size_t>(capacity_)) { |
148 | for (int i = 0; i < num_components(); ++i) { |
149 | queues_[i].push_back(tuple[i]); |
150 | } |
151 | return kComplete; |
152 | } else { |
153 | return kNoProgress; |
154 | } |
155 | }); |
156 | } |
157 | } |
158 | if (!already_cancelled) { |
159 | FlushUnlocked(); |
160 | } else { |
161 | ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled" )); |
162 | callback(); |
163 | } |
164 | } |
165 | |
166 | /* static */ |
167 | Status RandomShuffleQueue::GetElementComponentFromBatch(const Tuple& tuple, |
168 | int64_t index, |
169 | int component, |
170 | OpKernelContext* ctx, |
171 | Tensor* out_tensor) { |
172 | TensorShape element_shape(tuple[component].shape()); |
173 | element_shape.RemoveDim(0); |
174 | TF_RETURN_IF_ERROR( |
175 | ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor)); |
176 | TF_RETURN_IF_ERROR( |
177 | batch_util::CopySliceToElement(tuple[component], out_tensor, index)); |
178 | return OkStatus(); |
179 | } |
180 | |
181 | void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple, |
182 | OpKernelContext* ctx, |
183 | DoneCallback callback) { |
184 | const int64_t batch_size = tuple[0].dim_size(0); |
185 | if (batch_size == 0) { |
186 | callback(); |
187 | return; |
188 | } |
189 | |
190 | CancellationManager* cm = ctx->cancellation_manager(); |
191 | CancellationToken token = cm->get_cancellation_token(); |
192 | bool already_cancelled; |
193 | { |
194 | mutex_lock l(mu_); |
195 | already_cancelled = !cm->RegisterCallback( |
196 | token, [this, cm, token]() { Cancel(kEnqueue, cm, token); }); |
197 | if (!already_cancelled) { |
198 | enqueue_attempts_.emplace_back( |
199 | batch_size, callback, ctx, cm, token, |
200 | [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
201 | if (closed_) { |
202 | attempt->context->SetStatus(errors::Cancelled( |
203 | "RandomShuffleQueue '" , name_, "' is closed." )); |
204 | return kComplete; |
205 | } |
206 | RunResult result = kNoProgress; |
207 | while (queues_[0].size() < static_cast<size_t>(capacity_)) { |
208 | result = kProgress; |
209 | const int index = |
210 | tuple[0].dim_size(0) - attempt->elements_requested; |
211 | for (int i = 0; i < num_components(); ++i) { |
212 | Tensor element; |
213 | attempt->context->SetStatus(GetElementComponentFromBatch( |
214 | tuple, index, i, attempt->context, &element)); |
215 | if (!attempt->context->status().ok()) return kComplete; |
216 | queues_[i].push_back(element); |
217 | } |
218 | --attempt->elements_requested; |
219 | if (attempt->elements_requested == 0) { |
220 | return kComplete; |
221 | } |
222 | } |
223 | return result; |
224 | }); |
225 | } |
226 | } |
227 | if (!already_cancelled) { |
228 | FlushUnlocked(); |
229 | } else { |
230 | ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled" )); |
231 | callback(); |
232 | } |
233 | } |
234 | |
235 | void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx, |
236 | CallbackWithTuple callback) { |
237 | CancellationManager* cm = ctx->cancellation_manager(); |
238 | CancellationToken token = cm->get_cancellation_token(); |
239 | bool already_cancelled; |
240 | { |
241 | mutex_lock l(mu_); |
242 | already_cancelled = !cm->RegisterCallback( |
243 | token, [this, cm, token]() { Cancel(kDequeue, cm, token); }); |
244 | if (!already_cancelled) { |
245 | // TODO(josh11b): This makes two copies of callback, avoid this if possible. |
246 | dequeue_attempts_.emplace_back( |
247 | 1, [callback]() { callback(Tuple()); }, ctx, cm, token, |
248 | [callback, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
249 | int32_t queue_size = queues_[0].size(); |
250 | if (closed_ && queue_size == 0) { |
251 | attempt->context->SetStatus(errors::OutOfRange( |
252 | "RandomShuffleQueue '" , name_, "' is closed and has " , |
253 | "insufficient elements (requested " , 1, ", current size " , |
254 | queue_size, ")" )); |
255 | return kComplete; |
256 | } |
257 | if (!closed_) queue_size -= min_after_dequeue_; |
258 | if (queue_size > 0) { |
259 | Tuple tuple; |
260 | DequeueLocked(attempt->context, &tuple); |
261 | attempt->done_callback = [callback, tuple]() { callback(tuple); }; |
262 | return kComplete; |
263 | } else { |
264 | return kNoProgress; |
265 | } |
266 | }); |
267 | } |
268 | } |
269 | if (!already_cancelled) { |
270 | FlushUnlocked(); |
271 | } else { |
272 | ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled" )); |
273 | callback(Tuple()); |
274 | } |
275 | } |
276 | |
277 | void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, |
278 | bool allow_small_batch, |
279 | CallbackWithTuple callback) { |
280 | if (!specified_shapes()) { |
281 | ctx->SetStatus(errors::InvalidArgument( |
282 | "RandomShuffleQueue's DequeueMany and DequeueUpTo require the " |
283 | "components to have specified shapes." )); |
284 | callback(Tuple()); |
285 | return; |
286 | } |
287 | if (num_elements == 0) { |
288 | Tuple tuple; |
289 | tuple.reserve(num_components()); |
290 | for (int i = 0; i < num_components(); ++i) { |
291 | // TODO(josh11b,misard): Switch to allocate_output(). Problem is |
292 | // this breaks the abstraction boundary since we don't *really* |
293 | // know if and how the Tensors in the tuple we pass to callback |
294 | // correspond to the outputs of *ctx. For example, the |
295 | // ReaderRead Op uses TryDequeue() to get a filename out of a |
296 | // queue that is used internally by the reader and is not |
297 | // associated with any output of the ReaderRead. |
298 | // mrry@ adds: |
299 | // Maybe we need to pass a std::function<Tensor*(...)> (or |
300 | // better signature) that calls the appropriate allocator |
301 | // function in addition to ctx? (Or support a shim Allocator |
302 | // that has an internal OpKernelContext*, and dispatches to the |
303 | // appropriate method?) |
304 | // misard@ adds: |
305 | // I don't see that a std::function would help. The problem is |
306 | // that at this point (allocation time) the system doesn't know |
307 | // what is going to happen to the element read out of the |
308 | // queue. As long as we keep the generality that TensorFlow Ops |
309 | // do their own dynamic allocation in arbitrary C++ code, we |
310 | // need to preserve robustness to allocating output Tensors with |
311 | // the 'wrong' attributes, and fixing up with a copy. The only |
312 | // improvement I can see here in the future would be to support |
313 | // an optimized case where the queue 'knows' what attributes to |
314 | // use, and plumbs them through here. |
315 | Tensor element; |
316 | Status s = ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), |
317 | &element); |
318 | if (!s.ok()) { |
319 | ctx->SetStatus(s); |
320 | callback(Tuple()); |
321 | return; |
322 | } |
323 | tuple.emplace_back(element); |
324 | } |
325 | callback(tuple); |
326 | return; |
327 | } |
328 | |
329 | CancellationManager* cm = ctx->cancellation_manager(); |
330 | CancellationToken token = cm->get_cancellation_token(); |
331 | bool already_cancelled; |
332 | { |
333 | mutex_lock l(mu_); |
334 | already_cancelled = !cm->RegisterCallback( |
335 | token, [this, cm, token]() { Cancel(kDequeue, cm, token); }); |
336 | if (!already_cancelled) { |
337 | // TODO(josh11b): This makes two copies of callback, avoid this if possible. |
338 | dequeue_attempts_.emplace_back( |
339 | num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token, |
340 | [callback, allow_small_batch, |
341 | this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
342 | int32_t queue_size = queues_[0].size(); |
343 | if (closed_ && queue_size < attempt->elements_requested) { |
344 | // If we don't have enough for a full dequeue, we have |
345 | // to reset the attempt tuple. |
346 | if (!attempt->tuple.empty()) { |
347 | // Restore already-dequeued elements to the queue. |
348 | for (int64_t i = attempt->tuple[0].dim_size(0) - |
349 | attempt->elements_requested - 1; |
350 | i >= 0; --i) { |
351 | for (int j = 0; j < num_components(); ++j) { |
352 | Tensor element; |
353 | Status s = GetElementComponentFromBatch( |
354 | attempt->tuple, i, j, attempt->context, &element); |
355 | if (!s.ok()) { |
356 | attempt->context->SetStatus( |
357 | errors::DataLoss("Failed to restore element from " |
358 | "partially-dequeued batch " |
359 | "to RandomShuffleQueue: " , |
360 | s.error_message())); |
361 | } |
362 | queues_[j].push_back(element); |
363 | } |
364 | } |
365 | } |
366 | if (allow_small_batch && !queues_[0].empty()) { |
367 | // Request all remaining elements in the queue. |
368 | queue_size = queues_[0].size(); |
369 | attempt->tuple.clear(); |
370 | attempt->elements_requested = queue_size; |
371 | } else { |
372 | if (allow_small_batch) { |
373 | // There may be some other attempts containing |
374 | // values. If so, we'll yield and wait for them |
375 | // to add elements to the queue. |
376 | if (!enqueue_attempts_.empty()) return kProgress; |
377 | } |
378 | if (attempt->context->status().ok()) { |
379 | attempt->context->SetStatus(errors::OutOfRange( |
380 | "RandomShuffleQueue '" , name_, "' is closed and has " , |
381 | "insufficient elements (requested " , |
382 | attempt->elements_requested, ", current size " , |
383 | queue_size, ")" )); |
384 | } |
385 | return kComplete; |
386 | } |
387 | } |
388 | |
389 | RunResult result = kNoProgress; |
390 | if (!closed_) queue_size -= min_after_dequeue_; |
391 | for (; queue_size > 0; --queue_size) { |
392 | if (attempt->tuple.empty()) { |
393 | // Only allocate tuple when we have something to dequeue |
394 | // so we don't use excessive memory when there are many |
395 | // blocked dequeue attempts waiting. |
396 | attempt->tuple.reserve(num_components()); |
397 | for (int i = 0; i < num_components(); ++i) { |
398 | const TensorShape shape = |
399 | ManyOutShape(i, attempt->elements_requested); |
400 | Tensor element; |
401 | attempt->context->SetStatus(attempt->context->allocate_temp( |
402 | component_dtypes_[i], shape, &element)); |
403 | if (!attempt->context->status().ok()) return kComplete; |
404 | attempt->tuple.emplace_back(element); |
405 | } |
406 | } |
407 | result = kProgress; |
408 | Tuple tuple; |
409 | DequeueLocked(attempt->context, &tuple); |
410 | const int index = |
411 | attempt->tuple[0].dim_size(0) - attempt->elements_requested; |
412 | for (int i = 0; i < num_components(); ++i) { |
413 | attempt->context->SetStatus(batch_util::CopyElementToSlice( |
414 | std::move(tuple[i]), &attempt->tuple[i], index)); |
415 | if (!attempt->context->status().ok()) return kComplete; |
416 | } |
417 | tuple.clear(); |
418 | --attempt->elements_requested; |
419 | if (attempt->elements_requested == 0) { |
420 | tuple = attempt->tuple; |
421 | attempt->done_callback = [callback, tuple]() { |
422 | callback(tuple); |
423 | }; |
424 | return kComplete; |
425 | } |
426 | } |
427 | return result; |
428 | }); |
429 | } |
430 | } |
431 | if (!already_cancelled) { |
432 | FlushUnlocked(); |
433 | } else { |
434 | ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled" )); |
435 | callback(Tuple()); |
436 | } |
437 | } |
438 | |
439 | Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { |
440 | if (!MatchesNodeDefOp(node_def, "RandomShuffleQueue" ).ok() && |
441 | !MatchesNodeDefOp(node_def, "RandomShuffleQueueV2" ).ok()) { |
442 | return errors::InvalidArgument("Expected RandomShuffleQueue, found " , |
443 | node_def.op()); |
444 | } |
445 | TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); |
446 | |
447 | int32_t min_after_dequeue = -1; |
448 | TF_RETURN_IF_ERROR( |
449 | GetNodeAttr(node_def, "min_after_dequeue" , &min_after_dequeue)); |
450 | if (min_after_dequeue != min_after_dequeue_) { |
451 | return errors::InvalidArgument( |
452 | "Shared queue '" , name_, "' has min_after_dequeue " , min_after_dequeue_, |
453 | " but requested min_after_dequeue was " , min_after_dequeue, "." ); |
454 | } |
455 | |
456 | int64_t seed = -1; |
457 | int64_t seed2 = -1; |
458 | TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed" , &seed)); |
459 | TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed2" , &seed2)); |
460 | if ((seed != 0 || seed2 != 0) && |
461 | (seed != original_seed_ || seed2 != original_seed2_)) { |
462 | return errors::InvalidArgument( |
463 | "Shared queue '" , name_, "' has random seeds (" , original_seed_, ", " , |
464 | original_seed2_, ") but requested seeds are (" , seed, ", " , seed2, |
465 | ")." ); |
466 | } |
467 | |
468 | TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); |
469 | TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def)); |
470 | |
471 | return OkStatus(); |
472 | } |
473 | |
474 | // Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one |
475 | // backed by RandomShuffleQueue) that persists across different graph |
476 | // executions, and sessions. Running this op produces a single-element |
477 | // tensor of handles to Queues in the corresponding device. |
478 | class RandomShuffleQueueOp : public TypedQueueOp { |
479 | public: |
480 | explicit RandomShuffleQueueOp(OpKernelConstruction* context) |
481 | : TypedQueueOp(context) { |
482 | OP_REQUIRES_OK(context, |
483 | context->GetAttr("min_after_dequeue" , &min_after_dequeue_)); |
484 | OP_REQUIRES(context, min_after_dequeue_ >= 0, |
485 | errors::InvalidArgument("min_after_dequeue " , |
486 | min_after_dequeue_, " must be >= 0" )); |
487 | OP_REQUIRES( |
488 | context, min_after_dequeue_ < capacity_, |
489 | errors::InvalidArgument("min_after_dequeue " , min_after_dequeue_, |
490 | " must be < capacity " , capacity_)); |
491 | OP_REQUIRES_OK(context, context->GetAttr("seed" , &seed_)); |
492 | OP_REQUIRES_OK(context, context->GetAttr("seed2" , &seed2_)); |
493 | |
494 | OP_REQUIRES_OK(context, context->GetAttr("shapes" , &component_shapes_)); |
495 | } |
496 | |
497 | private: |
498 | Status CreateResource(QueueInterface** ret) override |
499 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
500 | RandomShuffleQueue* queue = new RandomShuffleQueue( |
501 | capacity_, min_after_dequeue_, seed_, seed2_, component_types_, |
502 | component_shapes_, cinfo_.name()); |
503 | return CreateTypedQueue(queue, ret); |
504 | } |
505 | |
506 | int32 min_after_dequeue_; |
507 | int64_t seed_; |
508 | int64_t seed2_; |
509 | std::vector<TensorShape> component_shapes_; |
510 | |
511 | TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp); |
512 | }; |
513 | |
514 | REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue" ).Device(DEVICE_CPU), |
515 | RandomShuffleQueueOp); |
516 | REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2" ).Device(DEVICE_CPU), |
517 | RandomShuffleQueueOp); |
518 | |
519 | } // namespace tensorflow |
520 | |