1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/data_flow_ops.cc.
17
18#include "tensorflow/core/kernels/padding_fifo_queue.h"
19
20#include <deque>
21#include <vector>
22
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/kernels/queue_base.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/mutex.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/util/batch_util.h"
34
35namespace tensorflow {
36
37PaddingFIFOQueue::PaddingFIFOQueue(
38 int capacity, const DataTypeVector& component_dtypes,
39 const std::vector<PartialTensorShape>& component_shapes, const string& name)
40 : FIFOQueue(capacity, component_dtypes,
41 ConvertShapesPartialDimensionsToZero(component_shapes), name),
42 partial_shapes_(component_shapes) {}
43
44Status PaddingFIFOQueue::Initialize() {
45 Status s = FIFOQueue::Initialize();
46 if (!s.ok()) return s;
47
48 if (component_dtypes_.size() != partial_shapes_.size()) {
49 return errors::InvalidArgument(
50 "Shapes must be provided for all components, but received ",
51 component_dtypes_.size(), " dtypes and ", partial_shapes_.size(),
52 " shapes.");
53 }
54
55 return OkStatus();
56}
57
58/* static */
59Status PaddingFIFOQueue::GetElementComponent(
60 const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx,
61 Tensor* out_tensor) {
62 TensorShape element_shape(tuple[component].shape());
63 TF_RETURN_IF_ERROR(
64 ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor));
65 *out_tensor = tuple[component];
66 return OkStatus();
67}
68
69void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
70 bool allow_small_batch,
71 CallbackWithTuple callback) {
72 if (num_elements == 0) {
73 Tuple tuple;
74 tuple.reserve(num_components());
75 for (int i = 0; i < num_components(); ++i) {
76 // TODO(josh11b,misard): Switch to allocate_output().
77 // See similar comment in fifo_queue.cc
78 Tensor element;
79 // Here, ManyOutShape returns zeros for undetermined shapes,
80 // which is exactly what we want to use.
81 OP_REQUIRES_OK(ctx, ctx->allocate_temp(component_dtypes_[i],
82 ManyOutShape(i, 0), &element));
83 tuple.emplace_back(element);
84 }
85 callback(tuple);
86 return;
87 }
88
89 CancellationManager* cm = ctx->cancellation_manager();
90 CancellationToken token = cm->get_cancellation_token();
91 bool already_cancelled;
92 {
93 mutex_lock l(mu_);
94 already_cancelled = !cm->RegisterCallback(
95 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
96 if (!already_cancelled) {
97 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
98 dequeue_attempts_.emplace_back(
99 num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
100 [callback, allow_small_batch,
101 this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
102 int32_t queue_size = queues_[0].size();
103 if (closed_ && queue_size < attempt->elements_requested) {
104 // If we don't have enough for a full dequeue, we have
105 // to reset the attempt tuple.
106 if (!attempt->tuples.empty()) {
107 // Restore already-dequeued elements to the front of the queue.
108 for (int64_t i = attempt->tuples.size() - 1; i >= 0; --i) {
109 for (int j = 0; j < num_components(); ++j) {
110 Tensor element;
111 Status s = GetElementComponent(attempt->tuples[i], j,
112 attempt->context, &element);
113 if (!s.ok()) {
114 attempt->context->SetStatus(
115 errors::DataLoss("Failed to restore element from "
116 "partially-dequeued batch "
117 "to PaddingFIFOQueue: ",
118 s.error_message()));
119 }
120 queues_[j].push_front(element);
121 }
122 }
123 }
124 if (allow_small_batch && !queues_[0].empty()) {
125 // Request all remaining elements in the queue.
126 queue_size = queues_[0].size();
127 attempt->tuples.clear();
128 attempt->elements_requested = queue_size;
129 } else {
130 if (allow_small_batch) {
131 // There may be some enqueue attempts containing
132 // values. If so, we'll yield and wait for them
133 // to add elements to the queue.
134 if (!enqueue_attempts_.empty()) return kProgress;
135 }
136 if (attempt->context->status().ok()) {
137 attempt->context->SetStatus(errors::OutOfRange(
138 "PaddingFIFOQueue '", name_, "' is closed and has ",
139 "insufficient elements (requested ",
140 attempt->elements_requested, ", current size ",
141 queue_size, ")"));
142 }
143 return kComplete;
144 }
145 }
146
147 RunResult result = kNoProgress;
148 for (; queue_size > 0; --queue_size) {
149 result = kProgress;
150 Tuple tuple;
151 DequeueLocked(attempt->context, &tuple);
152 attempt->tuples.push_back(tuple);
153 tuple.clear();
154 --attempt->elements_requested;
155
156 if (attempt->elements_requested == 0) {
157 // Finished. Allocate attempt->tuple and
158 // copy from attempt->tuples to attempt->tuple.
159 attempt->tuple.reserve(num_components());
160 std::vector<Tuple>& tuples = attempt->tuples;
161
162 std::vector<bool> dynamic_shape;
163 const int64_t batch_size = tuples.size();
164
165 for (int i = 0; i < num_components(); ++i) {
166 const PartialTensorShape partial_shape =
167 PartialTensorShape({batch_size})
168 .Concatenate(partial_shapes_[i]);
169 TensorShape shape({batch_size});
170
171 for (int j = 0; j < partial_shape.dims() - 1; ++j) {
172 if (partial_shape.dim_size(j + 1) > -1) {
173 shape.AddDim(partial_shape.dim_size(j + 1));
174 } else {
175 // Expand sizes to match.
176 int64_t max_val = 0;
177 for (const Tuple& t : tuples) {
178 max_val = std::max(max_val, t[i].shape().dim_size(j));
179 }
180 shape.AddDim(max_val);
181 }
182 }
183
184 Tensor element;
185 attempt->context->SetStatus(attempt->context->allocate_temp(
186 component_dtypes_[i], shape, &element));
187 if (!attempt->context->status().ok()) return kComplete;
188
189 bool has_dynamic_shape = !partial_shape.IsFullyDefined();
190 if (has_dynamic_shape) {
191 // Set all values to zero because not all values
192 // will get written over.
193 attempt->context->SetStatus(SetElementZero(&element));
194 if (!attempt->context->status().ok()) return kComplete;
195 }
196
197 dynamic_shape.push_back(has_dynamic_shape);
198 attempt->tuple.emplace_back(element);
199 }
200
201 for (size_t index = 0; index < tuples.size(); ++index) {
202 for (int i = 0; i < num_components(); ++i) {
203 if (dynamic_shape[i]) {
204 // Slightly slower copy operation
205 attempt->context->SetStatus(CopyElementToLargerSlice(
206 tuples[index][i], &attempt->tuple[i], index));
207 } else {
208 attempt->context->SetStatus(
209 batch_util::CopyElementToSlice(
210 std::move(tuples[index][i]), &attempt->tuple[i],
211 index));
212 }
213 if (!attempt->context->status().ok()) return kComplete;
214 }
215 }
216 tuple = attempt->tuple;
217 attempt->tuples.clear();
218 attempt->done_callback = [callback, tuple]() {
219 callback(tuple);
220 };
221 return kComplete;
222 }
223 }
224 return result;
225 });
226 }
227 }
228 if (!already_cancelled) {
229 FlushUnlocked();
230 } else {
231 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
232 callback(Tuple());
233 }
234}
235
236Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) {
237 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
238 for (size_t i = 0; i < tuple.size(); ++i) {
239 if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) {
240 return errors::InvalidArgument("Shape mismatch in tuple component ", i,
241 ". Expected ",
242 partial_shapes_[i].DebugString(), ", got ",
243 tuple[i].shape().DebugString());
244 }
245 }
246 return OkStatus();
247}
248
249Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) {
250 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
251 const int64_t batch_size = tuple[0].dim_size(0);
252 for (size_t i = 0; i < tuple.size(); ++i) {
253 // Expected shape is [batch_size] + partial_shapes_[i]
254 const PartialTensorShape expected_shape =
255 PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]);
256 if (!expected_shape.IsCompatibleWith(tuple[i].shape())) {
257 return errors::InvalidArgument("Shape mismatch in tuple component ", i,
258 ". Expected ",
259 expected_shape.DebugString(), ", got ",
260 tuple[i].shape().DebugString());
261 }
262 }
263 return OkStatus();
264}
265
266Status PaddingFIFOQueue::CompatibleNodeDefShapes(
267 const NodeDef& node_def) const {
268 std::vector<PartialTensorShape> requested_shapes;
269 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
270 if (!PartialTensorShapeUtils::AreCompatible(requested_shapes,
271 partial_shapes_)) {
272 return errors::InvalidArgument(
273 "Shared queue '", name_, "' has component shapes ",
274 PartialTensorShapeUtils::PartialShapeListString(partial_shapes_),
275 " but requested component shapes were ",
276 PartialTensorShapeUtils::PartialShapeListString(requested_shapes));
277 } else {
278 return OkStatus();
279 }
280}
281
282Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
283 if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() &&
284 !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) {
285 return errors::InvalidArgument("Expected PaddingFIFOQueue, found ",
286 node_def.op());
287 }
288 TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
289 TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
290 TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def));
291 return OkStatus();
292}
293
294static Status ValidateElementToLargerSlice(const Tensor& element,
295 Tensor* parent) {
296 DCHECK_NE(parent->dim_size(0), 0);
297 if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
298 TensorShape chip_shape = parent->shape();
299 chip_shape.RemoveDim(0);
300 return errors::Internal(
301 "HandleElementToLargerSlice Cannot copy slice: number of entries in "
302 "element is greater than number of elements in parent slice. ",
303 "Shapes are: [element]: ", element.shape().DebugString(),
304 ", [parent slice]: ", chip_shape.DebugString());
305 }
306 return OkStatus();
307}
308
309template <typename T, int NDIMS>
310Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
311 int index) {
312 Status s = ValidateElementToLargerSlice(element, parent);
313 if (!s.ok()) {
314 return s;
315 }
316 if (element.NumElements() == 0) {
317 return OkStatus();
318 }
319 auto element_t = element.tensor<T, NDIMS>();
320 auto parent_t = parent->tensor<T, NDIMS + 1>();
321 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
322 slice_indices[0] = index;
323 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
324 slice_size[0] = 1;
325 for (size_t i = 1; i < slice_size.size(); ++i) {
326 slice_size[i] = element_t.dimension(i - 1);
327 }
328 parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
329 return OkStatus();
330}
331
332namespace {
333
334template <int NDIMS>
335Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
336 int index) {
337#define HANDLE_TYPE(T) \
338 case DataTypeToEnum<T>::value: { \
339 return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
340 }
341
342 switch (element.dtype()) {
343 TF_CALL_ALL_TYPES(HANDLE_TYPE);
344#undef HANDLE_TYPE
345 default:
346 return errors::Unimplemented(
347 "HandleElementToLargerSliceWithRank Unhandled data type: ",
348 DataTypeString(element.dtype()));
349 }
350}
351
352} // namespace
353
354Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element,
355 Tensor* parent, int index) {
356 if (parent->dims() != element.dims() + 1) {
357 return errors::Internal(
358 "Mismatched ranks. Element's rank is: ", element.dims(),
359 " but element is meant to be a slice in output Tensor having rank: ",
360 parent->dims(), " (should be: ", element.dims() + 1, ")");
361 }
362
363#define HANDLE_DIMS(NDIMS) \
364 case NDIMS: { \
365 TF_RETURN_IF_ERROR( \
366 HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
367 return OkStatus(); \
368 }
369
370 switch (element.dims()) {
371 HANDLE_DIMS(0);
372 HANDLE_DIMS(1);
373 HANDLE_DIMS(2);
374 HANDLE_DIMS(3);
375 HANDLE_DIMS(4);
376#undef HANDLE_DIMS
377 default:
378 return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
379 element.dims());
380 }
381}
382
383// Static method
384Status PaddingFIFOQueue::SetElementZero(Tensor* element) {
385#define HANDLE_TYPE(T) \
386 if (element->dtype() == DataTypeToEnum<T>::value) { \
387 element->flat<T>().setConstant(T()); \
388 return OkStatus(); \
389 }
390 TF_CALL_ALL_TYPES(HANDLE_TYPE);
391#undef HANDLE_TYPE
392 return errors::Unimplemented("SetElementZero Unhandled data type: ",
393 DataTypeString(element->dtype()));
394}
395
396std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero(
397 const gtl::ArraySlice<PartialTensorShape>& partial_shapes) {
398 std::vector<TensorShape> shapes(partial_shapes.size());
399 for (size_t i = 0; i < shapes.size(); ++i) {
400 const PartialTensorShape& partial = partial_shapes[i];
401 TensorShape& shape = shapes[i];
402 for (int64_t s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s);
403 }
404 return shapes;
405}
406
407} // namespace tensorflow
408