1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/training/queue_runner.h" |
17 | #include "tensorflow/core/kernels/ops_util.h" |
18 | #include "tensorflow/core/platform/env.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, |
23 | std::unique_ptr<QueueRunner>* result) { |
24 | result->reset(new QueueRunner()); |
25 | return (*result)->Init(queue_runner_def); |
26 | } |
27 | |
28 | Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, |
29 | Coordinator* coord, |
30 | std::unique_ptr<QueueRunner>* result) { |
31 | result->reset(new QueueRunner()); |
32 | (*result)->coord_ = coord; |
33 | return (*result)->Init(queue_runner_def); |
34 | } |
35 | |
36 | void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) { |
37 | mutex_lock l(cb_mu_); |
38 | callbacks_.push_back(cb); |
39 | } |
40 | |
41 | void QueueRunner::ClearErrorCallbacks() { |
42 | mutex_lock l(cb_mu_); |
43 | callbacks_.clear(); |
44 | } |
45 | |
46 | Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { |
47 | queue_name_ = queue_runner_def.queue_name(); |
48 | enqueue_op_names_.clear(); |
49 | enqueue_op_names_.insert(enqueue_op_names_.end(), |
50 | queue_runner_def.enqueue_op_name().begin(), |
51 | queue_runner_def.enqueue_op_name().end()); |
52 | size_t op_names_size = enqueue_op_names_.size(); |
53 | if (op_names_size > kint32max) { |
54 | return Status(error::INVALID_ARGUMENT, |
55 | "Enqueue ops to run cannot exceed kint32max" ); |
56 | } |
57 | runs_ = static_cast<int>(op_names_size); |
58 | if (runs_ == 0) { |
59 | return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run." ); |
60 | } |
61 | close_op_name_ = queue_runner_def.close_op_name(); |
62 | cancel_op_name_ = queue_runner_def.cancel_op_name(); |
63 | if (queue_runner_def.queue_closed_exception_types_size() == 0) { |
64 | queue_closed_exception_types_.insert(error::OUT_OF_RANGE); |
65 | } else { |
66 | for (const auto& code : queue_runner_def.queue_closed_exception_types()) { |
67 | queue_closed_exception_types_.insert(static_cast<int>(code)); |
68 | } |
69 | } |
70 | |
71 | int nthreads = runs_; |
72 | if (coord_) { |
73 | // One more thread to call Stop() |
74 | nthreads++; |
75 | } |
76 | thread_pool_.reset(new thread::ThreadPool( |
77 | Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads)); |
78 | |
79 | return OkStatus(); |
80 | } |
81 | |
82 | QueueRunner::~QueueRunner() { |
83 | // Cannot run Stop() here because the session might already be closed or |
84 | // destroyed. |
85 | Join().IgnoreError(); |
86 | } |
87 | |
88 | Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } |
89 | |
90 | Status QueueRunner::StartAndCollectCostGraph(Session* sess, |
91 | const RunOptions& run_options) { |
92 | SetRunArgumentsAndCostGraph(run_options); |
93 | return Start(sess, 0); |
94 | } |
95 | |
96 | Status QueueRunner::Start(Session* sess, int wait_for) { |
97 | counter_.reset(new BlockingCounter(runs_)); |
98 | for (const string& enqueue_op : enqueue_op_names_) { |
99 | thread_pool_->Schedule( |
100 | std::bind(&QueueRunner::Run, this, sess, enqueue_op)); |
101 | } |
102 | if (coord_) { |
103 | thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess)); |
104 | } |
105 | // Wait for up to 'wait_for' milliseconds. |
106 | if (wait_for > 0) { |
107 | if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) { |
108 | return Status(error::DEADLINE_EXCEEDED, |
109 | "Queues not fed before the timeout" ); |
110 | } |
111 | // Check the status of the queue runner as well as the result of the enqueue |
112 | // operations. |
113 | mutex_lock l(mu_); |
114 | if (!enqueue_status_.ok()) { |
115 | return enqueue_status_; |
116 | } else { |
117 | return status_; |
118 | } |
119 | } |
120 | return OkStatus(); |
121 | } |
122 | |
123 | Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms, |
124 | const RunOptions& run_options) { |
125 | SetRunArgumentsAndCostGraph(run_options); |
126 | return Start(session, wait_for_ms); |
127 | } |
128 | |
129 | void QueueRunner::Stop(Session* sess) { |
130 | if (coord_ != nullptr) { |
131 | coord_->WaitForStop(); |
132 | } |
133 | if (!cancel_op_name_.empty()) { |
134 | UpdateStatus(RealRun(sess, cancel_op_name_, false)); |
135 | } |
136 | stopped_ = true; |
137 | } |
138 | |
139 | Status QueueRunner::Join() { |
140 | thread_pool_.reset(); |
141 | mutex_lock l(mu_); |
142 | return status_; |
143 | } |
144 | |
145 | void QueueRunner::UpdateStatus(const Status& status) { |
146 | { |
147 | mutex_lock l(mu_); |
148 | if (!status_.ok() || status.ok() || IsQueueClosed(status)) { |
149 | return; |
150 | } |
151 | status_ = status; |
152 | } |
153 | if (coord_) { |
154 | coord_->ReportStatus(status); |
155 | } |
156 | mutex_lock l(cb_mu_); |
157 | for (auto& cb : callbacks_) { |
158 | cb(status); |
159 | } |
160 | } |
161 | |
162 | void QueueRunner::Run(Session* sess, const string& enqueue_op) { |
163 | bool first_iteration = true; |
164 | Status status; |
165 | while (status.ok()) { |
166 | if (coord_ && coord_->ShouldStop()) { |
167 | break; |
168 | } |
169 | status = RealRun(sess, enqueue_op, true); |
170 | if (first_iteration) { |
171 | if (!status.ok()) { |
172 | mutex_lock l(mu_); |
173 | enqueue_status_ = status; |
174 | } |
175 | counter_->DecrementCount(); |
176 | first_iteration = false; |
177 | } |
178 | } |
179 | bool last_run = false; |
180 | { |
181 | mutex_lock l(mu_); |
182 | runs_--; |
183 | last_run = (runs_ == 0); |
184 | } |
185 | |
186 | // Close the queue unless the coordinator is shutting down since the cancel op |
187 | // will be run anyway in this case. |
188 | if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) { |
189 | if (last_run && !close_op_name_.empty()) { |
190 | UpdateStatus(RealRun(sess, close_op_name_, false)); |
191 | } |
192 | } else if (!status.ok()) { |
193 | LOG(ERROR) << "Queue runner thread got a failure status: " |
194 | << status.ToString(); |
195 | UpdateStatus(status); |
196 | if (coord_) { |
197 | coord_->RequestStop().IgnoreError(); |
198 | } |
199 | } |
200 | } |
201 | |
202 | Status QueueRunner::GetStatus() { |
203 | mutex_lock l(mu_); |
204 | return status_; |
205 | } |
206 | |
207 | Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const { |
208 | if (!cg_mu_) { |
209 | return Status(error::FAILED_PRECONDITION, |
210 | "This QueueRunner doesn't collect a cost graph." ); |
211 | } |
212 | mutex_lock l(*cg_mu_); |
213 | cost_graph->MergeFrom(*cost_graph_); |
214 | return OkStatus(); |
215 | } |
216 | |
217 | void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) { |
218 | cg_mu_.reset(new mutex()); |
219 | { |
220 | mutex_lock l(*cg_mu_); |
221 | cost_graph_.reset(new CostGraphDef()); |
222 | } |
223 | run_options_ = run_options; |
224 | } |
225 | |
226 | Status QueueRunner::RealRun(Session* sess, const string& op, |
227 | bool update_costs) { |
228 | Status s; |
229 | if (update_costs && cg_mu_) { |
230 | RunMetadata metadata; |
231 | s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata); |
232 | mutex_lock l(*cg_mu_); |
233 | cost_graph_->Swap(metadata.mutable_cost_graph()); |
234 | } else { |
235 | s = sess->Run({}, {}, {op}, nullptr); |
236 | } |
237 | return s; |
238 | } |
239 | |
240 | } // namespace tensorflow |
241 | |