1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
17#define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
18
19#include <memory>
20#include <string>
21#include <unordered_set>
22#include <vector>
23
24#include "tensorflow/cc/training/coordinator.h"
25#include "tensorflow/core/lib/core/status.h"
26#include "tensorflow/core/lib/core/threadpool.h"
27#include "tensorflow/core/platform/blocking_counter.h"
28#include "tensorflow/core/platform/mutex.h"
29#include "tensorflow/core/protobuf/config.pb.h"
30#include "tensorflow/core/protobuf/error_codes.pb.h"
31#include "tensorflow/core/protobuf/queue_runner.pb.h"
32#include "tensorflow/core/public/session.h"
33
34namespace tensorflow {
35
36/// QueueRunner class imitates the behavior of the python version of QueueRunner
37/// which creates a thread for each enqueue op, runs close op on completion.
38class QueueRunner : public RunnerInterface {
39 public:
40 /// Creates a new QueueRunner from proto.
41 // TODO(yuefengz): we may want to initialize from queues and ops in the
42 // future.
43 static Status New(const QueueRunnerDef& queue_runner_def,
44 std::unique_ptr<QueueRunner>* result);
45
46 /// Creates a new QueueRunner with a coordinator, see coordinator.h for usage.
47 static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
48 std::unique_ptr<QueueRunner>* result);
49
50 /// Adds a callback that the queue runner will call when it detects an error.
51 void AddErrorCallback(const std::function<void(Status)>& cb);
52
53 /// Delete the previously registered callbacks.
54 void ClearErrorCallbacks();
55
56 /// The destructor would join all the threads.
57 ~QueueRunner();
58
59 /// Starts the queue runner with the given session.
60 Status Start(Session* sess);
61
62 /// Starts the queue runner with the given session and sets the run arguments
63 /// for sess->Run. It also collects and stores the cost model.
64 Status StartAndCollectCostGraph(Session* sess,
65 const RunOptions& run_options = RunOptions());
66
67 /// Starts the queue runner with the given session, and wait for up to the
68 /// specified time (in milliseconds) for the queues to start to fill up.
69 Status Start(Session* sess, int wait_for_ms);
70 Status StartAndCollectCostGraph(Session* session, int wait_for_ms,
71 const RunOptions& run_options = RunOptions());
72
73 /// Requests to stop and runs the cancel op. It would be called in a separate
74 /// thread when coordinator is set. If there is no coordinator it should be
75 /// called before calling Join.
76 void Stop(Session* sess);
77
78 /// Joins all the threads. Returns okay if all threads run successfully;
79 /// otherwise returns the first captured failure status.
80 Status Join() final;
81
82 /// Returns the latest status.
83 Status GetStatus();
84
85 // Returns the stored cost model.
86 Status ExportCostGraph(CostGraphDef* cost_graph) const override;
87
88 private:
89 QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {}
90
91 // Initializes the instance with the QueueRunnerDef proto.
92 Status Init(const QueueRunnerDef& queue_runner_def);
93
94 // The Run function for each thread.
95 void Run(Session* sess, const string& enqueue_op);
96
97 // Updates the internal status; it only keeps OK or the first unexpected error
98 // status.
99 void UpdateStatus(const Status& status);
100
101 bool IsQueueClosed(Status status) const {
102 return queue_closed_exception_types_.count(
103 static_cast<int>(status.code())) > 0;
104 }
105
106 bool IsRunning() const override { return !stopped_; }
107
108 void SetRunArgumentsAndCostGraph(const RunOptions& run_options);
109
110 Status RealRun(Session* sess, const string& op, bool update_costs);
111
112 string queue_name_;
113 std::vector<string> enqueue_op_names_;
114 string close_op_name_;
115 string cancel_op_name_;
116 // code::Code casted to int to avoid a hash function.
117 std::unordered_set<int> queue_closed_exception_types_;
118
119 std::unique_ptr<thread::ThreadPool> thread_pool_;
120 mutex mu_;
121 int runs_ = 0;
122 Status status_ TF_GUARDED_BY(mu_);
123 Status enqueue_status_ TF_GUARDED_BY(mu_);
124 std::unique_ptr<BlockingCounter> counter_;
125
126 Coordinator* coord_;
127
128 std::atomic<bool> stopped_;
129
130 mutex cb_mu_;
131 std::vector<std::function<void(Status)>> callbacks_;
132
133 mutable std::unique_ptr<mutex> cg_mu_;
134 std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_);
135 RunOptions run_options_;
136};
137
138} // namespace tensorflow
139
140#endif // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
141