1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/allreduce.h"
10
11#include <algorithm>
12#include <array>
13#include <cstring>
14
15#include "gloo/common/logging.h"
16#include "gloo/math.h"
17#include "gloo/types.h"
18
19namespace gloo {
20
21namespace {
22
23using BufferVector = std::vector<std::unique_ptr<transport::UnboundBuffer>>;
24using ReductionFunction = AllreduceOptions::Func;
25using ReduceRangeFunction = std::function<void(size_t, size_t)>;
26using BroadcastRangeFunction = std::function<void(size_t, size_t)>;
27
28// Forward declaration of ring algorithm implementation.
29void ring(
30 const detail::AllreduceOptionsImpl& opts,
31 ReduceRangeFunction reduceInputs,
32 BroadcastRangeFunction broadcastOutputs);
33
34// Forward declaration of bcube algorithm implementation.
35void bcube(
36 const detail::AllreduceOptionsImpl& opts,
37 ReduceRangeFunction reduceInputs,
38 BroadcastRangeFunction broadcastOutputs);
39
40// Returns function that computes local reduction over inputs and
41// stores it in the output for a given range in those buffers.
42// This is done prior to either sending a region to a neighbor, or
43// reducing a region received from a neighbor.
44ReduceRangeFunction genLocalReduceFunction(
45 const BufferVector& in,
46 const BufferVector& out,
47 size_t elementSize,
48 ReductionFunction fn) {
49 if (in.size() > 0) {
50 if (in.size() == 1) {
51 return [&in, &out](size_t offset, size_t length) {
52 memcpy(
53 static_cast<uint8_t*>(out[0]->ptr) + offset,
54 static_cast<const uint8_t*>(in[0]->ptr) + offset,
55 length);
56 };
57 } else {
58 return [&in, &out, elementSize, fn](size_t offset, size_t length) {
59 fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
60 static_cast<const uint8_t*>(in[0]->ptr) + offset,
61 static_cast<const uint8_t*>(in[1]->ptr) + offset,
62 length / elementSize);
63 for (size_t i = 2; i < in.size(); i++) {
64 fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
65 static_cast<const uint8_t*>(out[0]->ptr) + offset,
66 static_cast<const uint8_t*>(in[i]->ptr) + offset,
67 length / elementSize);
68 }
69 };
70 }
71 } else {
72 return [&out, elementSize, fn](size_t offset, size_t length) {
73 for (size_t i = 1; i < out.size(); i++) {
74 fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
75 static_cast<const uint8_t*>(out[0]->ptr) + offset,
76 static_cast<const uint8_t*>(out[i]->ptr) + offset,
77 length / elementSize);
78 }
79 };
80 }
81}
82
83// Returns function that performs a local broadcast over outputs for a
84// given range in the buffers. This is executed after receiving every
85// globally reduced chunk.
86BroadcastRangeFunction genLocalBroadcastFunction(const BufferVector& out) {
87 return [&out](size_t offset, size_t length) {
88 for (size_t i = 1; i < out.size(); i++) {
89 memcpy(
90 static_cast<uint8_t*>(out[i]->ptr) + offset,
91 static_cast<const uint8_t*>(out[0]->ptr) + offset,
92 length);
93 }
94 };
95}
96
97void allreduce(const detail::AllreduceOptionsImpl& opts) {
98 if (opts.elements == 0) {
99 return;
100 }
101
102 const auto& context = opts.context;
103 const std::vector<std::unique_ptr<transport::UnboundBuffer>>& in = opts.in;
104 const std::vector<std::unique_ptr<transport::UnboundBuffer>>& out = opts.out;
105 const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag);
106
107 // Sanity checks
108 GLOO_ENFORCE_GT(out.size(), 0);
109 GLOO_ENFORCE(opts.elementSize > 0);
110 GLOO_ENFORCE(opts.reduce != nullptr);
111
112 // Assert the size of all inputs and outputs is identical.
113 const size_t totalBytes = opts.elements * opts.elementSize;
114 for (size_t i = 0; i < out.size(); i++) {
115 GLOO_ENFORCE_EQ(out[i]->size, totalBytes);
116 }
117 for (size_t i = 0; i < in.size(); i++) {
118 GLOO_ENFORCE_EQ(in[i]->size, totalBytes);
119 }
120
121 // Initialize local reduction and broadcast functions.
122 // Note that these are a no-op if only a single output is specified
123 // and is used as both input and output.
124 const auto reduceInputs =
125 genLocalReduceFunction(in, out, opts.elementSize, opts.reduce);
126 const auto broadcastOutputs = genLocalBroadcastFunction(out);
127
128 // Simple circuit if there is only a single process.
129 if (context->size == 1) {
130 reduceInputs(0, totalBytes);
131 broadcastOutputs(0, totalBytes);
132 return;
133 }
134
135 switch (opts.algorithm) {
136 case detail::AllreduceOptionsImpl::UNSPECIFIED:
137 case detail::AllreduceOptionsImpl::RING:
138 ring(opts, reduceInputs, broadcastOutputs);
139 break;
140 case detail::AllreduceOptionsImpl::BCUBE:
141 bcube(opts, reduceInputs, broadcastOutputs);
142 break;
143 default:
144 GLOO_ENFORCE(false, "Algorithm not handled.");
145 }
146}
147
148void ring(
149 const detail::AllreduceOptionsImpl& opts,
150 ReduceRangeFunction reduceInputs,
151 BroadcastRangeFunction broadcastOutputs) {
152 const auto& context = opts.context;
153 const std::vector<std::unique_ptr<transport::UnboundBuffer>>& out = opts.out;
154 const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag);
155 const size_t totalBytes = opts.elements * opts.elementSize;
156
157 // Note: context->size > 1
158 const auto recvRank = (context->size + context->rank + 1) % context->size;
159 const auto sendRank = (context->size + context->rank - 1) % context->size;
160 GLOO_ENFORCE(
161 context->getPair(recvRank),
162 "missing connection between rank " + std::to_string(context->rank) +
163 " (this process) and rank " + std::to_string(recvRank));
164 GLOO_ENFORCE(
165 context->getPair(sendRank),
166 "missing connection between rank " + std::to_string(context->rank) +
167 " (this process) and rank " + std::to_string(sendRank));
168
169 // The ring algorithm works as follows.
170 //
171 // The given input is split into a number of chunks equal to the
172 // number of processes. Once the algorithm has finished, every
173 // process hosts one chunk of reduced output, in sequential order
174 // (rank 0 has chunk 0, rank 1 has chunk 1, etc.). As the input may
175 // not be divisible by the number of processes, the chunk on the
176 // final ranks may have partial output or may be empty.
177 //
178 // As a chunk is passed along the ring and contains the reduction of
179 // successively more ranks, we have to alternate between performing
180 // I/O for that chunk and computing the reduction between the
181 // received chunk and the local chunk. To avoid this alternating
182 // pattern, we split up a chunk into multiple segments (>= 2), and
183 // ensure we have one segment in flight while computing a reduction
184 // on the other. The segment size has an upper bound to minimize
185 // memory usage and avoid poor cache behavior. This means we may
186 // have many segments per chunk when dealing with very large inputs.
187 //
188 // The nomenclature here is reflected in the variable naming below
189 // (one chunk per rank and many segments per chunk).
190 //
191
192 // Ensure that maximum segment size is a multiple of the element size.
193 // Otherwise, the segment size can exceed the maximum segment size after
194 // rounding it up to the nearest multiple of the element size.
195 // For example, if maxSegmentSize = 10, and elementSize = 4,
196 // then after rounding up: segmentSize = 12;
197 const size_t maxSegmentBytes = opts.elementSize *
198 std::max((size_t)1, opts.maxSegmentSize / opts.elementSize);
199
200 // Compute how many segments make up the input buffer.
201 //
202 // Round up to the nearest multiple of the context size such that
203 // there is an equal number of segments per process and execution is
204 // symmetric across processes.
205 //
206 // The minimum is twice the context size, because the algorithm
207 // below overlaps sending/receiving a segment with computing the
208 // reduction of the another segment.
209 //
210 const size_t numSegments = roundUp(
211 std::max(
212 (totalBytes + (maxSegmentBytes - 1)) / maxSegmentBytes,
213 (size_t)context->size * 2),
214 (size_t)context->size);
215 GLOO_ENFORCE_EQ(numSegments % context->size, 0);
216 GLOO_ENFORCE_GE(numSegments, context->size * 2);
217 const size_t numSegmentsPerRank = numSegments / context->size;
218 const size_t segmentBytes =
219 roundUp((totalBytes + numSegments - 1) / numSegments, opts.elementSize);
220
221 // Allocate scratch space to hold two chunks
222 std::unique_ptr<uint8_t[]> tmpAllocation(new uint8_t[segmentBytes * 2]);
223 std::unique_ptr<transport::UnboundBuffer> tmpBuffer =
224 context->createUnboundBuffer(tmpAllocation.get(), segmentBytes * 2);
225 transport::UnboundBuffer* tmp = tmpBuffer.get();
226
227 // Use dynamic lookup for chunk offset in the temporary buffer.
228 // With two operations in flight we need two offsets.
229 // They can be indexed using the loop counter.
230 std::array<size_t, 2> segmentOffset;
231 segmentOffset[0] = 0;
232 segmentOffset[1] = segmentBytes;
233
234 // Function computes the offsets and lengths of the segments to be
235 // sent and received for a given iteration during reduce/scatter.
236 auto computeReduceScatterOffsets = [&](size_t i) {
237 struct {
238 size_t sendOffset;
239 size_t recvOffset;
240 ssize_t sendLength;
241 ssize_t recvLength;
242 } result;
243
244 // Compute segment index to send from (to rank - 1) and segment
245 // index to receive into (from rank + 1). Multiply by the number
246 // of bytes in a chunk to get to an offset. The offset is allowed
247 // to be out of range (>= totalBytes) and this is taken into
248 // account when computing the associated length.
249 result.sendOffset =
250 ((((context->rank + 1) * numSegmentsPerRank) + i) * segmentBytes) %
251 (numSegments * segmentBytes);
252 result.recvOffset =
253 ((((context->rank + 2) * numSegmentsPerRank) + i) * segmentBytes) %
254 (numSegments * segmentBytes);
255
256 // If the segment is entirely in range, the following statement is
257 // equal to segmentBytes. If it isn't, it will be less, or even
258 // negative. This is why the ssize_t typecasts are needed.
259 result.sendLength = std::min(
260 (ssize_t)segmentBytes,
261 (ssize_t)totalBytes - (ssize_t)result.sendOffset);
262 result.recvLength = std::min(
263 (ssize_t)segmentBytes,
264 (ssize_t)totalBytes - (ssize_t)result.recvOffset);
265
266 return result;
267 };
268
269 // Ring reduce/scatter.
270 //
271 // Number of iterations is computed as follows:
272 // - Take `numSegments` for the total number of segments,
273 // - Subtract `numSegmentsPerRank` because the final segments hold
274 // the partial result and must not be forwarded in this phase.
275 // - Add 2 because we pipeline send and receive operations (we issue
276 // send/recv operations on iterations 0 and 1 and wait for them to
277 // complete on iterations 2 and 3).
278 //
279 for (auto i = 0; i < (numSegments - numSegmentsPerRank + 2); i++) {
280 if (i >= 2) {
281 // Compute send and receive offsets and lengths two iterations
282 // ago. Needed so we know when to wait for an operation and when
283 // to ignore (when the offset was out of bounds), and know where
284 // to reduce the contents of the temporary buffer.
285 auto prev = computeReduceScatterOffsets(i - 2);
286 if (prev.recvLength > 0) {
287 // Prepare out[0]->ptr to hold the local reduction
288 reduceInputs(prev.recvOffset, prev.recvLength);
289 // Wait for segment from neighbor.
290 tmp->waitRecv(opts.timeout);
291 // Reduce segment from neighbor into out->ptr.
292 opts.reduce(
293 static_cast<uint8_t*>(out[0]->ptr) + prev.recvOffset,
294 static_cast<const uint8_t*>(out[0]->ptr) + prev.recvOffset,
295 static_cast<const uint8_t*>(tmp->ptr) + segmentOffset[i & 0x1],
296 prev.recvLength / opts.elementSize);
297 }
298 if (prev.sendLength > 0) {
299 out[0]->waitSend(opts.timeout);
300 }
301 }
302
303 // Issue new send and receive operation in all but the final two
304 // iterations. At that point we have already sent all data we
305 // needed to and only have to wait for the final segments to be
306 // reduced into the output.
307 if (i < (numSegments - numSegmentsPerRank)) {
308 // Compute send and receive offsets and lengths for this iteration.
309 auto cur = computeReduceScatterOffsets(i);
310 if (cur.recvLength > 0) {
311 tmp->recv(recvRank, slot, segmentOffset[i & 0x1], cur.recvLength);
312 }
313 if (cur.sendLength > 0) {
314 // Prepare out[0]->ptr to hold the local reduction for this segment
315 if (i < numSegmentsPerRank) {
316 reduceInputs(cur.sendOffset, cur.sendLength);
317 }
318 out[0]->send(sendRank, slot, cur.sendOffset, cur.sendLength);
319 }
320 }
321 }
322
323 // Function computes the offsets and lengths of the segments to be
324 // sent and received for a given iteration during allgather.
325 auto computeAllgatherOffsets = [&](size_t i) {
326 struct {
327 size_t sendOffset;
328 size_t recvOffset;
329 ssize_t sendLength;
330 ssize_t recvLength;
331 } result;
332
333 result.sendOffset =
334 ((((context->rank) * numSegmentsPerRank) + i) * segmentBytes) %
335 (numSegments * segmentBytes);
336 result.recvOffset =
337 ((((context->rank + 1) * numSegmentsPerRank) + i) * segmentBytes) %
338 (numSegments * segmentBytes);
339
340 // If the segment is entirely in range, the following statement is
341 // equal to segmentBytes. If it isn't, it will be less, or even
342 // negative. This is why the ssize_t typecasts are needed.
343 result.sendLength = std::min(
344 (ssize_t)segmentBytes,
345 (ssize_t)totalBytes - (ssize_t)result.sendOffset);
346 result.recvLength = std::min(
347 (ssize_t)segmentBytes,
348 (ssize_t)totalBytes - (ssize_t)result.recvOffset);
349
350 return result;
351 };
352
353 // Ring allgather.
354 //
355 // Beware: totalBytes <= (numSegments * segmentBytes), which is
356 // incompatible with the generic allgather algorithm where the
357 // contribution is identical across processes.
358 //
359 // See comment prior to reduce/scatter loop on how the number of
360 // iterations for this loop is computed.
361 //
362 for (auto i = 0; i < (numSegments - numSegmentsPerRank + 2); i++) {
363 if (i >= 2) {
364 auto prev = computeAllgatherOffsets(i - 2);
365 if (prev.recvLength > 0) {
366 out[0]->waitRecv(opts.timeout);
367 // Broadcast received segments to output buffers.
368 broadcastOutputs(prev.recvOffset, prev.recvLength);
369 }
370 if (prev.sendLength > 0) {
371 out[0]->waitSend(opts.timeout);
372 }
373 }
374
375 // Issue new send and receive operation in all but the final two
376 // iterations. At that point we have already sent all data we
377 // needed to and only have to wait for the final segments to be
378 // sent to the output.
379 if (i < (numSegments - numSegmentsPerRank)) {
380 auto cur = computeAllgatherOffsets(i);
381 if (cur.recvLength > 0) {
382 out[0]->recv(recvRank, slot, cur.recvOffset, cur.recvLength);
383 }
384 if (cur.sendLength > 0) {
385 out[0]->send(sendRank, slot, cur.sendOffset, cur.sendLength);
386 // Broadcast first segments to outputs buffers.
387 if (i < numSegmentsPerRank) {
388 broadcastOutputs(cur.sendOffset, cur.sendLength);
389 }
390 }
391 }
392 }
393}
394
395// For a given context size and desired group size, compute the actual group
396// size per step. Note that the group size per step is n for all steps, only
397// if n^(#steps) == size. Otherwise, the final group size is != n.
398std::vector<size_t> computeGroupSizePerStep(size_t size, const size_t n) {
399 std::vector<size_t> result;
400 GLOO_ENFORCE_GT(n, 1);
401 while (size % n == 0) {
402 result.push_back(n);
403 size /= n;
404 }
405 if (size > 1) {
406 result.push_back(size);
407 }
408 return result;
409}
410
411// The bcube algorithm implements a hypercube-like strategy for reduction. The
412// constraint is that the number of processes can be factorized. If the minimum
413// component in the factorization is 2, and the number of processes is equal to
414// a power of 2, the algorithm is identical to recursive halving/doubling. The
415// number of elements in the factorization determines the number of steps of the
416// algorithm. Each element of the factorization determines the number of
417// processes each process communicates with at that particular step of the
418// algorithm. If the number of processes is not factorizable, the algorithm is
419// identical to a direct reduce-scatter followed by allgather.
420//
421// For example, if #processes == 8, and we factorize as 4 * 2, the algorithm
422// runs in 2 steps. In the first step, 2 groups of 4 processes exchange data
423// such that all processes have 1/4th of the partial result (with process 0
424// having the first quarter, 1 having the second quarter, and so forth). In the
425// second step, 4 groups of 2 processes exchange their partial result such that
426// all processes have 1/8th of the result. Then, the same factorization is
427// followed in reverse to perform an allgather.
428//
429void bcube(
430 const detail::AllreduceOptionsImpl& opts,
431 ReduceRangeFunction reduceInputs,
432 BroadcastRangeFunction broadcastOutputs) {
433 const auto& context = opts.context;
434 const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag);
435 const auto elementSize = opts.elementSize;
436 auto& out = opts.out[0];
437
438 constexpr auto n = 2;
439
440 // Figure out the number of steps in this algorithm.
441 const auto groupSizePerStep = computeGroupSizePerStep(context->size, n);
442
443 struct group {
444 // Distance between peers in this group.
445 size_t peerDistance;
446
447 // Segment that this group is responsible for reducing.
448 size_t bufferOffset;
449 size_t bufferLength;
450
451 // The process ranks that are a member of this group.
452 std::vector<size_t> ranks;
453
454 // Upper bound of the length of the chunk that each process has the
455 // reduced values for by the end of the reduction for this group.
456 size_t chunkLength;
457
458 // Chunk within the segment that this process is responsible for reducing.
459 size_t myChunkOffset;
460 size_t myChunkLength;
461 };
462
463 // Compute the details of a group at every algorithm step.
464 // We keep this in a vector because we iterate through it in forward order in
465 // the reduce/scatter phase and in backward order in the allgather phase.
466 std::vector<struct group> groups;
467 {
468 struct group group;
469 group.peerDistance = 1;
470 group.bufferOffset = 0;
471 group.bufferLength = opts.elements;
472 for (const size_t groupSize : groupSizePerStep) {
473 const size_t groupRank = (context->rank / group.peerDistance) % groupSize;
474 const size_t baseRank = context->rank - (groupRank * group.peerDistance);
475 group.ranks.reserve(groupSize);
476 for (size_t i = 0; i < groupSize; i++) {
477 group.ranks.push_back(baseRank + i * group.peerDistance);
478 }
479
480 // Compute the length of the chunk we're exchanging at this step.
481 group.chunkLength = ((group.bufferLength + (groupSize - 1)) / groupSize);
482
483 // This process is computing the reduction of the chunk positioned at
484 // <rank>/<size> within the current segment.
485 group.myChunkOffset =
486 group.bufferOffset + (groupRank * group.chunkLength);
487 group.myChunkLength = std::min(
488 size_t(group.chunkLength),
489 size_t(std::max(
490 int64_t(0),
491 int64_t(group.bufferLength) -
492 int64_t(groupRank * group.chunkLength))));
493
494 // Store a const copy of this group in the vector.
495 groups.push_back(group);
496
497 // Initialize with updated peer distance and segment offset and length.
498 struct group nextGroup;
499 nextGroup.peerDistance = group.peerDistance * groupSize;
500 nextGroup.bufferOffset = group.myChunkOffset;
501 nextGroup.bufferLength = group.myChunkLength;
502 std::swap(group, nextGroup);
503 }
504 }
505
506 // The chunk length is rounded up, so the maximum scratch space we need
507 // might be larger than the size of the output buffer. Compute the maximum
508 size_t bufferLength = opts.elements;
509 for (const auto& group : groups) {
510 bufferLength =
511 std::max(bufferLength, group.ranks.size() * group.chunkLength);
512 }
513
514 // Allocate scratch space to receive data from peers.
515 const size_t bufferSize = bufferLength * elementSize;
516 std::unique_ptr<uint8_t[]> buffer(new uint8_t[bufferSize]);
517 std::unique_ptr<transport::UnboundBuffer> tmp =
518 context->createUnboundBuffer(buffer.get(), bufferSize);
519
520 // Reduce/scatter.
521 for (size_t step = 0; step < groups.size(); step++) {
522 const auto& group = groups[step];
523
524 // Issue receive operations for chunks from peers.
525 for (size_t i = 0; i < group.ranks.size(); i++) {
526 const auto src = group.ranks[i];
527 if (src == context->rank) {
528 continue;
529 }
530 tmp->recv(
531 src,
532 slot,
533 i * group.chunkLength * elementSize,
534 group.myChunkLength * elementSize);
535 }
536
537 // Issue send operations for local chunks to peers.
538 for (size_t i = 0; i < group.ranks.size(); i++) {
539 const auto dst = group.ranks[i];
540 if (dst == context->rank) {
541 continue;
542 }
543 const size_t currentChunkOffset =
544 group.bufferOffset + i * group.chunkLength;
545 const size_t currentChunkLength = std::min(
546 size_t(group.chunkLength),
547 size_t(std::max(
548 int64_t(0),
549 int64_t(group.bufferLength) - int64_t(i * group.chunkLength))));
550 // Compute the local reduction only in the first step of the algorithm.
551 // In subsequent steps, we already have a partially reduced result.
552 if (step == 0) {
553 reduceInputs(
554 currentChunkOffset * elementSize, currentChunkLength * elementSize);
555 }
556 out->send(
557 dst,
558 slot,
559 currentChunkOffset * elementSize,
560 currentChunkLength * elementSize);
561 }
562
563 // Wait for send and receive operations to complete.
564 for (size_t i = 0; i < group.ranks.size(); i++) {
565 const auto peer = group.ranks[i];
566 if (peer == context->rank) {
567 continue;
568 }
569 tmp->waitRecv();
570 out->waitSend();
571 }
572
573 // In the first step, prepare the chunk this process is responsible for
574 // with the reduced version of its inputs (if multiple are specified).
575 if (step == 0) {
576 reduceInputs(
577 group.myChunkOffset * elementSize, group.myChunkLength * elementSize);
578 }
579
580 // Reduce chunks from peers.
581 for (size_t i = 0; i < group.ranks.size(); i++) {
582 const auto src = group.ranks[i];
583 if (src == context->rank) {
584 continue;
585 }
586 opts.reduce(
587 static_cast<uint8_t*>(out->ptr) + (group.myChunkOffset * elementSize),
588 static_cast<const uint8_t*>(out->ptr) +
589 (group.myChunkOffset * elementSize),
590 static_cast<const uint8_t*>(tmp->ptr) +
591 (i * group.chunkLength * elementSize),
592 group.myChunkLength);
593 }
594 }
595
596 // There is one chunk that contains the final result and this chunk
597 // can already be broadcast locally to out[1..N], if applicable.
598 // Doing so means we only have to broadcast locally to out[1..N] all
599 // chunks as we receive them from our peers during the allgather phase.
600 {
601 const auto& group = groups.back();
602 broadcastOutputs(
603 group.myChunkOffset * elementSize, group.myChunkLength * elementSize);
604 }
605
606 // Allgather.
607 for (auto it = groups.rbegin(); it != groups.rend(); it++) {
608 const auto& group = *it;
609
610 // Issue receive operations for reduced chunks from peers.
611 for (size_t i = 0; i < group.ranks.size(); i++) {
612 const auto src = group.ranks[i];
613 if (src == context->rank) {
614 continue;
615 }
616 const size_t currentChunkOffset =
617 group.bufferOffset + i * group.chunkLength;
618 const size_t currentChunkLength = std::min(
619 size_t(group.chunkLength),
620 size_t(std::max(
621 int64_t(0),
622 int64_t(group.bufferLength) - int64_t(i * group.chunkLength))));
623 out->recv(
624 src,
625 slot,
626 currentChunkOffset * elementSize,
627 currentChunkLength * elementSize);
628 }
629
630 // Issue send operations for reduced chunk to peers.
631 for (size_t i = 0; i < group.ranks.size(); i++) {
632 const auto dst = group.ranks[i];
633 if (dst == context->rank) {
634 continue;
635 }
636 out->send(
637 dst,
638 slot,
639 group.myChunkOffset * elementSize,
640 group.myChunkLength * elementSize);
641 }
642
643 // Wait for operations to complete.
644 for (size_t i = 0; i < group.ranks.size(); i++) {
645 const auto peer = group.ranks[i];
646 if (peer == context->rank) {
647 continue;
648 }
649 out->waitRecv();
650 out->waitSend();
651 }
652
653 // Broadcast result to multiple output buffers, if applicable.
654 for (size_t i = 0; i < group.ranks.size(); i++) {
655 const auto peer = group.ranks[i];
656 if (peer == context->rank) {
657 continue;
658 }
659 const size_t currentChunkOffset =
660 group.bufferOffset + i * group.chunkLength;
661 const size_t currentChunkLength = std::min(
662 size_t(group.chunkLength),
663 size_t(std::max(
664 int64_t(0),
665 int64_t(group.bufferLength) - int64_t(i * group.chunkLength))));
666 broadcastOutputs(
667 currentChunkOffset * elementSize, currentChunkLength * elementSize);
668 }
669 }
670}
671
672} // namespace
673
674void allreduce(const AllreduceOptions& opts) {
675 allreduce(opts.impl_);
676}
677
678} // namespace gloo
679