1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/broadcast.h"
10
11#include <algorithm>
12#include <cstring>
13
14#include "gloo/common/logging.h"
15#include "gloo/math.h"
16#include "gloo/types.h"
17
18namespace gloo {
19
20void broadcast(BroadcastOptions& opts) {
21 const auto& context = opts.context;
22 transport::UnboundBuffer* in = opts.in.get();
23 transport::UnboundBuffer* out = opts.out.get();
24 const auto slot = Slot::build(kBroadcastSlotPrefix, opts.tag);
25
26 // Sanity checks
27 GLOO_ENFORCE(opts.elementSize > 0);
28 GLOO_ENFORCE(opts.root >= 0 && opts.root < context->size);
29 GLOO_ENFORCE(out);
30 if (context->rank == opts.root) {
31 if (in) {
32 GLOO_ENFORCE_EQ(in->size, out->size);
33 } else {
34 // Broadcast in place
35 in = out;
36 }
37 } else {
38 GLOO_ENFORCE(!in, "Non-root may not specify input");
39
40 // Broadcast in place (for forwarding)
41 in = out;
42 }
43
44 // Map rank to new rank where root process has rank 0.
45 const size_t vsize = context->size;
46 const size_t vrank = (context->rank + vsize - opts.root) % vsize;
47 const size_t dim = log2ceil(vsize);
48
49 // Track number of pending send operations.
50 // Send operations can complete asynchronously because there is dependency
51 // between iterations. This unlike recv operations that must complete
52 // before any send operations can be queued.
53 size_t numSends = 0;
54
55 // Create mask with all 1's where we progressively set bits to 0
56 // starting with the LSB. When the mask applied to the virtual rank
57 // equals 0 we know the process must participate. This results in
58 // exponential participation starting with virtual ranks 0 and 1.
59 size_t mask = (1 << dim) - 1;
60
61 for (size_t i = 0; i < dim; i++) {
62 // Clear bit `i`. In the first iteration, virtual ranks 0 and 1 participate.
63 // In the second iteration 0, 1, 2, and 3 participate, and so on.
64 mask ^= (1 << i);
65 if ((vrank & mask) != 0) {
66 continue;
67 }
68
69 // The virtual rank of the peer in this iteration has opposite bit `i`.
70 auto vpeer = vrank ^ (1 << i);
71 if (vpeer >= vsize) {
72 continue;
73 }
74
75 // Map virtual rank of peer to actual rank of peer.
76 auto peer = (vpeer + opts.root) % vsize;
77 if ((vrank & (1 << i)) == 0) {
78 in->send(peer, slot);
79 numSends++;
80 } else {
81 out->recv(peer, slot);
82 out->waitRecv(opts.timeout);
83 }
84 }
85
86 // Copy local input to output if applicable.
87 if (context->rank == opts.root && in != out) {
88 memcpy(out->ptr, in->ptr, out->size);
89 }
90
91 // Wait on pending sends.
92 for (auto i = 0; i < numSends; i++) {
93 in->waitSend(opts.timeout);
94 }
95}
96
97} // namespace gloo
98