1 | /** |
2 | * Copyright (c) 2019-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/allgatherv.h" |
10 | |
11 | #include <cstring> |
12 | #include <numeric> |
13 | |
14 | #include "gloo/common/logging.h" |
15 | #include "gloo/types.h" |
16 | |
17 | namespace gloo { |
18 | |
19 | void AllgathervOptions::setElementSize(size_t elementSize) { |
20 | if (this->elementSize == 0) { |
21 | this->elementSize = elementSize; |
22 | } else { |
23 | GLOO_ENFORCE_EQ( |
24 | elementSize, |
25 | this->elementSize, |
26 | "Element size does not match existing value. " , |
27 | "Please double check that the input and output types match." ); |
28 | } |
29 | } |
30 | |
31 | void AllgathervOptions::setInput( |
32 | std::unique_ptr<transport::UnboundBuffer> buf, |
33 | size_t elementSize) { |
34 | setElementSize(elementSize); |
35 | this->in = std::move(buf); |
36 | } |
37 | |
38 | void AllgathervOptions::setInput( |
39 | void* ptr, |
40 | size_t elements, |
41 | size_t elementSize) { |
42 | setElementSize(elementSize); |
43 | this->in = context->createUnboundBuffer(ptr, elements * elementSize); |
44 | } |
45 | |
46 | void AllgathervOptions::setOutput( |
47 | std::unique_ptr<transport::UnboundBuffer> buf, |
48 | std::vector<size_t> elements, |
49 | size_t elementSize) { |
50 | const auto totalElements = |
51 | std::accumulate(elements.begin(), elements.end(), size_t(0)); |
52 | setElementSize(elementSize); |
53 | GLOO_ENFORCE_EQ(elements.size(), context->size); |
54 | this->elements = std::move(elements); |
55 | GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size); |
56 | this->out = std::move(buf); |
57 | } |
58 | |
59 | void AllgathervOptions::setOutput( |
60 | void* ptr, |
61 | std::vector<size_t> elements, |
62 | size_t elementSize) { |
63 | const auto totalElements = |
64 | std::accumulate(elements.begin(), elements.end(), size_t(0)); |
65 | setElementSize(elementSize); |
66 | GLOO_ENFORCE_EQ(elements.size(), context->size); |
67 | this->elements = std::move(elements); |
68 | this->out = context->createUnboundBuffer(ptr, totalElements * elementSize); |
69 | } |
70 | |
71 | void allgatherv(AllgathervOptions& opts) { |
72 | const auto& context = opts.context; |
73 | transport::UnboundBuffer* in = opts.in.get(); |
74 | transport::UnboundBuffer* out = opts.out.get(); |
75 | const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag); |
76 | |
77 | // Sanity checks |
78 | GLOO_ENFORCE(opts.elementSize > 0); |
79 | const auto recvRank = (context->size + context->rank - 1) % context->size; |
80 | GLOO_ENFORCE( |
81 | recvRank == context->rank || context->getPair(recvRank), |
82 | "missing connection between rank " + std::to_string(context->rank) + |
83 | " (this process) and rank " + std::to_string(recvRank)); |
84 | const auto sendRank = (context->size + context->rank + 1) % context->size; |
85 | GLOO_ENFORCE( |
86 | sendRank == context->rank || context->getPair(sendRank), |
87 | "missing connection between rank " + std::to_string(context->rank) + |
88 | " (this process) and rank " + std::to_string(sendRank)); |
89 | |
90 | // Compute byte counts and offsets into output buffer. |
91 | std::vector<size_t> byteCounts; |
92 | std::vector<size_t> byteOffsets; |
93 | byteCounts.reserve(context->size); |
94 | byteOffsets.reserve(context->size); |
95 | size_t offset = 0; |
96 | for (const auto& elements : opts.elements) { |
97 | const auto bytes = elements * opts.elementSize; |
98 | byteCounts.push_back(bytes); |
99 | byteOffsets.push_back(offset); |
100 | offset += bytes; |
101 | } |
102 | |
103 | // If the input buffer is specified, the output buffer needs to be primed. |
104 | if (in != nullptr) { |
105 | GLOO_ENFORCE_EQ(byteCounts[context->rank], in->size); |
106 | if (byteCounts[context->rank] > 0) { |
107 | memcpy( |
108 | static_cast<uint8_t*>(out->ptr) + byteOffsets[context->rank], |
109 | static_cast<uint8_t*>(in->ptr), |
110 | in->size); |
111 | } |
112 | } |
113 | |
114 | // Short circuit if there is only a single process. |
115 | if (context->size == 1) { |
116 | return; |
117 | } |
118 | |
119 | const auto baseIndex = context->size + context->rank; |
120 | for (auto i = 0; i < context->size - 1; i++) { |
121 | const size_t sendIndex = (baseIndex - i) % context->size; |
122 | const size_t recvIndex = (baseIndex - i - 1) % context->size; |
123 | |
124 | if (i == 0) { |
125 | out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]); |
126 | out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]); |
127 | continue; |
128 | } |
129 | |
130 | // Wait for previous operations to complete before kicking off new ones. |
131 | out->waitSend(opts.timeout); |
132 | out->waitRecv(opts.timeout); |
133 | out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]); |
134 | out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]); |
135 | } |
136 | |
137 | // Wait for final operations to complete. |
138 | out->waitSend(opts.timeout); |
139 | out->waitRecv(opts.timeout); |
140 | } |
141 | |
142 | } // namespace gloo |
143 | |