1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ |
17 | #define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | #include <unordered_set> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/cc/training/coordinator.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/lib/core/threadpool.h" |
27 | #include "tensorflow/core/platform/blocking_counter.h" |
28 | #include "tensorflow/core/platform/mutex.h" |
29 | #include "tensorflow/core/protobuf/config.pb.h" |
30 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
31 | #include "tensorflow/core/protobuf/queue_runner.pb.h" |
32 | #include "tensorflow/core/public/session.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | /// QueueRunner class imitates the behavior of the python version of QueueRunner |
37 | /// which creates a thread for each enqueue op, runs close op on completion. |
38 | class QueueRunner : public RunnerInterface { |
39 | public: |
40 | /// Creates a new QueueRunner from proto. |
41 | // TODO(yuefengz): we may want to initialize from queues and ops in the |
42 | // future. |
43 | static Status New(const QueueRunnerDef& queue_runner_def, |
44 | std::unique_ptr<QueueRunner>* result); |
45 | |
46 | /// Creates a new QueueRunner with a coordinator, see coordinator.h for usage. |
47 | static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord, |
48 | std::unique_ptr<QueueRunner>* result); |
49 | |
50 | /// Adds a callback that the queue runner will call when it detects an error. |
51 | void AddErrorCallback(const std::function<void(Status)>& cb); |
52 | |
53 | /// Delete the previously registered callbacks. |
54 | void ClearErrorCallbacks(); |
55 | |
56 | /// The destructor would join all the threads. |
57 | ~QueueRunner(); |
58 | |
59 | /// Starts the queue runner with the given session. |
60 | Status Start(Session* sess); |
61 | |
62 | /// Starts the queue runner with the given session and sets the run arguments |
63 | /// for sess->Run. It also collects and stores the cost model. |
64 | Status StartAndCollectCostGraph(Session* sess, |
65 | const RunOptions& run_options = RunOptions()); |
66 | |
67 | /// Starts the queue runner with the given session, and wait for up to the |
68 | /// specified time (in milliseconds) for the queues to start to fill up. |
69 | Status Start(Session* sess, int wait_for_ms); |
70 | Status StartAndCollectCostGraph(Session* session, int wait_for_ms, |
71 | const RunOptions& run_options = RunOptions()); |
72 | |
73 | /// Requests to stop and runs the cancel op. It would be called in a separate |
74 | /// thread when coordinator is set. If there is no coordinator it should be |
75 | /// called before calling Join. |
76 | void Stop(Session* sess); |
77 | |
78 | /// Joins all the threads. Returns okay if all threads run successfully; |
79 | /// otherwise returns the first captured failure status. |
80 | Status Join() final; |
81 | |
82 | /// Returns the latest status. |
83 | Status GetStatus(); |
84 | |
85 | // Returns the stored cost model. |
86 | Status ExportCostGraph(CostGraphDef* cost_graph) const override; |
87 | |
88 | private: |
89 | QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {} |
90 | |
91 | // Initializes the instance with the QueueRunnerDef proto. |
92 | Status Init(const QueueRunnerDef& queue_runner_def); |
93 | |
94 | // The Run function for each thread. |
95 | void Run(Session* sess, const string& enqueue_op); |
96 | |
97 | // Updates the internal status; it only keeps OK or the first unexpected error |
98 | // status. |
99 | void UpdateStatus(const Status& status); |
100 | |
101 | bool IsQueueClosed(Status status) const { |
102 | return queue_closed_exception_types_.count( |
103 | static_cast<int>(status.code())) > 0; |
104 | } |
105 | |
106 | bool IsRunning() const override { return !stopped_; } |
107 | |
108 | void SetRunArgumentsAndCostGraph(const RunOptions& run_options); |
109 | |
110 | Status RealRun(Session* sess, const string& op, bool update_costs); |
111 | |
112 | string queue_name_; |
113 | std::vector<string> enqueue_op_names_; |
114 | string close_op_name_; |
115 | string cancel_op_name_; |
116 | // code::Code casted to int to avoid a hash function. |
117 | std::unordered_set<int> queue_closed_exception_types_; |
118 | |
119 | std::unique_ptr<thread::ThreadPool> thread_pool_; |
120 | mutex mu_; |
121 | int runs_ = 0; |
122 | Status status_ TF_GUARDED_BY(mu_); |
123 | Status enqueue_status_ TF_GUARDED_BY(mu_); |
124 | std::unique_ptr<BlockingCounter> counter_; |
125 | |
126 | Coordinator* coord_; |
127 | |
128 | std::atomic<bool> stopped_; |
129 | |
130 | mutex cb_mu_; |
131 | std::vector<std::function<void(Status)>> callbacks_; |
132 | |
133 | mutable std::unique_ptr<mutex> cg_mu_; |
134 | std::unique_ptr<CostGraphDef> cost_graph_ TF_GUARDED_BY(cg_mu_); |
135 | RunOptions run_options_; |
136 | }; |
137 | |
138 | } // namespace tensorflow |
139 | |
140 | #endif // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ |
141 | |