1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
16
17#include <stddef.h>
18
19#include <algorithm>
20#include <unordered_set>
21#include <utility>
22#include <vector>
23
24#include "absl/container/flat_hash_set.h"
25#include "absl/strings/str_join.h"
26#include "tensorflow/core/common_runtime/device_mgr.h"
27#include "tensorflow/core/framework/cancellation.h"
28#include "tensorflow/core/framework/collective.h"
29#include "tensorflow/core/framework/device_attributes.pb.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/core/status.h"
33#include "tensorflow/core/lib/gtl/flatmap.h"
34#include "tensorflow/core/lib/strings/numbers.h"
35#include "tensorflow/core/lib/strings/str_util.h"
36#include "tensorflow/core/lib/strings/strcat.h"
37#include "tensorflow/core/platform/errors.h"
38#include "tensorflow/core/platform/status.h"
39#include "tensorflow/core/platform/types.h"
40#include "tensorflow/core/protobuf/config.pb.h"
41#include "tensorflow/core/util/device_name_utils.h"
42
43namespace tensorflow {
44
45CollectiveParamResolverLocal::CollectiveParamResolverLocal(
46 const ConfigProto& config, const DeviceMgr* dev_mgr,
47 DeviceResolverInterface* dev_resolver,
48 NcclCommunicatorInterface* nccl_communicator, const string& task_name)
49 : nccl_(config.experimental().collective_nccl()),
50 dev_mgr_(dev_mgr),
51 dev_resolver_(dev_resolver),
52 nccl_communicator_(nccl_communicator),
53 task_name_(task_name),
54 gpu_ring_order_(
55 config.gpu_options().experimental().collective_ring_order()) {}
56
57void CollectiveParamResolverLocal::CompleteGroupAsync(
58 const DeviceAttributes& device, CollGroupParams* group_params,
59 CancellationManager* cancel_mgr, const StatusCallback& done) {
60 CompleteGroupLocal(device, group_params, cancel_mgr, done);
61}
62
63namespace {
64const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
65 switch (cp->instance.type) {
66 case BROADCAST_COLLECTIVE:
67 return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
68
69 case REDUCTION_COLLECTIVE:
70 return nccl ? "NcclReduce" : "RingReduce";
71
72 case GATHER_COLLECTIVE:
73 return nccl ? "NcclGather" : "RingGather";
74
75 case PERMUTE_COLLECTIVE:
76 return "Permute";
77
78 case ALL_TO_ALL_COLLECTIVE:
79 return "AllToAll";
80
81 default:
82 return "undef";
83 }
84}
85
86string TaskNameFromDeviceName(const string& device_name) {
87 DeviceNameUtils::ParsedName parsed_device;
88 CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
89 string task_name;
90 CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
91 return task_name;
92}
93
94struct RankFormatter {
95 void operator()(std::string* out, CollGroupMember m) const {
96 out->append(std::to_string(m.rank));
97 }
98};
99
100Status CheckUserSpecifiedRanks(const std::vector<CollGroupMember> members) {
101 absl::flat_hash_set<int> user_ranks = {};
102 bool at_least_one_member_with_no_rank = false;
103 bool at_least_one_member_with_user_rank = false;
104 for (const auto& m : members) {
105 if (m.rank == -1) {
106 at_least_one_member_with_no_rank = true;
107 } else {
108 at_least_one_member_with_user_rank = true;
109 user_ranks.insert(m.rank);
110 }
111 }
112
113 auto received_ranks = absl::StrJoin(members, ",", RankFormatter());
114 if (at_least_one_member_with_no_rank && at_least_one_member_with_user_rank) {
115 return errors::InvalidArgument(
116 "Only part of the group members have user given rank specified.",
117 "Received ranks: ", received_ranks);
118 }
119
120 if (at_least_one_member_with_user_rank &&
121 user_ranks.size() < members.size()) {
122 return errors::InvalidArgument(
123 "Duplicate ranks specified for group members. Received ranks: ",
124 received_ranks);
125 }
126 return OkStatus();
127}
128} // namespace
129
130void CollectiveParamResolverLocal::CompleteGroupLocal(
131 const DeviceAttributes& device, CollGroupParams* group_params,
132 CancellationManager* cancel_mgr, StatusCallback done) {
133 VLOG(1) << "CompleteGroup device=" << device.name() << ": "
134 << group_params->ToString();
135 std::vector<StatusCallback> to_be_called;
136
137 GroupRec* gr = nullptr;
138 Status status;
139 {
140 mutex_lock l(group_mu_);
141 auto it = group_table_.find(group_params->group_key);
142 if (it == group_table_.end()) {
143 gr = new GroupRec;
144 mutex_lock grl(gr->mu);
145 gr->group.group_key = group_params->group_key;
146 gr->group.group_size = group_params->group_size;
147 gr->group.device_type = group_params->device_type;
148 if (nccl_communicator_ != nullptr) {
149 gr->group.runtime_details.communicator_key =
150 nccl_communicator_->GenerateCommunicatorKey();
151 }
152 // Store GroupRec in group_table_ which is shared between all devices on
153 // this worker.
154 group_table_[gr->group.group_key].reset(gr);
155 VLOG(2) << "New group_key=" << gr->group.group_key
156 << " group_size=" << gr->group.group_size
157 << " runtime_details=" << gr->group.runtime_details.ToString();
158 } else {
159 gr = it->second.get();
160 }
161 }
162 {
163 mutex_lock l(status_mu_);
164 status = status_;
165 }
166 if (!status.ok()) {
167 done(status);
168 return;
169 }
170
171 if (cancel_mgr != nullptr) {
172 CancellationToken token = cancel_mgr->get_cancellation_token();
173 bool is_cancelled = !cancel_mgr->RegisterCallback(
174 token, std::bind(&CollectiveParamResolverLocal::CancelGroup, this,
175 group_params->group_key));
176 if (is_cancelled) {
177 done(errors::Cancelled("CompleteGroup is cancelled before it starts"));
178 return;
179 }
180 done = [cancel_mgr, token,
181 original_done = std::move(done)](const Status& status) {
182 cancel_mgr->TryDeregisterCallback(token);
183 original_done(status);
184 };
185 }
186
187 {
188 mutex_lock gr_lock(gr->mu);
189 // If there is ever an error associated with a group key, we store the error
190 // status and invoke all waiting and future callbacks with this error
191 // status.
192 VLOG(2) << "gr device_type=" << gr->group.device_type
193 << " cp device_type=" << group_params->device_type
194 << " current device=" << device.name();
195 if (gr->status.ok()) {
196 // Check for consistency with existing GroupRec.
197 if (group_params->device_type != gr->group.device_type) {
198 gr->status = errors::Internal(
199 "Device ", device.name(),
200 " is joining a group with incompatible device type",
201 gr->group.device_type.type_string(),
202 " (group_key=", gr->group.group_key, ")");
203 } else if (group_params->group_size != gr->group.group_size) {
204 gr->status = errors::Internal(
205 "Device ", device.name(), " is joining a group with size",
206 group_params->group_size, ", but that group has size ",
207 gr->group.group_size, " (group_key=", gr->group.group_key, ")");
208 }
209 }
210 bool new_device = false;
211 if (gr->status.ok()) {
212 // Insert device if not already present.
213 auto it = gr->incarnations_by_device_name.find(device.name());
214 if (it == gr->incarnations_by_device_name.end()) {
215 if (gr->group.members.size() == gr->group.group_size) {
216 // The group is already full.
217 gr->status =
218 errors::Internal("Device ", device.name(),
219 " is joining a group that is already full",
220 " (group_key=", gr->group.group_key, ")");
221 } else {
222 // This is a new device that has not yet joined the group.
223 gr->incarnations_by_device_name[device.name()] = device.incarnation();
224 CollGroupMember member;
225 member.device = device;
226 if (group_params->user_specified_rank == -1 ||
227 (group_params->user_specified_rank >= 0 &&
228 group_params->user_specified_rank < gr->group.group_size)) {
229 member.rank = group_params->user_specified_rank;
230 } else {
231 gr->status = errors::InvalidArgument(
232 "User Provided rank is invalid. It should be between [0, "
233 "group_size)");
234 }
235 gr->group.members.push_back(std::move(member));
236 new_device = true;
237 if (VLOG_IS_ON(1)) {
238 string dev_buf;
239 for (const auto& m : gr->group.members) {
240 strings::StrAppend(&dev_buf, ",", m.device.name());
241 }
242 VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
243 << " group_size=" << gr->group.group_size << " (current"
244 << " devices)=(" << dev_buf << ") (number of"
245 << " devices pending)="
246 << (gr->group.group_size - gr->group.members.size());
247 }
248 }
249 } else {
250 // If the device already exists, check if the incarnation matches.
251 if (it->second != device.incarnation()) {
252 gr->status = errors::FailedPrecondition(
253 "Device ", device.name(),
254 " current incarnation doesn't match with one in the group. This "
255 "usually means this worker has restarted but the collective "
256 "leader hasn't, or this worker connects to a wrong cluster.");
257 }
258 }
259 }
260
261 if (gr->status.ok()) {
262 // If the group is not yet complete, queue to wait for it.
263 VLOG(2) << "group_size " << gr->group.group_size << " set size "
264 << gr->group.members.size() << " gr " << gr;
265
266 if (gr->group.members.size() < gr->group.group_size) {
267 gr->pending_done.push_back(std::move(done));
268 gr->pending_params.push_back(group_params);
269 return;
270 }
271 CHECK_EQ(gr->group.members.size(), gr->group.group_size);
272 // We get a full group. Fill in remaining fields in gr->group.
273 auto st = CheckUserSpecifiedRanks(gr->group.members);
274 if (!st.ok()) {
275 gr->status = st;
276 }
277 if (new_device) {
278 FinishGroup(gr);
279 }
280 // Copy to all pending CollGroupParams;
281 *group_params = gr->group;
282 for (auto* params : gr->pending_params) {
283 *params = gr->group;
284 }
285 }
286 // At this point, we either have a full group, or an error status. Ensure
287 // that all callbacks are invoked with the appropriate status.
288 to_be_called.swap(gr->pending_done);
289 gr->pending_params.clear();
290 status = gr->status;
291 }
292 done(status);
293 for (int i = 0; i < to_be_called.size(); ++i) {
294 to_be_called[i](status);
295 }
296}
297
298namespace {
299struct DevRec {
300 string task;
301 string device;
302 int original_rank;
303 int local_rank;
304 int global_rank;
305 const DeviceLocality* locality;
306};
307typedef std::unordered_map<string, DevRec> TaskDeviceMap;
308typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
309
310// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
311GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp) {
312 GlobalDeviceMap gdm;
313 CHECK_EQ(gp.members.size(), gp.members.size());
314 for (int i = 0; i < gp.members.size(); ++i) {
315 TaskDeviceMap& tdm = gdm[gp.members[i].task];
316 DevRec* dr = &tdm[gp.members[i].device.name()];
317 dr->task = gp.members[i].task;
318 dr->device = gp.members[i].device.name();
319 dr->original_rank = i;
320 dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
321 dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
322 dr->locality = &gp.members[i].device.locality();
323 }
324 return gdm;
325}
326
327bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
328 std::vector<string> split_gpu_ring_order_str =
329 str_util::Split(gpu_ring_order_str, ',');
330 if (split_gpu_ring_order_str.size() != tdm->size()) return false;
331
332 // gpu id -> local rank
333 gtl::FlatMap<int32, int32> gpu_ranks;
334 for (int32_t rank = 0;
335 rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
336 int32_t tmp;
337 if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
338 gpu_ranks[tmp] = rank;
339 } else {
340 return false;
341 }
342 }
343
344 for (auto& tdm_it : *tdm) {
345 DeviceNameUtils::ParsedName parsed_name;
346 DevRec* dr = &tdm_it.second;
347 if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
348 return false;
349 }
350 auto rank_it = gpu_ranks.find(parsed_name.id);
351 if (rank_it == gpu_ranks.end()) return false;
352 dr->local_rank = rank_it->second;
353 }
354 VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
355 return true;
356}
357
358void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
359 CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
360
361 // If a valid ring order has been passed in via ConfigProto, use that.
362 if (ParseRingOrder(gpu_ring_order, tdm)) return;
363
364 // Either no ring order was passed in, or the format was unexpected.
365 // We now assign a ring order based on link strengths. Note that this
366 // algorithm is not optimal and may not always find the best ring order.
367 int least_rank = -1;
368 string next_device;
369 std::set<string> selected;
370 // Starting device is one with the least initial rank.
371 for (const auto& it : *tdm) {
372 if (least_rank < 0 || it.second.original_rank < least_rank) {
373 least_rank = it.second.original_rank;
374 next_device = it.second.device;
375 }
376 }
377 CHECK_GE(least_rank, 0);
378 DeviceNameUtils::ParsedName parsed_name;
379 CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
380 // NOTE: InterconnectLink has only a device_id, nothing more, so for
381 // the time being if there's more than one device at a task we
382 // assume they're all GPUs.
383
384 int next_rank = 0;
385 while (true) {
386 selected.insert(next_device);
387 auto next_dev_it = tdm->find(next_device);
388 CHECK(next_dev_it != tdm->end());
389 DevRec* dr = &next_dev_it->second;
390 dr->local_rank = next_rank;
391 ++next_rank;
392 if (selected.size() == tdm->size()) {
393 break;
394 }
395 // For the present time we assume Locality links only cover GPUs.
396 // For multiple CPUs, just take them in order.
397 const InterconnectLink* best_link = nullptr;
398 if (parsed_name.type == "GPU") {
399 for (const InterconnectLink& il : dr->locality->links().link()) {
400 parsed_name.id = il.device_id();
401 string endpoint_device =
402 DeviceNameUtils::ParsedNameToString(parsed_name);
403 // Skip the device if we've already seen it.
404 if (selected.find(endpoint_device) != selected.end()) {
405 continue;
406 }
407 // Skip the device if it is not participating in this collective
408 // instance.
409 if (tdm->find(endpoint_device) == tdm->end()) {
410 continue;
411 }
412 if (best_link == nullptr || il.strength() > best_link->strength()) {
413 best_link = &il;
414 }
415 }
416 }
417 if (best_link != nullptr) {
418 // Follow the best edge
419 parsed_name.id = best_link->device_id();
420 next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
421 } else {
422 // No good edges, alas. Pick the lowest initial rank among remaining
423 // devices.
424 least_rank = -1;
425 for (const auto& it : *tdm) {
426 if (selected.find(it.second.device) != selected.end()) {
427 continue;
428 }
429 if (least_rank < 0 || it.second.original_rank < least_rank) {
430 least_rank = it.second.original_rank;
431 next_device = it.second.device;
432 }
433 }
434 CHECK_GE(least_rank, 0);
435 }
436 }
437}
438
439// The first time a CollGroupParams is established for a group we compute a good
440// rank order for all the devices in the group, that is appropriate for a ring
441// algorithm.
442GlobalDeviceMap EstablishGlobalRank(const CollGroupParams& gp,
443 const string& gpu_ring_order) {
444 VLOG(1) << "EstablishGlobalRank";
445 GlobalDeviceMap gdm = BuildDevRecs(gp);
446 for (auto& iter : gdm) {
447 TaskDeviceMap& tdm = iter.second;
448 OrderTaskDeviceMap(gpu_ring_order, &tdm);
449 }
450 // Connect the global rank order by the lexicographical order of the tasks.
451 std::set<string> tasks;
452 for (const CollGroupMember& member : gp.members) {
453 tasks.insert(member.task);
454 }
455 int next_rank = 0;
456 for (const string& task : tasks) {
457 TaskDeviceMap* tdm = &gdm[task];
458 for (auto& it : *tdm) {
459 it.second.global_rank = it.second.local_rank + next_rank;
460 }
461 next_rank += tdm->size();
462 }
463 return gdm;
464}
465
466// Count the devices associated with each task and set
467// gp->same_num_devices_per_task. Requires gp->task_names
468// be sorted.
469void SetDevPerTask(CollGroupParams* gp) {
470 gp->num_devices_per_task.clear();
471 for (const CollGroupMember& member : gp->members) {
472 gp->num_devices_per_task[member.task]++;
473 }
474 gp->same_num_devices_per_task = false;
475 int dev_per_task = -1;
476 for (const auto& task_dev : gp->num_devices_per_task) {
477 if (dev_per_task == -1) {
478 dev_per_task = task_dev.second;
479 } else if (dev_per_task != task_dev.second) {
480 return;
481 }
482 }
483 gp->same_num_devices_per_task = true;
484}
485
486} // namespace
487
488void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
489 // Populate group member task and is_local.
490 for (CollGroupMember& member : gr->group.members) {
491 member.task = TaskNameFromDeviceName(member.device.name());
492 member.is_local = member.task == task_name_;
493 }
494 // Establish the order of the members by considering localities of all
495 // devices.
496 CompleteDefaultRanking(&gr->group);
497 SetDevPerTask(&gr->group);
498 gr->group.num_tasks =
499 static_cast<int32>(gr->group.num_devices_per_task.size());
500}
501
502void CollectiveParamResolverLocal::CancelGroup(int32 group_key) {
503 std::vector<StatusCallback> pending_done;
504 GroupRec* gr = nullptr;
505 {
506 mutex_lock l(group_mu_);
507 auto it = group_table_.find(group_key);
508 if (it == group_table_.end()) {
509 return;
510 }
511 gr = it->second.get();
512 }
513 {
514 mutex_lock l(gr->mu);
515 if (gr->group.members.size() == gr->group.group_size) {
516 // The group is already complete. There's no need to cancel.
517 return;
518 }
519 gr->status = errors::Cancelled("group is cancelled");
520 pending_done.swap(gr->pending_done);
521 gr->pending_params.clear();
522 }
523 for (const StatusCallback& done : pending_done) {
524 done(errors::Cancelled("group is cancelled"));
525 }
526}
527
528void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
529 CollectiveParams* cp) {
530 CHECK_EQ(cp->group.group_size, cp->group.members.size()) << cp->ToString();
531 for (int i = 0; i < cp->group.group_size; ++i) {
532 if (cp->group.members[i].device.name() == device) {
533 cp->default_rank = i;
534 }
535 // Set member rank to default rank if not user specified.
536 if (cp->group.members[i].rank == -1) {
537 cp->group.members[i].rank = i;
538 }
539 }
540}
541
542void CollectiveParamResolverLocal::InitInstanceSharedParams(
543 const CollectiveParams* cp, InstanceRec* ir) {
544 ir->shared->instance = cp->instance;
545 ir->shared->default_rank = -1;
546}
547
548// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
549// to all devices that they are physically connected to and visible to the
550// TensorFlow runtime. This set of devices may be a superset of the devices
551// participating in this instance of collectives.
552void CollectiveParamResolverLocal::CompleteDefaultRanking(CollGroupParams* gp) {
553 // Sort gp->member to avoid indeterminism.
554 std::sort(gp->members.begin(), gp->members.end(),
555 [](const CollGroupMember& lhs, const CollGroupMember& rhs) {
556 DeviceNameUtils::ParsedName lhs_device_name, rhs_device_name;
557 if (DeviceNameUtils::ParseFullName(lhs.device.name(),
558 &lhs_device_name) &&
559 DeviceNameUtils::ParseFullName(rhs.device.name(),
560 &rhs_device_name)) {
561 if (lhs_device_name.job == rhs_device_name.job) {
562 if (lhs_device_name.task == rhs_device_name.task) {
563 return lhs_device_name.id < rhs_device_name.id;
564 } else {
565 return lhs_device_name.task < rhs_device_name.task;
566 }
567 } else {
568 return lhs_device_name.job < rhs_device_name.job;
569 }
570 }
571 return lhs.device.name() < rhs.device.name();
572 });
573 // Establish an instance-specific default rank order for devices
574 // based on localities. This rank order should be a good ring
575 // order, if possible.
576 GlobalDeviceMap gdm = EstablishGlobalRank(*gp, gpu_ring_order_);
577 // Reflect the new global ranking on shared
578 std::vector<CollGroupMember> new_members(gp->group_size);
579 for (const auto& git : gdm) {
580 const TaskDeviceMap& tdm = git.second;
581 for (const auto& tit : tdm) {
582 const DevRec& dr = tit.second;
583 new_members[dr.global_rank] = std::move(gp->members[dr.original_rank]);
584 }
585 }
586
587 if (VLOG_IS_ON(2)) {
588 string buf;
589 for (const auto& m : new_members)
590 strings::StrAppend(&buf, "\n", m.device.name());
591 VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
592 << buf;
593 }
594 gp->members = std::move(new_members);
595}
596
597CollectiveParamResolverLocal::InstanceRec*
598CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp,
599 bool* created) {
600 *created = false;
601 InstanceRec* irec = nullptr;
602 {
603 mutex_lock l(instance_mu_);
604 auto group_it = instance_table_.find(cp->group.group_key);
605 if (group_it != instance_table_.end()) {
606 auto instance_it = group_it->second.find(cp->instance.instance_key);
607 if (instance_it != group_it->second.end()) {
608 irec = instance_it->second.get();
609 }
610 }
611 if (irec == nullptr) {
612 // Create new InstanceRec.
613 irec = new InstanceRec;
614 *created = true;
615 {
616 mutex_lock il(irec->mu);
617 irec->known.resize(cp->group.group_size, false);
618 }
619 InitInstanceSharedParams(cp, irec);
620 instance_table_[cp->group.group_key][cp->instance.instance_key].reset(
621 irec);
622 }
623 }
624 Status status;
625 {
626 mutex_lock l(status_mu_);
627 status = status_;
628 }
629 if (!status.ok()) {
630 mutex_lock l(irec->mu);
631 irec->status = status;
632 }
633 return irec;
634}
635
636Status CollectiveParamResolverLocal::LookupGroup(int32_t group_key,
637 CollGroupParams* group) {
638 mutex_lock l(group_mu_);
639 auto group_rec = group_table_.find(group_key);
640 if (group_rec == group_table_.end()) {
641 return errors::InvalidArgument("Group ", group_key,
642 " is not "
643 "initialized. Please call group "
644 "initialization op first before invoking "
645 "collective op.");
646 }
647 mutex_lock lock(group_rec->second->mu);
648 if (!group_rec->second->status.ok()) {
649 return errors::FailedPrecondition(
650 "Failed to run collective due to "
651 "unsuccessful group initialization. "
652 "Group initialization failed with error ",
653 group_rec->second->status.ToString());
654 }
655 *group = group_rec->second->group;
656 return OkStatus();
657}
658
659void CollectiveParamResolverLocal::CompleteParamsAsync(
660 const DeviceAttributes& device, CollectiveParams* cp,
661 CancellationManager* cancel_mgr, const StatusCallback& done) {
662 VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
663 << cp->ToString();
664 if (cp->run_group_initialization) {
665 CompleteGroupLocal(device, &cp->group, cancel_mgr,
666 [this, device, cp, done](const Status& s) {
667 if (s.ok()) {
668 CompleteInstanceLocal(device.name(), cp, done);
669 } else {
670 done(s);
671 }
672 });
673 } else {
674 // For Collective V3 ops, group is already initialized. Fetch attributes
675 // for the already initialized group to pass to Insitance initialization.
676 const auto s = LookupGroup(cp->group.group_key, &cp->group);
677 if (s.ok()) {
678 CompleteInstanceLocal(device.name(), cp, done);
679 } else {
680 done(s);
681 }
682 }
683}
684
685void CollectiveParamResolverLocal::CompleteInstanceAsync(
686 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
687 CancellationManager* cancel_mgr, const StatusCallback& done) {
688 done(
689 errors::Internal("CompleteInstance is not implemented by "
690 "CollectiveParamResolverLocal which is "
691 "intended only for non-distributed deployment."));
692}
693
694// TODO(b/111897089): we need a better way to pick the collective
695// implementation. The ideal way would depend upon the topology and link
696// strength before picking a particular implementation.
697void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
698 // We use the NCCL implementation if this is an environment which supports
699 // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
700 // also if indicated either in `ConfigProto` or `communication_hint`.
701 //
702 // After enough testing, we may simplify this logic to use NCCL whenever
703 // available.
704 CollectiveImplementationInterface* col_impl;
705 bool use_nccl =
706 (nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
707 cp->group.device_type == DEVICE_GPU &&
708 CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
709 .ok();
710 cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
711 VLOG(1) << "AssignCollectiveType "
712 << cp->instance.impl_details.collective_name;
713}
714
715void CollectiveParamResolverLocal::CompleteInstanceLocal(
716 const string& device, CollectiveParams* cp, const StatusCallback& done) {
717 VLOG(1) << "CompleteInstanceLocal " << device
718 << " instance_key: " << cp->instance.instance_key << " group_key "
719 << cp->group.group_key;
720
721 bool created_irec;
722 InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
723 if (!created_irec) {
724 // Check that the preexisting IRec is consistent with the params passed into
725 // this invocation.
726 if (ir->shared->instance.type != cp->instance.type ||
727 ir->shared->instance.data_type != cp->instance.data_type) {
728 done(errors::Internal("Collective instance ", cp->instance.instance_key,
729 " expected type ", ir->shared->instance.type,
730 " and data_type ", ir->shared->instance.data_type,
731 " but got type ", cp->instance.type,
732 " and data_type ", cp->instance.data_type));
733 return;
734 }
735 }
736 CompleteInstanceFromInitializedIRec(device, cp, ir, done);
737}
738
739void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
740 const string& device, CollectiveParams* cp, InstanceRec* ir,
741 const StatusCallback& done) {
742 auto expected_shape = cp->instance.shape;
743 Status status;
744 // Populate the fields common across instance.
745 {
746 mutex_lock l(ir->mu);
747 status = ir->status;
748 if (status.ok()) {
749 // custom operator= does a deep copy.
750 cp->instance = ir->shared->instance;
751 }
752 }
753 if (!status.ok()) {
754 done(status);
755 return;
756 }
757 if (expected_shape != cp->instance.shape) {
758 done(errors::InvalidArgument(
759 "Shape mismatch in the collective instance ", cp->instance.instance_key,
760 ". Op at device ", device, " expected shape ",
761 expected_shape.DebugString(), " but another member in the group ",
762 "expected shape ", cp->instance.shape.DebugString(), ". This is likely",
763 " due to different input shapes at different members of the collective",
764 " op."));
765 return;
766 }
767 // Populate the fields common across task.
768 AssignCollectiveType(cp);
769 SetDefaultRank(device, cp);
770
771 CollectiveImplementationInterface* col_impl;
772 status = CollectiveRegistry::LookupParamResolverInstance(
773 cp->instance.impl_details.collective_name, &col_impl);
774 if (!status.ok()) {
775 done(status);
776 return;
777 }
778
779 // We may need to wait for the group, if this is a broadcast, for source
780 // discovery.
781 if (cp->instance.type == BROADCAST_COLLECTIVE) {
782 WaitForGroup(ir, cp, [col_impl, ir, device, cp, done](InstanceRec* irec) {
783 Status s;
784 if (ir != irec) {
785 s = errors::Internal("Expected ir ", ir, " and irec ", irec,
786 " to be equal");
787 } else {
788 mutex_lock l(irec->mu);
789 s = irec->status;
790 cp->source_rank = irec->source_rank;
791 }
792 if (s.ok()) {
793 s = col_impl->InitializeCollectiveParams(cp);
794 }
795 done(s);
796 });
797 } else {
798 done(col_impl->InitializeCollectiveParams(cp));
799 }
800}
801
802void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
803 CollectiveParams* cp,
804 const IRConsumer& f) {
805 std::vector<IRConsumer> ready_waiters;
806 do {
807 mutex_lock l(ir->mu);
808 if (!ir->status.ok()) {
809 break;
810 }
811 CHECK_EQ(cp->group.group_size, ir->known.size());
812 CHECK_GE(cp->default_rank, 0);
813 if (!ir->known[cp->default_rank]) {
814 ir->known[cp->default_rank] = true;
815 ++ir->known_count;
816 if (cp->is_source) {
817 // Initialize source rank.
818 if (ir->source_rank >= 0) {
819 ir->status = errors::Internal("Instance ", cp->instance.instance_key,
820 " already has source ", ir->source_rank,
821 ", received second claim from ",
822 cp->default_rank);
823 } else {
824 ir->source_rank = cp->default_rank;
825 }
826 }
827 }
828 if (ir->known_count < cp->group.group_size) {
829 ir->known_waiters.push_back(f);
830 return;
831 }
832 CHECK_EQ(ir->known_count, cp->group.group_size);
833 if (ir->source_rank < 0) {
834 // NOTE(ayushd): changing the error message below would also require
835 // updating CompleteParamsBroadcastForgotSend test in
836 // CollectiveParamResolverLocalTest.
837 ir->status =
838 errors::Internal("Instance ", cp->instance.instance_key,
839 " found no source for broadcast. This "
840 "could mean that there were group_size=",
841 ir->known_count, " BcastRecvs but no BcastSend.");
842 }
843 if (!ir->known_waiters.empty()) {
844 ready_waiters = std::move(ir->known_waiters);
845 }
846 } while (false);
847 f(ir);
848 for (auto& f : ready_waiters) {
849 f(ir);
850 }
851}
852
853void CollectiveParamResolverLocal::StartAbort(const Status& s) {
854 {
855 mutex_lock l(status_mu_);
856 if (!status_.ok()) {
857 VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
858 "subsequent abortion with status: "
859 << s;
860 return;
861 }
862 status_ = s;
863 }
864 StartAbortLocal(s);
865}
866
867void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
868 std::vector<StatusCallback> pending_done;
869 {
870 mutex_lock l(group_mu_);
871 for (const auto& item : group_table_) {
872 GroupRec* gr = item.second.get();
873 {
874 mutex_lock gl(gr->mu);
875 gr->status = s;
876 for (auto& done : gr->pending_done) {
877 pending_done.push_back(std::move(done));
878 }
879 gr->pending_done.clear();
880 gr->pending_params.clear();
881 }
882 }
883 }
884 for (const StatusCallback& done : pending_done) {
885 done(s);
886 }
887 std::vector<InstanceRec*> instances;
888 {
889 mutex_lock l(instance_mu_);
890 for (const auto& group_entry : instance_table_) {
891 for (const auto& item : group_entry.second) {
892 instances.push_back(item.second.get());
893 }
894 }
895 }
896 for (InstanceRec* ir : instances) {
897 std::vector<IRConsumer> known_waiters;
898 {
899 mutex_lock il(ir->mu);
900 ir->status = s;
901 known_waiters.swap(ir->known_waiters);
902 }
903 for (const IRConsumer& done : known_waiters) {
904 done(ir);
905 }
906 }
907}
908
909} // namespace tensorflow
910