1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/data_flow_ops.cc. |
17 | |
18 | #include "tensorflow/core/kernels/padding_fifo_queue.h" |
19 | |
20 | #include <deque> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/kernels/queue_base.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | #include "tensorflow/core/util/batch_util.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | PaddingFIFOQueue::PaddingFIFOQueue( |
38 | int capacity, const DataTypeVector& component_dtypes, |
39 | const std::vector<PartialTensorShape>& component_shapes, const string& name) |
40 | : FIFOQueue(capacity, component_dtypes, |
41 | ConvertShapesPartialDimensionsToZero(component_shapes), name), |
42 | partial_shapes_(component_shapes) {} |
43 | |
44 | Status PaddingFIFOQueue::Initialize() { |
45 | Status s = FIFOQueue::Initialize(); |
46 | if (!s.ok()) return s; |
47 | |
48 | if (component_dtypes_.size() != partial_shapes_.size()) { |
49 | return errors::InvalidArgument( |
50 | "Shapes must be provided for all components, but received " , |
51 | component_dtypes_.size(), " dtypes and " , partial_shapes_.size(), |
52 | " shapes." ); |
53 | } |
54 | |
55 | return OkStatus(); |
56 | } |
57 | |
58 | /* static */ |
59 | Status PaddingFIFOQueue::GetElementComponent( |
60 | const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx, |
61 | Tensor* out_tensor) { |
62 | TensorShape element_shape(tuple[component].shape()); |
63 | TF_RETURN_IF_ERROR( |
64 | ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor)); |
65 | *out_tensor = tuple[component]; |
66 | return OkStatus(); |
67 | } |
68 | |
69 | void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, |
70 | bool allow_small_batch, |
71 | CallbackWithTuple callback) { |
72 | if (num_elements == 0) { |
73 | Tuple tuple; |
74 | tuple.reserve(num_components()); |
75 | for (int i = 0; i < num_components(); ++i) { |
76 | // TODO(josh11b,misard): Switch to allocate_output(). |
77 | // See similar comment in fifo_queue.cc |
78 | Tensor element; |
79 | // Here, ManyOutShape returns zeros for undetermined shapes, |
80 | // which is exactly what we want to use. |
81 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(component_dtypes_[i], |
82 | ManyOutShape(i, 0), &element)); |
83 | tuple.emplace_back(element); |
84 | } |
85 | callback(tuple); |
86 | return; |
87 | } |
88 | |
89 | CancellationManager* cm = ctx->cancellation_manager(); |
90 | CancellationToken token = cm->get_cancellation_token(); |
91 | bool already_cancelled; |
92 | { |
93 | mutex_lock l(mu_); |
94 | already_cancelled = !cm->RegisterCallback( |
95 | token, [this, cm, token]() { Cancel(kDequeue, cm, token); }); |
96 | if (!already_cancelled) { |
97 | // TODO(josh11b): This makes two copies of callback, avoid this if possible. |
98 | dequeue_attempts_.emplace_back( |
99 | num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token, |
100 | [callback, allow_small_batch, |
101 | this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
102 | int32_t queue_size = queues_[0].size(); |
103 | if (closed_ && queue_size < attempt->elements_requested) { |
104 | // If we don't have enough for a full dequeue, we have |
105 | // to reset the attempt tuple. |
106 | if (!attempt->tuples.empty()) { |
107 | // Restore already-dequeued elements to the front of the queue. |
108 | for (int64_t i = attempt->tuples.size() - 1; i >= 0; --i) { |
109 | for (int j = 0; j < num_components(); ++j) { |
110 | Tensor element; |
111 | Status s = GetElementComponent(attempt->tuples[i], j, |
112 | attempt->context, &element); |
113 | if (!s.ok()) { |
114 | attempt->context->SetStatus( |
115 | errors::DataLoss("Failed to restore element from " |
116 | "partially-dequeued batch " |
117 | "to PaddingFIFOQueue: " , |
118 | s.error_message())); |
119 | } |
120 | queues_[j].push_front(element); |
121 | } |
122 | } |
123 | } |
124 | if (allow_small_batch && !queues_[0].empty()) { |
125 | // Request all remaining elements in the queue. |
126 | queue_size = queues_[0].size(); |
127 | attempt->tuples.clear(); |
128 | attempt->elements_requested = queue_size; |
129 | } else { |
130 | if (allow_small_batch) { |
131 | // There may be some enqueue attempts containing |
132 | // values. If so, we'll yield and wait for them |
133 | // to add elements to the queue. |
134 | if (!enqueue_attempts_.empty()) return kProgress; |
135 | } |
136 | if (attempt->context->status().ok()) { |
137 | attempt->context->SetStatus(errors::OutOfRange( |
138 | "PaddingFIFOQueue '" , name_, "' is closed and has " , |
139 | "insufficient elements (requested " , |
140 | attempt->elements_requested, ", current size " , |
141 | queue_size, ")" )); |
142 | } |
143 | return kComplete; |
144 | } |
145 | } |
146 | |
147 | RunResult result = kNoProgress; |
148 | for (; queue_size > 0; --queue_size) { |
149 | result = kProgress; |
150 | Tuple tuple; |
151 | DequeueLocked(attempt->context, &tuple); |
152 | attempt->tuples.push_back(tuple); |
153 | tuple.clear(); |
154 | --attempt->elements_requested; |
155 | |
156 | if (attempt->elements_requested == 0) { |
157 | // Finished. Allocate attempt->tuple and |
158 | // copy from attempt->tuples to attempt->tuple. |
159 | attempt->tuple.reserve(num_components()); |
160 | std::vector<Tuple>& tuples = attempt->tuples; |
161 | |
162 | std::vector<bool> dynamic_shape; |
163 | const int64_t batch_size = tuples.size(); |
164 | |
165 | for (int i = 0; i < num_components(); ++i) { |
166 | const PartialTensorShape partial_shape = |
167 | PartialTensorShape({batch_size}) |
168 | .Concatenate(partial_shapes_[i]); |
169 | TensorShape shape({batch_size}); |
170 | |
171 | for (int j = 0; j < partial_shape.dims() - 1; ++j) { |
172 | if (partial_shape.dim_size(j + 1) > -1) { |
173 | shape.AddDim(partial_shape.dim_size(j + 1)); |
174 | } else { |
175 | // Expand sizes to match. |
176 | int64_t max_val = 0; |
177 | for (const Tuple& t : tuples) { |
178 | max_val = std::max(max_val, t[i].shape().dim_size(j)); |
179 | } |
180 | shape.AddDim(max_val); |
181 | } |
182 | } |
183 | |
184 | Tensor element; |
185 | attempt->context->SetStatus(attempt->context->allocate_temp( |
186 | component_dtypes_[i], shape, &element)); |
187 | if (!attempt->context->status().ok()) return kComplete; |
188 | |
189 | bool has_dynamic_shape = !partial_shape.IsFullyDefined(); |
190 | if (has_dynamic_shape) { |
191 | // Set all values to zero because not all values |
192 | // will get written over. |
193 | attempt->context->SetStatus(SetElementZero(&element)); |
194 | if (!attempt->context->status().ok()) return kComplete; |
195 | } |
196 | |
197 | dynamic_shape.push_back(has_dynamic_shape); |
198 | attempt->tuple.emplace_back(element); |
199 | } |
200 | |
201 | for (size_t index = 0; index < tuples.size(); ++index) { |
202 | for (int i = 0; i < num_components(); ++i) { |
203 | if (dynamic_shape[i]) { |
204 | // Slightly slower copy operation |
205 | attempt->context->SetStatus(CopyElementToLargerSlice( |
206 | tuples[index][i], &attempt->tuple[i], index)); |
207 | } else { |
208 | attempt->context->SetStatus( |
209 | batch_util::CopyElementToSlice( |
210 | std::move(tuples[index][i]), &attempt->tuple[i], |
211 | index)); |
212 | } |
213 | if (!attempt->context->status().ok()) return kComplete; |
214 | } |
215 | } |
216 | tuple = attempt->tuple; |
217 | attempt->tuples.clear(); |
218 | attempt->done_callback = [callback, tuple]() { |
219 | callback(tuple); |
220 | }; |
221 | return kComplete; |
222 | } |
223 | } |
224 | return result; |
225 | }); |
226 | } |
227 | } |
228 | if (!already_cancelled) { |
229 | FlushUnlocked(); |
230 | } else { |
231 | ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled" )); |
232 | callback(Tuple()); |
233 | } |
234 | } |
235 | |
236 | Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) { |
237 | TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); |
238 | for (size_t i = 0; i < tuple.size(); ++i) { |
239 | if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) { |
240 | return errors::InvalidArgument("Shape mismatch in tuple component " , i, |
241 | ". Expected " , |
242 | partial_shapes_[i].DebugString(), ", got " , |
243 | tuple[i].shape().DebugString()); |
244 | } |
245 | } |
246 | return OkStatus(); |
247 | } |
248 | |
249 | Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) { |
250 | TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); |
251 | const int64_t batch_size = tuple[0].dim_size(0); |
252 | for (size_t i = 0; i < tuple.size(); ++i) { |
253 | // Expected shape is [batch_size] + partial_shapes_[i] |
254 | const PartialTensorShape expected_shape = |
255 | PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]); |
256 | if (!expected_shape.IsCompatibleWith(tuple[i].shape())) { |
257 | return errors::InvalidArgument("Shape mismatch in tuple component " , i, |
258 | ". Expected " , |
259 | expected_shape.DebugString(), ", got " , |
260 | tuple[i].shape().DebugString()); |
261 | } |
262 | } |
263 | return OkStatus(); |
264 | } |
265 | |
266 | Status PaddingFIFOQueue::CompatibleNodeDefShapes( |
267 | const NodeDef& node_def) const { |
268 | std::vector<PartialTensorShape> requested_shapes; |
269 | TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes" , &requested_shapes)); |
270 | if (!PartialTensorShapeUtils::AreCompatible(requested_shapes, |
271 | partial_shapes_)) { |
272 | return errors::InvalidArgument( |
273 | "Shared queue '" , name_, "' has component shapes " , |
274 | PartialTensorShapeUtils::PartialShapeListString(partial_shapes_), |
275 | " but requested component shapes were " , |
276 | PartialTensorShapeUtils::PartialShapeListString(requested_shapes)); |
277 | } else { |
278 | return OkStatus(); |
279 | } |
280 | } |
281 | |
282 | Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) { |
283 | if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue" ).ok() && |
284 | !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2" ).ok()) { |
285 | return errors::InvalidArgument("Expected PaddingFIFOQueue, found " , |
286 | node_def.op()); |
287 | } |
288 | TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); |
289 | TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); |
290 | TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def)); |
291 | return OkStatus(); |
292 | } |
293 | |
294 | static Status ValidateElementToLargerSlice(const Tensor& element, |
295 | Tensor* parent) { |
296 | DCHECK_NE(parent->dim_size(0), 0); |
297 | if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { |
298 | TensorShape chip_shape = parent->shape(); |
299 | chip_shape.RemoveDim(0); |
300 | return errors::Internal( |
301 | "HandleElementToLargerSlice Cannot copy slice: number of entries in " |
302 | "element is greater than number of elements in parent slice. " , |
303 | "Shapes are: [element]: " , element.shape().DebugString(), |
304 | ", [parent slice]: " , chip_shape.DebugString()); |
305 | } |
306 | return OkStatus(); |
307 | } |
308 | |
309 | template <typename T, int NDIMS> |
310 | Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, |
311 | int index) { |
312 | Status s = ValidateElementToLargerSlice(element, parent); |
313 | if (!s.ok()) { |
314 | return s; |
315 | } |
316 | if (element.NumElements() == 0) { |
317 | return OkStatus(); |
318 | } |
319 | auto element_t = element.tensor<T, NDIMS>(); |
320 | auto parent_t = parent->tensor<T, NDIMS + 1>(); |
321 | Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices; |
322 | slice_indices[0] = index; |
323 | Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size; |
324 | slice_size[0] = 1; |
325 | for (size_t i = 1; i < slice_size.size(); ++i) { |
326 | slice_size[i] = element_t.dimension(i - 1); |
327 | } |
328 | parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); |
329 | return OkStatus(); |
330 | } |
331 | |
332 | namespace { |
333 | |
334 | template <int NDIMS> |
335 | Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, |
336 | int index) { |
337 | #define HANDLE_TYPE(T) \ |
338 | case DataTypeToEnum<T>::value: { \ |
339 | return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \ |
340 | } |
341 | |
342 | switch (element.dtype()) { |
343 | TF_CALL_ALL_TYPES(HANDLE_TYPE); |
344 | #undef HANDLE_TYPE |
345 | default: |
346 | return errors::Unimplemented( |
347 | "HandleElementToLargerSliceWithRank Unhandled data type: " , |
348 | DataTypeString(element.dtype())); |
349 | } |
350 | } |
351 | |
352 | } // namespace |
353 | |
354 | Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element, |
355 | Tensor* parent, int index) { |
356 | if (parent->dims() != element.dims() + 1) { |
357 | return errors::Internal( |
358 | "Mismatched ranks. Element's rank is: " , element.dims(), |
359 | " but element is meant to be a slice in output Tensor having rank: " , |
360 | parent->dims(), " (should be: " , element.dims() + 1, ")" ); |
361 | } |
362 | |
363 | #define HANDLE_DIMS(NDIMS) \ |
364 | case NDIMS: { \ |
365 | TF_RETURN_IF_ERROR( \ |
366 | HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \ |
367 | return OkStatus(); \ |
368 | } |
369 | |
370 | switch (element.dims()) { |
371 | HANDLE_DIMS(0); |
372 | HANDLE_DIMS(1); |
373 | HANDLE_DIMS(2); |
374 | HANDLE_DIMS(3); |
375 | HANDLE_DIMS(4); |
376 | #undef HANDLE_DIMS |
377 | default: |
378 | return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: " , |
379 | element.dims()); |
380 | } |
381 | } |
382 | |
383 | // Static method |
384 | Status PaddingFIFOQueue::SetElementZero(Tensor* element) { |
385 | #define HANDLE_TYPE(T) \ |
386 | if (element->dtype() == DataTypeToEnum<T>::value) { \ |
387 | element->flat<T>().setConstant(T()); \ |
388 | return OkStatus(); \ |
389 | } |
390 | TF_CALL_ALL_TYPES(HANDLE_TYPE); |
391 | #undef HANDLE_TYPE |
392 | return errors::Unimplemented("SetElementZero Unhandled data type: " , |
393 | DataTypeString(element->dtype())); |
394 | } |
395 | |
396 | std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero( |
397 | const gtl::ArraySlice<PartialTensorShape>& partial_shapes) { |
398 | std::vector<TensorShape> shapes(partial_shapes.size()); |
399 | for (size_t i = 0; i < shapes.size(); ++i) { |
400 | const PartialTensorShape& partial = partial_shapes[i]; |
401 | TensorShape& shape = shapes[i]; |
402 | for (int64_t s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s); |
403 | } |
404 | return shapes; |
405 | } |
406 | |
407 | } // namespace tensorflow |
408 | |