1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
16
17#include <map>
18
19#include "tensorflow/core/common_runtime/function.h"
20#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
21#include "tensorflow/core/distributed_runtime/worker_session.h"
22#include "tensorflow/core/framework/function.h"
23#include "tensorflow/core/framework/graph_def_util.h"
24#include "tensorflow/core/framework/node_def.pb.h"
25#include "tensorflow/core/framework/node_def_builder.h"
26#include "tensorflow/core/framework/tensor.pb.h"
27#include "tensorflow/core/graph/node_builder.h"
28#include "tensorflow/core/lib/gtl/cleanup.h"
29#include "tensorflow/core/lib/random/random.h"
30#include "tensorflow/core/protobuf/named_tensor.pb.h"
31#include "tensorflow/core/protobuf/worker.pb.h"
32
33namespace tensorflow {
34
35/* static */
36Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
37 const OpDef& sig, AttrSlice attrs,
38 const FunctionLibraryRuntime::InstantiateOptions& options,
39 const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
40 std::vector<string>* send_keys, std::vector<string>* recv_keys) {
41 const string& target = options.target;
42 const string& func_name = sig.name();
43 const FunctionDef* func_def = flib_def.Find(sig.name());
44 if (func_def == nullptr) {
45 return errors::InvalidArgument("Function ", func_name,
46 " not found in flib_def.");
47 }
48
49 // Build a smaller flib_def containing only the functions used by the given
50 // function, plus that function itself.
51 FunctionLibraryDefinition pruned_flib_def =
52 flib_def.ReachableDefinitions(*func_def);
53 TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
54
55 Graph g(pruned_flib_def);
56
57 std::vector<Node*> input_nodes;
58 input_nodes.reserve(sig.input_arg_size());
59
60 // Construct recv nodes for each input argument.
61 int i = 0;
62 for (const auto& in : sig.input_arg()) {
63 // Resolve the input type.
64 bool is_type_list;
65 DataTypeVector dtypes;
66 TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes));
67 // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
68 if (is_type_list || dtypes.size() > 1) {
69 return errors::Unimplemented("Input arg: ", in.name(),
70 " has a list type or variadic number of "
71 "attrs. Currently unsupported.");
72 }
73
74 auto input_node_builder =
75 NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv")
76 .Attr("tensor_type", dtypes[0])
77 .Attr("tensor_name", in.name())
78 .Attr("send_device", target)
79 .Attr("recv_device", target)
80 .Attr("send_device_incarnation", 1)
81 .Attr("client_terminated", true)
82 .Device(target);
83
84 Node* input_node;
85 TF_RETURN_IF_ERROR(
86 NodeBuilder(input_node_builder).Finalize(&g, &input_node));
87 input_nodes.push_back(input_node);
88
89 // src_incarnation = 1 works because the transfer is across the same device.
90 // TODO(rohanj): Find the src_incarnation for the remote device and set it.
91 const string& key = Rendezvous::CreateKey(
92 target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0));
93 send_keys->push_back(key);
94 ++i;
95 }
96
97 NodeDef function_node_def;
98 function_node_def.set_name(func_name);
99 function_node_def.set_op(func_name);
100 i = 0;
101 function_node_def.set_device(target);
102 for (const auto& p : attrs) {
103 (*function_node_def.mutable_attr())[p.first] = p.second;
104 }
105 TF_ASSIGN_OR_RETURN(Node * function_node,
106 g.AddNode(std::move(function_node_def)));
107 for (size_t i = 0; i < input_nodes.size(); ++i) {
108 g.AddEdge(input_nodes[i], 0, function_node, i);
109 }
110
111 // Construct output nodes for each output.
112 i = 0;
113 for (const auto& out : sig.output_arg()) {
114 // Resolve the output type.
115 bool is_type_list;
116 DataTypeVector dtypes;
117 TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes));
118 // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
119 if (is_type_list || dtypes.size() > 1) {
120 return errors::Unimplemented("Output arg: ", out.name(),
121 " has a list type or variadic number of "
122 "attrs. Currently unsupported.");
123 }
124
125 auto output_node_builder =
126 NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
127 .Input(func_name, i, dtypes[0])
128 .Attr("tensor_name", out.name())
129 .Attr("send_device", target)
130 .Attr("recv_device", target)
131 .Attr("send_device_incarnation", 1)
132 .Attr("client_terminated", true)
133 .Device(target);
134
135 Node* output_node;
136 TF_RETURN_IF_ERROR(
137 NodeBuilder(output_node_builder).Finalize(&g, &output_node));
138
139 g.AddEdge(function_node, i, output_node, 0);
140
141 const string& key =
142 Rendezvous::CreateKey(target, 1 /* src_incarnation */, target,
143 out.name(), FrameAndIter(0, 0));
144 recv_keys->push_back(key);
145 ++i;
146 }
147
148 // Inline function node into the graph.
149 InlineFunctionBodyOptions inline_options;
150 inline_options.inlined_function_body_placer =
151 InlinedFunctionBodyPlacer::SingleDevice();
152 // When the remote call is a partition of a multi-device function, and the
153 // Send/Recv nodes depend on the frame names in the original graph, we must
154 // retain the original frame names. Since the graph contains a single function
155 // call, we do not need to add a unique prefix to frame names inside the
156 // inlined graph.
157 inline_options.uniquify_frame_names = false;
158 std::unique_ptr<FunctionBody> function_body;
159 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
160 &function_body));
161 TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
162 function_body.get(), inline_options));
163
164 g.ToGraphDef(gdef);
165
166 // Since we have inlined `function_node`, we can prune its function definition
167 // from the library.
168 *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto();
169
170 return OkStatus();
171}
172
173ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
174 for (auto& function_data : function_data_) {
175 worker_session_->worker_cache()->ReleaseWorker(function_data.target,
176 function_data.wi);
177 }
178}
179
180void ClusterFunctionLibraryRuntime::Instantiate(
181 const string& function_name, const FunctionLibraryDefinition& lib_def,
182 AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
183 FunctionLibraryRuntime::LocalHandle* handle,
184 FunctionLibraryRuntime::DoneCallback done) {
185 auto target = options.target;
186 VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
187 << " (this: " << this << ")";
188 std::shared_ptr<WorkerCacheInterface> worker_cache =
189 worker_session_->GetSharedWorkerCache();
190 WorkerInterface* wi = worker_cache->GetOrCreateWorker(target);
191
192 if (wi == nullptr) {
193 std::vector<string> workers;
194 worker_session_->worker_cache()->ListWorkers(&workers);
195 done(errors::InvalidArgument(
196 "Could not find worker with target: ", target,
197 " Available workers: ", absl::StrJoin(workers, ", ")));
198 return;
199 }
200
201 // Make RPC and obtain a graph handle.
202 GraphDef gdef;
203 auto* send_keys = new std::vector<string>;
204 auto* recv_keys = new std::vector<string>;
205 auto construct_graph_fn = [&](const FunctionLibraryDefinition* lib_def) {
206 const FunctionDef* fdef = lib_def->Find(function_name);
207 const OpDef& sig = fdef->signature();
208 TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def,
209 &gdef, send_keys, recv_keys));
210 return OkStatus();
211 };
212 Status s;
213 if (options.lib_def) {
214 s = construct_graph_fn(options.lib_def);
215 } else {
216 s = construct_graph_fn(&lib_def);
217 }
218 if (!s.ok()) {
219 done(s);
220 return;
221 }
222
223 auto* req = new RegisterGraphRequest;
224 req->set_session_handle(worker_session_->session_name());
225 req->set_create_worker_session_called(create_worker_session_called_);
226 *req->mutable_graph_def() = std::move(gdef);
227 StripDefaultAttributes(*OpRegistry::Global(),
228 req->mutable_graph_def()->mutable_node());
229 req->mutable_graph_options()
230 ->mutable_optimizer_options()
231 ->set_do_function_inlining(true);
232 auto* resp = new RegisterGraphResponse;
233
234 wi->RegisterGraphAsync(
235 req, resp,
236 [this, handle, req, resp, worker_cache, wi, function_name, target,
237 send_keys, recv_keys, done](const Status& status) {
238 if (status.ok()) {
239 mutex_lock l(mu_);
240 *handle = function_data_.size();
241 function_data_.push_back(FunctionData(resp->graph_handle(), target,
242 worker_cache, wi, *send_keys,
243 *recv_keys));
244 VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on "
245 << target << " (this: " << this << ")"
246 << " with handle: " << *handle;
247 }
248 done(status);
249 delete recv_keys;
250 delete send_keys;
251 delete req;
252 delete resp;
253 });
254}
255
256void ClusterFunctionLibraryRuntime::Run(
257 const FunctionLibraryRuntime::Options& opts,
258 FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
259 std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
260 FunctionData* function_data = nullptr;
261 {
262 mutex_lock l(mu_);
263 CHECK_LE(handle, function_data_.size());
264 function_data = &function_data_[handle];
265 }
266
267 WorkerInterface* wi = function_data->wi;
268
269 if (wi == nullptr) {
270 done(errors::Internal("Could not find worker"));
271 return;
272 }
273
274 RunGraphRequest* req = new RunGraphRequest;
275 req->set_session_handle(worker_session_->session_name());
276 req->set_create_worker_session_called(create_worker_session_called_);
277 req->set_graph_handle(function_data->graph_handle);
278 req->set_step_id(opts.step_id);
279 int i = 0;
280 for (const auto& send_key : function_data->send_keys) {
281 NamedTensorProto* send = req->add_send();
282 send->set_name(send_key);
283 args[i].AsProtoTensorContent(send->mutable_tensor());
284 i++;
285 }
286 const std::vector<string>& recv_keys = function_data->recv_keys;
287 for (const auto& recv_key : recv_keys) {
288 req->add_recv_key(recv_key);
289 }
290
291 RunGraphResponse* resp = new RunGraphResponse();
292 CallOptions* call_options = new CallOptions();
293 wi->RunGraphAsync(
294 call_options, req, resp,
295 [call_options, req, resp, rets, recv_keys, done](const Status& status) {
296 Status* local_status = new Status(status);
297 auto cleanup =
298 gtl::MakeCleanup([call_options, req, resp, local_status, done] {
299 done(*local_status);
300 delete call_options;
301 delete req;
302 delete resp;
303 delete local_status;
304 });
305 if (!local_status->ok()) {
306 return;
307 }
308 std::map<string, TensorProto*> mapped_recvs;
309 for (auto& recv : *resp->mutable_recv()) {
310 mapped_recvs[recv.name()] = recv.mutable_tensor();
311 }
312
313 for (const auto& recv_key : recv_keys) {
314 TensorProto* tp = mapped_recvs[recv_key];
315 if (tp == nullptr) {
316 local_status->Update(
317 errors::Internal("Could not find key: ", recv_key));
318 return;
319 }
320 Tensor t;
321 if (t.FromProto(*tp)) {
322 rets->push_back(t);
323 } else {
324 local_status->Update(errors::Internal(
325 "Could not convert tensor proto: ", tp->DebugString()));
326 return;
327 }
328 }
329 });
330}
331
332void ClusterFunctionLibraryRuntime::Run(
333 const FunctionLibraryRuntime::Options& opts,
334 FunctionLibraryRuntime::LocalHandle handle,
335 gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
336 FunctionLibraryRuntime::DoneCallback done) {
337 std::vector<Tensor> tensors;
338 for (const auto& arg : args) {
339 if (arg.index() == 0) {
340 tensors.push_back(absl::get<Tensor>(arg));
341 } else {
342 done(
343 errors::Internal("ClusterFunctionLibraryRuntime doesn't support "
344 "eager::RemoteTensorHandle."));
345 return;
346 }
347 }
348 std::vector<Tensor>* ret_tensors = new std::vector<Tensor>;
349 return Run(opts, handle, tensors, ret_tensors,
350 [rets, ret_tensors, done = std::move(done)](const Status& s) {
351 if (s.ok()) {
352 for (const auto& t : *ret_tensors) {
353 rets->push_back(t);
354 }
355 }
356 delete ret_tensors;
357 done(s);
358 });
359}
360
361void ClusterFunctionLibraryRuntime::CleanUp(
362 uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
363 FunctionLibraryRuntime::DoneCallback done) {
364 FunctionData* function_data = nullptr;
365 {
366 mutex_lock l(mu_);
367 DCHECK_LE(handle, function_data_.size());
368 function_data = &function_data_[handle];
369 }
370
371 WorkerInterface* wi = function_data->wi;
372
373 if (wi == nullptr) {
374 done(errors::Internal("Could not find worker"));
375 return;
376 }
377 CleanupGraphRequest* cleanup_req = new CleanupGraphRequest;
378 cleanup_req->set_step_id(step_id);
379 CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse;
380 wi->CleanupGraphAsync(
381 cleanup_req, cleanup_resp,
382 [cleanup_req, cleanup_resp, done](const Status& cleanup_status) {
383 done(cleanup_status);
384 delete cleanup_req;
385 delete cleanup_resp;
386 });
387}
388
389} // namespace tensorflow
390