1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h" |
16 | |
17 | #include <map> |
18 | |
19 | #include "tensorflow/core/common_runtime/function.h" |
20 | #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
21 | #include "tensorflow/core/distributed_runtime/worker_session.h" |
22 | #include "tensorflow/core/framework/function.h" |
23 | #include "tensorflow/core/framework/graph_def_util.h" |
24 | #include "tensorflow/core/framework/node_def.pb.h" |
25 | #include "tensorflow/core/framework/node_def_builder.h" |
26 | #include "tensorflow/core/framework/tensor.pb.h" |
27 | #include "tensorflow/core/graph/node_builder.h" |
28 | #include "tensorflow/core/lib/gtl/cleanup.h" |
29 | #include "tensorflow/core/lib/random/random.h" |
30 | #include "tensorflow/core/protobuf/named_tensor.pb.h" |
31 | #include "tensorflow/core/protobuf/worker.pb.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | /* static */ |
36 | Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( |
37 | const OpDef& sig, AttrSlice attrs, |
38 | const FunctionLibraryRuntime::InstantiateOptions& options, |
39 | const FunctionLibraryDefinition& flib_def, GraphDef* gdef, |
40 | std::vector<string>* send_keys, std::vector<string>* recv_keys) { |
41 | const string& target = options.target; |
42 | const string& func_name = sig.name(); |
43 | const FunctionDef* func_def = flib_def.Find(sig.name()); |
44 | if (func_def == nullptr) { |
45 | return errors::InvalidArgument("Function " , func_name, |
46 | " not found in flib_def." ); |
47 | } |
48 | |
49 | // Build a smaller flib_def containing only the functions used by the given |
50 | // function, plus that function itself. |
51 | FunctionLibraryDefinition pruned_flib_def = |
52 | flib_def.ReachableDefinitions(*func_def); |
53 | TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def)); |
54 | |
55 | Graph g(pruned_flib_def); |
56 | |
57 | std::vector<Node*> input_nodes; |
58 | input_nodes.reserve(sig.input_arg_size()); |
59 | |
60 | // Construct recv nodes for each input argument. |
61 | int i = 0; |
62 | for (const auto& in : sig.input_arg()) { |
63 | // Resolve the input type. |
64 | bool is_type_list; |
65 | DataTypeVector dtypes; |
66 | TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes)); |
67 | // TODO(rohanj): Handle list and variadic number of attrs. Here and below. |
68 | if (is_type_list || dtypes.size() > 1) { |
69 | return errors::Unimplemented("Input arg: " , in.name(), |
70 | " has a list type or variadic number of " |
71 | "attrs. Currently unsupported." ); |
72 | } |
73 | |
74 | auto input_node_builder = |
75 | NodeDefBuilder(strings::StrCat("_recv_" , in.name(), "_" , i), "_Recv" ) |
76 | .Attr("tensor_type" , dtypes[0]) |
77 | .Attr("tensor_name" , in.name()) |
78 | .Attr("send_device" , target) |
79 | .Attr("recv_device" , target) |
80 | .Attr("send_device_incarnation" , 1) |
81 | .Attr("client_terminated" , true) |
82 | .Device(target); |
83 | |
84 | Node* input_node; |
85 | TF_RETURN_IF_ERROR( |
86 | NodeBuilder(input_node_builder).Finalize(&g, &input_node)); |
87 | input_nodes.push_back(input_node); |
88 | |
89 | // src_incarnation = 1 works because the transfer is across the same device. |
90 | // TODO(rohanj): Find the src_incarnation for the remote device and set it. |
91 | const string& key = Rendezvous::CreateKey( |
92 | target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0)); |
93 | send_keys->push_back(key); |
94 | ++i; |
95 | } |
96 | |
97 | NodeDef function_node_def; |
98 | function_node_def.set_name(func_name); |
99 | function_node_def.set_op(func_name); |
100 | i = 0; |
101 | function_node_def.set_device(target); |
102 | for (const auto& p : attrs) { |
103 | (*function_node_def.mutable_attr())[p.first] = p.second; |
104 | } |
105 | TF_ASSIGN_OR_RETURN(Node * function_node, |
106 | g.AddNode(std::move(function_node_def))); |
107 | for (size_t i = 0; i < input_nodes.size(); ++i) { |
108 | g.AddEdge(input_nodes[i], 0, function_node, i); |
109 | } |
110 | |
111 | // Construct output nodes for each output. |
112 | i = 0; |
113 | for (const auto& out : sig.output_arg()) { |
114 | // Resolve the output type. |
115 | bool is_type_list; |
116 | DataTypeVector dtypes; |
117 | TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes)); |
118 | // TODO(rohanj): Handle list and variadic number of attrs. Here and below. |
119 | if (is_type_list || dtypes.size() > 1) { |
120 | return errors::Unimplemented("Output arg: " , out.name(), |
121 | " has a list type or variadic number of " |
122 | "attrs. Currently unsupported." ); |
123 | } |
124 | |
125 | auto output_node_builder = |
126 | NodeDefBuilder(strings::StrCat("_send_" , out.name(), "_" , i), "_Send" ) |
127 | .Input(func_name, i, dtypes[0]) |
128 | .Attr("tensor_name" , out.name()) |
129 | .Attr("send_device" , target) |
130 | .Attr("recv_device" , target) |
131 | .Attr("send_device_incarnation" , 1) |
132 | .Attr("client_terminated" , true) |
133 | .Device(target); |
134 | |
135 | Node* output_node; |
136 | TF_RETURN_IF_ERROR( |
137 | NodeBuilder(output_node_builder).Finalize(&g, &output_node)); |
138 | |
139 | g.AddEdge(function_node, i, output_node, 0); |
140 | |
141 | const string& key = |
142 | Rendezvous::CreateKey(target, 1 /* src_incarnation */, target, |
143 | out.name(), FrameAndIter(0, 0)); |
144 | recv_keys->push_back(key); |
145 | ++i; |
146 | } |
147 | |
148 | // Inline function node into the graph. |
149 | InlineFunctionBodyOptions inline_options; |
150 | inline_options.inlined_function_body_placer = |
151 | InlinedFunctionBodyPlacer::SingleDevice(); |
152 | // When the remote call is a partition of a multi-device function, and the |
153 | // Send/Recv nodes depend on the frame names in the original graph, we must |
154 | // retain the original frame names. Since the graph contains a single function |
155 | // call, we do not need to add a unique prefix to frame names inside the |
156 | // inlined graph. |
157 | inline_options.uniquify_frame_names = false; |
158 | std::unique_ptr<FunctionBody> function_body; |
159 | TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def, |
160 | &function_body)); |
161 | TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node, |
162 | function_body.get(), inline_options)); |
163 | |
164 | g.ToGraphDef(gdef); |
165 | |
166 | // Since we have inlined `function_node`, we can prune its function definition |
167 | // from the library. |
168 | *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto(); |
169 | |
170 | return OkStatus(); |
171 | } |
172 | |
173 | ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() { |
174 | for (auto& function_data : function_data_) { |
175 | worker_session_->worker_cache()->ReleaseWorker(function_data.target, |
176 | function_data.wi); |
177 | } |
178 | } |
179 | |
180 | void ClusterFunctionLibraryRuntime::Instantiate( |
181 | const string& function_name, const FunctionLibraryDefinition& lib_def, |
182 | AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, |
183 | FunctionLibraryRuntime::LocalHandle* handle, |
184 | FunctionLibraryRuntime::DoneCallback done) { |
185 | auto target = options.target; |
186 | VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target |
187 | << " (this: " << this << ")" ; |
188 | std::shared_ptr<WorkerCacheInterface> worker_cache = |
189 | worker_session_->GetSharedWorkerCache(); |
190 | WorkerInterface* wi = worker_cache->GetOrCreateWorker(target); |
191 | |
192 | if (wi == nullptr) { |
193 | std::vector<string> workers; |
194 | worker_session_->worker_cache()->ListWorkers(&workers); |
195 | done(errors::InvalidArgument( |
196 | "Could not find worker with target: " , target, |
197 | " Available workers: " , absl::StrJoin(workers, ", " ))); |
198 | return; |
199 | } |
200 | |
201 | // Make RPC and obtain a graph handle. |
202 | GraphDef gdef; |
203 | auto* send_keys = new std::vector<string>; |
204 | auto* recv_keys = new std::vector<string>; |
205 | auto construct_graph_fn = [&](const FunctionLibraryDefinition* lib_def) { |
206 | const FunctionDef* fdef = lib_def->Find(function_name); |
207 | const OpDef& sig = fdef->signature(); |
208 | TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def, |
209 | &gdef, send_keys, recv_keys)); |
210 | return OkStatus(); |
211 | }; |
212 | Status s; |
213 | if (options.lib_def) { |
214 | s = construct_graph_fn(options.lib_def); |
215 | } else { |
216 | s = construct_graph_fn(&lib_def); |
217 | } |
218 | if (!s.ok()) { |
219 | done(s); |
220 | return; |
221 | } |
222 | |
223 | auto* req = new RegisterGraphRequest; |
224 | req->set_session_handle(worker_session_->session_name()); |
225 | req->set_create_worker_session_called(create_worker_session_called_); |
226 | *req->mutable_graph_def() = std::move(gdef); |
227 | StripDefaultAttributes(*OpRegistry::Global(), |
228 | req->mutable_graph_def()->mutable_node()); |
229 | req->mutable_graph_options() |
230 | ->mutable_optimizer_options() |
231 | ->set_do_function_inlining(true); |
232 | auto* resp = new RegisterGraphResponse; |
233 | |
234 | wi->RegisterGraphAsync( |
235 | req, resp, |
236 | [this, handle, req, resp, worker_cache, wi, function_name, target, |
237 | send_keys, recv_keys, done](const Status& status) { |
238 | if (status.ok()) { |
239 | mutex_lock l(mu_); |
240 | *handle = function_data_.size(); |
241 | function_data_.push_back(FunctionData(resp->graph_handle(), target, |
242 | worker_cache, wi, *send_keys, |
243 | *recv_keys)); |
244 | VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on " |
245 | << target << " (this: " << this << ")" |
246 | << " with handle: " << *handle; |
247 | } |
248 | done(status); |
249 | delete recv_keys; |
250 | delete send_keys; |
251 | delete req; |
252 | delete resp; |
253 | }); |
254 | } |
255 | |
256 | void ClusterFunctionLibraryRuntime::Run( |
257 | const FunctionLibraryRuntime::Options& opts, |
258 | FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args, |
259 | std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) { |
260 | FunctionData* function_data = nullptr; |
261 | { |
262 | mutex_lock l(mu_); |
263 | CHECK_LE(handle, function_data_.size()); |
264 | function_data = &function_data_[handle]; |
265 | } |
266 | |
267 | WorkerInterface* wi = function_data->wi; |
268 | |
269 | if (wi == nullptr) { |
270 | done(errors::Internal("Could not find worker" )); |
271 | return; |
272 | } |
273 | |
274 | RunGraphRequest* req = new RunGraphRequest; |
275 | req->set_session_handle(worker_session_->session_name()); |
276 | req->set_create_worker_session_called(create_worker_session_called_); |
277 | req->set_graph_handle(function_data->graph_handle); |
278 | req->set_step_id(opts.step_id); |
279 | int i = 0; |
280 | for (const auto& send_key : function_data->send_keys) { |
281 | NamedTensorProto* send = req->add_send(); |
282 | send->set_name(send_key); |
283 | args[i].AsProtoTensorContent(send->mutable_tensor()); |
284 | i++; |
285 | } |
286 | const std::vector<string>& recv_keys = function_data->recv_keys; |
287 | for (const auto& recv_key : recv_keys) { |
288 | req->add_recv_key(recv_key); |
289 | } |
290 | |
291 | RunGraphResponse* resp = new RunGraphResponse(); |
292 | CallOptions* call_options = new CallOptions(); |
293 | wi->RunGraphAsync( |
294 | call_options, req, resp, |
295 | [call_options, req, resp, rets, recv_keys, done](const Status& status) { |
296 | Status* local_status = new Status(status); |
297 | auto cleanup = |
298 | gtl::MakeCleanup([call_options, req, resp, local_status, done] { |
299 | done(*local_status); |
300 | delete call_options; |
301 | delete req; |
302 | delete resp; |
303 | delete local_status; |
304 | }); |
305 | if (!local_status->ok()) { |
306 | return; |
307 | } |
308 | std::map<string, TensorProto*> mapped_recvs; |
309 | for (auto& recv : *resp->mutable_recv()) { |
310 | mapped_recvs[recv.name()] = recv.mutable_tensor(); |
311 | } |
312 | |
313 | for (const auto& recv_key : recv_keys) { |
314 | TensorProto* tp = mapped_recvs[recv_key]; |
315 | if (tp == nullptr) { |
316 | local_status->Update( |
317 | errors::Internal("Could not find key: " , recv_key)); |
318 | return; |
319 | } |
320 | Tensor t; |
321 | if (t.FromProto(*tp)) { |
322 | rets->push_back(t); |
323 | } else { |
324 | local_status->Update(errors::Internal( |
325 | "Could not convert tensor proto: " , tp->DebugString())); |
326 | return; |
327 | } |
328 | } |
329 | }); |
330 | } |
331 | |
332 | void ClusterFunctionLibraryRuntime::Run( |
333 | const FunctionLibraryRuntime::Options& opts, |
334 | FunctionLibraryRuntime::LocalHandle handle, |
335 | gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets, |
336 | FunctionLibraryRuntime::DoneCallback done) { |
337 | std::vector<Tensor> tensors; |
338 | for (const auto& arg : args) { |
339 | if (arg.index() == 0) { |
340 | tensors.push_back(absl::get<Tensor>(arg)); |
341 | } else { |
342 | done( |
343 | errors::Internal("ClusterFunctionLibraryRuntime doesn't support " |
344 | "eager::RemoteTensorHandle." )); |
345 | return; |
346 | } |
347 | } |
348 | std::vector<Tensor>* ret_tensors = new std::vector<Tensor>; |
349 | return Run(opts, handle, tensors, ret_tensors, |
350 | [rets, ret_tensors, done = std::move(done)](const Status& s) { |
351 | if (s.ok()) { |
352 | for (const auto& t : *ret_tensors) { |
353 | rets->push_back(t); |
354 | } |
355 | } |
356 | delete ret_tensors; |
357 | done(s); |
358 | }); |
359 | } |
360 | |
361 | void ClusterFunctionLibraryRuntime::CleanUp( |
362 | uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, |
363 | FunctionLibraryRuntime::DoneCallback done) { |
364 | FunctionData* function_data = nullptr; |
365 | { |
366 | mutex_lock l(mu_); |
367 | DCHECK_LE(handle, function_data_.size()); |
368 | function_data = &function_data_[handle]; |
369 | } |
370 | |
371 | WorkerInterface* wi = function_data->wi; |
372 | |
373 | if (wi == nullptr) { |
374 | done(errors::Internal("Could not find worker" )); |
375 | return; |
376 | } |
377 | CleanupGraphRequest* cleanup_req = new CleanupGraphRequest; |
378 | cleanup_req->set_step_id(step_id); |
379 | CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse; |
380 | wi->CleanupGraphAsync( |
381 | cleanup_req, cleanup_resp, |
382 | [cleanup_req, cleanup_resp, done](const Status& cleanup_status) { |
383 | done(cleanup_status); |
384 | delete cleanup_req; |
385 | delete cleanup_resp; |
386 | }); |
387 | } |
388 | |
389 | } // namespace tensorflow |
390 | |