1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/allgather.h"
10
11#include <array>
12#include <cstring>
13
14#include "gloo/common/logging.h"
15#include "gloo/types.h"
16
17namespace gloo {
18
19void allgather(AllgatherOptions& opts) {
20 const auto& context = opts.context;
21 transport::UnboundBuffer* in = opts.in.get();
22 transport::UnboundBuffer* out = opts.out.get();
23 const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag);
24
25 // Sanity checks
26 GLOO_ENFORCE(opts.elementSize > 0);
27 const auto recvRank = (context->size + context->rank - 1) % context->size;
28 GLOO_ENFORCE(
29 recvRank == context->rank || context->getPair(recvRank),
30 "missing connection between rank " + std::to_string(context->rank) +
31 " (this process) and rank " + std::to_string(recvRank));
32 const auto sendRank = (context->size + context->rank + 1) % context->size;
33 GLOO_ENFORCE(
34 sendRank == context->rank || context->getPair(sendRank),
35 "missing connection between rank " + std::to_string(context->rank) +
36 " (this process) and rank " + std::to_string(sendRank));
37
38 if (in != nullptr) {
39 GLOO_ENFORCE_EQ(out->size, in->size * context->size);
40 } else {
41 GLOO_ENFORCE_EQ(out->size % context->size, 0);
42 }
43
44 const size_t inBytes = out->size / context->size;
45 const size_t outBytes = out->size;
46
47 // If the input buffer is specified, this is NOT an in place operation,
48 // and the output buffer needs to be primed with the input.
49 if (in != nullptr) {
50 memcpy(
51 static_cast<uint8_t*>(out->ptr) + context->rank * in->size,
52 static_cast<uint8_t*>(in->ptr),
53 in->size);
54 }
55
56 // Short circuit if there is only a single process.
57 if (context->size == 1) {
58 return;
59 }
60
61 // The chunk size may not be divisible by 2; use dynamic lookup.
62 std::array<size_t, 2> chunkSize;
63 chunkSize[0] = inBytes / 2;
64 chunkSize[1] = inBytes - chunkSize[0];
65 std::array<size_t, 2> chunkOffset;
66 chunkOffset[0] = 0;
67 chunkOffset[1] = chunkSize[0];
68
69 for (auto i = 0; i < (context->size - 1) * 2; i++) {
70 const size_t sendSegment = context->size + context->rank - (i / 2);
71 const size_t recvSegment = sendSegment - 1;
72 size_t sendOffset =
73 ((sendSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes;
74 size_t recvOffset =
75 ((recvSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes;
76 size_t size = chunkSize[i & 0x1];
77 if (i < 2) {
78 out->send(sendRank, slot, sendOffset, size);
79 out->recv(recvRank, slot, recvOffset, size);
80 continue;
81 }
82
83 // Wait for pending operations to complete to synchronize with the
84 // previous iteration. Because we kick off two operations before
85 // getting here we always wait for the next-to-last operation.
86 out->waitSend(opts.timeout);
87 out->waitRecv(opts.timeout);
88 out->send(sendRank, slot, sendOffset, size);
89 out->recv(recvRank, slot, recvOffset, size);
90 }
91
92 // Wait for completes
93 for (auto i = 0; i < 2; i++) {
94 out->waitSend(opts.timeout);
95 out->waitRecv(opts.timeout);
96 }
97}
98
99} // namespace gloo
100