1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
16
17#include <functional>
18#include <memory>
19#include <string>
20#include <utility>
21
22#include "tensorflow/core/common_runtime/collective_rma_local.h"
23#include "tensorflow/core/common_runtime/collective_util.h"
24#include "tensorflow/core/common_runtime/device_mgr.h"
25#include "tensorflow/core/common_runtime/dma_helper.h"
26#include "tensorflow/core/framework/device_base.h"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/lib/core/notification.h"
30#include "tensorflow/core/lib/core/status.h"
31#include "tensorflow/core/lib/strings/str_util.h"
32#include "tensorflow/core/lib/strings/strcat.h"
33#include "tensorflow/core/platform/env.h"
34#include "tensorflow/core/platform/types.h"
35#include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
36#include "tensorflow/core/profiler/lib/traceme.h"
37
38// Set true for greater intelligibility of debug mode log messages.
39#define READABLE_KEYS false
40
41namespace tensorflow {
42
43namespace {
44// Key to be used for BufRendezvous by Broadcaster.
45string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
46 int dst_rank) {
47 if (READABLE_KEYS) {
48 return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
49 "):src(", src_rank, "):dst(", dst_rank, ")");
50 } else {
51 // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash.
52 return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
53 }
54}
55} // namespace
56
57HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster()
58 : col_ctx_(nullptr),
59 col_params_(nullptr),
60 done_(nullptr),
61 is_source_(false) {}
62
63int HierarchicalTreeBroadcaster::GetDeviceTask(
64 int device_rank, const std::vector<int>& dev_per_task) {
65 int num_tasks = static_cast<int>(dev_per_task.size());
66 int task_lo = 0;
67 int task_hi = -1;
68 for (int ti = 0; ti < num_tasks; ti++) {
69 task_hi = task_lo + dev_per_task[ti];
70 if (task_lo <= device_rank && device_rank < task_hi) return ti;
71 task_lo = task_hi;
72 }
73 LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
74 << " devices";
75 return -1;
76}
77
78Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
79 CollectiveParams* col_params) {
80 CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE);
81 CHECK_EQ(col_params->instance.impl_details.collective_name,
82 "HierarchicalTreeBroadcast");
83 const string& device_name =
84 col_params->group.members[col_params->default_rank].device.name();
85 // Start by counting the devices in each task.
86 // Precondition: device_names must be sorted so that all devices in
87 // the same task are adjacent.
88 std::vector<int> dev_per_task;
89 const string* prior_task_name = &col_params->group.members[0].task;
90 int dev_count = 1;
91 for (int di = 1; di < col_params->group.group_size; ++di) {
92 if (col_params->group.members[di].task != *prior_task_name) {
93 dev_per_task.push_back(dev_count);
94 dev_count = 1;
95 prior_task_name = &col_params->group.members[di].task;
96 } else {
97 ++dev_count;
98 }
99 }
100 dev_per_task.push_back(dev_count);
101 CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
102
103 if (VLOG_IS_ON(2)) {
104 string dpt_buf;
105 for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
106 VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device="
107 << device_name << " source_rank=" << col_params->source_rank
108 << " dev_per_task=" << dpt_buf;
109 }
110 int num_tasks = col_params->group.num_tasks;
111 // If there is just 1 task, then execute binary tree broadcast over all
112 // devices. Otherwise, the first subdiv is inter-task broadcast, and then
113 // there are N more subdivs, where N is #task.
114 int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
115 int total_num_devices = 0;
116 for (int num_dev : dev_per_task) total_num_devices += num_dev;
117
118 col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs);
119 col_params->subdiv_rank.reserve(num_subdivs);
120 col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
121
122 // Inter-task subdiv. Pick one device from each task - this is the source
123 // device if it belongs to that task, or device 0 for that task. If a device
124 // does not participate in the subdiv, set subdiv_rank to -1.
125 if (num_tasks > 1) {
126 const int sdi = 0;
127 std::vector<int>& perm =
128 col_params->instance.impl_details.subdiv_permutations[sdi];
129 CHECK_EQ(perm.size(), 0);
130 int device_count = 0;
131 int source_task = GetDeviceTask(col_params->source_rank, dev_per_task);
132 for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
133 bool participate = false;
134 if (source_task == ti) {
135 // Source device belongs to this task.
136 perm.push_back(col_params->source_rank);
137 participate =
138 col_params->group.members[col_params->source_rank].device.name() ==
139 device_name;
140 } else {
141 // Source does not belong to this task, choose dev 0.
142 perm.push_back(device_count);
143 participate = col_params->group.members[device_count].device.name() ==
144 device_name;
145 }
146 if (participate) col_params->subdiv_rank.push_back(ti);
147 device_count += dev_per_task[ti];
148 }
149 if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1);
150 col_params->instance.impl_details.subdiv_source_rank.push_back(source_task);
151 }
152 VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
153
154 // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
155 // source to dev 0 for that task if it does not contain original source, else
156 // set to rank of original source. If a device does not participate in
157 // the subdiv, set subdiv_rank to -1;
158 int abs_di = 0;
159 for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
160 const int sdi = ti + (num_tasks > 1 ? 1 : 0);
161 std::vector<int>& perm =
162 col_params->instance.impl_details.subdiv_permutations[sdi];
163 CHECK_EQ(perm.size(), 0);
164 bool participate = false;
165 int subdiv_source = 0;
166 for (int di = 0; di < dev_per_task[ti]; di++) {
167 perm.push_back(abs_di);
168 if (col_params->group.members[abs_di].device.name() == device_name) {
169 participate = true;
170 col_params->subdiv_rank.push_back(di);
171 }
172 if (abs_di == col_params->source_rank) subdiv_source = di;
173 abs_di++;
174 }
175 if (!participate) col_params->subdiv_rank.push_back(-1);
176 col_params->instance.impl_details.subdiv_source_rank.push_back(
177 subdiv_source);
178 }
179
180 for (int sri = 0; sri < num_subdivs; sri++) {
181 CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0);
182 }
183
184 VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
185 return OkStatus();
186}
187
188Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
189 std::shared_ptr<CollectiveContext> col_ctx) {
190 CHECK(col_ctx->dev_mgr);
191 col_ctx_ = col_ctx;
192 col_params_ = col_ctx->col_params.get();
193 return collective_util::InitializeDeviceAndLocality(
194 col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
195 &col_ctx->device_locality);
196}
197
198void HierarchicalTreeBroadcaster::Run(StatusCallback done) {
199 CHECK(col_ctx_);
200 CHECK(col_params_);
201 done_ = std::move(done);
202 is_source_ = col_params_->is_source;
203 RunTree();
204}
205
206// Binary tree parent/child relations are trivial to calculate, i.e.
207// device at rank r is the parent of 2r+1 and 2r+2. The one exception
208// is if the source is not rank 0. We treat that case as though the
209// source is appended to the front of the rank ordering as well as
210// continuing to occupy its current position. Hence we calculate as
211// though each device's rank is actually r+1, then subtract 1 again to
212// get the descendent ranks. If the source is not rank 0 then its
213// descendants include both {0,1} and the descendents of its current
214// position. Where a non-0-rank source is a descendent of another
215// device, no send to it is necessary.
216
217/* static*/
218int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp,
219 int subdiv) {
220 DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
221 int my_rank = cp.subdiv_rank[subdiv];
222 if (-1 == my_rank) return -1;
223
224 const auto& impl = cp.instance.impl_details;
225 DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
226 int source_rank = impl.subdiv_source_rank[subdiv];
227 if (my_rank == source_rank) return -1;
228 if (source_rank == 0) {
229 return (my_rank - 1) / 2;
230 } else {
231 int predecessor_rank = (my_rank / 2) - 1;
232 return (predecessor_rank < 0) ? source_rank : predecessor_rank;
233 }
234}
235
236/* static */
237void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp,
238 int subdiv,
239 std::vector<int>* targets) {
240 DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
241 int my_rank = cp.subdiv_rank[subdiv];
242 if (-1 == my_rank) return;
243
244 const auto& impl = cp.instance.impl_details;
245 DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
246 int source_rank = impl.subdiv_source_rank[subdiv];
247
248 int group_size = 0;
249 for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
250 if (impl.subdiv_permutations[subdiv][i] >= 0) {
251 group_size++;
252 }
253 }
254
255 targets->clear();
256 int successor_rank = 0;
257 if (source_rank == 0) {
258 successor_rank = (2 * my_rank) + 1;
259 } else {
260 successor_rank = (2 * (my_rank + 1));
261 }
262 DCHECK_NE(successor_rank, my_rank);
263 if (cp.is_source && source_rank != 0) {
264 // The source sends to rank 0,1 in addition to its positional
265 // descendants.
266 if (group_size > 1) {
267 targets->push_back(0);
268 }
269 if (group_size > 2 && source_rank != 1) {
270 targets->push_back(1);
271 }
272 }
273 for (int i = 0; i < 2; ++i) {
274 if (successor_rank < group_size && successor_rank != source_rank) {
275 targets->push_back(successor_rank);
276 }
277 ++successor_rank;
278 }
279}
280
281// Executes a hierarchical tree broadcast.
282// Each subdiv is a broadcast between a subset of the devices.
283// If there is only one task, there is one subdiv comprising a broadcast between
284// all devices belonging to the task.
285// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
286// subdiv, one device from each task participates in a binary tree broadcast.
287// Each task receives a copy of the tensor on one device via this broadcast.
288// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
289// corresponds to broadcast between all devices on task i. Thus, each task
290// participates in at most 2 subdivs.
291void HierarchicalTreeBroadcaster::RunTree() {
292 int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
293 // TODO(b/78352018): this is easily improved when a node participates in both
294 // first and second subdivision. It would first send to its descendents in
295 // the first subdiv, then wait until all pending ops are finished before
296 // sending to descendents in second subdiv. A better implementation would
297 // collapse the two send blocks.
298 for (int si = 0; si < num_subdivs; si++) {
299 int my_rank = col_params_->subdiv_rank[si];
300 // If rank is -1, this device does not participate in this subdiv.
301 if (-1 == my_rank) continue;
302 int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si];
303 if (VLOG_IS_ON(1)) {
304 string subdiv_buf;
305 for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) {
306 strings::StrAppend(&subdiv_buf, r, ",");
307 }
308 VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name
309 << " subdiv=" << si << " perm=" << subdiv_buf
310 << " my_rank=" << my_rank << " source_rank=" << source_rank;
311 }
312
313 mutex mu; // also guards status_ while callbacks are pending
314 int pending_count = 0; // TF_GUARDED_BY(mu)
315 condition_variable all_done;
316
317 if (my_rank >= 0 && my_rank != source_rank) {
318 // Begin by receiving the value.
319 profiler::TraceMe activity(
320 [&] { return strings::StrCat("ReceiveValue:", si); },
321 profiler::TraceMeLevel::kInfo);
322 int recv_from_rank = TreeRecvFrom(*col_params_, si);
323 Notification note;
324 DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output,
325 [this, &mu, &note](const Status& s) {
326 mutex_lock l(mu);
327 status_.Update(s);
328 note.Notify();
329 });
330 note.WaitForNotification();
331 }
332
333 // Then forward value to all descendent devices.
334 {
335 profiler::TraceMe activity(
336 [&] { return strings::StrCat("ForwardValue:", si); },
337 profiler::TraceMeLevel::kInfo);
338 if (my_rank >= 0 && status_.ok()) {
339 std::vector<int> send_to_ranks;
340 TreeSendTo(*col_params_, si, &send_to_ranks);
341 for (int i = 0; i < send_to_ranks.size(); ++i) {
342 int target_rank = send_to_ranks[i];
343 {
344 mutex_lock l(mu);
345 ++pending_count;
346 }
347 DispatchSend(si, target_rank, my_rank,
348 (is_source_ ? col_ctx_->input : col_ctx_->output),
349 [this, &mu, &pending_count, &all_done](const Status& s) {
350 mutex_lock l(mu);
351 status_.Update(s);
352 --pending_count;
353 if (pending_count == 0) {
354 all_done.notify_all();
355 }
356 });
357 }
358 }
359
360 // For the original source device, we copy input to output if they are
361 // different.
362 // If there is only 1 subdiv, we do this in that subdiv. If there is more
363 // than 1 subdiv, then the original source device will participate in 2
364 // subdivs - the global inter-task broadcast and one local intra-task
365 // broadcast. In this case, we perform the copy in the second subdiv for
366 // this device.
367 if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
368 VLOG(2) << "copying input to output for device="
369 << col_ctx_->device_name << " subdiv=" << si;
370 if (col_ctx_->input != col_ctx_->output &&
371 (DMAHelper::base(col_ctx_->input) !=
372 DMAHelper::base(col_ctx_->output))) {
373 {
374 mutex_lock l(mu);
375 ++pending_count;
376 }
377 DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
378 CollectiveRemoteAccessLocal::MemCpyAsync(
379 op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device,
380 col_ctx_->op_ctx->input_alloc_attr(0),
381 col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
382 col_ctx_->output, 0, /*stream_index*/
383 [this, &mu, &pending_count, &all_done](const Status& s) {
384 mutex_lock l(mu);
385 status_.Update(s);
386 --pending_count;
387 if (0 == pending_count) {
388 all_done.notify_all();
389 }
390 });
391 }
392 }
393
394 // Then wait for all pending actions to complete.
395 {
396 mutex_lock l(mu);
397 if (pending_count > 0) {
398 all_done.wait(l);
399 }
400 }
401 }
402 }
403 VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_;
404 done_(status_);
405}
406
407void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
408 int src_rank,
409 const Tensor* src_tensor,
410 const StatusCallback& done) {
411 profiler::ScopedMemoryDebugAnnotation op_annotation(
412 col_params_->name.data(), col_ctx_->step_id, "dynamic",
413 src_tensor->dtype(),
414 [src_tensor]() { return src_tensor->shape().DebugString(); });
415 string send_buf_key =
416 BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
417 int dst_idx =
418 col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
419 VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
420 << col_ctx_->device_name << " to_device "
421 << col_params_->group.members[dst_idx].device.name()
422 << " subdiv=" << subdiv << " dst_rank=" << dst_rank
423 << " dst_idx=" << dst_idx;
424 col_ctx_->col_exec->remote_access()->PostToPeer(
425 col_params_->group.members[dst_idx].device.name(),
426 col_params_->group.members[dst_idx].task, send_buf_key, col_ctx_->device,
427 col_ctx_->op_ctx->op_device_context(),
428 col_ctx_->op_ctx->output_alloc_attr(0), src_tensor,
429 col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
430 done);
431}
432
433void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
434 int dst_rank, Tensor* dst_tensor,
435 const StatusCallback& done) {
436 string recv_buf_key =
437 BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
438 int src_idx =
439 col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
440 VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
441 << col_params_->group.members[src_idx].device.name() << " to_device "
442 << col_ctx_->device_name << " subdiv=" << subdiv
443 << " src_rank=" << src_rank << " src_idx=" << src_idx;
444 col_ctx_->col_exec->remote_access()->RecvFromPeer(
445 col_params_->group.members[src_idx].device.name(),
446 col_params_->group.members[src_idx].task,
447 col_params_->group.members[src_idx].is_local, recv_buf_key,
448 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
449 col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
450 col_ctx_->device_locality, 0 /*stream_index*/,
451 col_ctx_->op_ctx->cancellation_manager(), done);
452}
453
454namespace {
455REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster);
456} // namespace
457
458} // namespace tensorflow
459