1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h" |
16 | |
17 | #include <functional> |
18 | #include <memory> |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
23 | #include "tensorflow/core/common_runtime/collective_util.h" |
24 | #include "tensorflow/core/common_runtime/device_mgr.h" |
25 | #include "tensorflow/core/common_runtime/dma_helper.h" |
26 | #include "tensorflow/core/framework/device_base.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/lib/core/notification.h" |
30 | #include "tensorflow/core/lib/core/status.h" |
31 | #include "tensorflow/core/lib/strings/str_util.h" |
32 | #include "tensorflow/core/lib/strings/strcat.h" |
33 | #include "tensorflow/core/platform/env.h" |
34 | #include "tensorflow/core/platform/types.h" |
35 | #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" |
36 | #include "tensorflow/core/profiler/lib/traceme.h" |
37 | |
38 | // Set true for greater intelligibility of debug mode log messages. |
39 | #define READABLE_KEYS false |
40 | |
41 | namespace tensorflow { |
42 | |
43 | namespace { |
44 | // Key to be used for BufRendezvous by Broadcaster. |
45 | string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank, |
46 | int dst_rank) { |
47 | if (READABLE_KEYS) { |
48 | return strings::StrCat("broadcast(" , exec_key, "):subdiv(" , subdiv, |
49 | "):src(" , src_rank, "):dst(" , dst_rank, ")" ); |
50 | } else { |
51 | // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash. |
52 | return strings::StrCat(exec_key, ":" , subdiv, ":" , src_rank, ":" , dst_rank); |
53 | } |
54 | } |
55 | } // namespace |
56 | |
57 | HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster() |
58 | : col_ctx_(nullptr), |
59 | col_params_(nullptr), |
60 | done_(nullptr), |
61 | is_source_(false) {} |
62 | |
63 | int HierarchicalTreeBroadcaster::GetDeviceTask( |
64 | int device_rank, const std::vector<int>& dev_per_task) { |
65 | int num_tasks = static_cast<int>(dev_per_task.size()); |
66 | int task_lo = 0; |
67 | int task_hi = -1; |
68 | for (int ti = 0; ti < num_tasks; ti++) { |
69 | task_hi = task_lo + dev_per_task[ti]; |
70 | if (task_lo <= device_rank && device_rank < task_hi) return ti; |
71 | task_lo = task_hi; |
72 | } |
73 | LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi |
74 | << " devices" ; |
75 | return -1; |
76 | } |
77 | |
78 | Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( |
79 | CollectiveParams* col_params) { |
80 | CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE); |
81 | CHECK_EQ(col_params->instance.impl_details.collective_name, |
82 | "HierarchicalTreeBroadcast" ); |
83 | const string& device_name = |
84 | col_params->group.members[col_params->default_rank].device.name(); |
85 | // Start by counting the devices in each task. |
86 | // Precondition: device_names must be sorted so that all devices in |
87 | // the same task are adjacent. |
88 | std::vector<int> dev_per_task; |
89 | const string* prior_task_name = &col_params->group.members[0].task; |
90 | int dev_count = 1; |
91 | for (int di = 1; di < col_params->group.group_size; ++di) { |
92 | if (col_params->group.members[di].task != *prior_task_name) { |
93 | dev_per_task.push_back(dev_count); |
94 | dev_count = 1; |
95 | prior_task_name = &col_params->group.members[di].task; |
96 | } else { |
97 | ++dev_count; |
98 | } |
99 | } |
100 | dev_per_task.push_back(dev_count); |
101 | CHECK_EQ(col_params->group.num_tasks, dev_per_task.size()); |
102 | |
103 | if (VLOG_IS_ON(2)) { |
104 | string dpt_buf; |
105 | for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";" ); |
106 | VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device=" |
107 | << device_name << " source_rank=" << col_params->source_rank |
108 | << " dev_per_task=" << dpt_buf; |
109 | } |
110 | int num_tasks = col_params->group.num_tasks; |
111 | // If there is just 1 task, then execute binary tree broadcast over all |
112 | // devices. Otherwise, the first subdiv is inter-task broadcast, and then |
113 | // there are N more subdivs, where N is #task. |
114 | int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0); |
115 | int total_num_devices = 0; |
116 | for (int num_dev : dev_per_task) total_num_devices += num_dev; |
117 | |
118 | col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs); |
119 | col_params->subdiv_rank.reserve(num_subdivs); |
120 | col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs); |
121 | |
122 | // Inter-task subdiv. Pick one device from each task - this is the source |
123 | // device if it belongs to that task, or device 0 for that task. If a device |
124 | // does not participate in the subdiv, set subdiv_rank to -1. |
125 | if (num_tasks > 1) { |
126 | const int sdi = 0; |
127 | std::vector<int>& perm = |
128 | col_params->instance.impl_details.subdiv_permutations[sdi]; |
129 | CHECK_EQ(perm.size(), 0); |
130 | int device_count = 0; |
131 | int source_task = GetDeviceTask(col_params->source_rank, dev_per_task); |
132 | for (int ti = 0; ti < col_params->group.num_tasks; ti++) { |
133 | bool participate = false; |
134 | if (source_task == ti) { |
135 | // Source device belongs to this task. |
136 | perm.push_back(col_params->source_rank); |
137 | participate = |
138 | col_params->group.members[col_params->source_rank].device.name() == |
139 | device_name; |
140 | } else { |
141 | // Source does not belong to this task, choose dev 0. |
142 | perm.push_back(device_count); |
143 | participate = col_params->group.members[device_count].device.name() == |
144 | device_name; |
145 | } |
146 | if (participate) col_params->subdiv_rank.push_back(ti); |
147 | device_count += dev_per_task[ti]; |
148 | } |
149 | if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1); |
150 | col_params->instance.impl_details.subdiv_source_rank.push_back(source_task); |
151 | } |
152 | VLOG(2) << collective_util::SubdivPermDebugString(*col_params); |
153 | |
154 | // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set |
155 | // source to dev 0 for that task if it does not contain original source, else |
156 | // set to rank of original source. If a device does not participate in |
157 | // the subdiv, set subdiv_rank to -1; |
158 | int abs_di = 0; |
159 | for (int ti = 0; ti < col_params->group.num_tasks; ti++) { |
160 | const int sdi = ti + (num_tasks > 1 ? 1 : 0); |
161 | std::vector<int>& perm = |
162 | col_params->instance.impl_details.subdiv_permutations[sdi]; |
163 | CHECK_EQ(perm.size(), 0); |
164 | bool participate = false; |
165 | int subdiv_source = 0; |
166 | for (int di = 0; di < dev_per_task[ti]; di++) { |
167 | perm.push_back(abs_di); |
168 | if (col_params->group.members[abs_di].device.name() == device_name) { |
169 | participate = true; |
170 | col_params->subdiv_rank.push_back(di); |
171 | } |
172 | if (abs_di == col_params->source_rank) subdiv_source = di; |
173 | abs_di++; |
174 | } |
175 | if (!participate) col_params->subdiv_rank.push_back(-1); |
176 | col_params->instance.impl_details.subdiv_source_rank.push_back( |
177 | subdiv_source); |
178 | } |
179 | |
180 | for (int sri = 0; sri < num_subdivs; sri++) { |
181 | CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0); |
182 | } |
183 | |
184 | VLOG(2) << collective_util::SubdivPermDebugString(*col_params); |
185 | return OkStatus(); |
186 | } |
187 | |
188 | Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( |
189 | std::shared_ptr<CollectiveContext> col_ctx) { |
190 | CHECK(col_ctx->dev_mgr); |
191 | col_ctx_ = col_ctx; |
192 | col_params_ = col_ctx->col_params.get(); |
193 | return collective_util::InitializeDeviceAndLocality( |
194 | col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, |
195 | &col_ctx->device_locality); |
196 | } |
197 | |
198 | void HierarchicalTreeBroadcaster::Run(StatusCallback done) { |
199 | CHECK(col_ctx_); |
200 | CHECK(col_params_); |
201 | done_ = std::move(done); |
202 | is_source_ = col_params_->is_source; |
203 | RunTree(); |
204 | } |
205 | |
206 | // Binary tree parent/child relations are trivial to calculate, i.e. |
207 | // device at rank r is the parent of 2r+1 and 2r+2. The one exception |
208 | // is if the source is not rank 0. We treat that case as though the |
209 | // source is appended to the front of the rank ordering as well as |
210 | // continuing to occupy its current position. Hence we calculate as |
211 | // though each device's rank is actually r+1, then subtract 1 again to |
212 | // get the descendent ranks. If the source is not rank 0 then its |
213 | // descendants include both {0,1} and the descendents of its current |
214 | // position. Where a non-0-rank source is a descendent of another |
215 | // device, no send to it is necessary. |
216 | |
217 | /* static*/ |
218 | int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp, |
219 | int subdiv) { |
220 | DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); |
221 | int my_rank = cp.subdiv_rank[subdiv]; |
222 | if (-1 == my_rank) return -1; |
223 | |
224 | const auto& impl = cp.instance.impl_details; |
225 | DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); |
226 | int source_rank = impl.subdiv_source_rank[subdiv]; |
227 | if (my_rank == source_rank) return -1; |
228 | if (source_rank == 0) { |
229 | return (my_rank - 1) / 2; |
230 | } else { |
231 | int predecessor_rank = (my_rank / 2) - 1; |
232 | return (predecessor_rank < 0) ? source_rank : predecessor_rank; |
233 | } |
234 | } |
235 | |
236 | /* static */ |
237 | void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp, |
238 | int subdiv, |
239 | std::vector<int>* targets) { |
240 | DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); |
241 | int my_rank = cp.subdiv_rank[subdiv]; |
242 | if (-1 == my_rank) return; |
243 | |
244 | const auto& impl = cp.instance.impl_details; |
245 | DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); |
246 | int source_rank = impl.subdiv_source_rank[subdiv]; |
247 | |
248 | int group_size = 0; |
249 | for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) { |
250 | if (impl.subdiv_permutations[subdiv][i] >= 0) { |
251 | group_size++; |
252 | } |
253 | } |
254 | |
255 | targets->clear(); |
256 | int successor_rank = 0; |
257 | if (source_rank == 0) { |
258 | successor_rank = (2 * my_rank) + 1; |
259 | } else { |
260 | successor_rank = (2 * (my_rank + 1)); |
261 | } |
262 | DCHECK_NE(successor_rank, my_rank); |
263 | if (cp.is_source && source_rank != 0) { |
264 | // The source sends to rank 0,1 in addition to its positional |
265 | // descendants. |
266 | if (group_size > 1) { |
267 | targets->push_back(0); |
268 | } |
269 | if (group_size > 2 && source_rank != 1) { |
270 | targets->push_back(1); |
271 | } |
272 | } |
273 | for (int i = 0; i < 2; ++i) { |
274 | if (successor_rank < group_size && successor_rank != source_rank) { |
275 | targets->push_back(successor_rank); |
276 | } |
277 | ++successor_rank; |
278 | } |
279 | } |
280 | |
281 | // Executes a hierarchical tree broadcast. |
282 | // Each subdiv is a broadcast between a subset of the devices. |
283 | // If there is only one task, there is one subdiv comprising a broadcast between |
284 | // all devices belonging to the task. |
285 | // If there are n tasks, n>1, then there are n+1 subdivs. In the first (global) |
286 | // subdiv, one device from each task participates in a binary tree broadcast. |
287 | // Each task receives a copy of the tensor on one device via this broadcast. |
288 | // Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1 |
289 | // corresponds to broadcast between all devices on task i. Thus, each task |
290 | // participates in at most 2 subdivs. |
291 | void HierarchicalTreeBroadcaster::RunTree() { |
292 | int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size()); |
293 | // TODO(b/78352018): this is easily improved when a node participates in both |
294 | // first and second subdivision. It would first send to its descendents in |
295 | // the first subdiv, then wait until all pending ops are finished before |
296 | // sending to descendents in second subdiv. A better implementation would |
297 | // collapse the two send blocks. |
298 | for (int si = 0; si < num_subdivs; si++) { |
299 | int my_rank = col_params_->subdiv_rank[si]; |
300 | // If rank is -1, this device does not participate in this subdiv. |
301 | if (-1 == my_rank) continue; |
302 | int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si]; |
303 | if (VLOG_IS_ON(1)) { |
304 | string subdiv_buf; |
305 | for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) { |
306 | strings::StrAppend(&subdiv_buf, r, "," ); |
307 | } |
308 | VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name |
309 | << " subdiv=" << si << " perm=" << subdiv_buf |
310 | << " my_rank=" << my_rank << " source_rank=" << source_rank; |
311 | } |
312 | |
313 | mutex mu; // also guards status_ while callbacks are pending |
314 | int pending_count = 0; // TF_GUARDED_BY(mu) |
315 | condition_variable all_done; |
316 | |
317 | if (my_rank >= 0 && my_rank != source_rank) { |
318 | // Begin by receiving the value. |
319 | profiler::TraceMe activity( |
320 | [&] { return strings::StrCat("ReceiveValue:" , si); }, |
321 | profiler::TraceMeLevel::kInfo); |
322 | int recv_from_rank = TreeRecvFrom(*col_params_, si); |
323 | Notification note; |
324 | DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output, |
325 | [this, &mu, ¬e](const Status& s) { |
326 | mutex_lock l(mu); |
327 | status_.Update(s); |
328 | note.Notify(); |
329 | }); |
330 | note.WaitForNotification(); |
331 | } |
332 | |
333 | // Then forward value to all descendent devices. |
334 | { |
335 | profiler::TraceMe activity( |
336 | [&] { return strings::StrCat("ForwardValue:" , si); }, |
337 | profiler::TraceMeLevel::kInfo); |
338 | if (my_rank >= 0 && status_.ok()) { |
339 | std::vector<int> send_to_ranks; |
340 | TreeSendTo(*col_params_, si, &send_to_ranks); |
341 | for (int i = 0; i < send_to_ranks.size(); ++i) { |
342 | int target_rank = send_to_ranks[i]; |
343 | { |
344 | mutex_lock l(mu); |
345 | ++pending_count; |
346 | } |
347 | DispatchSend(si, target_rank, my_rank, |
348 | (is_source_ ? col_ctx_->input : col_ctx_->output), |
349 | [this, &mu, &pending_count, &all_done](const Status& s) { |
350 | mutex_lock l(mu); |
351 | status_.Update(s); |
352 | --pending_count; |
353 | if (pending_count == 0) { |
354 | all_done.notify_all(); |
355 | } |
356 | }); |
357 | } |
358 | } |
359 | |
360 | // For the original source device, we copy input to output if they are |
361 | // different. |
362 | // If there is only 1 subdiv, we do this in that subdiv. If there is more |
363 | // than 1 subdiv, then the original source device will participate in 2 |
364 | // subdivs - the global inter-task broadcast and one local intra-task |
365 | // broadcast. In this case, we perform the copy in the second subdiv for |
366 | // this device. |
367 | if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) { |
368 | VLOG(2) << "copying input to output for device=" |
369 | << col_ctx_->device_name << " subdiv=" << si; |
370 | if (col_ctx_->input != col_ctx_->output && |
371 | (DMAHelper::base(col_ctx_->input) != |
372 | DMAHelper::base(col_ctx_->output))) { |
373 | { |
374 | mutex_lock l(mu); |
375 | ++pending_count; |
376 | } |
377 | DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context(); |
378 | CollectiveRemoteAccessLocal::MemCpyAsync( |
379 | op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device, |
380 | col_ctx_->op_ctx->input_alloc_attr(0), |
381 | col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, |
382 | col_ctx_->output, 0, /*stream_index*/ |
383 | [this, &mu, &pending_count, &all_done](const Status& s) { |
384 | mutex_lock l(mu); |
385 | status_.Update(s); |
386 | --pending_count; |
387 | if (0 == pending_count) { |
388 | all_done.notify_all(); |
389 | } |
390 | }); |
391 | } |
392 | } |
393 | |
394 | // Then wait for all pending actions to complete. |
395 | { |
396 | mutex_lock l(mu); |
397 | if (pending_count > 0) { |
398 | all_done.wait(l); |
399 | } |
400 | } |
401 | } |
402 | } |
403 | VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_; |
404 | done_(status_); |
405 | } |
406 | |
407 | void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank, |
408 | int src_rank, |
409 | const Tensor* src_tensor, |
410 | const StatusCallback& done) { |
411 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
412 | col_params_->name.data(), col_ctx_->step_id, "dynamic" , |
413 | src_tensor->dtype(), |
414 | [src_tensor]() { return src_tensor->shape().DebugString(); }); |
415 | string send_buf_key = |
416 | BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank); |
417 | int dst_idx = |
418 | col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank]; |
419 | VLOG(3) << "DispatchSend " << send_buf_key << " from_device " |
420 | << col_ctx_->device_name << " to_device " |
421 | << col_params_->group.members[dst_idx].device.name() |
422 | << " subdiv=" << subdiv << " dst_rank=" << dst_rank |
423 | << " dst_idx=" << dst_idx; |
424 | col_ctx_->col_exec->remote_access()->PostToPeer( |
425 | col_params_->group.members[dst_idx].device.name(), |
426 | col_params_->group.members[dst_idx].task, send_buf_key, col_ctx_->device, |
427 | col_ctx_->op_ctx->op_device_context(), |
428 | col_ctx_->op_ctx->output_alloc_attr(0), src_tensor, |
429 | col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(), |
430 | done); |
431 | } |
432 | |
433 | void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank, |
434 | int dst_rank, Tensor* dst_tensor, |
435 | const StatusCallback& done) { |
436 | string recv_buf_key = |
437 | BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank); |
438 | int src_idx = |
439 | col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank]; |
440 | VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device " |
441 | << col_params_->group.members[src_idx].device.name() << " to_device " |
442 | << col_ctx_->device_name << " subdiv=" << subdiv |
443 | << " src_rank=" << src_rank << " src_idx=" << src_idx; |
444 | col_ctx_->col_exec->remote_access()->RecvFromPeer( |
445 | col_params_->group.members[src_idx].device.name(), |
446 | col_params_->group.members[src_idx].task, |
447 | col_params_->group.members[src_idx].is_local, recv_buf_key, |
448 | col_ctx_->device, col_ctx_->op_ctx->op_device_context(), |
449 | col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, |
450 | col_ctx_->device_locality, 0 /*stream_index*/, |
451 | col_ctx_->op_ctx->cancellation_manager(), done); |
452 | } |
453 | |
454 | namespace { |
455 | REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster); |
456 | } // namespace |
457 | |
458 | } // namespace tensorflow |
459 | |