1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/training/queue_runner.h"
17#include "tensorflow/core/kernels/ops_util.h"
18#include "tensorflow/core/platform/env.h"
19
20namespace tensorflow {
21
22Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
23 std::unique_ptr<QueueRunner>* result) {
24 result->reset(new QueueRunner());
25 return (*result)->Init(queue_runner_def);
26}
27
28Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
29 Coordinator* coord,
30 std::unique_ptr<QueueRunner>* result) {
31 result->reset(new QueueRunner());
32 (*result)->coord_ = coord;
33 return (*result)->Init(queue_runner_def);
34}
35
36void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) {
37 mutex_lock l(cb_mu_);
38 callbacks_.push_back(cb);
39}
40
41void QueueRunner::ClearErrorCallbacks() {
42 mutex_lock l(cb_mu_);
43 callbacks_.clear();
44}
45
46Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
47 queue_name_ = queue_runner_def.queue_name();
48 enqueue_op_names_.clear();
49 enqueue_op_names_.insert(enqueue_op_names_.end(),
50 queue_runner_def.enqueue_op_name().begin(),
51 queue_runner_def.enqueue_op_name().end());
52 size_t op_names_size = enqueue_op_names_.size();
53 if (op_names_size > kint32max) {
54 return Status(error::INVALID_ARGUMENT,
55 "Enqueue ops to run cannot exceed kint32max");
56 }
57 runs_ = static_cast<int>(op_names_size);
58 if (runs_ == 0) {
59 return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run.");
60 }
61 close_op_name_ = queue_runner_def.close_op_name();
62 cancel_op_name_ = queue_runner_def.cancel_op_name();
63 if (queue_runner_def.queue_closed_exception_types_size() == 0) {
64 queue_closed_exception_types_.insert(error::OUT_OF_RANGE);
65 } else {
66 for (const auto& code : queue_runner_def.queue_closed_exception_types()) {
67 queue_closed_exception_types_.insert(static_cast<int>(code));
68 }
69 }
70
71 int nthreads = runs_;
72 if (coord_) {
73 // One more thread to call Stop()
74 nthreads++;
75 }
76 thread_pool_.reset(new thread::ThreadPool(
77 Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads));
78
79 return OkStatus();
80}
81
82QueueRunner::~QueueRunner() {
83 // Cannot run Stop() here because the session might already be closed or
84 // destroyed.
85 Join().IgnoreError();
86}
87
88Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
89
90Status QueueRunner::StartAndCollectCostGraph(Session* sess,
91 const RunOptions& run_options) {
92 SetRunArgumentsAndCostGraph(run_options);
93 return Start(sess, 0);
94}
95
96Status QueueRunner::Start(Session* sess, int wait_for) {
97 counter_.reset(new BlockingCounter(runs_));
98 for (const string& enqueue_op : enqueue_op_names_) {
99 thread_pool_->Schedule(
100 std::bind(&QueueRunner::Run, this, sess, enqueue_op));
101 }
102 if (coord_) {
103 thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
104 }
105 // Wait for up to 'wait_for' milliseconds.
106 if (wait_for > 0) {
107 if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
108 return Status(error::DEADLINE_EXCEEDED,
109 "Queues not fed before the timeout");
110 }
111 // Check the status of the queue runner as well as the result of the enqueue
112 // operations.
113 mutex_lock l(mu_);
114 if (!enqueue_status_.ok()) {
115 return enqueue_status_;
116 } else {
117 return status_;
118 }
119 }
120 return OkStatus();
121}
122
123Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms,
124 const RunOptions& run_options) {
125 SetRunArgumentsAndCostGraph(run_options);
126 return Start(session, wait_for_ms);
127}
128
129void QueueRunner::Stop(Session* sess) {
130 if (coord_ != nullptr) {
131 coord_->WaitForStop();
132 }
133 if (!cancel_op_name_.empty()) {
134 UpdateStatus(RealRun(sess, cancel_op_name_, false));
135 }
136 stopped_ = true;
137}
138
139Status QueueRunner::Join() {
140 thread_pool_.reset();
141 mutex_lock l(mu_);
142 return status_;
143}
144
145void QueueRunner::UpdateStatus(const Status& status) {
146 {
147 mutex_lock l(mu_);
148 if (!status_.ok() || status.ok() || IsQueueClosed(status)) {
149 return;
150 }
151 status_ = status;
152 }
153 if (coord_) {
154 coord_->ReportStatus(status);
155 }
156 mutex_lock l(cb_mu_);
157 for (auto& cb : callbacks_) {
158 cb(status);
159 }
160}
161
162void QueueRunner::Run(Session* sess, const string& enqueue_op) {
163 bool first_iteration = true;
164 Status status;
165 while (status.ok()) {
166 if (coord_ && coord_->ShouldStop()) {
167 break;
168 }
169 status = RealRun(sess, enqueue_op, true);
170 if (first_iteration) {
171 if (!status.ok()) {
172 mutex_lock l(mu_);
173 enqueue_status_ = status;
174 }
175 counter_->DecrementCount();
176 first_iteration = false;
177 }
178 }
179 bool last_run = false;
180 {
181 mutex_lock l(mu_);
182 runs_--;
183 last_run = (runs_ == 0);
184 }
185
186 // Close the queue unless the coordinator is shutting down since the cancel op
187 // will be run anyway in this case.
188 if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
189 if (last_run && !close_op_name_.empty()) {
190 UpdateStatus(RealRun(sess, close_op_name_, false));
191 }
192 } else if (!status.ok()) {
193 LOG(ERROR) << "Queue runner thread got a failure status: "
194 << status.ToString();
195 UpdateStatus(status);
196 if (coord_) {
197 coord_->RequestStop().IgnoreError();
198 }
199 }
200}
201
202Status QueueRunner::GetStatus() {
203 mutex_lock l(mu_);
204 return status_;
205}
206
207Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const {
208 if (!cg_mu_) {
209 return Status(error::FAILED_PRECONDITION,
210 "This QueueRunner doesn't collect a cost graph.");
211 }
212 mutex_lock l(*cg_mu_);
213 cost_graph->MergeFrom(*cost_graph_);
214 return OkStatus();
215}
216
217void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) {
218 cg_mu_.reset(new mutex());
219 {
220 mutex_lock l(*cg_mu_);
221 cost_graph_.reset(new CostGraphDef());
222 }
223 run_options_ = run_options;
224}
225
226Status QueueRunner::RealRun(Session* sess, const string& op,
227 bool update_costs) {
228 Status s;
229 if (update_costs && cg_mu_) {
230 RunMetadata metadata;
231 s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
232 mutex_lock l(*cg_mu_);
233 cost_graph_->Swap(metadata.mutable_cost_graph());
234 } else {
235 s = sess->Run({}, {}, {op}, nullptr);
236 }
237 return s;
238}
239
240} // namespace tensorflow
241