1 | /** |
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/reduce.h" |
10 | |
11 | #include <array> |
12 | #include <algorithm> |
13 | #include <cstring> |
14 | |
15 | #include "gloo/common/logging.h" |
16 | #include "gloo/math.h" |
17 | #include "gloo/types.h" |
18 | |
19 | namespace gloo { |
20 | |
21 | void reduce(ReduceOptions& opts) { |
22 | if (opts.elements == 0) { |
23 | return; |
24 | } |
25 | const auto& context = opts.context; |
26 | transport::UnboundBuffer* in = opts.in.get(); |
27 | transport::UnboundBuffer* out = opts.out.get(); |
28 | const auto slot = Slot::build(kReduceSlotPrefix, opts.tag); |
29 | |
30 | // Sanity checks |
31 | GLOO_ENFORCE(opts.elementSize > 0); |
32 | GLOO_ENFORCE(opts.root >= 0 && opts.root < context->size); |
33 | GLOO_ENFORCE(opts.reduce != nullptr); |
34 | const auto recvRank = (context->size + context->rank + 1) % context->size; |
35 | GLOO_ENFORCE( |
36 | recvRank == context->rank || context->getPair(recvRank), |
37 | "missing connection between rank " + std::to_string(context->rank) + |
38 | " (this process) and rank " + std::to_string(recvRank)); |
39 | const auto sendRank = (context->size + context->rank - 1) % context->size; |
40 | GLOO_ENFORCE( |
41 | sendRank == context->rank || context->getPair(sendRank), |
42 | "missing connection between rank " + std::to_string(context->rank) + |
43 | " (this process) and rank " + std::to_string(sendRank)); |
44 | |
45 | // If input buffer is not specified, the output is also the input |
46 | if (in == nullptr) { |
47 | in = out; |
48 | } |
49 | |
50 | GLOO_ENFORCE_EQ(in->size, opts.elements * opts.elementSize); |
51 | GLOO_ENFORCE_EQ(out->size, opts.elements * opts.elementSize); |
52 | |
53 | // Short circuit if there is only a single process. |
54 | if (context->size == 1) { |
55 | if (in != out) { |
56 | memcpy(out->ptr, in->ptr, opts.elements * opts.elementSize); |
57 | } |
58 | return; |
59 | } |
60 | |
61 | // The ring algorithm works as follows. |
62 | // |
63 | // The given input is split into a number of chunks equal to the |
64 | // number of processes. Once the algorithm has finished, every |
65 | // process hosts one chunk of reduced output, in sequential order |
66 | // (rank 0 has chunk 0, rank 1 has chunk 1, etc.). As the input may |
67 | // not be divisible by the number of processes, the chunk on the |
68 | // final ranks may have partial output or may be empty. |
69 | // |
70 | // As a chunk is passed along the ring and contains the reduction of |
71 | // successively more ranks, we have to alternate between performing |
72 | // I/O for that chunk and computing the reduction between the |
73 | // received chunk and the local chunk. To avoid this alternating |
74 | // pattern, we split up a chunk into multiple segments (>= 2), and |
75 | // ensure we have one segment in flight while computing a reduction |
76 | // on the other. The segment size has an upper bound to minimize |
77 | // memory usage and avoid poor cache behavior. This means we may |
78 | // have many segments per chunk when dealing with very large inputs. |
79 | // |
80 | // The nomenclature here is reflected in the variable naming below |
81 | // (one chunk per rank and many segments per chunk). |
82 | // |
83 | const size_t totalBytes = opts.elements * opts.elementSize; |
84 | |
85 | // Ensure that maximum segment size is a multiple of the element size. |
86 | // Otherwise, the segment size can exceed the maximum segment size after |
87 | // rounding it up to the nearest multiple of the element size. |
88 | // For example, if maxSegmentSize = 10, and elementSize = 4, |
89 | // then after rounding up: segmentSize = 12; |
90 | const size_t maxSegmentSize = |
91 | opts.elementSize * (opts.maxSegmentSize / opts.elementSize); |
92 | |
93 | // The number of bytes per segment must be a multiple of the bytes |
94 | // per element for the reduction to work; round up if necessary. |
95 | const size_t segmentBytes = roundUp( |
96 | std::min( |
97 | // Rounded division to have >= 2 segments per chunk. |
98 | (totalBytes + (context->size * 2 - 1)) / (context->size * 2), |
99 | // Configurable segment size limit |
100 | maxSegmentSize), |
101 | opts.elementSize); |
102 | |
103 | // Compute how many segments make up the input buffer. |
104 | // |
105 | // Round up to the nearest multiple of the context size such that |
106 | // there is an equal number of segments per process and execution is |
107 | // symmetric across processes. |
108 | // |
109 | // The minimum is twice the context size, because the algorithm |
110 | // below overlaps sending/receiving a segment with computing the |
111 | // reduction of the another segment. |
112 | // |
113 | const size_t numSegments = roundUp( |
114 | std::max( |
115 | (totalBytes + (segmentBytes - 1)) / segmentBytes, |
116 | (size_t)context->size * 2), |
117 | (size_t)context->size); |
118 | GLOO_ENFORCE_EQ(numSegments % context->size, 0); |
119 | GLOO_ENFORCE_GE(numSegments, context->size * 2); |
120 | const size_t numSegmentsPerRank = numSegments / context->size; |
121 | const size_t chunkBytes = numSegmentsPerRank * segmentBytes; |
122 | |
123 | // Allocate scratch space to hold two chunks |
124 | std::unique_ptr<uint8_t[]> tmpAllocation(new uint8_t[segmentBytes * 2]); |
125 | std::unique_ptr<transport::UnboundBuffer> tmpBuffer = |
126 | context->createUnboundBuffer(tmpAllocation.get(), segmentBytes * 2); |
127 | transport::UnboundBuffer* tmp = tmpBuffer.get(); |
128 | |
129 | // Use dynamic lookup for chunk offset in the temporary buffer. |
130 | // With two operations in flight we need two offsets. |
131 | // They can be indexed using the loop counter. |
132 | std::array<size_t, 2> segmentOffset; |
133 | segmentOffset[0] = 0; |
134 | segmentOffset[1] = segmentBytes; |
135 | |
136 | // Function computes the offsets and lengths of the chunks to be |
137 | // sent and received for a given chunk iteration. |
138 | auto computeReduceScatterOffsets = [&](size_t i) { |
139 | struct { |
140 | size_t sendOffset; |
141 | size_t recvOffset; |
142 | ssize_t sendLength; |
143 | ssize_t recvLength; |
144 | } result; |
145 | |
146 | // Compute segment index to send from (to rank - 1) and segment |
147 | // index to receive into (from rank + 1). Multiply by the number |
148 | // of bytes in a chunk to get to an offset. The offset is allowed |
149 | // to be out of range (>= totalBytes) and this is taken into |
150 | // account when computing the associated length. |
151 | result.sendOffset = |
152 | ((((context->rank + 1) * numSegmentsPerRank) + i) * segmentBytes) % |
153 | (numSegments * segmentBytes); |
154 | result.recvOffset = |
155 | ((((context->rank + 2) * numSegmentsPerRank) + i) * segmentBytes) % |
156 | (numSegments * segmentBytes); |
157 | |
158 | // If the segment is entirely in range, the following statement is |
159 | // equal to segmentBytes. If it isn't, it will be less, or even |
160 | // negative. This is why the ssize_t typecasts are needed. |
161 | result.sendLength = std::min( |
162 | (ssize_t)segmentBytes, |
163 | (ssize_t)totalBytes - (ssize_t)result.sendOffset); |
164 | result.recvLength = std::min( |
165 | (ssize_t)segmentBytes, |
166 | (ssize_t)totalBytes - (ssize_t)result.recvOffset); |
167 | |
168 | return result; |
169 | }; |
170 | |
171 | for (auto i = 0; i < numSegments; i++) { |
172 | if (i >= 2) { |
173 | // Compute send and receive offsets and lengths two iterations |
174 | // ago. Needed so we know when to wait for an operation and when |
175 | // to ignore (when the offset was out of bounds), and know where |
176 | // to reduce the contents of the temporary buffer. |
177 | auto prev = computeReduceScatterOffsets(i - 2); |
178 | if (prev.recvLength > 0) { |
179 | tmp->waitRecv(opts.timeout); |
180 | opts.reduce( |
181 | static_cast<uint8_t*>(out->ptr) + prev.recvOffset, |
182 | static_cast<const uint8_t*>(in->ptr) + prev.recvOffset, |
183 | static_cast<const uint8_t*>(tmp->ptr) + segmentOffset[i & 0x1], |
184 | prev.recvLength / opts.elementSize); |
185 | } |
186 | if (prev.sendLength > 0) { |
187 | if ((i - 2) < numSegmentsPerRank) { |
188 | in->waitSend(opts.timeout); |
189 | } else { |
190 | out->waitSend(opts.timeout); |
191 | } |
192 | } |
193 | } |
194 | |
195 | // Issue new send and receive operation in all but the final two |
196 | // iterations. At that point we have already sent all data we |
197 | // needed to and only have to wait for the final segments to be |
198 | // reduced into the output. |
199 | if (i < (numSegments - 2)) { |
200 | // Compute send and receive offsets and lengths for this iteration. |
201 | auto cur = computeReduceScatterOffsets(i); |
202 | if (cur.recvLength > 0) { |
203 | tmp->recv(recvRank, slot, segmentOffset[i & 0x1], cur.recvLength); |
204 | } |
205 | if (cur.sendLength > 0) { |
206 | if (i < numSegmentsPerRank) { |
207 | in->send(sendRank, slot, cur.sendOffset, cur.sendLength); |
208 | } else { |
209 | out->send(sendRank, slot, cur.sendOffset, cur.sendLength); |
210 | } |
211 | } |
212 | } |
213 | } |
214 | |
215 | // Gather to root rank. |
216 | // |
217 | // Beware: totalBytes <= (numSegments * segmentBytes), which is |
218 | // incompatible with the generic gather algorithm where the |
219 | // contribution is identical across processes. |
220 | // |
221 | if (context->rank == opts.root) { |
222 | size_t numRecv = 0; |
223 | for (size_t rank = 0; rank < context->size; rank++) { |
224 | if (rank == context->rank) { |
225 | continue; |
226 | } |
227 | size_t recvOffset = rank * numSegmentsPerRank * segmentBytes; |
228 | ssize_t recvLength = std::min( |
229 | (ssize_t)chunkBytes, (ssize_t)totalBytes - (ssize_t)recvOffset); |
230 | if (recvLength > 0) { |
231 | out->recv(rank, slot, recvOffset, recvLength); |
232 | numRecv++; |
233 | } |
234 | } |
235 | for (size_t i = 0; i < numRecv; i++) { |
236 | out->waitRecv(opts.timeout); |
237 | } |
238 | } else { |
239 | size_t sendOffset = context->rank * numSegmentsPerRank * segmentBytes; |
240 | ssize_t sendLength = std::min( |
241 | (ssize_t)chunkBytes, (ssize_t)totalBytes - (ssize_t)sendOffset); |
242 | if (sendLength > 0) { |
243 | out->send(opts.root, slot, sendOffset, sendLength); |
244 | out->waitSend(opts.timeout); |
245 | } |
246 | } |
247 | } |
248 | |
249 | } // namespace gloo |
250 | |