1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | // See docs in ../ops/data_flow_ops.cc. |
16 | |
17 | #include "tensorflow/core/kernels/priority_queue.h" |
18 | |
19 | #include <deque> |
20 | #include <queue> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/kernels/queue_base.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/lib/gtl/priority_queue_util.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | #include "tensorflow/core/util/batch_util.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | PriorityQueue::PriorityQueue(int32_t capacity, |
38 | const DataTypeVector& component_dtypes, |
39 | const std::vector<TensorShape>& component_shapes, |
40 | const string& name) |
41 | : TypedQueue(capacity, component_dtypes, component_shapes, name) {} |
42 | |
43 | Status PriorityQueue::Initialize() { |
44 | Status s = TypedQueue::Initialize(); |
45 | if (!s.ok()) return s; |
46 | |
47 | mutex_lock lock(mu_); |
48 | if (component_dtypes_[0] != DT_INT64) { |
49 | return errors::InvalidArgument( |
50 | "PriorityQueue priority index component must be type int64, but " |
51 | "dtype is: " , |
52 | DataTypeString(component_dtypes_[0])); |
53 | } |
54 | if (specified_shapes() && !TensorShapeUtils::IsScalar(component_shapes_[0])) { |
55 | return errors::InvalidArgument( |
56 | "PriorityQueue priority index component must be a scalar, but shape " |
57 | "is: " , |
58 | component_shapes_[0].DebugString()); |
59 | } |
60 | return OkStatus(); |
61 | } |
62 | |
63 | void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) { |
64 | DCHECK_GT(queues_[0].size(), 0); |
65 | (*tuple).reserve(num_components()); |
66 | for (int i = 0; i < num_components(); ++i) { |
67 | Tensor tensor = gtl::ConsumeTop(&queues_[i]).second; |
68 | (*tuple).push_back(tensor); |
69 | } |
70 | } |
71 | |
72 | void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, |
73 | DoneCallback callback) { |
74 | CancellationManager* cm = ctx->cancellation_manager(); |
75 | CancellationToken token = cm->get_cancellation_token(); |
76 | bool already_cancelled; |
77 | { |
78 | mutex_lock l(mu_); |
79 | already_cancelled = !cm->RegisterCallback( |
80 | token, [this, cm, token]() { Cancel(kEnqueue, cm, token); }); |
81 | if (!already_cancelled) { |
82 | enqueue_attempts_.emplace_back( |
83 | 1, callback, ctx, cm, token, |
84 | [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
85 | if (closed_) { |
86 | attempt->context->SetStatus( |
87 | errors::Cancelled("PriorityQueue '" , name_, "' is closed." )); |
88 | return kComplete; |
89 | } |
90 | if (queues_[0].size() < static_cast<size_t>(capacity_)) { |
91 | if (!TensorShapeUtils::IsScalar(tuple[0].shape())) { |
92 | attempt->context->SetStatus(errors::InvalidArgument( |
93 | "Expected the priority element to be a scalar, but " |
94 | "received shape: " , |
95 | tuple[0].shape().DebugString())); |
96 | return kComplete; |
97 | } |
98 | const int64_t priority = tuple[0].scalar<int64_t>()(); |
99 | for (int i = 0; i < num_components(); ++i) { |
100 | queues_[i].emplace(priority, tuple[i]); |
101 | } |
102 | return kComplete; |
103 | } else { |
104 | return kNoProgress; |
105 | } |
106 | }); |
107 | } |
108 | } |
109 | if (!already_cancelled) { |
110 | FlushUnlocked(); |
111 | } else { |
112 | ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled" )); |
113 | callback(); |
114 | } |
115 | } |
116 | |
117 | /* static */ |
118 | Status PriorityQueue::GetElementComponentFromBatch( |
119 | const PriorityQueue::Tuple& tuple, int index, int component, |
120 | OpKernelContext* ctx, Tensor* out_element) { |
121 | TensorShape element_shape(tuple[component].shape()); |
122 | element_shape.RemoveDim(0); |
123 | TF_RETURN_IF_ERROR( |
124 | ctx->allocate_temp(tuple[component].dtype(), element_shape, out_element)); |
125 | TF_RETURN_IF_ERROR( |
126 | batch_util::CopySliceToElement(tuple[component], out_element, index)); |
127 | return OkStatus(); |
128 | } |
129 | |
130 | void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, |
131 | DoneCallback callback) { |
132 | const int64_t batch_size = tuple[0].dim_size(0); |
133 | if (batch_size == 0) { |
134 | callback(); |
135 | return; |
136 | } |
137 | |
138 | CancellationManager* cm = ctx->cancellation_manager(); |
139 | CancellationToken token = cm->get_cancellation_token(); |
140 | bool already_cancelled; |
141 | { |
142 | mutex_lock l(mu_); |
143 | already_cancelled = !cm->RegisterCallback( |
144 | token, [this, cm, token]() { Cancel(kEnqueue, cm, token); }); |
145 | if (!already_cancelled) { |
146 | enqueue_attempts_.emplace_back( |
147 | batch_size, callback, ctx, cm, token, |
148 | [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
149 | if (closed_) { |
150 | attempt->context->SetStatus( |
151 | errors::Cancelled("PriorityQueue '" , name_, "' is closed." )); |
152 | return kComplete; |
153 | } |
154 | RunResult result = kNoProgress; |
155 | while (queues_[0].size() < static_cast<size_t>(capacity_)) { |
156 | result = kProgress; |
157 | const int index = |
158 | tuple[0].dim_size(0) - attempt->elements_requested; |
159 | |
160 | Tensor priority_element; |
161 | attempt->context->SetStatus(GetElementComponentFromBatch( |
162 | tuple, index, 0, attempt->context, &priority_element)); |
163 | if (!attempt->context->status().ok()) return kComplete; |
164 | if (!TensorShapeUtils::IsScalar(priority_element.shape())) { |
165 | attempt->context->SetStatus(errors::InvalidArgument( |
166 | "Expected the priority element to be a scalar, but " |
167 | "received shape: " , |
168 | priority_element.shape().DebugString())); |
169 | return kComplete; |
170 | } |
171 | const int64_t priority = priority_element.scalar<int64_t>()(); |
172 | for (int i = 0; i < num_components(); ++i) { |
173 | Tensor element; |
174 | attempt->context->SetStatus(GetElementComponentFromBatch( |
175 | tuple, index, i, attempt->context, &element)); |
176 | if (!attempt->context->status().ok()) return kComplete; |
177 | queues_[i].emplace(priority, element); |
178 | } |
179 | --attempt->elements_requested; |
180 | if (attempt->elements_requested == 0) { |
181 | return kComplete; |
182 | } |
183 | } |
184 | return result; |
185 | }); |
186 | } |
187 | } |
188 | if (!already_cancelled) { |
189 | FlushUnlocked(); |
190 | } else { |
191 | ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled" )); |
192 | callback(); |
193 | } |
194 | } |
195 | |
196 | void PriorityQueue::TryDequeue(OpKernelContext* ctx, |
197 | CallbackWithTuple callback) { |
198 | CancellationManager* cm = ctx->cancellation_manager(); |
199 | CancellationToken token = cm->get_cancellation_token(); |
200 | bool already_cancelled; |
201 | { |
202 | mutex_lock l(mu_); |
203 | already_cancelled = !cm->RegisterCallback( |
204 | token, [this, cm, token]() { Cancel(kDequeue, cm, token); }); |
205 | if (!already_cancelled) { |
206 | // TODO(josh11b): This makes two copies of callback, avoid this if possible. |
207 | dequeue_attempts_.emplace_back( |
208 | 1, [callback]() { callback(Tuple()); }, ctx, cm, token, |
209 | [callback, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
210 | const int32_t s = queues_[0].size(); |
211 | if (closed_ && s == 0) { |
212 | attempt->context->SetStatus(errors::OutOfRange( |
213 | "PriorityQueue '" , name_, "' is closed and has " , |
214 | "insufficient elements (requested " , 1, ", current size " , s, |
215 | ")" )); |
216 | return kComplete; |
217 | } |
218 | if (s > 0) { |
219 | Tuple tuple; |
220 | DequeueLocked(attempt->context, &tuple); |
221 | attempt->done_callback = [callback, tuple]() { callback(tuple); }; |
222 | return kComplete; |
223 | } else { |
224 | return kNoProgress; |
225 | } |
226 | }); |
227 | } |
228 | } |
229 | if (!already_cancelled) { |
230 | FlushUnlocked(); |
231 | } else { |
232 | ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled" )); |
233 | callback(Tuple()); |
234 | } |
235 | } |
236 | |
237 | void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, |
238 | bool allow_small_batch, |
239 | CallbackWithTuple callback) { |
240 | if (!specified_shapes()) { |
241 | ctx->SetStatus( |
242 | errors::InvalidArgument("PriorityQueue's DequeueMany requires the " |
243 | "components to have specified shapes." )); |
244 | callback(Tuple()); |
245 | return; |
246 | } |
247 | if (num_elements == 0) { |
248 | Tuple tuple; |
249 | tuple.reserve(num_components()); |
250 | for (int i = 0; i < num_components(); ++i) { |
251 | // TODO(josh11b,misard): Switch to allocate_output(). Problem is |
252 | // this breaks the abstraction boundary since we don't *really* |
253 | // know if and how the Tensors in the tuple we pass to callback |
254 | // correspond to the outputs of *ctx. For example, the |
255 | // ReaderRead Op uses TryDequeue() to get a filename out of a |
256 | // queue that is used internally by the reader and is not |
257 | // associated with any output of the ReaderRead. |
258 | // mrry@ adds: |
259 | // Maybe we need to pass a std::function<Tensor*(...)> (or |
260 | // better signature) that calls the appropriate allocator |
261 | // function in addition to ctx? (Or support a shim Allocator |
262 | // that has an internal OpKernelContext*, and dispatches to the |
263 | // appropriate method?) |
264 | // misard@ adds: |
265 | // I don't see that a std::function would help. The problem is |
266 | // that at this point (allocation time) the system doesn't know |
267 | // what is going to happen to the element read out of the |
268 | // queue. As long as we keep the generality that TensorFlow Ops |
269 | // do their own dynamic allocation in arbitrary C++ code, we |
270 | // need to preserve robustness to allocating output Tensors with |
271 | // the 'wrong' attributes, and fixing up with a copy. The only |
272 | // improvement I can see here in the future would be to support |
273 | // an optimized case where the queue 'knows' what attributes to |
274 | // use, and plumbs them through here. |
275 | Tensor element; |
276 | Status status = ctx->allocate_temp(component_dtypes_[i], |
277 | ManyOutShape(i, 0), &element); |
278 | if (!status.ok()) { |
279 | ctx->SetStatus(status); |
280 | callback(Tuple()); |
281 | return; |
282 | } |
283 | tuple.emplace_back(element); |
284 | } |
285 | callback(tuple); |
286 | return; |
287 | } |
288 | |
289 | CancellationManager* cm = ctx->cancellation_manager(); |
290 | CancellationToken token = cm->get_cancellation_token(); |
291 | bool already_cancelled; |
292 | { |
293 | mutex_lock l(mu_); |
294 | already_cancelled = !cm->RegisterCallback( |
295 | token, [this, cm, token]() { Cancel(kDequeue, cm, token); }); |
296 | if (!already_cancelled) { |
297 | // TODO(josh11b): This makes two copies of callback, avoid this if possible. |
298 | dequeue_attempts_.emplace_back( |
299 | num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token, |
300 | [callback, this, allow_small_batch]( |
301 | Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
302 | int32_t s = queues_[0].size(); |
303 | // Return OutOfRange if closed and there are fewer elements |
304 | // available than requested. *Unless* allow_small_batch |
305 | // is true, in which case we return as many elements as |
306 | // possible. |
307 | if (closed_) { |
308 | if (s == 0 || |
309 | (!allow_small_batch && s < attempt->elements_requested)) { |
310 | attempt->context->SetStatus(errors::OutOfRange( |
311 | "PriorityQueue '" , name_, "' is closed and has " , |
312 | "insufficient elements (requested " , |
313 | attempt->elements_requested, ", current size " , s, ")" )); |
314 | return kComplete; |
315 | } |
316 | } |
317 | |
318 | // The PriorityQueue is expected to always return a |
319 | // sorted set of entries. In order to do this, the underlying |
320 | // queue must have at least this many entries already. |
321 | // Doing the dynamic thing and pulling out a portion at a |
322 | // time leads to unordered output in calls to DequeueMany. |
323 | // |
324 | // An alternative solution is to store the attempt tuple |
325 | // entries in an identical priority_queue and push onto |
326 | // this queue dynamically, then when it is full, do all |
327 | // the Tensor concatenation at the very end. |
328 | // TODO(ebrevdo): Change approach if this leads to locking issues. |
329 | if (s < attempt->elements_requested) { |
330 | // If we have no elements at all, then wait. |
331 | // Otherwise proceed if closed and allow small batch is true. |
332 | // Otherwise wait until we have more enqueued elements. |
333 | if (s == 0 || !(closed_ && allow_small_batch)) { |
334 | return kNoProgress; |
335 | } |
336 | } |
337 | |
338 | RunResult result = kNoProgress; |
339 | for (; s > 0; --s) { |
340 | if (attempt->tuple.empty()) { |
341 | // Only allocate tuple when we have something to dequeue |
342 | // so we don't use excessive memory when there are many |
343 | // blocked dequeue attempts waiting. |
344 | attempt->tuple.reserve(num_components()); |
345 | for (int i = 0; i < num_components(); ++i) { |
346 | const TensorShape shape = |
347 | ManyOutShape(i, attempt->elements_requested); |
348 | Tensor element; |
349 | attempt->context->SetStatus(attempt->context->allocate_temp( |
350 | component_dtypes_[i], shape, &element)); |
351 | if (!attempt->context->status().ok()) return kComplete; |
352 | attempt->tuple.emplace_back(element); |
353 | } |
354 | } |
355 | result = kProgress; |
356 | Tuple tuple; |
357 | DequeueLocked(attempt->context, &tuple); |
358 | const int index = |
359 | attempt->tuple[0].dim_size(0) - attempt->elements_requested; |
360 | for (int i = 0; i < num_components(); ++i) { |
361 | attempt->context->SetStatus(batch_util::CopyElementToSlice( |
362 | std::move(tuple[i]), &attempt->tuple[i], index)); |
363 | if (!attempt->context->status().ok()) return kComplete; |
364 | } |
365 | tuple.clear(); |
366 | --attempt->elements_requested; |
367 | if (attempt->elements_requested == 0) { |
368 | tuple = attempt->tuple; |
369 | attempt->done_callback = [callback, tuple]() { |
370 | callback(tuple); |
371 | }; |
372 | return kComplete; |
373 | } |
374 | } |
375 | return result; |
376 | }); |
377 | } |
378 | } |
379 | if (!already_cancelled) { |
380 | FlushUnlocked(); |
381 | } else { |
382 | ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled" )); |
383 | callback(Tuple()); |
384 | } |
385 | } |
386 | |
387 | Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) { |
388 | if (!MatchesNodeDefOp(node_def, "PriorityQueue" ).ok() && |
389 | !MatchesNodeDefOp(node_def, "PriorityQueueV2" ).ok()) { |
390 | return errors::InvalidArgument("Expected PriorityQueue, found " , |
391 | node_def.op()); |
392 | } |
393 | TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); |
394 | TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def)); |
395 | TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def)); |
396 | return OkStatus(); |
397 | } |
398 | |
399 | Status PriorityQueue::MatchesPriorityNodeDefTypes( |
400 | const NodeDef& node_def) const { |
401 | DataTypeVector requested_dtypes; |
402 | TF_RETURN_IF_ERROR( |
403 | GetNodeAttr(node_def, "component_types" , &requested_dtypes)); |
404 | requested_dtypes.insert(requested_dtypes.begin(), DT_INT64); |
405 | if (requested_dtypes != component_dtypes_) { |
406 | return errors::InvalidArgument("Shared queue '" , name_, |
407 | "' has component types " , |
408 | DataTypeSliceString(component_dtypes_), |
409 | " but requested component types were " , |
410 | DataTypeSliceString(requested_dtypes)); |
411 | } |
412 | return OkStatus(); |
413 | } |
414 | |
415 | Status PriorityQueue::MatchesPriorityNodeDefShapes( |
416 | const NodeDef& node_def) const { |
417 | std::vector<TensorShape> requested_shapes; |
418 | TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes" , &requested_shapes)); |
419 | requested_shapes.insert(requested_shapes.begin(), TensorShape({})); |
420 | if (requested_shapes != component_shapes_) { |
421 | return errors::InvalidArgument("Shared queue '" , name_, |
422 | "' has component shapes " , |
423 | ShapeListString(component_shapes_), |
424 | " but requested component shapes were " , |
425 | ShapeListString(requested_shapes)); |
426 | } |
427 | return OkStatus(); |
428 | } |
429 | |
430 | } // namespace tensorflow |
431 | |