1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ |
18 | |
19 | #include <deque> |
20 | #include <vector> |
21 | |
22 | #include "absl/base/macros.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/queue_interface.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/lib/gtl/array_slice.h" |
29 | #include "tensorflow/core/platform/macros.h" |
30 | #include "tensorflow/core/platform/mutex.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | // Functionality common to asynchronous QueueInterface implementations. |
36 | class QueueBase : public QueueInterface { |
37 | public: |
38 | // As a possible value of 'capacity'. |
39 | static constexpr int32_t kUnbounded = INT_MAX; |
40 | |
41 | // Args: |
42 | // component_dtypes: The types of each component in a queue-element tuple. |
43 | // component_shapes: The shapes of each component in a queue-element tuple, |
44 | // which must either be empty (if the shapes are not specified) or |
45 | // or have the same size as component_dtypes. |
46 | // name: A name to use for the queue. |
47 | QueueBase(int32_t capacity, const DataTypeVector& component_dtypes, |
48 | const std::vector<TensorShape>& component_shapes, |
49 | const string& name); |
50 | |
51 | // Implementations of QueueInterface methods -------------------------------- |
52 | const DataTypeVector& component_dtypes() const override { |
53 | return component_dtypes_; |
54 | } |
55 | |
56 | Status ValidateTuple(const Tuple& tuple) override; |
57 | Status ValidateManyTuple(const Tuple& tuple) override; |
58 | |
59 | void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, |
60 | DoneCallback callback) override; |
61 | |
62 | // Other public methods ----------------------------------------------------- |
63 | const std::vector<TensorShape>& component_shapes() const { |
64 | return component_shapes_; |
65 | } |
66 | |
67 | int32 capacity() const { return capacity_; } |
68 | |
69 | bool is_closed() const override { |
70 | mutex_lock lock(mu_); |
71 | return closed_; |
72 | } |
73 | |
74 | // Copies the index^th slice (in the first dimension) of parent into element. |
75 | static Status CopySliceToElement(const Tensor& parent, Tensor* element, |
76 | int64_t index); |
77 | |
78 | // Copies element into the index^th slice (in the first dimension) of parent. |
79 | // NOTE(mrry): This method is deprecated. Use |
80 | // `tensorflow::batch_util::CopySliceToElement()` defined in |
81 | // "./batch_util.h" instead. |
82 | ABSL_DEPRECATED( |
83 | "Use `tensorflow::batch_util::CopySliceToElement()` defined in " |
84 | "\"./batch_util.h\" instead." ) |
85 | static Status CopyElementToSlice(const Tensor& element, Tensor* parent, |
86 | int64_t index); |
87 | |
88 | protected: |
89 | enum Action { kEnqueue, kDequeue }; |
90 | enum RunResult { kNoProgress, kProgress, kComplete }; |
91 | |
92 | // Tries to enqueue/dequeue (or close) based on whatever is at the |
93 | // front of enqueue_attempts_/dequeue_attempts_. Appends to |
94 | // *finished the callback for any finished attempt (so it may be |
95 | // called once mu_ is released). Returns true if any progress was |
96 | // made. |
97 | struct CleanUp { |
98 | CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) |
99 | : finished(f), to_deregister(ct), cm(cm) {} |
100 | DoneCallback finished; |
101 | CancellationToken to_deregister; |
102 | CancellationManager* cm; |
103 | }; |
104 | |
105 | // Returns the number of components in a queue-element tuple. |
106 | int32 num_components() const { return component_dtypes_.size(); } |
107 | |
108 | // True if shapes were specified. If so, inputs will be validated |
109 | // against them, etc. |
110 | bool specified_shapes() const { return component_shapes_.size() > 0; } |
111 | |
112 | // Code common to Validate*Tuple(). |
113 | Status ValidateTupleCommon(const Tuple& tuple) const; |
114 | |
115 | TensorShape ManyOutShape(int i, int64_t batch_size) { |
116 | TensorShape shape({batch_size}); |
117 | shape.AppendShape(component_shapes_[i]); |
118 | return shape; |
119 | } |
120 | |
121 | void Cancel(Action action, CancellationManager* cancellation_manager, |
122 | CancellationToken token); |
123 | |
124 | // Helper for cancelling all pending Enqueue(Many) operations when |
125 | // Close is called with cancel_pending_enqueues. |
126 | void CloseAndCancel(); |
127 | |
128 | bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up) |
129 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
130 | |
131 | // Tries to make progress on the enqueues or dequeues at the front |
132 | // of the *_attempts_ queues. |
133 | void FlushUnlocked(); |
134 | |
135 | ~QueueBase() override; |
136 | |
137 | // Helpers for implementing MatchesNodeDef(). |
138 | static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes); |
139 | Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const; |
140 | Status MatchesNodeDefCapacity(const NodeDef& node_def, |
141 | int32_t capacity) const; |
142 | Status MatchesNodeDefTypes(const NodeDef& node_def) const; |
143 | Status MatchesNodeDefShapes(const NodeDef& node_def) const; |
144 | |
145 | protected: |
146 | const int32 capacity_; |
147 | const DataTypeVector component_dtypes_; |
148 | const std::vector<TensorShape> component_shapes_; |
149 | const string name_; |
150 | mutable mutex mu_; |
151 | bool closed_ TF_GUARDED_BY(mu_); |
152 | |
153 | struct Attempt; |
154 | typedef std::function<RunResult(Attempt*)> RunCallback; |
155 | struct Attempt { |
156 | int32 elements_requested; |
157 | DoneCallback done_callback; // must be run outside mu_ |
158 | OpKernelContext* context; |
159 | CancellationManager* cancellation_manager; // not owned |
160 | CancellationToken cancellation_token; |
161 | RunCallback run_callback; // must be run while holding mu_ |
162 | bool is_cancelled; |
163 | Tuple tuple; |
164 | // tuples is used by some implementations allowing dynamic shapes. |
165 | std::vector<Tuple> tuples; |
166 | |
167 | Attempt(int32_t elements_requested, DoneCallback done_callback, |
168 | OpKernelContext* context, CancellationManager* cancellation_manager, |
169 | CancellationToken cancellation_token, RunCallback run_callback) |
170 | : elements_requested(elements_requested), |
171 | done_callback(done_callback), |
172 | context(context), |
173 | cancellation_manager(cancellation_manager), |
174 | cancellation_token(cancellation_token), |
175 | run_callback(run_callback), |
176 | is_cancelled(false) {} |
177 | }; |
178 | std::deque<Attempt> enqueue_attempts_ TF_GUARDED_BY(mu_); |
179 | std::deque<Attempt> dequeue_attempts_ TF_GUARDED_BY(mu_); |
180 | |
181 | TF_DISALLOW_COPY_AND_ASSIGN(QueueBase); |
182 | }; |
183 | |
184 | } // namespace tensorflow |
185 | |
186 | #endif // TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ |
187 | |