1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// See docs in ../ops/data_flow_ops.cc.
16
17#include "tensorflow/core/kernels/priority_queue.h"
18
19#include <deque>
20#include <queue>
21#include <vector>
22
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/kernels/queue_base.h"
28#include "tensorflow/core/lib/core/errors.h"
29#include "tensorflow/core/lib/gtl/priority_queue_util.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/mutex.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/util/batch_util.h"
34
35namespace tensorflow {
36
37PriorityQueue::PriorityQueue(int32_t capacity,
38 const DataTypeVector& component_dtypes,
39 const std::vector<TensorShape>& component_shapes,
40 const string& name)
41 : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
42
43Status PriorityQueue::Initialize() {
44 Status s = TypedQueue::Initialize();
45 if (!s.ok()) return s;
46
47 mutex_lock lock(mu_);
48 if (component_dtypes_[0] != DT_INT64) {
49 return errors::InvalidArgument(
50 "PriorityQueue priority index component must be type int64, but "
51 "dtype is: ",
52 DataTypeString(component_dtypes_[0]));
53 }
54 if (specified_shapes() && !TensorShapeUtils::IsScalar(component_shapes_[0])) {
55 return errors::InvalidArgument(
56 "PriorityQueue priority index component must be a scalar, but shape "
57 "is: ",
58 component_shapes_[0].DebugString());
59 }
60 return OkStatus();
61}
62
63void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
64 DCHECK_GT(queues_[0].size(), 0);
65 (*tuple).reserve(num_components());
66 for (int i = 0; i < num_components(); ++i) {
67 Tensor tensor = gtl::ConsumeTop(&queues_[i]).second;
68 (*tuple).push_back(tensor);
69 }
70}
71
72void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
73 DoneCallback callback) {
74 CancellationManager* cm = ctx->cancellation_manager();
75 CancellationToken token = cm->get_cancellation_token();
76 bool already_cancelled;
77 {
78 mutex_lock l(mu_);
79 already_cancelled = !cm->RegisterCallback(
80 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
81 if (!already_cancelled) {
82 enqueue_attempts_.emplace_back(
83 1, callback, ctx, cm, token,
84 [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
85 if (closed_) {
86 attempt->context->SetStatus(
87 errors::Cancelled("PriorityQueue '", name_, "' is closed."));
88 return kComplete;
89 }
90 if (queues_[0].size() < static_cast<size_t>(capacity_)) {
91 if (!TensorShapeUtils::IsScalar(tuple[0].shape())) {
92 attempt->context->SetStatus(errors::InvalidArgument(
93 "Expected the priority element to be a scalar, but "
94 "received shape: ",
95 tuple[0].shape().DebugString()));
96 return kComplete;
97 }
98 const int64_t priority = tuple[0].scalar<int64_t>()();
99 for (int i = 0; i < num_components(); ++i) {
100 queues_[i].emplace(priority, tuple[i]);
101 }
102 return kComplete;
103 } else {
104 return kNoProgress;
105 }
106 });
107 }
108 }
109 if (!already_cancelled) {
110 FlushUnlocked();
111 } else {
112 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
113 callback();
114 }
115}
116
117/* static */
118Status PriorityQueue::GetElementComponentFromBatch(
119 const PriorityQueue::Tuple& tuple, int index, int component,
120 OpKernelContext* ctx, Tensor* out_element) {
121 TensorShape element_shape(tuple[component].shape());
122 element_shape.RemoveDim(0);
123 TF_RETURN_IF_ERROR(
124 ctx->allocate_temp(tuple[component].dtype(), element_shape, out_element));
125 TF_RETURN_IF_ERROR(
126 batch_util::CopySliceToElement(tuple[component], out_element, index));
127 return OkStatus();
128}
129
130void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
131 DoneCallback callback) {
132 const int64_t batch_size = tuple[0].dim_size(0);
133 if (batch_size == 0) {
134 callback();
135 return;
136 }
137
138 CancellationManager* cm = ctx->cancellation_manager();
139 CancellationToken token = cm->get_cancellation_token();
140 bool already_cancelled;
141 {
142 mutex_lock l(mu_);
143 already_cancelled = !cm->RegisterCallback(
144 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
145 if (!already_cancelled) {
146 enqueue_attempts_.emplace_back(
147 batch_size, callback, ctx, cm, token,
148 [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
149 if (closed_) {
150 attempt->context->SetStatus(
151 errors::Cancelled("PriorityQueue '", name_, "' is closed."));
152 return kComplete;
153 }
154 RunResult result = kNoProgress;
155 while (queues_[0].size() < static_cast<size_t>(capacity_)) {
156 result = kProgress;
157 const int index =
158 tuple[0].dim_size(0) - attempt->elements_requested;
159
160 Tensor priority_element;
161 attempt->context->SetStatus(GetElementComponentFromBatch(
162 tuple, index, 0, attempt->context, &priority_element));
163 if (!attempt->context->status().ok()) return kComplete;
164 if (!TensorShapeUtils::IsScalar(priority_element.shape())) {
165 attempt->context->SetStatus(errors::InvalidArgument(
166 "Expected the priority element to be a scalar, but "
167 "received shape: ",
168 priority_element.shape().DebugString()));
169 return kComplete;
170 }
171 const int64_t priority = priority_element.scalar<int64_t>()();
172 for (int i = 0; i < num_components(); ++i) {
173 Tensor element;
174 attempt->context->SetStatus(GetElementComponentFromBatch(
175 tuple, index, i, attempt->context, &element));
176 if (!attempt->context->status().ok()) return kComplete;
177 queues_[i].emplace(priority, element);
178 }
179 --attempt->elements_requested;
180 if (attempt->elements_requested == 0) {
181 return kComplete;
182 }
183 }
184 return result;
185 });
186 }
187 }
188 if (!already_cancelled) {
189 FlushUnlocked();
190 } else {
191 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
192 callback();
193 }
194}
195
196void PriorityQueue::TryDequeue(OpKernelContext* ctx,
197 CallbackWithTuple callback) {
198 CancellationManager* cm = ctx->cancellation_manager();
199 CancellationToken token = cm->get_cancellation_token();
200 bool already_cancelled;
201 {
202 mutex_lock l(mu_);
203 already_cancelled = !cm->RegisterCallback(
204 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
205 if (!already_cancelled) {
206 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
207 dequeue_attempts_.emplace_back(
208 1, [callback]() { callback(Tuple()); }, ctx, cm, token,
209 [callback, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
210 const int32_t s = queues_[0].size();
211 if (closed_ && s == 0) {
212 attempt->context->SetStatus(errors::OutOfRange(
213 "PriorityQueue '", name_, "' is closed and has ",
214 "insufficient elements (requested ", 1, ", current size ", s,
215 ")"));
216 return kComplete;
217 }
218 if (s > 0) {
219 Tuple tuple;
220 DequeueLocked(attempt->context, &tuple);
221 attempt->done_callback = [callback, tuple]() { callback(tuple); };
222 return kComplete;
223 } else {
224 return kNoProgress;
225 }
226 });
227 }
228 }
229 if (!already_cancelled) {
230 FlushUnlocked();
231 } else {
232 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
233 callback(Tuple());
234 }
235}
236
237void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
238 bool allow_small_batch,
239 CallbackWithTuple callback) {
240 if (!specified_shapes()) {
241 ctx->SetStatus(
242 errors::InvalidArgument("PriorityQueue's DequeueMany requires the "
243 "components to have specified shapes."));
244 callback(Tuple());
245 return;
246 }
247 if (num_elements == 0) {
248 Tuple tuple;
249 tuple.reserve(num_components());
250 for (int i = 0; i < num_components(); ++i) {
251 // TODO(josh11b,misard): Switch to allocate_output(). Problem is
252 // this breaks the abstraction boundary since we don't *really*
253 // know if and how the Tensors in the tuple we pass to callback
254 // correspond to the outputs of *ctx. For example, the
255 // ReaderRead Op uses TryDequeue() to get a filename out of a
256 // queue that is used internally by the reader and is not
257 // associated with any output of the ReaderRead.
258 // mrry@ adds:
259 // Maybe we need to pass a std::function<Tensor*(...)> (or
260 // better signature) that calls the appropriate allocator
261 // function in addition to ctx? (Or support a shim Allocator
262 // that has an internal OpKernelContext*, and dispatches to the
263 // appropriate method?)
264 // misard@ adds:
265 // I don't see that a std::function would help. The problem is
266 // that at this point (allocation time) the system doesn't know
267 // what is going to happen to the element read out of the
268 // queue. As long as we keep the generality that TensorFlow Ops
269 // do their own dynamic allocation in arbitrary C++ code, we
270 // need to preserve robustness to allocating output Tensors with
271 // the 'wrong' attributes, and fixing up with a copy. The only
272 // improvement I can see here in the future would be to support
273 // an optimized case where the queue 'knows' what attributes to
274 // use, and plumbs them through here.
275 Tensor element;
276 Status status = ctx->allocate_temp(component_dtypes_[i],
277 ManyOutShape(i, 0), &element);
278 if (!status.ok()) {
279 ctx->SetStatus(status);
280 callback(Tuple());
281 return;
282 }
283 tuple.emplace_back(element);
284 }
285 callback(tuple);
286 return;
287 }
288
289 CancellationManager* cm = ctx->cancellation_manager();
290 CancellationToken token = cm->get_cancellation_token();
291 bool already_cancelled;
292 {
293 mutex_lock l(mu_);
294 already_cancelled = !cm->RegisterCallback(
295 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
296 if (!already_cancelled) {
297 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
298 dequeue_attempts_.emplace_back(
299 num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
300 [callback, this, allow_small_batch](
301 Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
302 int32_t s = queues_[0].size();
303 // Return OutOfRange if closed and there are fewer elements
304 // available than requested. *Unless* allow_small_batch
305 // is true, in which case we return as many elements as
306 // possible.
307 if (closed_) {
308 if (s == 0 ||
309 (!allow_small_batch && s < attempt->elements_requested)) {
310 attempt->context->SetStatus(errors::OutOfRange(
311 "PriorityQueue '", name_, "' is closed and has ",
312 "insufficient elements (requested ",
313 attempt->elements_requested, ", current size ", s, ")"));
314 return kComplete;
315 }
316 }
317
318 // The PriorityQueue is expected to always return a
319 // sorted set of entries. In order to do this, the underlying
320 // queue must have at least this many entries already.
321 // Doing the dynamic thing and pulling out a portion at a
322 // time leads to unordered output in calls to DequeueMany.
323 //
324 // An alternative solution is to store the attempt tuple
325 // entries in an identical priority_queue and push onto
326 // this queue dynamically, then when it is full, do all
327 // the Tensor concatenation at the very end.
328 // TODO(ebrevdo): Change approach if this leads to locking issues.
329 if (s < attempt->elements_requested) {
330 // If we have no elements at all, then wait.
331 // Otherwise proceed if closed and allow small batch is true.
332 // Otherwise wait until we have more enqueued elements.
333 if (s == 0 || !(closed_ && allow_small_batch)) {
334 return kNoProgress;
335 }
336 }
337
338 RunResult result = kNoProgress;
339 for (; s > 0; --s) {
340 if (attempt->tuple.empty()) {
341 // Only allocate tuple when we have something to dequeue
342 // so we don't use excessive memory when there are many
343 // blocked dequeue attempts waiting.
344 attempt->tuple.reserve(num_components());
345 for (int i = 0; i < num_components(); ++i) {
346 const TensorShape shape =
347 ManyOutShape(i, attempt->elements_requested);
348 Tensor element;
349 attempt->context->SetStatus(attempt->context->allocate_temp(
350 component_dtypes_[i], shape, &element));
351 if (!attempt->context->status().ok()) return kComplete;
352 attempt->tuple.emplace_back(element);
353 }
354 }
355 result = kProgress;
356 Tuple tuple;
357 DequeueLocked(attempt->context, &tuple);
358 const int index =
359 attempt->tuple[0].dim_size(0) - attempt->elements_requested;
360 for (int i = 0; i < num_components(); ++i) {
361 attempt->context->SetStatus(batch_util::CopyElementToSlice(
362 std::move(tuple[i]), &attempt->tuple[i], index));
363 if (!attempt->context->status().ok()) return kComplete;
364 }
365 tuple.clear();
366 --attempt->elements_requested;
367 if (attempt->elements_requested == 0) {
368 tuple = attempt->tuple;
369 attempt->done_callback = [callback, tuple]() {
370 callback(tuple);
371 };
372 return kComplete;
373 }
374 }
375 return result;
376 });
377 }
378 }
379 if (!already_cancelled) {
380 FlushUnlocked();
381 } else {
382 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
383 callback(Tuple());
384 }
385}
386
387Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) {
388 if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() &&
389 !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) {
390 return errors::InvalidArgument("Expected PriorityQueue, found ",
391 node_def.op());
392 }
393 TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
394 TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def));
395 TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def));
396 return OkStatus();
397}
398
399Status PriorityQueue::MatchesPriorityNodeDefTypes(
400 const NodeDef& node_def) const {
401 DataTypeVector requested_dtypes;
402 TF_RETURN_IF_ERROR(
403 GetNodeAttr(node_def, "component_types", &requested_dtypes));
404 requested_dtypes.insert(requested_dtypes.begin(), DT_INT64);
405 if (requested_dtypes != component_dtypes_) {
406 return errors::InvalidArgument("Shared queue '", name_,
407 "' has component types ",
408 DataTypeSliceString(component_dtypes_),
409 " but requested component types were ",
410 DataTypeSliceString(requested_dtypes));
411 }
412 return OkStatus();
413}
414
415Status PriorityQueue::MatchesPriorityNodeDefShapes(
416 const NodeDef& node_def) const {
417 std::vector<TensorShape> requested_shapes;
418 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
419 requested_shapes.insert(requested_shapes.begin(), TensorShape({}));
420 if (requested_shapes != component_shapes_) {
421 return errors::InvalidArgument("Shared queue '", name_,
422 "' has component shapes ",
423 ShapeListString(component_shapes_),
424 " but requested component shapes were ",
425 ShapeListString(requested_shapes));
426 }
427 return OkStatus();
428}
429
430} // namespace tensorflow
431