1 | /** |
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/broadcast.h" |
10 | |
11 | #include <algorithm> |
12 | #include <cstring> |
13 | |
14 | #include "gloo/common/logging.h" |
15 | #include "gloo/math.h" |
16 | #include "gloo/types.h" |
17 | |
18 | namespace gloo { |
19 | |
20 | void broadcast(BroadcastOptions& opts) { |
21 | const auto& context = opts.context; |
22 | transport::UnboundBuffer* in = opts.in.get(); |
23 | transport::UnboundBuffer* out = opts.out.get(); |
24 | const auto slot = Slot::build(kBroadcastSlotPrefix, opts.tag); |
25 | |
26 | // Sanity checks |
27 | GLOO_ENFORCE(opts.elementSize > 0); |
28 | GLOO_ENFORCE(opts.root >= 0 && opts.root < context->size); |
29 | GLOO_ENFORCE(out); |
30 | if (context->rank == opts.root) { |
31 | if (in) { |
32 | GLOO_ENFORCE_EQ(in->size, out->size); |
33 | } else { |
34 | // Broadcast in place |
35 | in = out; |
36 | } |
37 | } else { |
38 | GLOO_ENFORCE(!in, "Non-root may not specify input" ); |
39 | |
40 | // Broadcast in place (for forwarding) |
41 | in = out; |
42 | } |
43 | |
44 | // Map rank to new rank where root process has rank 0. |
45 | const size_t vsize = context->size; |
46 | const size_t vrank = (context->rank + vsize - opts.root) % vsize; |
47 | const size_t dim = log2ceil(vsize); |
48 | |
49 | // Track number of pending send operations. |
50 | // Send operations can complete asynchronously because there is dependency |
51 | // between iterations. This unlike recv operations that must complete |
52 | // before any send operations can be queued. |
53 | size_t numSends = 0; |
54 | |
55 | // Create mask with all 1's where we progressively set bits to 0 |
56 | // starting with the LSB. When the mask applied to the virtual rank |
57 | // equals 0 we know the process must participate. This results in |
58 | // exponential participation starting with virtual ranks 0 and 1. |
59 | size_t mask = (1 << dim) - 1; |
60 | |
61 | for (size_t i = 0; i < dim; i++) { |
62 | // Clear bit `i`. In the first iteration, virtual ranks 0 and 1 participate. |
63 | // In the second iteration 0, 1, 2, and 3 participate, and so on. |
64 | mask ^= (1 << i); |
65 | if ((vrank & mask) != 0) { |
66 | continue; |
67 | } |
68 | |
69 | // The virtual rank of the peer in this iteration has opposite bit `i`. |
70 | auto vpeer = vrank ^ (1 << i); |
71 | if (vpeer >= vsize) { |
72 | continue; |
73 | } |
74 | |
75 | // Map virtual rank of peer to actual rank of peer. |
76 | auto peer = (vpeer + opts.root) % vsize; |
77 | if ((vrank & (1 << i)) == 0) { |
78 | in->send(peer, slot); |
79 | numSends++; |
80 | } else { |
81 | out->recv(peer, slot); |
82 | out->waitRecv(opts.timeout); |
83 | } |
84 | } |
85 | |
86 | // Copy local input to output if applicable. |
87 | if (context->rank == opts.root && in != out) { |
88 | memcpy(out->ptr, in->ptr, out->size); |
89 | } |
90 | |
91 | // Wait on pending sends. |
92 | for (auto i = 0; i < numSends; i++) { |
93 | in->waitSend(opts.timeout); |
94 | } |
95 | } |
96 | |
97 | } // namespace gloo |
98 | |