1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16
17#include "absl/strings/escaping.h"
18#include "tensorflow/core/common_runtime/device.h"
19#include "tensorflow/core/common_runtime/device_mgr.h"
20#include "tensorflow/core/distributed_runtime/cancellable_call.h"
21#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
22#include "tensorflow/core/distributed_runtime/worker_cache.h"
23#include "tensorflow/core/framework/device_attributes.pb.h"
24#include "tensorflow/core/platform/errors.h"
25#include "tensorflow/core/platform/status.h"
26#include "tensorflow/core/protobuf/config.pb.h"
27#include "tensorflow/core/util/device_name_utils.h"
28
29namespace tensorflow {
30namespace {
31
32class CompleteGroupCall : public CancellableCall {
33 public:
34 CompleteGroupCall(const CollGroupParams& group,
35 const DeviceAttributes& device,
36 CancellationManager* cancel_mgr,
37 const string& remote_worker, WorkerCacheInterface* wc)
38 : CancellableCall(cancel_mgr, remote_worker, wc) {
39 req_.set_group_key(group.group_key);
40 req_.set_group_size(group.group_size);
41 req_.set_device_type(group.device_type.type_string());
42 *req_.mutable_device_attributes() = device;
43 }
44 ~CompleteGroupCall() override {}
45
46 void IssueCall(const StatusCallback& done) override {
47 wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
48 }
49
50 CompleteGroupRequest req_;
51 CompleteGroupResponse resp_;
52};
53
54class CompleteInstanceCall : public CancellableCall {
55 public:
56 CompleteInstanceCall(const CollGroupParams& group,
57 const CollInstanceParams& instance,
58 const string& node_name, const string& device_name,
59 bool is_source, CancellationManager* cancel_mgr,
60 const string& remote_worker, WorkerCacheInterface* wc)
61 : CancellableCall(cancel_mgr, remote_worker, wc) {
62 req_.set_name(node_name);
63 req_.set_type(instance.type);
64 req_.set_data_type(instance.data_type);
65 instance.shape.AsProto(req_.mutable_shape());
66 req_.set_group_key(group.group_key);
67 req_.set_group_size(group.group_size);
68 req_.set_instance_key(instance.instance_key);
69 req_.set_device_type(group.device_type.type_string());
70 for (int32_t offset : instance.impl_details.subdiv_offsets) {
71 req_.add_subdiv_offset(offset);
72 }
73 req_.set_device(device_name);
74 req_.set_is_source(is_source);
75 }
76
77 ~CompleteInstanceCall() override {}
78
79 void IssueCall(const StatusCallback& done) override {
80 wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
81 }
82
83 CompleteInstanceRequest req_;
84 CompleteInstanceResponse resp_;
85};
86
87} // namespace
88
89CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
90 const ConfigProto& config, const DeviceMgr* dev_mgr,
91 DeviceResolverDistributed* dev_resolver,
92 NcclCommunicatorInterface* nccl_communicator,
93 WorkerCacheInterface* worker_cache, const string& task_name)
94 : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver,
95 nccl_communicator, task_name),
96 worker_cache_(worker_cache),
97 group_leader_(task_name == config.experimental().collective_group_leader()
98 ? ""
99 : config.experimental().collective_group_leader()) {
100 VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
101 << "} config.collective_group_leader={"
102 << config.experimental().collective_group_leader() << "}"
103 << " config.collective_nccl={"
104 << config.experimental().collective_nccl() << "}";
105}
106
107void CollectiveParamResolverDistributed::CompleteParamsAsync(
108 const DeviceAttributes& device, CollectiveParams* cp,
109 CancellationManager* cancel_mgr, const StatusCallback& done) {
110 VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
111 << ": " << cp->ToString();
112 if (cp->run_group_initialization) {
113 CompleteGroupDistributed(
114 device, &cp->group, cancel_mgr,
115 [this, device, cp, cancel_mgr, done](Status s) {
116 if (s.ok()) {
117 std::vector<DeviceAttributes> devices;
118 devices.reserve(cp->group.group_size);
119 for (const CollGroupMember& m : cp->group.members) {
120 devices.push_back(m.device);
121 }
122 s = dev_resolver_->UpdateDeviceAttributes(devices);
123 }
124 if (s.ok()) {
125 CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done);
126 } else {
127 done(s);
128 }
129 });
130 } else {
131 // For Collective V3 ops, group is already initialized. Fetch attributes
132 // for the already initialized group to pass to Insitance initialization.
133 auto s = LookupGroup(cp->group.group_key, &cp->group);
134 if (s.ok()) {
135 CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done);
136 } else {
137 done(s);
138 }
139 }
140}
141
142void CollectiveParamResolverDistributed::CompleteGroupAsync(
143 const DeviceAttributes& device, CollGroupParams* group_params,
144 CancellationManager* cancel_mgr, const StatusCallback& done) {
145 CompleteGroupDistributed(device, group_params, cancel_mgr, done);
146}
147
148void CollectiveParamResolverDistributed::CompleteInstanceAsync(
149 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
150 CancellationManager* cancel_mgr, const StatusCallback& done) {
151 GroupRec* gr = GetCachedGroup(request->group_key());
152 if (gr == nullptr) {
153 done(errors::FailedPrecondition(
154 "group ", request->group_key(),
155 " not found. This normally means the server has restarted"));
156 return;
157 }
158 CollectiveParams* cp = new CollectiveParams;
159 {
160 mutex_lock l(gr->mu);
161 if (!gr->status.ok()) {
162 done(gr->status);
163 return;
164 } else if (gr->group.members.size() != gr->group.group_size) {
165 done(errors::FailedPrecondition(
166 "group ", request->group_key(),
167 " failed to resolve. This normally means the server has restarted"));
168 return;
169 }
170 cp->group = gr->group;
171 }
172 cp->name = request->name();
173 cp->instance.type = CollectiveType(request->type());
174 cp->instance.instance_key = request->instance_key();
175 cp->instance.data_type = request->data_type();
176 cp->instance.shape = TensorShape(request->shape());
177 cp->is_source = request->is_source();
178 for (int32_t offset : request->subdiv_offset()) {
179 cp->instance.impl_details.subdiv_offsets.push_back(offset);
180 }
181 StatusCallback done_and_cleanup = [cp, done](const Status& s) {
182 done(s);
183 cp->Unref();
184 };
185 CompleteInstanceDistributed(
186 request->device(), cp, cancel_mgr,
187 [this, cp, response, done_and_cleanup](Status status) {
188 if (status.ok()) {
189 // Now source_rank should be known, so retrieve it.
190 bool created_irec;
191 InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
192 {
193 mutex_lock l(ir->mu);
194 status = ir->status;
195 if (ir->status.ok()) {
196 response->set_instance_key(cp->instance.instance_key);
197 response->set_source_rank(ir->source_rank);
198 }
199 }
200 }
201 done_and_cleanup(status);
202 });
203}
204
205CollectiveParamResolverDistributed::GroupRec*
206CollectiveParamResolverDistributed::GetCachedGroup(int32_t group_key) {
207 mutex_lock l(group_mu_);
208 auto it = group_table_.find(group_key);
209 if (it == group_table_.end()) {
210 return nullptr;
211 }
212 return it->second.get();
213}
214
215Status CollectiveParamResolverDistributed::UpdateGroupCache(
216 const CompleteGroupResponse& resp) {
217 // Build a new record from resp.
218 std::unique_ptr<GroupRec> gr(new GroupRec);
219 {
220 mutex_lock grl(gr->mu);
221 gr->group.device_type = DeviceType(resp.device_type());
222 gr->group.group_key = resp.group_key();
223 gr->group.group_size = resp.group_size();
224 gr->group.num_tasks = resp.num_tasks();
225 if (resp.device_attributes().empty()) {
226 return errors::Internal(
227 "CompleteGroupResponse device_attributes is empty. Make sure you're "
228 "running the same version of Tensorflow on all workers.");
229 }
230 if (resp.device_attributes_size() != gr->group.group_size) {
231 return errors::Internal(
232 "CompleteGroupResponse group_size doesn't match device_name list");
233 }
234 gr->group.members.reserve(resp.device_attributes().size());
235 for (const DeviceAttributes& device : resp.device_attributes()) {
236 CollGroupMember member;
237 member.device = device;
238 gr->group.members.push_back(std::move(member));
239 gr->incarnations_by_device_name[device.name()] = device.incarnation();
240 }
241 gr->group.runtime_details.communicator_key = resp.communicator_key();
242 FinishGroup(gr.get());
243 }
244 GroupRec* previous_gr = nullptr;
245 {
246 // Group membership should never change. Once a record is in group_table_
247 // it never gets removed.
248 mutex_lock l(group_mu_);
249 auto it = group_table_.find(resp.group_key());
250 if (it == group_table_.end()) {
251 VLOG(2) << "UpdateGroupCache: communicator_key="
252 << absl::CEscape(resp.communicator_key());
253 group_table_[gr->group.group_key] = std::move(gr);
254 } else {
255 previous_gr = it->second.get();
256 }
257 }
258 if (previous_gr != nullptr) {
259 mutex_lock grl(previous_gr->mu);
260 if (previous_gr->group.runtime_details.communicator_key !=
261 resp.communicator_key()) {
262 return errors::Internal(
263 "UpdateGroupCache: CompleteGroupResponse for group ",
264 resp.group_key(),
265 " gives communicator_key=", absl::CEscape(resp.communicator_key()),
266 " but cache already holds communicator_key=",
267 absl::CEscape(previous_gr->group.runtime_details.communicator_key));
268 }
269 }
270 return OkStatus();
271}
272
273void CollectiveParamResolverDistributed::CompleteGroupDistributed(
274 const DeviceAttributes& device, CollGroupParams* group_params,
275 CancellationManager* cancel_mgr, const StatusCallback& done) {
276 VLOG(1) << "CompleteGroupDistributed group_key=" << group_params->group_key
277 << " dev: " << device.name()
278 << " is_leader=" << (group_leader_.empty());
279 if (group_leader_.empty()) {
280 // This is the group leader, so resolution is local.
281 return CompleteGroupLocal(device, group_params, cancel_mgr, done);
282 } else if (GetCachedGroup(group_params->group_key) == nullptr) {
283 // Need to update Group cache from the leader.
284 CompleteGroupCall* call = new CompleteGroupCall(
285 *group_params, device, cancel_mgr, group_leader_, worker_cache_);
286 CancellationToken abortion_token =
287 abortion_cancel_mgr_.get_cancellation_token();
288 bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
289 abortion_token, [call] { call->Cancel(); });
290 if (already_aborted) {
291 done(errors::Cancelled("collective ops already aborted"));
292 delete call;
293 return;
294 }
295 call->Start([this, device, group_params, call, cancel_mgr, abortion_token,
296 done](const Status& s) {
297 abortion_cancel_mgr_.DeregisterCallback(abortion_token);
298 if (s.ok()) {
299 Status status = UpdateGroupCache(call->resp_);
300 if (status.ok()) {
301 CompleteGroupLocal(device, group_params, cancel_mgr, done);
302 } else {
303 done(status);
304 }
305 } else {
306 done(s);
307 }
308 delete call;
309 });
310 return;
311 } else {
312 return CompleteGroupLocal(device, group_params, cancel_mgr, done);
313 }
314}
315
316bool CollectiveParamResolverDistributed::InstanceIsCached(
317 int32_t group_key, int32_t instance_key) {
318 mutex_lock l(instance_mu_);
319 auto group_it = instance_table_.find(group_key);
320 if (group_it == instance_table_.end()) {
321 return false;
322 }
323 auto instance_it = group_it->second.find(instance_key);
324 return instance_it != group_it->second.end();
325}
326
327Status CollectiveParamResolverDistributed::UpdateInstanceCache(
328 CollectiveParams* cp, const CompleteInstanceResponse& resp) {
329 int32_t source_rank = resp.source_rank();
330 bool created_irec;
331 InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
332 mutex_lock l(ir->mu);
333 if (!ir->status.ok()) {
334 return ir->status;
335 }
336 if (ir->source_rank != source_rank) {
337 if (ir->source_rank >= 0) {
338 ir->status = errors::Internal(
339 "UpdateInstanceCache: CompleteInstanceResponse for instance ",
340 cp->instance.instance_key, " gives source_rank=", source_rank,
341 " but cache already holds value=", ir->source_rank);
342 return ir->status;
343 }
344 ir->source_rank = source_rank;
345 }
346 if (ir->known_count < cp->group.group_size) {
347 ir->known_count = cp->group.group_size;
348 const int ir_known_size = ir->known.size();
349 if (ir_known_size != cp->group.group_size) {
350 ir->status = errors::Internal(
351 "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
352 cp->instance.instance_key, " has known.size()=", ir->known.size(),
353 " < group_size=", cp->group.group_size);
354 return ir->status;
355 }
356 for (int i = 0; i < ir_known_size; ++i) {
357 ir->known[i] = true;
358 }
359 }
360 return ir->status;
361}
362
363void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
364 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
365 const StatusCallback& done) {
366 if (group_leader_.empty()) {
367 // This is the group leader so resolution is local.
368 return CompleteInstanceLocal(device, cp, done);
369 } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) {
370 return CompleteInstanceLocal(device, cp, done);
371 } else {
372 CompleteInstanceCall* call = new CompleteInstanceCall(
373 cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
374 group_leader_, worker_cache_);
375 CancellationToken abortion_token =
376 abortion_cancel_mgr_.get_cancellation_token();
377 bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
378 abortion_token, [call] { call->Cancel(); });
379 if (already_aborted) {
380 done(errors::Cancelled("collective ops already aborted"));
381 delete call;
382 return;
383 }
384 call->Start([this, device, cp, call, abortion_token, done](Status s) {
385 abortion_cancel_mgr_.DeregisterCallback(abortion_token);
386 if (s.ok()) {
387 s = UpdateInstanceCache(cp, call->resp_);
388 }
389 if (s.ok()) {
390 CompleteInstanceLocal(device, cp, done);
391 } else {
392 done(s);
393 }
394 delete call;
395 });
396 return;
397 }
398}
399
400void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
401 {
402 mutex_lock l(status_mu_);
403 if (!status_.ok()) {
404 VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
405 "subsequent abortion with status: "
406 << s;
407 return;
408 }
409 status_ = s;
410 }
411 StartAbortLocal(s);
412 abortion_cancel_mgr_.StartCancel();
413}
414
415} // namespace tensorflow
416