1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CC_TRAINING_COORDINATOR_H_ |
17 | #define TENSORFLOW_CC_TRAINING_COORDINATOR_H_ |
18 | |
19 | #include <atomic> |
20 | #include <memory> |
21 | #include <unordered_set> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/cost_graph.pb.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/platform/macros.h" |
27 | #include "tensorflow/core/platform/mutex.h" |
28 | #include "tensorflow/core/protobuf/config.pb.h" |
29 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | /// The abstract interface for runners which must implement the Join and the |
34 | /// IsRunning function. |
35 | class RunnerInterface { |
36 | public: |
37 | virtual ~RunnerInterface() {} |
38 | virtual Status Join() = 0; |
39 | virtual Status ExportCostGraph(CostGraphDef* cost_graph) const { |
40 | return Status(error::INVALID_ARGUMENT, "No cost model to export." ); |
41 | } |
42 | /// Returns true iff the runner is running, i.e. if it is trying to populate |
43 | /// its queue. |
44 | virtual bool IsRunning() const = 0; |
45 | }; |
46 | |
47 | /// Coordinator class manages the termination of a collection of QueueRunners. |
48 | /// Without a coordinator, QueueRunners have to be joined in a specific order; |
49 | /// otherwise the QueueRunner::Join() could sometimes hang. The |
50 | /// Coordinator::RequestStop() plays the key role which notifies all running |
51 | /// threads under a coordinator to stop. This function could be called by any |
52 | /// thread or any client. |
53 | /// Usage, in the client: |
54 | /// Coordinator coord; |
55 | /// std::unique_ptr<QueueRunner> qr(&coord, ...); |
56 | /// qr.Start(session); |
57 | /// coord.RegisterRunner(std::move(qr)); |
58 | /// /// do some work |
59 | /// TF_CHECK_OK(coord.Join()); |
60 | /// In each thread of QueueRunner, the coordinator needs to be used as: |
61 | /// void Run() { |
62 | /// while (!coord->ShouldStop()) { |
63 | /// /// do some work |
64 | /// if (error) { |
65 | /// coord->RequestStop(); |
66 | /// coord->ReportStatus(error_status); |
67 | /// } |
68 | /// } |
69 | /// } |
70 | class Coordinator { |
71 | public: |
72 | Coordinator(); |
73 | |
74 | /// Constructor with a list of error codes which would not be taken as errors |
75 | /// in status reporting. |
76 | Coordinator(const std::vector<error::Code>& clean_stop_errors); |
77 | |
78 | /// In the destructor, RequestStop() and Join() would be called. |
79 | ~Coordinator(); |
80 | |
81 | /// Registers a runner, i.e. a unit of running threads which is usually a |
82 | /// QueueRunner. It takes the ownership of runner to avoid lifecycle-related |
83 | /// problems. Note, the coordinator would not start these threads; they are |
84 | /// supposed to be in running state when they are registered here. |
85 | Status RegisterRunner(std::unique_ptr<RunnerInterface> runner); |
86 | |
87 | /// Returns true iff all the registered runners have been stopped. |
88 | bool (); |
89 | |
90 | /// Requests all running threads to stop. |
91 | Status RequestStop(); |
92 | |
93 | /// Returns true if its RequestStop() has been called. |
94 | bool ShouldStop(); |
95 | |
96 | /// Joins all threads, returns OK or the first reported and unexpected status. |
97 | Status Join(); |
98 | |
99 | /// Reports status to the coordinator. This is usually called by threads. |
100 | void ReportStatus(const Status& status); |
101 | |
102 | /// Returns the latest status. |
103 | Status GetStatus(); |
104 | |
105 | /// Returns immediately if the coordinator is stopped or blocks until |
106 | /// RequestStop() is called. |
107 | void WaitForStop(); |
108 | |
109 | // Returns the cost graph from stored run metadata in registered runners. |
110 | Status ExportCostGraph(CostGraphDef* cost_graph) const; |
111 | |
112 | private: |
113 | std::unordered_set<int> clean_stop_errors_; |
114 | condition_variable wait_for_stop_; |
115 | |
116 | mutex mu_; |
117 | bool should_stop_ TF_GUARDED_BY(mu_); |
118 | |
119 | mutex status_lock_; |
120 | Status status_ TF_GUARDED_BY(status_lock_); |
121 | |
122 | mutable mutex runners_lock_; |
123 | std::vector<std::unique_ptr<RunnerInterface>> runners_ |
124 | TF_GUARDED_BY(runners_lock_); |
125 | |
126 | TF_DISALLOW_COPY_AND_ASSIGN(Coordinator); |
127 | }; |
128 | |
129 | } // namespace tensorflow |
130 | |
131 | #endif // TENSORFLOW_CC_TRAINING_COORDINATOR_H_ |
132 | |