1/**
2 * Copyright (c) 2019-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/allgatherv.h"
10
11#include <cstring>
12#include <numeric>
13
14#include "gloo/common/logging.h"
15#include "gloo/types.h"
16
17namespace gloo {
18
19void AllgathervOptions::setElementSize(size_t elementSize) {
20 if (this->elementSize == 0) {
21 this->elementSize = elementSize;
22 } else {
23 GLOO_ENFORCE_EQ(
24 elementSize,
25 this->elementSize,
26 "Element size does not match existing value. ",
27 "Please double check that the input and output types match.");
28 }
29}
30
31void AllgathervOptions::setInput(
32 std::unique_ptr<transport::UnboundBuffer> buf,
33 size_t elementSize) {
34 setElementSize(elementSize);
35 this->in = std::move(buf);
36}
37
38void AllgathervOptions::setInput(
39 void* ptr,
40 size_t elements,
41 size_t elementSize) {
42 setElementSize(elementSize);
43 this->in = context->createUnboundBuffer(ptr, elements * elementSize);
44}
45
46void AllgathervOptions::setOutput(
47 std::unique_ptr<transport::UnboundBuffer> buf,
48 std::vector<size_t> elements,
49 size_t elementSize) {
50 const auto totalElements =
51 std::accumulate(elements.begin(), elements.end(), size_t(0));
52 setElementSize(elementSize);
53 GLOO_ENFORCE_EQ(elements.size(), context->size);
54 this->elements = std::move(elements);
55 GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size);
56 this->out = std::move(buf);
57}
58
59void AllgathervOptions::setOutput(
60 void* ptr,
61 std::vector<size_t> elements,
62 size_t elementSize) {
63 const auto totalElements =
64 std::accumulate(elements.begin(), elements.end(), size_t(0));
65 setElementSize(elementSize);
66 GLOO_ENFORCE_EQ(elements.size(), context->size);
67 this->elements = std::move(elements);
68 this->out = context->createUnboundBuffer(ptr, totalElements * elementSize);
69}
70
71void allgatherv(AllgathervOptions& opts) {
72 const auto& context = opts.context;
73 transport::UnboundBuffer* in = opts.in.get();
74 transport::UnboundBuffer* out = opts.out.get();
75 const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag);
76
77 // Sanity checks
78 GLOO_ENFORCE(opts.elementSize > 0);
79 const auto recvRank = (context->size + context->rank - 1) % context->size;
80 GLOO_ENFORCE(
81 recvRank == context->rank || context->getPair(recvRank),
82 "missing connection between rank " + std::to_string(context->rank) +
83 " (this process) and rank " + std::to_string(recvRank));
84 const auto sendRank = (context->size + context->rank + 1) % context->size;
85 GLOO_ENFORCE(
86 sendRank == context->rank || context->getPair(sendRank),
87 "missing connection between rank " + std::to_string(context->rank) +
88 " (this process) and rank " + std::to_string(sendRank));
89
90 // Compute byte counts and offsets into output buffer.
91 std::vector<size_t> byteCounts;
92 std::vector<size_t> byteOffsets;
93 byteCounts.reserve(context->size);
94 byteOffsets.reserve(context->size);
95 size_t offset = 0;
96 for (const auto& elements : opts.elements) {
97 const auto bytes = elements * opts.elementSize;
98 byteCounts.push_back(bytes);
99 byteOffsets.push_back(offset);
100 offset += bytes;
101 }
102
103 // If the input buffer is specified, the output buffer needs to be primed.
104 if (in != nullptr) {
105 GLOO_ENFORCE_EQ(byteCounts[context->rank], in->size);
106 if (byteCounts[context->rank] > 0) {
107 memcpy(
108 static_cast<uint8_t*>(out->ptr) + byteOffsets[context->rank],
109 static_cast<uint8_t*>(in->ptr),
110 in->size);
111 }
112 }
113
114 // Short circuit if there is only a single process.
115 if (context->size == 1) {
116 return;
117 }
118
119 const auto baseIndex = context->size + context->rank;
120 for (auto i = 0; i < context->size - 1; i++) {
121 const size_t sendIndex = (baseIndex - i) % context->size;
122 const size_t recvIndex = (baseIndex - i - 1) % context->size;
123
124 if (i == 0) {
125 out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]);
126 out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]);
127 continue;
128 }
129
130 // Wait for previous operations to complete before kicking off new ones.
131 out->waitSend(opts.timeout);
132 out->waitRecv(opts.timeout);
133 out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]);
134 out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]);
135 }
136
137 // Wait for final operations to complete.
138 out->waitSend(opts.timeout);
139 out->waitRecv(opts.timeout);
140}
141
142} // namespace gloo
143