1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/queue_base.h"
17
18#include <vector>
19#include "tensorflow/core/framework/node_def.pb.h"
20#include "tensorflow/core/framework/tensor_shape.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/platform/mutex.h"
23#include "tensorflow/core/platform/types.h"
24#include "tensorflow/core/util/batch_util.h"
25
26namespace tensorflow {
27
28namespace {
29
30template <DataType DT>
31Status HandleSliceToElement(const Tensor& parent, Tensor* element,
32 int64_t index) {
33 typedef typename EnumToDataType<DT>::Type T;
34 DCHECK_NE(parent.dim_size(0), 0);
35 DCHECK_GE(index, 0);
36 if (element->NumElements() != (parent.NumElements() / parent.dim_size(0))) {
37 TensorShape chip_shape = parent.shape();
38 chip_shape.RemoveDim(0);
39 return errors::Internal(
40 "HandleSliceToElement Cannot copy slice: number of elements does not "
41 "match. Shapes are: [element]: ",
42 element->shape().DebugString(),
43 ", [parent slice]: ", chip_shape.DebugString());
44 }
45 auto parent_as_matrix = parent.flat_outer_dims<T>();
46 element->flat<T>() = parent_as_matrix.chip(index, 0);
47 return OkStatus();
48}
49
50} // namespace
51
52QueueBase::QueueBase(int32_t capacity, const DataTypeVector& component_dtypes,
53 const std::vector<TensorShape>& component_shapes,
54 const string& name)
55 : capacity_(capacity),
56 component_dtypes_(component_dtypes),
57 component_shapes_(component_shapes),
58 name_(name),
59 closed_(false) {}
60
61QueueBase::~QueueBase() {}
62
63Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
64 if (tuple.size() != static_cast<size_t>(num_components())) {
65 return errors::InvalidArgument(
66 "Wrong number of components in tuple. Expected ", num_components(),
67 ", got ", tuple.size());
68 }
69 for (size_t i = 0; i < tuple.size(); ++i) {
70 if (tuple[i].dtype() != component_dtypes_[i]) {
71 return errors::InvalidArgument(
72 "Type mismatch in tuple component ", i, ". Expected ",
73 DataTypeString(component_dtypes_[i]), ", got ",
74 DataTypeString(tuple[i].dtype()));
75 }
76 }
77 return OkStatus();
78}
79
80// static
81string QueueBase::ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
82 string result = "[";
83 bool first = true;
84 for (const TensorShape& shape : shapes) {
85 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
86 first = false;
87 }
88 strings::StrAppend(&result, "]");
89 return result;
90}
91
92Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def,
93 const string& op) const {
94 if (node_def.op() != op) {
95 return errors::InvalidArgument("Shared queue '", name_, "' has type '", op,
96 "' that does not match type of Node '",
97 node_def.name(), "': ", node_def.op());
98 }
99 return OkStatus();
100}
101
102Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def,
103 int32_t capacity) const {
104 int32_t requested_capacity = -1;
105 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity", &requested_capacity));
106 if (requested_capacity < 0) requested_capacity = kUnbounded;
107 if (requested_capacity != capacity) {
108 return errors::InvalidArgument("Shared queue '", name_, "' has capacity ",
109 capacity, " but requested capacity was ",
110 requested_capacity);
111 }
112 return OkStatus();
113}
114
115Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const {
116 DataTypeVector requested_dtypes;
117 TF_RETURN_IF_ERROR(
118 GetNodeAttr(node_def, "component_types", &requested_dtypes));
119 if (requested_dtypes != component_dtypes_) {
120 return errors::InvalidArgument("Shared queue '", name_,
121 "' has component types ",
122 DataTypeSliceString(component_dtypes_),
123 " but requested component types were ",
124 DataTypeSliceString(requested_dtypes));
125 }
126 return OkStatus();
127}
128
129Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
130 std::vector<TensorShape> requested_shapes;
131 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
132 if (requested_shapes != component_shapes_) {
133 return errors::InvalidArgument("Shared queue '", name_,
134 "' has component shapes ",
135 ShapeListString(component_shapes_),
136 " but requested component shapes were ",
137 ShapeListString(requested_shapes));
138 }
139 return OkStatus();
140}
141
142// TODO(mrry): If these checks become a bottleneck, find a way to
143// reduce the number of times that they are called.
144Status QueueBase::ValidateTuple(const Tuple& tuple) {
145 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
146 if (specified_shapes()) {
147 for (size_t i = 0; i < tuple.size(); ++i) {
148 if (!component_shapes_[i].IsSameSize(tuple[i].shape())) {
149 return errors::InvalidArgument(
150 "Shape mismatch in tuple component ", i, ". Expected ",
151 component_shapes_[i].DebugString(), ", got ",
152 tuple[i].shape().DebugString());
153 }
154 }
155 }
156 return OkStatus();
157}
158
159// TODO(mrry): If these checks become a bottleneck, find a way to
160// reduce the number of times that they are called.
161Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
162 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
163 const int64_t batch_size = tuple[0].dim_size(0);
164 if (specified_shapes()) {
165 for (size_t i = 0; i < tuple.size(); ++i) {
166 // Expected shape is [batch_size] + component_shapes_[i]
167 const TensorShape expected_shape = ManyOutShape(i, batch_size);
168 if (!expected_shape.IsSameSize(tuple[i].shape())) {
169 return errors::InvalidArgument("Shape mismatch in tuple component ", i,
170 ". Expected ",
171 expected_shape.DebugString(), ", got ",
172 tuple[i].shape().DebugString());
173 }
174 }
175 } else {
176 for (size_t i = 1; i < tuple.size(); ++i) {
177 if (tuple[i].dim_size(0) != batch_size) {
178 return errors::InvalidArgument(
179 "All input tensors must have the same size in the 0th ",
180 "dimension. Component ", i, " has ", tuple[i].dim_size(0),
181 ", and should have ", batch_size);
182 }
183 }
184 }
185 return OkStatus();
186}
187
188void QueueBase::Cancel(Action action, CancellationManager* cancellation_manager,
189 CancellationToken token) {
190 DoneCallback callback = nullptr;
191 {
192 mutex_lock lock(mu_);
193 std::deque<Attempt>* attempts =
194 action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
195
196 for (Attempt& attempt : *attempts) {
197 if (attempt.cancellation_manager == cancellation_manager &&
198 attempt.cancellation_token == token) {
199 if (!attempt.is_cancelled) {
200 attempt.is_cancelled = true;
201 if (action == kEnqueue) {
202 attempt.context->SetStatus(
203 errors::Cancelled("Enqueue operation was cancelled"));
204 } else {
205 attempt.context->SetStatus(
206 errors::Cancelled("Dequeue operation was cancelled"));
207 }
208 std::swap(callback, attempt.done_callback);
209 }
210 break;
211 }
212 }
213 }
214 if (callback) {
215 callback();
216 FlushUnlocked();
217 }
218}
219
220void QueueBase::CloseAndCancel() {
221 std::vector<DoneCallback> callbacks;
222 {
223 mutex_lock lock(mu_);
224 closed_ = true;
225 for (Attempt& attempt : enqueue_attempts_) {
226 if (!attempt.is_cancelled) {
227 attempt.is_cancelled = true;
228 attempt.context->SetStatus(
229 errors::Cancelled("Enqueue operation was cancelled"));
230 callbacks.emplace_back(std::move(attempt.done_callback));
231 }
232 }
233 }
234 for (const DoneCallback& callback : callbacks) {
235 callback();
236 }
237 FlushUnlocked();
238}
239
240void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
241 DoneCallback callback) {
242 if (cancel_pending_enqueues) {
243 CloseAndCancel();
244 callback();
245 } else {
246 {
247 mutex_lock lock(mu_);
248 enqueue_attempts_.emplace_back(
249 0, callback, ctx, nullptr, CancellationManager::kInvalidToken,
250 [this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
251 if (closed_) {
252 attempt->context->SetStatus(
253 errors::Cancelled("Queue '", name_, "' is already closed."));
254 } else {
255 closed_ = true;
256 }
257 return kComplete;
258 });
259 }
260 FlushUnlocked();
261 }
262}
263
264bool QueueBase::TryAttemptLocked(Action action,
265 std::vector<CleanUp>* clean_up) {
266 std::deque<Attempt>* attempts =
267 action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
268
269 bool progress = false;
270 bool done = false;
271 while (!done && !attempts->empty()) {
272 if (attempts->front().is_cancelled) {
273 if (action == kEnqueue) {
274 if (closed_) {
275 VLOG(1) << "Skipping cancelled enqueue attempt";
276 } else {
277 LOG(WARNING)
278 << name_
279 << ": Skipping cancelled enqueue attempt with queue not closed";
280 }
281 } else {
282 if (closed_) {
283 VLOG(1) << "Skipping cancelled dequeue attempt";
284 } else {
285 LOG(WARNING)
286 << name_
287 << ": Skipping cancelled dequeue attempt with queue not closed";
288 }
289 }
290 attempts->pop_front();
291 } else {
292 Attempt* cur_attempt = &attempts->front();
293 switch (cur_attempt->run_callback(cur_attempt)) {
294 case kNoProgress:
295 done = true;
296 break;
297 case kProgress:
298 done = true;
299 progress = true;
300 break;
301 case kComplete:
302 progress = true;
303 clean_up->emplace_back(std::move(cur_attempt->done_callback),
304 cur_attempt->cancellation_token,
305 cur_attempt->context->cancellation_manager());
306 attempts->pop_front();
307 break;
308 }
309 }
310 }
311 return progress;
312}
313
314void QueueBase::FlushUnlocked() {
315 std::vector<CleanUp> clean_up;
316 Ref();
317 {
318 mutex_lock lock(mu_);
319 bool changed;
320 do {
321 changed = TryAttemptLocked(kEnqueue, &clean_up);
322 changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
323 } while (changed);
324 }
325 Unref();
326 for (const auto& to_clean : clean_up) {
327 if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
328 // NOTE(mrry): We can safely ignore the return value of
329 // DeregisterCallback because the mutex mu_ ensures that the
330 // cleanup action only executes once.
331 to_clean.cm->DeregisterCallback(to_clean.to_deregister);
332 }
333 to_clean.finished();
334 }
335}
336
337Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
338 int64_t index) {
339 return batch_util::CopySliceToElement(parent, element, index);
340}
341
342/* static */
343Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
344 int64_t index) {
345 return batch_util::CopyElementToSlice(element, parent, index);
346}
347
348} // namespace tensorflow
349