1 | /** |
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/allgather.h" |
10 | |
11 | #include <array> |
12 | #include <cstring> |
13 | |
14 | #include "gloo/common/logging.h" |
15 | #include "gloo/types.h" |
16 | |
17 | namespace gloo { |
18 | |
19 | void allgather(AllgatherOptions& opts) { |
20 | const auto& context = opts.context; |
21 | transport::UnboundBuffer* in = opts.in.get(); |
22 | transport::UnboundBuffer* out = opts.out.get(); |
23 | const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag); |
24 | |
25 | // Sanity checks |
26 | GLOO_ENFORCE(opts.elementSize > 0); |
27 | const auto recvRank = (context->size + context->rank - 1) % context->size; |
28 | GLOO_ENFORCE( |
29 | recvRank == context->rank || context->getPair(recvRank), |
30 | "missing connection between rank " + std::to_string(context->rank) + |
31 | " (this process) and rank " + std::to_string(recvRank)); |
32 | const auto sendRank = (context->size + context->rank + 1) % context->size; |
33 | GLOO_ENFORCE( |
34 | sendRank == context->rank || context->getPair(sendRank), |
35 | "missing connection between rank " + std::to_string(context->rank) + |
36 | " (this process) and rank " + std::to_string(sendRank)); |
37 | |
38 | if (in != nullptr) { |
39 | GLOO_ENFORCE_EQ(out->size, in->size * context->size); |
40 | } else { |
41 | GLOO_ENFORCE_EQ(out->size % context->size, 0); |
42 | } |
43 | |
44 | const size_t inBytes = out->size / context->size; |
45 | const size_t outBytes = out->size; |
46 | |
47 | // If the input buffer is specified, this is NOT an in place operation, |
48 | // and the output buffer needs to be primed with the input. |
49 | if (in != nullptr) { |
50 | memcpy( |
51 | static_cast<uint8_t*>(out->ptr) + context->rank * in->size, |
52 | static_cast<uint8_t*>(in->ptr), |
53 | in->size); |
54 | } |
55 | |
56 | // Short circuit if there is only a single process. |
57 | if (context->size == 1) { |
58 | return; |
59 | } |
60 | |
61 | // The chunk size may not be divisible by 2; use dynamic lookup. |
62 | std::array<size_t, 2> chunkSize; |
63 | chunkSize[0] = inBytes / 2; |
64 | chunkSize[1] = inBytes - chunkSize[0]; |
65 | std::array<size_t, 2> chunkOffset; |
66 | chunkOffset[0] = 0; |
67 | chunkOffset[1] = chunkSize[0]; |
68 | |
69 | for (auto i = 0; i < (context->size - 1) * 2; i++) { |
70 | const size_t sendSegment = context->size + context->rank - (i / 2); |
71 | const size_t recvSegment = sendSegment - 1; |
72 | size_t sendOffset = |
73 | ((sendSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes; |
74 | size_t recvOffset = |
75 | ((recvSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes; |
76 | size_t size = chunkSize[i & 0x1]; |
77 | if (i < 2) { |
78 | out->send(sendRank, slot, sendOffset, size); |
79 | out->recv(recvRank, slot, recvOffset, size); |
80 | continue; |
81 | } |
82 | |
83 | // Wait for pending operations to complete to synchronize with the |
84 | // previous iteration. Because we kick off two operations before |
85 | // getting here we always wait for the next-to-last operation. |
86 | out->waitSend(opts.timeout); |
87 | out->waitRecv(opts.timeout); |
88 | out->send(sendRank, slot, sendOffset, size); |
89 | out->recv(recvRank, slot, recvOffset, size); |
90 | } |
91 | |
92 | // Wait for completes |
93 | for (auto i = 0; i < 2; i++) { |
94 | out->waitSend(opts.timeout); |
95 | out->waitRecv(opts.timeout); |
96 | } |
97 | } |
98 | |
99 | } // namespace gloo |
100 | |