1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/queue_op.h" |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/queue_interface.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/platform/macros.h" |
25 | #include "tensorflow/core/platform/refcount.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) { |
31 | OP_REQUIRES_OK(context, context->GetAttr("capacity" , &capacity_)); |
32 | if (capacity_ < 0) { |
33 | capacity_ = QueueBase::kUnbounded; |
34 | } |
35 | OP_REQUIRES_OK(context, |
36 | context->GetAttr("component_types" , &component_types_)); |
37 | } |
38 | |
39 | void QueueOp::Compute(OpKernelContext* context) { |
40 | ResourceOpKernel<QueueInterface>::Compute(context); |
41 | core::RefCountPtr<QueueInterface> resource = get_resource(); |
42 | if (resource != nullptr && context->track_allocations()) { |
43 | context->record_persistent_memory_allocation(resource->MemoryUsed()); |
44 | } |
45 | } |
46 | |
47 | Status QueueOp::VerifyResource(QueueInterface* queue) { |
48 | return queue->MatchesNodeDef(def()); |
49 | } |
50 | |
51 | |
52 | QueueOpKernel::QueueOpKernel(OpKernelConstruction* context) |
53 | : AsyncOpKernel(context) {} |
54 | |
55 | void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) { |
56 | QueueInterface* queue; |
57 | if (ctx->input_dtype(0) == DT_RESOURCE) { |
58 | OP_REQUIRES_OK_ASYNC( |
59 | ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback); |
60 | } else { |
61 | OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle" , &queue), |
62 | callback); |
63 | } |
64 | ComputeAsync(ctx, queue, [callback, queue]() { |
65 | queue->Unref(); |
66 | callback(); |
67 | }); |
68 | } |
69 | |
70 | QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context) |
71 | : QueueOpKernel(context) { |
72 | OP_REQUIRES_OK(context, context->GetAttr("timeout_ms" , &timeout_)); |
73 | // TODO(keveman): Enable timeout. |
74 | OP_REQUIRES(context, timeout_ == -1, |
75 | errors::InvalidArgument("Timeout not supported yet." )); |
76 | } |
77 | |
78 | // Defines an EnqueueOp, the execution of which enqueues a tuple of |
79 | // tensors in the given Queue. |
80 | // |
81 | // The op has 1 + k inputs, where k is the number of components in the |
82 | // tuples stored in the given Queue: |
83 | // - Input 0: queue handle. |
84 | // - Input 1: 0th element of the tuple. |
85 | // - ... |
86 | // - Input (1+k): kth element of the tuple. |
87 | EnqueueOp::EnqueueOp(OpKernelConstruction* context) |
88 | : QueueAccessOpKernel(context) {} |
89 | |
90 | void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
91 | DoneCallback callback) { |
92 | DataTypeVector expected_inputs; |
93 | if (ctx->input_dtype(0) == DT_RESOURCE) { |
94 | expected_inputs.push_back(DT_RESOURCE); |
95 | } else { |
96 | expected_inputs.push_back(DT_STRING_REF); |
97 | } |
98 | for (DataType dt : queue->component_dtypes()) { |
99 | expected_inputs.push_back(dt); |
100 | } |
101 | OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback); |
102 | |
103 | QueueInterface::Tuple tuple; |
104 | OpInputList components; |
105 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components" , &components), |
106 | callback); |
107 | for (const Tensor& Tcomponent : components) { |
108 | tuple.push_back(Tcomponent); |
109 | } |
110 | |
111 | OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); |
112 | queue->TryEnqueue(tuple, ctx, callback); |
113 | } |
114 | |
115 | // Defines an EnqueueManyOp, the execution of which slices each |
116 | // component of a tuple of tensors along the 0th dimension, and |
117 | // enqueues tuples of slices in the given Queue. |
118 | // |
119 | // The op has 1 + k inputs, where k is the number of components in the |
120 | // tuples stored in the given Queue: |
121 | // - Input 0: queue handle. |
122 | // - Input 1: 0th element of the tuple. |
123 | // - ... |
124 | // - Input (1+k): kth element of the tuple. |
125 | // |
126 | // N.B. All tuple components must have the same size in the 0th |
127 | // dimension. |
128 | EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context) |
129 | : QueueAccessOpKernel(context) {} |
130 | |
131 | void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
132 | DoneCallback callback) { |
133 | DataTypeVector expected_inputs; |
134 | if (ctx->input_dtype(0) == DT_RESOURCE) { |
135 | expected_inputs.push_back(DT_RESOURCE); |
136 | } else { |
137 | expected_inputs.push_back(DT_STRING_REF); |
138 | } |
139 | for (DataType dt : queue->component_dtypes()) { |
140 | expected_inputs.push_back(dt); |
141 | } |
142 | OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback); |
143 | |
144 | QueueInterface::Tuple tuple; |
145 | OpInputList components; |
146 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components" , &components), |
147 | callback); |
148 | for (const Tensor& Tcomponent : components) { |
149 | tuple.push_back(Tcomponent); |
150 | } |
151 | |
152 | OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); |
153 | queue->TryEnqueueMany(tuple, ctx, callback); |
154 | } |
155 | |
156 | EnqueueManyOp::~EnqueueManyOp() = default; |
157 | |
158 | // Defines a DequeueOp, the execution of which dequeues a tuple of |
159 | // tensors from the given Queue. |
160 | // |
161 | // The op has one input, which is the handle of the appropriate |
162 | // Queue. The op has k outputs, where k is the number of components in |
163 | // the tuples stored in the given Queue, and output i is the ith |
164 | // component of the dequeued tuple. |
165 | DequeueOp::DequeueOp(OpKernelConstruction* context) |
166 | : QueueAccessOpKernel(context) {} |
167 | |
168 | void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
169 | DoneCallback callback) { |
170 | if (ctx->input_dtype(0) == DT_RESOURCE) { |
171 | OP_REQUIRES_OK_ASYNC( |
172 | ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()), |
173 | callback); |
174 | } else { |
175 | OP_REQUIRES_OK_ASYNC( |
176 | ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), |
177 | callback); |
178 | } |
179 | |
180 | queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { |
181 | if (!ctx->status().ok()) { |
182 | callback(); |
183 | return; |
184 | } |
185 | OpOutputList output_components; |
186 | OP_REQUIRES_OK_ASYNC( |
187 | ctx, ctx->output_list("components" , &output_components), callback); |
188 | for (int i = 0; i < ctx->num_outputs(); ++i) { |
189 | output_components.set(i, tuple[i]); |
190 | } |
191 | callback(); |
192 | }); |
193 | } |
194 | |
195 | DequeueOp::~DequeueOp() = default; |
196 | |
197 | // Defines a DequeueManyOp, the execution of which concatenates the |
198 | // requested number of elements from the given Queue along the 0th |
199 | // dimension, and emits the result as a single tuple of tensors. |
200 | // |
201 | // The op has two inputs: |
202 | // - Input 0: the handle to a queue. |
203 | // - Input 1: the number of elements to dequeue. |
204 | // |
205 | // The op has k outputs, where k is the number of components in the |
206 | // tuples stored in the given Queue, and output i is the ith component |
207 | // of the dequeued tuple. |
208 | DequeueManyOp::DequeueManyOp(OpKernelConstruction* context) |
209 | : QueueAccessOpKernel(context) {} |
210 | |
211 | void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
212 | DoneCallback callback) { |
213 | const Tensor& Tnum_elements = ctx->input(1); |
214 | int32_t num_elements = Tnum_elements.flat<int32>()(0); |
215 | |
216 | OP_REQUIRES_ASYNC(ctx, num_elements >= 0, |
217 | errors::InvalidArgument("DequeueManyOp requested " , |
218 | num_elements, " < 0 elements" ), |
219 | callback); |
220 | |
221 | if (ctx->input_dtype(0) == DT_RESOURCE) { |
222 | OP_REQUIRES_OK_ASYNC( |
223 | ctx, |
224 | ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()), |
225 | callback); |
226 | } else { |
227 | OP_REQUIRES_OK_ASYNC(ctx, |
228 | ctx->MatchSignature({DT_STRING_REF, DT_INT32}, |
229 | queue->component_dtypes()), |
230 | callback); |
231 | } |
232 | |
233 | queue->TryDequeueMany( |
234 | num_elements, ctx, false /* allow_small_batch */, |
235 | [ctx, callback](const QueueInterface::Tuple& tuple) { |
236 | if (!ctx->status().ok()) { |
237 | callback(); |
238 | return; |
239 | } |
240 | OpOutputList output_components; |
241 | OP_REQUIRES_OK_ASYNC( |
242 | ctx, ctx->output_list("components" , &output_components), callback); |
243 | for (int i = 0; i < ctx->num_outputs(); ++i) { |
244 | output_components.set(i, tuple[i]); |
245 | } |
246 | callback(); |
247 | }); |
248 | } |
249 | |
250 | DequeueManyOp::~DequeueManyOp() = default; |
251 | |
252 | // Defines a DequeueUpToOp, the execution of which concatenates the |
253 | // requested number of elements from the given Queue along the 0th |
254 | // dimension, and emits the result as a single tuple of tensors. |
255 | // |
256 | // The difference between this op and DequeueMany is the handling when |
257 | // the Queue is closed. While the DequeueMany op will return if there |
258 | // an error when there are less than num_elements elements left in the |
259 | // closed queue, this op will return between 1 and |
260 | // min(num_elements, elements_remaining_in_queue), and will not block. |
261 | // If there are no elements left, then the standard DequeueMany error |
262 | // is returned. |
263 | // |
264 | // This op only works if the underlying Queue implementation accepts |
265 | // the allow_small_batch = true parameter to TryDequeueMany. |
266 | // If it does not, an errors::Unimplemented exception is returned. |
267 | // |
268 | // The op has two inputs: |
269 | // - Input 0: the handle to a queue. |
270 | // - Input 1: the number of elements to dequeue. |
271 | // |
272 | // The op has k outputs, where k is the number of components in the |
273 | // tuples stored in the given Queue, and output i is the ith component |
274 | // of the dequeued tuple. |
275 | // |
276 | // The op has one attribute: allow_small_batch. If the Queue supports |
277 | // it, setting this to true causes the queue to return smaller |
278 | // (possibly zero length) batches when it is closed, up to however |
279 | // many elements are available when the op executes. In this case, |
280 | // the Queue does not block when closed. |
281 | DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context) |
282 | : QueueAccessOpKernel(context) {} |
283 | |
284 | void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
285 | DoneCallback callback) { |
286 | const Tensor& Tnum_elements = ctx->input(1); |
287 | int32_t num_elements = Tnum_elements.flat<int32>()(0); |
288 | |
289 | OP_REQUIRES_ASYNC(ctx, num_elements >= 0, |
290 | errors::InvalidArgument("DequeueUpToOp requested " , |
291 | num_elements, " < 0 elements" ), |
292 | callback); |
293 | |
294 | if (ctx->input_dtype(0) == DT_RESOURCE) { |
295 | OP_REQUIRES_OK_ASYNC( |
296 | ctx, |
297 | ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()), |
298 | callback); |
299 | } else { |
300 | OP_REQUIRES_OK_ASYNC(ctx, |
301 | ctx->MatchSignature({DT_STRING_REF, DT_INT32}, |
302 | queue->component_dtypes()), |
303 | callback); |
304 | } |
305 | |
306 | queue->TryDequeueMany( |
307 | num_elements, ctx, true /* allow_small_batch */, |
308 | [ctx, callback](const QueueInterface::Tuple& tuple) { |
309 | if (!ctx->status().ok()) { |
310 | callback(); |
311 | return; |
312 | } |
313 | OpOutputList output_components; |
314 | OP_REQUIRES_OK_ASYNC( |
315 | ctx, ctx->output_list("components" , &output_components), callback); |
316 | for (int i = 0; i < ctx->num_outputs(); ++i) { |
317 | output_components.set(i, tuple[i]); |
318 | } |
319 | callback(); |
320 | }); |
321 | } |
322 | |
323 | DequeueUpToOp::~DequeueUpToOp() = default; |
324 | |
325 | // Defines a QueueCloseOp, which closes the given Queue. Closing a |
326 | // Queue signals that no more elements will be enqueued in it. |
327 | // |
328 | // The op has one input, which is the handle of the appropriate Queue. |
329 | QueueCloseOp::QueueCloseOp(OpKernelConstruction* context) |
330 | : QueueOpKernel(context) { |
331 | OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues" , |
332 | &cancel_pending_enqueues_)); |
333 | } |
334 | |
335 | void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
336 | DoneCallback callback) { |
337 | queue->Close(ctx, cancel_pending_enqueues_, callback); |
338 | } |
339 | |
340 | // Defines a QueueSizeOp, which computes the number of elements in the |
341 | // given Queue, and emits it as an output tensor. |
342 | // |
343 | // The op has one input, which is the handle of the appropriate Queue; |
344 | // and one output, which is a single-element tensor containing the current |
345 | // size of that Queue. |
346 | QueueSizeOp::QueueSizeOp(OpKernelConstruction* context) |
347 | : QueueOpKernel(context) {} |
348 | |
349 | void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
350 | DoneCallback callback) { |
351 | Tensor* Tqueue_size = nullptr; |
352 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); |
353 | Tqueue_size->flat<int32>().setConstant(queue->size()); |
354 | callback(); |
355 | } |
356 | |
357 | QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context) |
358 | : QueueOpKernel(context) {} |
359 | |
360 | void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, |
361 | DoneCallback callback) { |
362 | Tensor* Tqueue_is_closed = nullptr; |
363 | OP_REQUIRES_OK(ctx, |
364 | ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); |
365 | Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed()); |
366 | callback(); |
367 | } |
368 | |
369 | } // namespace tensorflow |
370 | |