1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/queue_base.h" |
17 | |
18 | #include <vector> |
19 | #include "tensorflow/core/framework/node_def.pb.h" |
20 | #include "tensorflow/core/framework/tensor_shape.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/platform/mutex.h" |
23 | #include "tensorflow/core/platform/types.h" |
24 | #include "tensorflow/core/util/batch_util.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | namespace { |
29 | |
30 | template <DataType DT> |
31 | Status HandleSliceToElement(const Tensor& parent, Tensor* element, |
32 | int64_t index) { |
33 | typedef typename EnumToDataType<DT>::Type T; |
34 | DCHECK_NE(parent.dim_size(0), 0); |
35 | DCHECK_GE(index, 0); |
36 | if (element->NumElements() != (parent.NumElements() / parent.dim_size(0))) { |
37 | TensorShape chip_shape = parent.shape(); |
38 | chip_shape.RemoveDim(0); |
39 | return errors::Internal( |
40 | "HandleSliceToElement Cannot copy slice: number of elements does not " |
41 | "match. Shapes are: [element]: " , |
42 | element->shape().DebugString(), |
43 | ", [parent slice]: " , chip_shape.DebugString()); |
44 | } |
45 | auto parent_as_matrix = parent.flat_outer_dims<T>(); |
46 | element->flat<T>() = parent_as_matrix.chip(index, 0); |
47 | return OkStatus(); |
48 | } |
49 | |
50 | } // namespace |
51 | |
52 | QueueBase::QueueBase(int32_t capacity, const DataTypeVector& component_dtypes, |
53 | const std::vector<TensorShape>& component_shapes, |
54 | const string& name) |
55 | : capacity_(capacity), |
56 | component_dtypes_(component_dtypes), |
57 | component_shapes_(component_shapes), |
58 | name_(name), |
59 | closed_(false) {} |
60 | |
61 | QueueBase::~QueueBase() {} |
62 | |
63 | Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const { |
64 | if (tuple.size() != static_cast<size_t>(num_components())) { |
65 | return errors::InvalidArgument( |
66 | "Wrong number of components in tuple. Expected " , num_components(), |
67 | ", got " , tuple.size()); |
68 | } |
69 | for (size_t i = 0; i < tuple.size(); ++i) { |
70 | if (tuple[i].dtype() != component_dtypes_[i]) { |
71 | return errors::InvalidArgument( |
72 | "Type mismatch in tuple component " , i, ". Expected " , |
73 | DataTypeString(component_dtypes_[i]), ", got " , |
74 | DataTypeString(tuple[i].dtype())); |
75 | } |
76 | } |
77 | return OkStatus(); |
78 | } |
79 | |
80 | // static |
81 | string QueueBase::ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) { |
82 | string result = "[" ; |
83 | bool first = true; |
84 | for (const TensorShape& shape : shapes) { |
85 | strings::StrAppend(&result, (first ? "" : ", " ), shape.DebugString()); |
86 | first = false; |
87 | } |
88 | strings::StrAppend(&result, "]" ); |
89 | return result; |
90 | } |
91 | |
92 | Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def, |
93 | const string& op) const { |
94 | if (node_def.op() != op) { |
95 | return errors::InvalidArgument("Shared queue '" , name_, "' has type '" , op, |
96 | "' that does not match type of Node '" , |
97 | node_def.name(), "': " , node_def.op()); |
98 | } |
99 | return OkStatus(); |
100 | } |
101 | |
102 | Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, |
103 | int32_t capacity) const { |
104 | int32_t requested_capacity = -1; |
105 | TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity" , &requested_capacity)); |
106 | if (requested_capacity < 0) requested_capacity = kUnbounded; |
107 | if (requested_capacity != capacity) { |
108 | return errors::InvalidArgument("Shared queue '" , name_, "' has capacity " , |
109 | capacity, " but requested capacity was " , |
110 | requested_capacity); |
111 | } |
112 | return OkStatus(); |
113 | } |
114 | |
115 | Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { |
116 | DataTypeVector requested_dtypes; |
117 | TF_RETURN_IF_ERROR( |
118 | GetNodeAttr(node_def, "component_types" , &requested_dtypes)); |
119 | if (requested_dtypes != component_dtypes_) { |
120 | return errors::InvalidArgument("Shared queue '" , name_, |
121 | "' has component types " , |
122 | DataTypeSliceString(component_dtypes_), |
123 | " but requested component types were " , |
124 | DataTypeSliceString(requested_dtypes)); |
125 | } |
126 | return OkStatus(); |
127 | } |
128 | |
129 | Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { |
130 | std::vector<TensorShape> requested_shapes; |
131 | TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes" , &requested_shapes)); |
132 | if (requested_shapes != component_shapes_) { |
133 | return errors::InvalidArgument("Shared queue '" , name_, |
134 | "' has component shapes " , |
135 | ShapeListString(component_shapes_), |
136 | " but requested component shapes were " , |
137 | ShapeListString(requested_shapes)); |
138 | } |
139 | return OkStatus(); |
140 | } |
141 | |
142 | // TODO(mrry): If these checks become a bottleneck, find a way to |
143 | // reduce the number of times that they are called. |
144 | Status QueueBase::ValidateTuple(const Tuple& tuple) { |
145 | TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); |
146 | if (specified_shapes()) { |
147 | for (size_t i = 0; i < tuple.size(); ++i) { |
148 | if (!component_shapes_[i].IsSameSize(tuple[i].shape())) { |
149 | return errors::InvalidArgument( |
150 | "Shape mismatch in tuple component " , i, ". Expected " , |
151 | component_shapes_[i].DebugString(), ", got " , |
152 | tuple[i].shape().DebugString()); |
153 | } |
154 | } |
155 | } |
156 | return OkStatus(); |
157 | } |
158 | |
159 | // TODO(mrry): If these checks become a bottleneck, find a way to |
160 | // reduce the number of times that they are called. |
161 | Status QueueBase::ValidateManyTuple(const Tuple& tuple) { |
162 | TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); |
163 | const int64_t batch_size = tuple[0].dim_size(0); |
164 | if (specified_shapes()) { |
165 | for (size_t i = 0; i < tuple.size(); ++i) { |
166 | // Expected shape is [batch_size] + component_shapes_[i] |
167 | const TensorShape expected_shape = ManyOutShape(i, batch_size); |
168 | if (!expected_shape.IsSameSize(tuple[i].shape())) { |
169 | return errors::InvalidArgument("Shape mismatch in tuple component " , i, |
170 | ". Expected " , |
171 | expected_shape.DebugString(), ", got " , |
172 | tuple[i].shape().DebugString()); |
173 | } |
174 | } |
175 | } else { |
176 | for (size_t i = 1; i < tuple.size(); ++i) { |
177 | if (tuple[i].dim_size(0) != batch_size) { |
178 | return errors::InvalidArgument( |
179 | "All input tensors must have the same size in the 0th " , |
180 | "dimension. Component " , i, " has " , tuple[i].dim_size(0), |
181 | ", and should have " , batch_size); |
182 | } |
183 | } |
184 | } |
185 | return OkStatus(); |
186 | } |
187 | |
188 | void QueueBase::Cancel(Action action, CancellationManager* cancellation_manager, |
189 | CancellationToken token) { |
190 | DoneCallback callback = nullptr; |
191 | { |
192 | mutex_lock lock(mu_); |
193 | std::deque<Attempt>* attempts = |
194 | action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; |
195 | |
196 | for (Attempt& attempt : *attempts) { |
197 | if (attempt.cancellation_manager == cancellation_manager && |
198 | attempt.cancellation_token == token) { |
199 | if (!attempt.is_cancelled) { |
200 | attempt.is_cancelled = true; |
201 | if (action == kEnqueue) { |
202 | attempt.context->SetStatus( |
203 | errors::Cancelled("Enqueue operation was cancelled" )); |
204 | } else { |
205 | attempt.context->SetStatus( |
206 | errors::Cancelled("Dequeue operation was cancelled" )); |
207 | } |
208 | std::swap(callback, attempt.done_callback); |
209 | } |
210 | break; |
211 | } |
212 | } |
213 | } |
214 | if (callback) { |
215 | callback(); |
216 | FlushUnlocked(); |
217 | } |
218 | } |
219 | |
220 | void QueueBase::CloseAndCancel() { |
221 | std::vector<DoneCallback> callbacks; |
222 | { |
223 | mutex_lock lock(mu_); |
224 | closed_ = true; |
225 | for (Attempt& attempt : enqueue_attempts_) { |
226 | if (!attempt.is_cancelled) { |
227 | attempt.is_cancelled = true; |
228 | attempt.context->SetStatus( |
229 | errors::Cancelled("Enqueue operation was cancelled" )); |
230 | callbacks.emplace_back(std::move(attempt.done_callback)); |
231 | } |
232 | } |
233 | } |
234 | for (const DoneCallback& callback : callbacks) { |
235 | callback(); |
236 | } |
237 | FlushUnlocked(); |
238 | } |
239 | |
240 | void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues, |
241 | DoneCallback callback) { |
242 | if (cancel_pending_enqueues) { |
243 | CloseAndCancel(); |
244 | callback(); |
245 | } else { |
246 | { |
247 | mutex_lock lock(mu_); |
248 | enqueue_attempts_.emplace_back( |
249 | 0, callback, ctx, nullptr, CancellationManager::kInvalidToken, |
250 | [this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
251 | if (closed_) { |
252 | attempt->context->SetStatus( |
253 | errors::Cancelled("Queue '" , name_, "' is already closed." )); |
254 | } else { |
255 | closed_ = true; |
256 | } |
257 | return kComplete; |
258 | }); |
259 | } |
260 | FlushUnlocked(); |
261 | } |
262 | } |
263 | |
264 | bool QueueBase::TryAttemptLocked(Action action, |
265 | std::vector<CleanUp>* clean_up) { |
266 | std::deque<Attempt>* attempts = |
267 | action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; |
268 | |
269 | bool progress = false; |
270 | bool done = false; |
271 | while (!done && !attempts->empty()) { |
272 | if (attempts->front().is_cancelled) { |
273 | if (action == kEnqueue) { |
274 | if (closed_) { |
275 | VLOG(1) << "Skipping cancelled enqueue attempt" ; |
276 | } else { |
277 | LOG(WARNING) |
278 | << name_ |
279 | << ": Skipping cancelled enqueue attempt with queue not closed" ; |
280 | } |
281 | } else { |
282 | if (closed_) { |
283 | VLOG(1) << "Skipping cancelled dequeue attempt" ; |
284 | } else { |
285 | LOG(WARNING) |
286 | << name_ |
287 | << ": Skipping cancelled dequeue attempt with queue not closed" ; |
288 | } |
289 | } |
290 | attempts->pop_front(); |
291 | } else { |
292 | Attempt* cur_attempt = &attempts->front(); |
293 | switch (cur_attempt->run_callback(cur_attempt)) { |
294 | case kNoProgress: |
295 | done = true; |
296 | break; |
297 | case kProgress: |
298 | done = true; |
299 | progress = true; |
300 | break; |
301 | case kComplete: |
302 | progress = true; |
303 | clean_up->emplace_back(std::move(cur_attempt->done_callback), |
304 | cur_attempt->cancellation_token, |
305 | cur_attempt->context->cancellation_manager()); |
306 | attempts->pop_front(); |
307 | break; |
308 | } |
309 | } |
310 | } |
311 | return progress; |
312 | } |
313 | |
314 | void QueueBase::FlushUnlocked() { |
315 | std::vector<CleanUp> clean_up; |
316 | Ref(); |
317 | { |
318 | mutex_lock lock(mu_); |
319 | bool changed; |
320 | do { |
321 | changed = TryAttemptLocked(kEnqueue, &clean_up); |
322 | changed = TryAttemptLocked(kDequeue, &clean_up) || changed; |
323 | } while (changed); |
324 | } |
325 | Unref(); |
326 | for (const auto& to_clean : clean_up) { |
327 | if (to_clean.to_deregister != CancellationManager::kInvalidToken) { |
328 | // NOTE(mrry): We can safely ignore the return value of |
329 | // DeregisterCallback because the mutex mu_ ensures that the |
330 | // cleanup action only executes once. |
331 | to_clean.cm->DeregisterCallback(to_clean.to_deregister); |
332 | } |
333 | to_clean.finished(); |
334 | } |
335 | } |
336 | |
337 | Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, |
338 | int64_t index) { |
339 | return batch_util::CopySliceToElement(parent, element, index); |
340 | } |
341 | |
342 | /* static */ |
343 | Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, |
344 | int64_t index) { |
345 | return batch_util::CopyElementToSlice(element, parent, index); |
346 | } |
347 | |
348 | } // namespace tensorflow |
349 | |