1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" |
16 | |
17 | #include <stddef.h> |
18 | |
19 | #include <algorithm> |
20 | #include <unordered_set> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "absl/container/flat_hash_set.h" |
25 | #include "absl/strings/str_join.h" |
26 | #include "tensorflow/core/common_runtime/device_mgr.h" |
27 | #include "tensorflow/core/framework/cancellation.h" |
28 | #include "tensorflow/core/framework/collective.h" |
29 | #include "tensorflow/core/framework/device_attributes.pb.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/core/status.h" |
33 | #include "tensorflow/core/lib/gtl/flatmap.h" |
34 | #include "tensorflow/core/lib/strings/numbers.h" |
35 | #include "tensorflow/core/lib/strings/str_util.h" |
36 | #include "tensorflow/core/lib/strings/strcat.h" |
37 | #include "tensorflow/core/platform/errors.h" |
38 | #include "tensorflow/core/platform/status.h" |
39 | #include "tensorflow/core/platform/types.h" |
40 | #include "tensorflow/core/protobuf/config.pb.h" |
41 | #include "tensorflow/core/util/device_name_utils.h" |
42 | |
43 | namespace tensorflow { |
44 | |
45 | CollectiveParamResolverLocal::CollectiveParamResolverLocal( |
46 | const ConfigProto& config, const DeviceMgr* dev_mgr, |
47 | DeviceResolverInterface* dev_resolver, |
48 | NcclCommunicatorInterface* nccl_communicator, const string& task_name) |
49 | : nccl_(config.experimental().collective_nccl()), |
50 | dev_mgr_(dev_mgr), |
51 | dev_resolver_(dev_resolver), |
52 | nccl_communicator_(nccl_communicator), |
53 | task_name_(task_name), |
54 | gpu_ring_order_( |
55 | config.gpu_options().experimental().collective_ring_order()) {} |
56 | |
57 | void CollectiveParamResolverLocal::CompleteGroupAsync( |
58 | const DeviceAttributes& device, CollGroupParams* group_params, |
59 | CancellationManager* cancel_mgr, const StatusCallback& done) { |
60 | CompleteGroupLocal(device, group_params, cancel_mgr, done); |
61 | } |
62 | |
63 | namespace { |
64 | const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) { |
65 | switch (cp->instance.type) { |
66 | case BROADCAST_COLLECTIVE: |
67 | return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast" ; |
68 | |
69 | case REDUCTION_COLLECTIVE: |
70 | return nccl ? "NcclReduce" : "RingReduce" ; |
71 | |
72 | case GATHER_COLLECTIVE: |
73 | return nccl ? "NcclGather" : "RingGather" ; |
74 | |
75 | case PERMUTE_COLLECTIVE: |
76 | return "Permute" ; |
77 | |
78 | case ALL_TO_ALL_COLLECTIVE: |
79 | return "AllToAll" ; |
80 | |
81 | default: |
82 | return "undef" ; |
83 | } |
84 | } |
85 | |
86 | string TaskNameFromDeviceName(const string& device_name) { |
87 | DeviceNameUtils::ParsedName parsed_device; |
88 | CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device)); |
89 | string task_name; |
90 | CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name)); |
91 | return task_name; |
92 | } |
93 | |
94 | struct RankFormatter { |
95 | void operator()(std::string* out, CollGroupMember m) const { |
96 | out->append(std::to_string(m.rank)); |
97 | } |
98 | }; |
99 | |
100 | Status CheckUserSpecifiedRanks(const std::vector<CollGroupMember> members) { |
101 | absl::flat_hash_set<int> user_ranks = {}; |
102 | bool at_least_one_member_with_no_rank = false; |
103 | bool at_least_one_member_with_user_rank = false; |
104 | for (const auto& m : members) { |
105 | if (m.rank == -1) { |
106 | at_least_one_member_with_no_rank = true; |
107 | } else { |
108 | at_least_one_member_with_user_rank = true; |
109 | user_ranks.insert(m.rank); |
110 | } |
111 | } |
112 | |
113 | auto received_ranks = absl::StrJoin(members, "," , RankFormatter()); |
114 | if (at_least_one_member_with_no_rank && at_least_one_member_with_user_rank) { |
115 | return errors::InvalidArgument( |
116 | "Only part of the group members have user given rank specified." , |
117 | "Received ranks: " , received_ranks); |
118 | } |
119 | |
120 | if (at_least_one_member_with_user_rank && |
121 | user_ranks.size() < members.size()) { |
122 | return errors::InvalidArgument( |
123 | "Duplicate ranks specified for group members. Received ranks: " , |
124 | received_ranks); |
125 | } |
126 | return OkStatus(); |
127 | } |
128 | } // namespace |
129 | |
130 | void CollectiveParamResolverLocal::CompleteGroupLocal( |
131 | const DeviceAttributes& device, CollGroupParams* group_params, |
132 | CancellationManager* cancel_mgr, StatusCallback done) { |
133 | VLOG(1) << "CompleteGroup device=" << device.name() << ": " |
134 | << group_params->ToString(); |
135 | std::vector<StatusCallback> to_be_called; |
136 | |
137 | GroupRec* gr = nullptr; |
138 | Status status; |
139 | { |
140 | mutex_lock l(group_mu_); |
141 | auto it = group_table_.find(group_params->group_key); |
142 | if (it == group_table_.end()) { |
143 | gr = new GroupRec; |
144 | mutex_lock grl(gr->mu); |
145 | gr->group.group_key = group_params->group_key; |
146 | gr->group.group_size = group_params->group_size; |
147 | gr->group.device_type = group_params->device_type; |
148 | if (nccl_communicator_ != nullptr) { |
149 | gr->group.runtime_details.communicator_key = |
150 | nccl_communicator_->GenerateCommunicatorKey(); |
151 | } |
152 | // Store GroupRec in group_table_ which is shared between all devices on |
153 | // this worker. |
154 | group_table_[gr->group.group_key].reset(gr); |
155 | VLOG(2) << "New group_key=" << gr->group.group_key |
156 | << " group_size=" << gr->group.group_size |
157 | << " runtime_details=" << gr->group.runtime_details.ToString(); |
158 | } else { |
159 | gr = it->second.get(); |
160 | } |
161 | } |
162 | { |
163 | mutex_lock l(status_mu_); |
164 | status = status_; |
165 | } |
166 | if (!status.ok()) { |
167 | done(status); |
168 | return; |
169 | } |
170 | |
171 | if (cancel_mgr != nullptr) { |
172 | CancellationToken token = cancel_mgr->get_cancellation_token(); |
173 | bool is_cancelled = !cancel_mgr->RegisterCallback( |
174 | token, std::bind(&CollectiveParamResolverLocal::CancelGroup, this, |
175 | group_params->group_key)); |
176 | if (is_cancelled) { |
177 | done(errors::Cancelled("CompleteGroup is cancelled before it starts" )); |
178 | return; |
179 | } |
180 | done = [cancel_mgr, token, |
181 | original_done = std::move(done)](const Status& status) { |
182 | cancel_mgr->TryDeregisterCallback(token); |
183 | original_done(status); |
184 | }; |
185 | } |
186 | |
187 | { |
188 | mutex_lock gr_lock(gr->mu); |
189 | // If there is ever an error associated with a group key, we store the error |
190 | // status and invoke all waiting and future callbacks with this error |
191 | // status. |
192 | VLOG(2) << "gr device_type=" << gr->group.device_type |
193 | << " cp device_type=" << group_params->device_type |
194 | << " current device=" << device.name(); |
195 | if (gr->status.ok()) { |
196 | // Check for consistency with existing GroupRec. |
197 | if (group_params->device_type != gr->group.device_type) { |
198 | gr->status = errors::Internal( |
199 | "Device " , device.name(), |
200 | " is joining a group with incompatible device type" , |
201 | gr->group.device_type.type_string(), |
202 | " (group_key=" , gr->group.group_key, ")" ); |
203 | } else if (group_params->group_size != gr->group.group_size) { |
204 | gr->status = errors::Internal( |
205 | "Device " , device.name(), " is joining a group with size" , |
206 | group_params->group_size, ", but that group has size " , |
207 | gr->group.group_size, " (group_key=" , gr->group.group_key, ")" ); |
208 | } |
209 | } |
210 | bool new_device = false; |
211 | if (gr->status.ok()) { |
212 | // Insert device if not already present. |
213 | auto it = gr->incarnations_by_device_name.find(device.name()); |
214 | if (it == gr->incarnations_by_device_name.end()) { |
215 | if (gr->group.members.size() == gr->group.group_size) { |
216 | // The group is already full. |
217 | gr->status = |
218 | errors::Internal("Device " , device.name(), |
219 | " is joining a group that is already full" , |
220 | " (group_key=" , gr->group.group_key, ")" ); |
221 | } else { |
222 | // This is a new device that has not yet joined the group. |
223 | gr->incarnations_by_device_name[device.name()] = device.incarnation(); |
224 | CollGroupMember member; |
225 | member.device = device; |
226 | if (group_params->user_specified_rank == -1 || |
227 | (group_params->user_specified_rank >= 0 && |
228 | group_params->user_specified_rank < gr->group.group_size)) { |
229 | member.rank = group_params->user_specified_rank; |
230 | } else { |
231 | gr->status = errors::InvalidArgument( |
232 | "User Provided rank is invalid. It should be between [0, " |
233 | "group_size)" ); |
234 | } |
235 | gr->group.members.push_back(std::move(member)); |
236 | new_device = true; |
237 | if (VLOG_IS_ON(1)) { |
238 | string dev_buf; |
239 | for (const auto& m : gr->group.members) { |
240 | strings::StrAppend(&dev_buf, "," , m.device.name()); |
241 | } |
242 | VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key |
243 | << " group_size=" << gr->group.group_size << " (current" |
244 | << " devices)=(" << dev_buf << ") (number of" |
245 | << " devices pending)=" |
246 | << (gr->group.group_size - gr->group.members.size()); |
247 | } |
248 | } |
249 | } else { |
250 | // If the device already exists, check if the incarnation matches. |
251 | if (it->second != device.incarnation()) { |
252 | gr->status = errors::FailedPrecondition( |
253 | "Device " , device.name(), |
254 | " current incarnation doesn't match with one in the group. This " |
255 | "usually means this worker has restarted but the collective " |
256 | "leader hasn't, or this worker connects to a wrong cluster." ); |
257 | } |
258 | } |
259 | } |
260 | |
261 | if (gr->status.ok()) { |
262 | // If the group is not yet complete, queue to wait for it. |
263 | VLOG(2) << "group_size " << gr->group.group_size << " set size " |
264 | << gr->group.members.size() << " gr " << gr; |
265 | |
266 | if (gr->group.members.size() < gr->group.group_size) { |
267 | gr->pending_done.push_back(std::move(done)); |
268 | gr->pending_params.push_back(group_params); |
269 | return; |
270 | } |
271 | CHECK_EQ(gr->group.members.size(), gr->group.group_size); |
272 | // We get a full group. Fill in remaining fields in gr->group. |
273 | auto st = CheckUserSpecifiedRanks(gr->group.members); |
274 | if (!st.ok()) { |
275 | gr->status = st; |
276 | } |
277 | if (new_device) { |
278 | FinishGroup(gr); |
279 | } |
280 | // Copy to all pending CollGroupParams; |
281 | *group_params = gr->group; |
282 | for (auto* params : gr->pending_params) { |
283 | *params = gr->group; |
284 | } |
285 | } |
286 | // At this point, we either have a full group, or an error status. Ensure |
287 | // that all callbacks are invoked with the appropriate status. |
288 | to_be_called.swap(gr->pending_done); |
289 | gr->pending_params.clear(); |
290 | status = gr->status; |
291 | } |
292 | done(status); |
293 | for (int i = 0; i < to_be_called.size(); ++i) { |
294 | to_be_called[i](status); |
295 | } |
296 | } |
297 | |
298 | namespace { |
299 | struct DevRec { |
300 | string task; |
301 | string device; |
302 | int original_rank; |
303 | int local_rank; |
304 | int global_rank; |
305 | const DeviceLocality* locality; |
306 | }; |
307 | typedef std::unordered_map<string, DevRec> TaskDeviceMap; |
308 | typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap; |
309 | |
310 | // Create a populated GlobalDeviceMap from CollInstanceParams and localities. |
311 | GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp) { |
312 | GlobalDeviceMap gdm; |
313 | CHECK_EQ(gp.members.size(), gp.members.size()); |
314 | for (int i = 0; i < gp.members.size(); ++i) { |
315 | TaskDeviceMap& tdm = gdm[gp.members[i].task]; |
316 | DevRec* dr = &tdm[gp.members[i].device.name()]; |
317 | dr->task = gp.members[i].task; |
318 | dr->device = gp.members[i].device.name(); |
319 | dr->original_rank = i; |
320 | dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap. |
321 | dr->global_rank = 0; // Will be populated later by EstablishGlobalRank. |
322 | dr->locality = &gp.members[i].device.locality(); |
323 | } |
324 | return gdm; |
325 | } |
326 | |
327 | bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) { |
328 | std::vector<string> split_gpu_ring_order_str = |
329 | str_util::Split(gpu_ring_order_str, ','); |
330 | if (split_gpu_ring_order_str.size() != tdm->size()) return false; |
331 | |
332 | // gpu id -> local rank |
333 | gtl::FlatMap<int32, int32> gpu_ranks; |
334 | for (int32_t rank = 0; |
335 | rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) { |
336 | int32_t tmp; |
337 | if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) { |
338 | gpu_ranks[tmp] = rank; |
339 | } else { |
340 | return false; |
341 | } |
342 | } |
343 | |
344 | for (auto& tdm_it : *tdm) { |
345 | DeviceNameUtils::ParsedName parsed_name; |
346 | DevRec* dr = &tdm_it.second; |
347 | if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) { |
348 | return false; |
349 | } |
350 | auto rank_it = gpu_ranks.find(parsed_name.id); |
351 | if (rank_it == gpu_ranks.end()) return false; |
352 | dr->local_rank = rank_it->second; |
353 | } |
354 | VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str; |
355 | return true; |
356 | } |
357 | |
358 | void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) { |
359 | CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices |
360 | |
361 | // If a valid ring order has been passed in via ConfigProto, use that. |
362 | if (ParseRingOrder(gpu_ring_order, tdm)) return; |
363 | |
364 | // Either no ring order was passed in, or the format was unexpected. |
365 | // We now assign a ring order based on link strengths. Note that this |
366 | // algorithm is not optimal and may not always find the best ring order. |
367 | int least_rank = -1; |
368 | string next_device; |
369 | std::set<string> selected; |
370 | // Starting device is one with the least initial rank. |
371 | for (const auto& it : *tdm) { |
372 | if (least_rank < 0 || it.second.original_rank < least_rank) { |
373 | least_rank = it.second.original_rank; |
374 | next_device = it.second.device; |
375 | } |
376 | } |
377 | CHECK_GE(least_rank, 0); |
378 | DeviceNameUtils::ParsedName parsed_name; |
379 | CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name)); |
380 | // NOTE: InterconnectLink has only a device_id, nothing more, so for |
381 | // the time being if there's more than one device at a task we |
382 | // assume they're all GPUs. |
383 | |
384 | int next_rank = 0; |
385 | while (true) { |
386 | selected.insert(next_device); |
387 | auto next_dev_it = tdm->find(next_device); |
388 | CHECK(next_dev_it != tdm->end()); |
389 | DevRec* dr = &next_dev_it->second; |
390 | dr->local_rank = next_rank; |
391 | ++next_rank; |
392 | if (selected.size() == tdm->size()) { |
393 | break; |
394 | } |
395 | // For the present time we assume Locality links only cover GPUs. |
396 | // For multiple CPUs, just take them in order. |
397 | const InterconnectLink* best_link = nullptr; |
398 | if (parsed_name.type == "GPU" ) { |
399 | for (const InterconnectLink& il : dr->locality->links().link()) { |
400 | parsed_name.id = il.device_id(); |
401 | string endpoint_device = |
402 | DeviceNameUtils::ParsedNameToString(parsed_name); |
403 | // Skip the device if we've already seen it. |
404 | if (selected.find(endpoint_device) != selected.end()) { |
405 | continue; |
406 | } |
407 | // Skip the device if it is not participating in this collective |
408 | // instance. |
409 | if (tdm->find(endpoint_device) == tdm->end()) { |
410 | continue; |
411 | } |
412 | if (best_link == nullptr || il.strength() > best_link->strength()) { |
413 | best_link = &il; |
414 | } |
415 | } |
416 | } |
417 | if (best_link != nullptr) { |
418 | // Follow the best edge |
419 | parsed_name.id = best_link->device_id(); |
420 | next_device = DeviceNameUtils::ParsedNameToString(parsed_name); |
421 | } else { |
422 | // No good edges, alas. Pick the lowest initial rank among remaining |
423 | // devices. |
424 | least_rank = -1; |
425 | for (const auto& it : *tdm) { |
426 | if (selected.find(it.second.device) != selected.end()) { |
427 | continue; |
428 | } |
429 | if (least_rank < 0 || it.second.original_rank < least_rank) { |
430 | least_rank = it.second.original_rank; |
431 | next_device = it.second.device; |
432 | } |
433 | } |
434 | CHECK_GE(least_rank, 0); |
435 | } |
436 | } |
437 | } |
438 | |
439 | // The first time a CollGroupParams is established for a group we compute a good |
440 | // rank order for all the devices in the group, that is appropriate for a ring |
441 | // algorithm. |
442 | GlobalDeviceMap EstablishGlobalRank(const CollGroupParams& gp, |
443 | const string& gpu_ring_order) { |
444 | VLOG(1) << "EstablishGlobalRank" ; |
445 | GlobalDeviceMap gdm = BuildDevRecs(gp); |
446 | for (auto& iter : gdm) { |
447 | TaskDeviceMap& tdm = iter.second; |
448 | OrderTaskDeviceMap(gpu_ring_order, &tdm); |
449 | } |
450 | // Connect the global rank order by the lexicographical order of the tasks. |
451 | std::set<string> tasks; |
452 | for (const CollGroupMember& member : gp.members) { |
453 | tasks.insert(member.task); |
454 | } |
455 | int next_rank = 0; |
456 | for (const string& task : tasks) { |
457 | TaskDeviceMap* tdm = &gdm[task]; |
458 | for (auto& it : *tdm) { |
459 | it.second.global_rank = it.second.local_rank + next_rank; |
460 | } |
461 | next_rank += tdm->size(); |
462 | } |
463 | return gdm; |
464 | } |
465 | |
466 | // Count the devices associated with each task and set |
467 | // gp->same_num_devices_per_task. Requires gp->task_names |
468 | // be sorted. |
469 | void SetDevPerTask(CollGroupParams* gp) { |
470 | gp->num_devices_per_task.clear(); |
471 | for (const CollGroupMember& member : gp->members) { |
472 | gp->num_devices_per_task[member.task]++; |
473 | } |
474 | gp->same_num_devices_per_task = false; |
475 | int dev_per_task = -1; |
476 | for (const auto& task_dev : gp->num_devices_per_task) { |
477 | if (dev_per_task == -1) { |
478 | dev_per_task = task_dev.second; |
479 | } else if (dev_per_task != task_dev.second) { |
480 | return; |
481 | } |
482 | } |
483 | gp->same_num_devices_per_task = true; |
484 | } |
485 | |
486 | } // namespace |
487 | |
488 | void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) { |
489 | // Populate group member task and is_local. |
490 | for (CollGroupMember& member : gr->group.members) { |
491 | member.task = TaskNameFromDeviceName(member.device.name()); |
492 | member.is_local = member.task == task_name_; |
493 | } |
494 | // Establish the order of the members by considering localities of all |
495 | // devices. |
496 | CompleteDefaultRanking(&gr->group); |
497 | SetDevPerTask(&gr->group); |
498 | gr->group.num_tasks = |
499 | static_cast<int32>(gr->group.num_devices_per_task.size()); |
500 | } |
501 | |
502 | void CollectiveParamResolverLocal::CancelGroup(int32 group_key) { |
503 | std::vector<StatusCallback> pending_done; |
504 | GroupRec* gr = nullptr; |
505 | { |
506 | mutex_lock l(group_mu_); |
507 | auto it = group_table_.find(group_key); |
508 | if (it == group_table_.end()) { |
509 | return; |
510 | } |
511 | gr = it->second.get(); |
512 | } |
513 | { |
514 | mutex_lock l(gr->mu); |
515 | if (gr->group.members.size() == gr->group.group_size) { |
516 | // The group is already complete. There's no need to cancel. |
517 | return; |
518 | } |
519 | gr->status = errors::Cancelled("group is cancelled" ); |
520 | pending_done.swap(gr->pending_done); |
521 | gr->pending_params.clear(); |
522 | } |
523 | for (const StatusCallback& done : pending_done) { |
524 | done(errors::Cancelled("group is cancelled" )); |
525 | } |
526 | } |
527 | |
528 | void CollectiveParamResolverLocal::SetDefaultRank(const string& device, |
529 | CollectiveParams* cp) { |
530 | CHECK_EQ(cp->group.group_size, cp->group.members.size()) << cp->ToString(); |
531 | for (int i = 0; i < cp->group.group_size; ++i) { |
532 | if (cp->group.members[i].device.name() == device) { |
533 | cp->default_rank = i; |
534 | } |
535 | // Set member rank to default rank if not user specified. |
536 | if (cp->group.members[i].rank == -1) { |
537 | cp->group.members[i].rank = i; |
538 | } |
539 | } |
540 | } |
541 | |
542 | void CollectiveParamResolverLocal::InitInstanceSharedParams( |
543 | const CollectiveParams* cp, InstanceRec* ir) { |
544 | ir->shared->instance = cp->instance; |
545 | ir->shared->default_rank = -1; |
546 | } |
547 | |
548 | // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks |
549 | // to all devices that they are physically connected to and visible to the |
550 | // TensorFlow runtime. This set of devices may be a superset of the devices |
551 | // participating in this instance of collectives. |
552 | void CollectiveParamResolverLocal::CompleteDefaultRanking(CollGroupParams* gp) { |
553 | // Sort gp->member to avoid indeterminism. |
554 | std::sort(gp->members.begin(), gp->members.end(), |
555 | [](const CollGroupMember& lhs, const CollGroupMember& rhs) { |
556 | DeviceNameUtils::ParsedName lhs_device_name, rhs_device_name; |
557 | if (DeviceNameUtils::ParseFullName(lhs.device.name(), |
558 | &lhs_device_name) && |
559 | DeviceNameUtils::ParseFullName(rhs.device.name(), |
560 | &rhs_device_name)) { |
561 | if (lhs_device_name.job == rhs_device_name.job) { |
562 | if (lhs_device_name.task == rhs_device_name.task) { |
563 | return lhs_device_name.id < rhs_device_name.id; |
564 | } else { |
565 | return lhs_device_name.task < rhs_device_name.task; |
566 | } |
567 | } else { |
568 | return lhs_device_name.job < rhs_device_name.job; |
569 | } |
570 | } |
571 | return lhs.device.name() < rhs.device.name(); |
572 | }); |
573 | // Establish an instance-specific default rank order for devices |
574 | // based on localities. This rank order should be a good ring |
575 | // order, if possible. |
576 | GlobalDeviceMap gdm = EstablishGlobalRank(*gp, gpu_ring_order_); |
577 | // Reflect the new global ranking on shared |
578 | std::vector<CollGroupMember> new_members(gp->group_size); |
579 | for (const auto& git : gdm) { |
580 | const TaskDeviceMap& tdm = git.second; |
581 | for (const auto& tit : tdm) { |
582 | const DevRec& dr = tit.second; |
583 | new_members[dr.global_rank] = std::move(gp->members[dr.original_rank]); |
584 | } |
585 | } |
586 | |
587 | if (VLOG_IS_ON(2)) { |
588 | string buf; |
589 | for (const auto& m : new_members) |
590 | strings::StrAppend(&buf, "\n" , m.device.name()); |
591 | VLOG(2) << "Optimized device order for group " << gp->group_key << ": " |
592 | << buf; |
593 | } |
594 | gp->members = std::move(new_members); |
595 | } |
596 | |
597 | CollectiveParamResolverLocal::InstanceRec* |
598 | CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp, |
599 | bool* created) { |
600 | *created = false; |
601 | InstanceRec* irec = nullptr; |
602 | { |
603 | mutex_lock l(instance_mu_); |
604 | auto group_it = instance_table_.find(cp->group.group_key); |
605 | if (group_it != instance_table_.end()) { |
606 | auto instance_it = group_it->second.find(cp->instance.instance_key); |
607 | if (instance_it != group_it->second.end()) { |
608 | irec = instance_it->second.get(); |
609 | } |
610 | } |
611 | if (irec == nullptr) { |
612 | // Create new InstanceRec. |
613 | irec = new InstanceRec; |
614 | *created = true; |
615 | { |
616 | mutex_lock il(irec->mu); |
617 | irec->known.resize(cp->group.group_size, false); |
618 | } |
619 | InitInstanceSharedParams(cp, irec); |
620 | instance_table_[cp->group.group_key][cp->instance.instance_key].reset( |
621 | irec); |
622 | } |
623 | } |
624 | Status status; |
625 | { |
626 | mutex_lock l(status_mu_); |
627 | status = status_; |
628 | } |
629 | if (!status.ok()) { |
630 | mutex_lock l(irec->mu); |
631 | irec->status = status; |
632 | } |
633 | return irec; |
634 | } |
635 | |
636 | Status CollectiveParamResolverLocal::LookupGroup(int32_t group_key, |
637 | CollGroupParams* group) { |
638 | mutex_lock l(group_mu_); |
639 | auto group_rec = group_table_.find(group_key); |
640 | if (group_rec == group_table_.end()) { |
641 | return errors::InvalidArgument("Group " , group_key, |
642 | " is not " |
643 | "initialized. Please call group " |
644 | "initialization op first before invoking " |
645 | "collective op." ); |
646 | } |
647 | mutex_lock lock(group_rec->second->mu); |
648 | if (!group_rec->second->status.ok()) { |
649 | return errors::FailedPrecondition( |
650 | "Failed to run collective due to " |
651 | "unsuccessful group initialization. " |
652 | "Group initialization failed with error " , |
653 | group_rec->second->status.ToString()); |
654 | } |
655 | *group = group_rec->second->group; |
656 | return OkStatus(); |
657 | } |
658 | |
659 | void CollectiveParamResolverLocal::CompleteParamsAsync( |
660 | const DeviceAttributes& device, CollectiveParams* cp, |
661 | CancellationManager* cancel_mgr, const StatusCallback& done) { |
662 | VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": " |
663 | << cp->ToString(); |
664 | if (cp->run_group_initialization) { |
665 | CompleteGroupLocal(device, &cp->group, cancel_mgr, |
666 | [this, device, cp, done](const Status& s) { |
667 | if (s.ok()) { |
668 | CompleteInstanceLocal(device.name(), cp, done); |
669 | } else { |
670 | done(s); |
671 | } |
672 | }); |
673 | } else { |
674 | // For Collective V3 ops, group is already initialized. Fetch attributes |
675 | // for the already initialized group to pass to Insitance initialization. |
676 | const auto s = LookupGroup(cp->group.group_key, &cp->group); |
677 | if (s.ok()) { |
678 | CompleteInstanceLocal(device.name(), cp, done); |
679 | } else { |
680 | done(s); |
681 | } |
682 | } |
683 | } |
684 | |
685 | void CollectiveParamResolverLocal::CompleteInstanceAsync( |
686 | const CompleteInstanceRequest* request, CompleteInstanceResponse* response, |
687 | CancellationManager* cancel_mgr, const StatusCallback& done) { |
688 | done( |
689 | errors::Internal("CompleteInstance is not implemented by " |
690 | "CollectiveParamResolverLocal which is " |
691 | "intended only for non-distributed deployment." )); |
692 | } |
693 | |
694 | // TODO(b/111897089): we need a better way to pick the collective |
695 | // implementation. The ideal way would depend upon the topology and link |
696 | // strength before picking a particular implementation. |
697 | void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) { |
698 | // We use the NCCL implementation if this is an environment which supports |
699 | // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and |
700 | // also if indicated either in `ConfigProto` or `communication_hint`. |
701 | // |
702 | // After enough testing, we may simplify this logic to use NCCL whenever |
703 | // available. |
704 | CollectiveImplementationInterface* col_impl; |
705 | bool use_nccl = |
706 | (nccl_ || cp->instance.impl_details.communication_hint == "nccl" ) && |
707 | cp->group.device_type == DEVICE_GPU && |
708 | CollectiveRegistry::LookupParamResolverInstance("NcclReduce" , &col_impl) |
709 | .ok(); |
710 | cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl); |
711 | VLOG(1) << "AssignCollectiveType " |
712 | << cp->instance.impl_details.collective_name; |
713 | } |
714 | |
715 | void CollectiveParamResolverLocal::CompleteInstanceLocal( |
716 | const string& device, CollectiveParams* cp, const StatusCallback& done) { |
717 | VLOG(1) << "CompleteInstanceLocal " << device |
718 | << " instance_key: " << cp->instance.instance_key << " group_key " |
719 | << cp->group.group_key; |
720 | |
721 | bool created_irec; |
722 | InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec); |
723 | if (!created_irec) { |
724 | // Check that the preexisting IRec is consistent with the params passed into |
725 | // this invocation. |
726 | if (ir->shared->instance.type != cp->instance.type || |
727 | ir->shared->instance.data_type != cp->instance.data_type) { |
728 | done(errors::Internal("Collective instance " , cp->instance.instance_key, |
729 | " expected type " , ir->shared->instance.type, |
730 | " and data_type " , ir->shared->instance.data_type, |
731 | " but got type " , cp->instance.type, |
732 | " and data_type " , cp->instance.data_type)); |
733 | return; |
734 | } |
735 | } |
736 | CompleteInstanceFromInitializedIRec(device, cp, ir, done); |
737 | } |
738 | |
739 | void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec( |
740 | const string& device, CollectiveParams* cp, InstanceRec* ir, |
741 | const StatusCallback& done) { |
742 | auto expected_shape = cp->instance.shape; |
743 | Status status; |
744 | // Populate the fields common across instance. |
745 | { |
746 | mutex_lock l(ir->mu); |
747 | status = ir->status; |
748 | if (status.ok()) { |
749 | // custom operator= does a deep copy. |
750 | cp->instance = ir->shared->instance; |
751 | } |
752 | } |
753 | if (!status.ok()) { |
754 | done(status); |
755 | return; |
756 | } |
757 | if (expected_shape != cp->instance.shape) { |
758 | done(errors::InvalidArgument( |
759 | "Shape mismatch in the collective instance " , cp->instance.instance_key, |
760 | ". Op at device " , device, " expected shape " , |
761 | expected_shape.DebugString(), " but another member in the group " , |
762 | "expected shape " , cp->instance.shape.DebugString(), ". This is likely" , |
763 | " due to different input shapes at different members of the collective" , |
764 | " op." )); |
765 | return; |
766 | } |
767 | // Populate the fields common across task. |
768 | AssignCollectiveType(cp); |
769 | SetDefaultRank(device, cp); |
770 | |
771 | CollectiveImplementationInterface* col_impl; |
772 | status = CollectiveRegistry::LookupParamResolverInstance( |
773 | cp->instance.impl_details.collective_name, &col_impl); |
774 | if (!status.ok()) { |
775 | done(status); |
776 | return; |
777 | } |
778 | |
779 | // We may need to wait for the group, if this is a broadcast, for source |
780 | // discovery. |
781 | if (cp->instance.type == BROADCAST_COLLECTIVE) { |
782 | WaitForGroup(ir, cp, [col_impl, ir, device, cp, done](InstanceRec* irec) { |
783 | Status s; |
784 | if (ir != irec) { |
785 | s = errors::Internal("Expected ir " , ir, " and irec " , irec, |
786 | " to be equal" ); |
787 | } else { |
788 | mutex_lock l(irec->mu); |
789 | s = irec->status; |
790 | cp->source_rank = irec->source_rank; |
791 | } |
792 | if (s.ok()) { |
793 | s = col_impl->InitializeCollectiveParams(cp); |
794 | } |
795 | done(s); |
796 | }); |
797 | } else { |
798 | done(col_impl->InitializeCollectiveParams(cp)); |
799 | } |
800 | } |
801 | |
802 | void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir, |
803 | CollectiveParams* cp, |
804 | const IRConsumer& f) { |
805 | std::vector<IRConsumer> ready_waiters; |
806 | do { |
807 | mutex_lock l(ir->mu); |
808 | if (!ir->status.ok()) { |
809 | break; |
810 | } |
811 | CHECK_EQ(cp->group.group_size, ir->known.size()); |
812 | CHECK_GE(cp->default_rank, 0); |
813 | if (!ir->known[cp->default_rank]) { |
814 | ir->known[cp->default_rank] = true; |
815 | ++ir->known_count; |
816 | if (cp->is_source) { |
817 | // Initialize source rank. |
818 | if (ir->source_rank >= 0) { |
819 | ir->status = errors::Internal("Instance " , cp->instance.instance_key, |
820 | " already has source " , ir->source_rank, |
821 | ", received second claim from " , |
822 | cp->default_rank); |
823 | } else { |
824 | ir->source_rank = cp->default_rank; |
825 | } |
826 | } |
827 | } |
828 | if (ir->known_count < cp->group.group_size) { |
829 | ir->known_waiters.push_back(f); |
830 | return; |
831 | } |
832 | CHECK_EQ(ir->known_count, cp->group.group_size); |
833 | if (ir->source_rank < 0) { |
834 | // NOTE(ayushd): changing the error message below would also require |
835 | // updating CompleteParamsBroadcastForgotSend test in |
836 | // CollectiveParamResolverLocalTest. |
837 | ir->status = |
838 | errors::Internal("Instance " , cp->instance.instance_key, |
839 | " found no source for broadcast. This " |
840 | "could mean that there were group_size=" , |
841 | ir->known_count, " BcastRecvs but no BcastSend." ); |
842 | } |
843 | if (!ir->known_waiters.empty()) { |
844 | ready_waiters = std::move(ir->known_waiters); |
845 | } |
846 | } while (false); |
847 | f(ir); |
848 | for (auto& f : ready_waiters) { |
849 | f(ir); |
850 | } |
851 | } |
852 | |
853 | void CollectiveParamResolverLocal::StartAbort(const Status& s) { |
854 | { |
855 | mutex_lock l(status_mu_); |
856 | if (!status_.ok()) { |
857 | VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring " |
858 | "subsequent abortion with status: " |
859 | << s; |
860 | return; |
861 | } |
862 | status_ = s; |
863 | } |
864 | StartAbortLocal(s); |
865 | } |
866 | |
867 | void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) { |
868 | std::vector<StatusCallback> pending_done; |
869 | { |
870 | mutex_lock l(group_mu_); |
871 | for (const auto& item : group_table_) { |
872 | GroupRec* gr = item.second.get(); |
873 | { |
874 | mutex_lock gl(gr->mu); |
875 | gr->status = s; |
876 | for (auto& done : gr->pending_done) { |
877 | pending_done.push_back(std::move(done)); |
878 | } |
879 | gr->pending_done.clear(); |
880 | gr->pending_params.clear(); |
881 | } |
882 | } |
883 | } |
884 | for (const StatusCallback& done : pending_done) { |
885 | done(s); |
886 | } |
887 | std::vector<InstanceRec*> instances; |
888 | { |
889 | mutex_lock l(instance_mu_); |
890 | for (const auto& group_entry : instance_table_) { |
891 | for (const auto& item : group_entry.second) { |
892 | instances.push_back(item.second.get()); |
893 | } |
894 | } |
895 | } |
896 | for (InstanceRec* ir : instances) { |
897 | std::vector<IRConsumer> known_waiters; |
898 | { |
899 | mutex_lock il(ir->mu); |
900 | ir->status = s; |
901 | known_waiters.swap(ir->known_waiters); |
902 | } |
903 | for (const IRConsumer& done : known_waiters) { |
904 | done(ir); |
905 | } |
906 | } |
907 | } |
908 | |
909 | } // namespace tensorflow |
910 | |