1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ |
18 | |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/common_runtime/costmodel_manager.h" |
23 | #include "tensorflow/core/common_runtime/executor.h" |
24 | #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
25 | #include "tensorflow/core/distributed_runtime/message_wrappers.h" |
26 | #include "tensorflow/core/distributed_runtime/worker_env.h" |
27 | #include "tensorflow/core/framework/cancellation.h" |
28 | #include "tensorflow/core/framework/collective.h" |
29 | #include "tensorflow/core/framework/cost_graph.pb.h" |
30 | #include "tensorflow/core/framework/function.h" |
31 | #include "tensorflow/core/lib/core/refcount.h" |
32 | #include "tensorflow/core/platform/logging.h" |
33 | #include "tensorflow/core/platform/macros.h" |
34 | #include "tensorflow/core/platform/mutex.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | #include "tensorflow/core/protobuf/config.pb.h" |
37 | #include "tensorflow/core/protobuf/debug.pb.h" |
38 | #include "tensorflow/core/protobuf/worker.pb.h" |
39 | |
40 | namespace tensorflow { |
41 | |
42 | class ExecutorOpts; |
43 | class StepStatsCollector; |
44 | class RendezvousMgrInterface; |
45 | class DeviceMgr; |
46 | class WorkerSession; |
47 | class CoordinationServiceAgent; |
48 | |
49 | // GraphMgr keeps track of a set of graphs that are registered with a |
50 | // TensorFlow worker. Each registered graph is identified by a handle |
51 | // that is generated by GraphMgr and returned to the caller. |
52 | // |
53 | // After a successful registration, the caller executes a graph using |
54 | // the graph handle. Each execution is distinguished from others by a |
55 | // caller generated global unique id "step_id". Multiple executions |
56 | // can use the same graph concurrently and independently as long as |
57 | // "step_id" used are different. |
58 | // |
59 | // Multiple threads can call GraphMgr methods concurrently. |
60 | // |
61 | // E.g., |
62 | // GraphMgr gmgr(worker_env); |
63 | // string handle; |
64 | // TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b }, |
65 | // &handle)); |
66 | // GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) }, |
67 | // { "b", Tensor({3, 4}) } }; |
68 | // GraphMgr::NamedTensors out = { { "c", Tensor() } }; |
69 | // TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out)); |
70 | // EXPECT_EQ(out["c"], Tensor({4, 6})); |
71 | class GraphMgr { |
72 | public: |
73 | explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr); |
74 | ~GraphMgr(); |
75 | |
76 | // Registers a graph. Fills in "handle". The registered graph retains a |
77 | // reference to cluster_flr to do cross process function calls. |
78 | Status Register(const string& handle, const GraphDef& gdef, |
79 | const GraphOptions& graph_options, |
80 | const DebugOptions& debug_options, |
81 | const ConfigProto& config_proto, int64_t collective_graph_key, |
82 | WorkerSession* session, |
83 | DistributedFunctionLibraryRuntime* cluster_flr, |
84 | string* graph_handle); |
85 | |
86 | // Executes one step of a registered graph "handle". |
87 | // |
88 | // If "out" is not nullptr, "out" specifies all keys the execution |
89 | // should receive upon finish. |
90 | typedef std::map<string, Tensor> NamedTensors; |
91 | typedef std::function<void(const Status&)> StatusCallback; |
92 | void ExecuteAsync(const string& handle, const int64_t step_id, |
93 | const ExecutorOpts& opts, const NamedTensors& in, |
94 | WorkerSession* session, StepStatsCollector* collector, |
95 | MutableRunGraphResponseWrapper* response, |
96 | CancellationManager* cancellation_manager, |
97 | CoordinationServiceAgent* coordination_service_agent, |
98 | StatusCallback done); |
99 | |
100 | Status SendInputs(const int64_t step_id, const NamedTensors& in); |
101 | Status RecvOutputs(const int64_t step_id, NamedTensors* out); |
102 | void RecvOutputsAsync(const int64_t step_id, NamedTensors* out, |
103 | StatusCallback done); |
104 | |
105 | // Deregisters a graph. |
106 | Status Deregister(const string& handle); |
107 | |
108 | // Deregister all graphs. |
109 | Status DeregisterAll(); |
110 | |
111 | private: |
112 | typedef GraphMgr ME; |
113 | |
114 | struct ExecutionUnit { |
115 | std::unique_ptr<Graph> graph = nullptr; |
116 | Device* device = nullptr; // not owned. |
117 | Executor* root = nullptr; // not owned. |
118 | FunctionLibraryRuntime* lib = nullptr; // not owned. |
119 | // Build the cost model if this value is strictly positive. |
120 | int64_t build_cost_model = 0; |
121 | }; |
122 | |
123 | struct Item : public core::RefCounted { |
124 | // TODO(zhifengc): Keeps a copy of the original graph if the need arises. |
125 | // TODO(zhifengc): Stats, updated by multiple runs potentially. |
126 | // TODO(zhifengc): Dup-detection. Ensure step_id only run once. |
127 | ~Item() override; |
128 | |
129 | // Session handle. |
130 | string session; |
131 | |
132 | // Graph handle. |
133 | string handle; |
134 | |
135 | std::unique_ptr<FunctionLibraryDefinition> lib_def; |
136 | // Owns the FunctionLibraryRuntime objects needed to execute functions, one |
137 | // per device. |
138 | std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; |
139 | // A graph is partitioned over multiple devices. Each partition |
140 | // has a root executor which may call into the runtime library. |
141 | std::vector<ExecutionUnit> units; |
142 | |
143 | // Used to deregister a cost model when cost model is required in graph |
144 | // manager. |
145 | GraphMgr* graph_mgr; |
146 | |
147 | int64_t collective_graph_key; |
148 | }; |
149 | |
150 | const WorkerEnv* worker_env_; // Not owned. |
151 | const DeviceMgr* device_mgr_; |
152 | |
153 | CostModelManager cost_model_manager_; |
154 | |
155 | // Owned. |
156 | mutex mu_; |
157 | int64_t next_id_ TF_GUARDED_BY(mu_) = 0; |
158 | |
159 | // If true, blocks until device has finished all queued operations in a step. |
160 | bool sync_on_finish_ = true; |
161 | |
162 | // Table mapping graph handles to registered graphs. |
163 | // |
164 | // TODO(zhifengc): If the client does not call Deregister, we'll |
165 | // lose memory over time. We should implement a timeout-based |
166 | // mechanism to gc these graphs. |
167 | std::unordered_map<string, Item*> table_; |
168 | |
169 | void StartParallelExecutors( |
170 | const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous, |
171 | CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, |
172 | CostGraphDef* cost_graph, CancellationManager* cancellation_manager, |
173 | WorkerSession* session, int64_t start_time_usecs, |
174 | CoordinationServiceAgent* coordination_service_agent, |
175 | StatusCallback done); |
176 | |
177 | // Don't attempt to process cost models unless explicitly requested for at |
178 | // least one of the items. |
179 | bool skip_cost_models_ = true; |
180 | |
181 | void BuildCostModel(Item* item, StepStatsCollector* collector, |
182 | CostGraphDef* cost_graph); |
183 | |
184 | Status InitItem(const string& handle, const GraphDef& gdef, |
185 | const GraphOptions& graph_options, |
186 | const DebugOptions& debug_options, |
187 | const ConfigProto& config_proto, int64_t collective_graph_key, |
188 | WorkerSession* session, |
189 | DistributedFunctionLibraryRuntime* cluster_flr, Item* item); |
190 | |
191 | Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options, |
192 | Graph* graph, Device* device); |
193 | |
194 | TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr); |
195 | }; |
196 | |
197 | } // end namespace tensorflow |
198 | |
199 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ |
200 | |