1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
17#define TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
18
19#include <deque>
20#include <vector>
21
22#include "absl/base/macros.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/queue_interface.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/lib/gtl/array_slice.h"
29#include "tensorflow/core/platform/macros.h"
30#include "tensorflow/core/platform/mutex.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34
35// Functionality common to asynchronous QueueInterface implementations.
36class QueueBase : public QueueInterface {
37 public:
38 // As a possible value of 'capacity'.
39 static constexpr int32_t kUnbounded = INT_MAX;
40
41 // Args:
42 // component_dtypes: The types of each component in a queue-element tuple.
43 // component_shapes: The shapes of each component in a queue-element tuple,
44 // which must either be empty (if the shapes are not specified) or
45 // or have the same size as component_dtypes.
46 // name: A name to use for the queue.
47 QueueBase(int32_t capacity, const DataTypeVector& component_dtypes,
48 const std::vector<TensorShape>& component_shapes,
49 const string& name);
50
51 // Implementations of QueueInterface methods --------------------------------
52 const DataTypeVector& component_dtypes() const override {
53 return component_dtypes_;
54 }
55
56 Status ValidateTuple(const Tuple& tuple) override;
57 Status ValidateManyTuple(const Tuple& tuple) override;
58
59 void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
60 DoneCallback callback) override;
61
62 // Other public methods -----------------------------------------------------
63 const std::vector<TensorShape>& component_shapes() const {
64 return component_shapes_;
65 }
66
67 int32 capacity() const { return capacity_; }
68
69 bool is_closed() const override {
70 mutex_lock lock(mu_);
71 return closed_;
72 }
73
74 // Copies the index^th slice (in the first dimension) of parent into element.
75 static Status CopySliceToElement(const Tensor& parent, Tensor* element,
76 int64_t index);
77
78 // Copies element into the index^th slice (in the first dimension) of parent.
79 // NOTE(mrry): This method is deprecated. Use
80 // `tensorflow::batch_util::CopySliceToElement()` defined in
81 // "./batch_util.h" instead.
82 ABSL_DEPRECATED(
83 "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
84 "\"./batch_util.h\" instead.")
85 static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
86 int64_t index);
87
88 protected:
89 enum Action { kEnqueue, kDequeue };
90 enum RunResult { kNoProgress, kProgress, kComplete };
91
92 // Tries to enqueue/dequeue (or close) based on whatever is at the
93 // front of enqueue_attempts_/dequeue_attempts_. Appends to
94 // *finished the callback for any finished attempt (so it may be
95 // called once mu_ is released). Returns true if any progress was
96 // made.
97 struct CleanUp {
98 CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
99 : finished(f), to_deregister(ct), cm(cm) {}
100 DoneCallback finished;
101 CancellationToken to_deregister;
102 CancellationManager* cm;
103 };
104
105 // Returns the number of components in a queue-element tuple.
106 int32 num_components() const { return component_dtypes_.size(); }
107
108 // True if shapes were specified. If so, inputs will be validated
109 // against them, etc.
110 bool specified_shapes() const { return component_shapes_.size() > 0; }
111
112 // Code common to Validate*Tuple().
113 Status ValidateTupleCommon(const Tuple& tuple) const;
114
115 TensorShape ManyOutShape(int i, int64_t batch_size) {
116 TensorShape shape({batch_size});
117 shape.AppendShape(component_shapes_[i]);
118 return shape;
119 }
120
121 void Cancel(Action action, CancellationManager* cancellation_manager,
122 CancellationToken token);
123
124 // Helper for cancelling all pending Enqueue(Many) operations when
125 // Close is called with cancel_pending_enqueues.
126 void CloseAndCancel();
127
128 bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
129 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
130
131 // Tries to make progress on the enqueues or dequeues at the front
132 // of the *_attempts_ queues.
133 void FlushUnlocked();
134
135 ~QueueBase() override;
136
137 // Helpers for implementing MatchesNodeDef().
138 static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
139 Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const;
140 Status MatchesNodeDefCapacity(const NodeDef& node_def,
141 int32_t capacity) const;
142 Status MatchesNodeDefTypes(const NodeDef& node_def) const;
143 Status MatchesNodeDefShapes(const NodeDef& node_def) const;
144
145 protected:
146 const int32 capacity_;
147 const DataTypeVector component_dtypes_;
148 const std::vector<TensorShape> component_shapes_;
149 const string name_;
150 mutable mutex mu_;
151 bool closed_ TF_GUARDED_BY(mu_);
152
153 struct Attempt;
154 typedef std::function<RunResult(Attempt*)> RunCallback;
155 struct Attempt {
156 int32 elements_requested;
157 DoneCallback done_callback; // must be run outside mu_
158 OpKernelContext* context;
159 CancellationManager* cancellation_manager; // not owned
160 CancellationToken cancellation_token;
161 RunCallback run_callback; // must be run while holding mu_
162 bool is_cancelled;
163 Tuple tuple;
164 // tuples is used by some implementations allowing dynamic shapes.
165 std::vector<Tuple> tuples;
166
167 Attempt(int32_t elements_requested, DoneCallback done_callback,
168 OpKernelContext* context, CancellationManager* cancellation_manager,
169 CancellationToken cancellation_token, RunCallback run_callback)
170 : elements_requested(elements_requested),
171 done_callback(done_callback),
172 context(context),
173 cancellation_manager(cancellation_manager),
174 cancellation_token(cancellation_token),
175 run_callback(run_callback),
176 is_cancelled(false) {}
177 };
178 std::deque<Attempt> enqueue_attempts_ TF_GUARDED_BY(mu_);
179 std::deque<Attempt> dequeue_attempts_ TF_GUARDED_BY(mu_);
180
181 TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
182};
183
184} // namespace tensorflow
185
186#endif // TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
187