1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/reduce.h"
10
11#include <array>
12#include <algorithm>
13#include <cstring>
14
15#include "gloo/common/logging.h"
16#include "gloo/math.h"
17#include "gloo/types.h"
18
19namespace gloo {
20
21void reduce(ReduceOptions& opts) {
22 if (opts.elements == 0) {
23 return;
24 }
25 const auto& context = opts.context;
26 transport::UnboundBuffer* in = opts.in.get();
27 transport::UnboundBuffer* out = opts.out.get();
28 const auto slot = Slot::build(kReduceSlotPrefix, opts.tag);
29
30 // Sanity checks
31 GLOO_ENFORCE(opts.elementSize > 0);
32 GLOO_ENFORCE(opts.root >= 0 && opts.root < context->size);
33 GLOO_ENFORCE(opts.reduce != nullptr);
34 const auto recvRank = (context->size + context->rank + 1) % context->size;
35 GLOO_ENFORCE(
36 recvRank == context->rank || context->getPair(recvRank),
37 "missing connection between rank " + std::to_string(context->rank) +
38 " (this process) and rank " + std::to_string(recvRank));
39 const auto sendRank = (context->size + context->rank - 1) % context->size;
40 GLOO_ENFORCE(
41 sendRank == context->rank || context->getPair(sendRank),
42 "missing connection between rank " + std::to_string(context->rank) +
43 " (this process) and rank " + std::to_string(sendRank));
44
45 // If input buffer is not specified, the output is also the input
46 if (in == nullptr) {
47 in = out;
48 }
49
50 GLOO_ENFORCE_EQ(in->size, opts.elements * opts.elementSize);
51 GLOO_ENFORCE_EQ(out->size, opts.elements * opts.elementSize);
52
53 // Short circuit if there is only a single process.
54 if (context->size == 1) {
55 if (in != out) {
56 memcpy(out->ptr, in->ptr, opts.elements * opts.elementSize);
57 }
58 return;
59 }
60
61 // The ring algorithm works as follows.
62 //
63 // The given input is split into a number of chunks equal to the
64 // number of processes. Once the algorithm has finished, every
65 // process hosts one chunk of reduced output, in sequential order
66 // (rank 0 has chunk 0, rank 1 has chunk 1, etc.). As the input may
67 // not be divisible by the number of processes, the chunk on the
68 // final ranks may have partial output or may be empty.
69 //
70 // As a chunk is passed along the ring and contains the reduction of
71 // successively more ranks, we have to alternate between performing
72 // I/O for that chunk and computing the reduction between the
73 // received chunk and the local chunk. To avoid this alternating
74 // pattern, we split up a chunk into multiple segments (>= 2), and
75 // ensure we have one segment in flight while computing a reduction
76 // on the other. The segment size has an upper bound to minimize
77 // memory usage and avoid poor cache behavior. This means we may
78 // have many segments per chunk when dealing with very large inputs.
79 //
80 // The nomenclature here is reflected in the variable naming below
81 // (one chunk per rank and many segments per chunk).
82 //
83 const size_t totalBytes = opts.elements * opts.elementSize;
84
85 // Ensure that maximum segment size is a multiple of the element size.
86 // Otherwise, the segment size can exceed the maximum segment size after
87 // rounding it up to the nearest multiple of the element size.
88 // For example, if maxSegmentSize = 10, and elementSize = 4,
89 // then after rounding up: segmentSize = 12;
90 const size_t maxSegmentSize =
91 opts.elementSize * (opts.maxSegmentSize / opts.elementSize);
92
93 // The number of bytes per segment must be a multiple of the bytes
94 // per element for the reduction to work; round up if necessary.
95 const size_t segmentBytes = roundUp(
96 std::min(
97 // Rounded division to have >= 2 segments per chunk.
98 (totalBytes + (context->size * 2 - 1)) / (context->size * 2),
99 // Configurable segment size limit
100 maxSegmentSize),
101 opts.elementSize);
102
103 // Compute how many segments make up the input buffer.
104 //
105 // Round up to the nearest multiple of the context size such that
106 // there is an equal number of segments per process and execution is
107 // symmetric across processes.
108 //
109 // The minimum is twice the context size, because the algorithm
110 // below overlaps sending/receiving a segment with computing the
111 // reduction of the another segment.
112 //
113 const size_t numSegments = roundUp(
114 std::max(
115 (totalBytes + (segmentBytes - 1)) / segmentBytes,
116 (size_t)context->size * 2),
117 (size_t)context->size);
118 GLOO_ENFORCE_EQ(numSegments % context->size, 0);
119 GLOO_ENFORCE_GE(numSegments, context->size * 2);
120 const size_t numSegmentsPerRank = numSegments / context->size;
121 const size_t chunkBytes = numSegmentsPerRank * segmentBytes;
122
123 // Allocate scratch space to hold two chunks
124 std::unique_ptr<uint8_t[]> tmpAllocation(new uint8_t[segmentBytes * 2]);
125 std::unique_ptr<transport::UnboundBuffer> tmpBuffer =
126 context->createUnboundBuffer(tmpAllocation.get(), segmentBytes * 2);
127 transport::UnboundBuffer* tmp = tmpBuffer.get();
128
129 // Use dynamic lookup for chunk offset in the temporary buffer.
130 // With two operations in flight we need two offsets.
131 // They can be indexed using the loop counter.
132 std::array<size_t, 2> segmentOffset;
133 segmentOffset[0] = 0;
134 segmentOffset[1] = segmentBytes;
135
136 // Function computes the offsets and lengths of the chunks to be
137 // sent and received for a given chunk iteration.
138 auto computeReduceScatterOffsets = [&](size_t i) {
139 struct {
140 size_t sendOffset;
141 size_t recvOffset;
142 ssize_t sendLength;
143 ssize_t recvLength;
144 } result;
145
146 // Compute segment index to send from (to rank - 1) and segment
147 // index to receive into (from rank + 1). Multiply by the number
148 // of bytes in a chunk to get to an offset. The offset is allowed
149 // to be out of range (>= totalBytes) and this is taken into
150 // account when computing the associated length.
151 result.sendOffset =
152 ((((context->rank + 1) * numSegmentsPerRank) + i) * segmentBytes) %
153 (numSegments * segmentBytes);
154 result.recvOffset =
155 ((((context->rank + 2) * numSegmentsPerRank) + i) * segmentBytes) %
156 (numSegments * segmentBytes);
157
158 // If the segment is entirely in range, the following statement is
159 // equal to segmentBytes. If it isn't, it will be less, or even
160 // negative. This is why the ssize_t typecasts are needed.
161 result.sendLength = std::min(
162 (ssize_t)segmentBytes,
163 (ssize_t)totalBytes - (ssize_t)result.sendOffset);
164 result.recvLength = std::min(
165 (ssize_t)segmentBytes,
166 (ssize_t)totalBytes - (ssize_t)result.recvOffset);
167
168 return result;
169 };
170
171 for (auto i = 0; i < numSegments; i++) {
172 if (i >= 2) {
173 // Compute send and receive offsets and lengths two iterations
174 // ago. Needed so we know when to wait for an operation and when
175 // to ignore (when the offset was out of bounds), and know where
176 // to reduce the contents of the temporary buffer.
177 auto prev = computeReduceScatterOffsets(i - 2);
178 if (prev.recvLength > 0) {
179 tmp->waitRecv(opts.timeout);
180 opts.reduce(
181 static_cast<uint8_t*>(out->ptr) + prev.recvOffset,
182 static_cast<const uint8_t*>(in->ptr) + prev.recvOffset,
183 static_cast<const uint8_t*>(tmp->ptr) + segmentOffset[i & 0x1],
184 prev.recvLength / opts.elementSize);
185 }
186 if (prev.sendLength > 0) {
187 if ((i - 2) < numSegmentsPerRank) {
188 in->waitSend(opts.timeout);
189 } else {
190 out->waitSend(opts.timeout);
191 }
192 }
193 }
194
195 // Issue new send and receive operation in all but the final two
196 // iterations. At that point we have already sent all data we
197 // needed to and only have to wait for the final segments to be
198 // reduced into the output.
199 if (i < (numSegments - 2)) {
200 // Compute send and receive offsets and lengths for this iteration.
201 auto cur = computeReduceScatterOffsets(i);
202 if (cur.recvLength > 0) {
203 tmp->recv(recvRank, slot, segmentOffset[i & 0x1], cur.recvLength);
204 }
205 if (cur.sendLength > 0) {
206 if (i < numSegmentsPerRank) {
207 in->send(sendRank, slot, cur.sendOffset, cur.sendLength);
208 } else {
209 out->send(sendRank, slot, cur.sendOffset, cur.sendLength);
210 }
211 }
212 }
213 }
214
215 // Gather to root rank.
216 //
217 // Beware: totalBytes <= (numSegments * segmentBytes), which is
218 // incompatible with the generic gather algorithm where the
219 // contribution is identical across processes.
220 //
221 if (context->rank == opts.root) {
222 size_t numRecv = 0;
223 for (size_t rank = 0; rank < context->size; rank++) {
224 if (rank == context->rank) {
225 continue;
226 }
227 size_t recvOffset = rank * numSegmentsPerRank * segmentBytes;
228 ssize_t recvLength = std::min(
229 (ssize_t)chunkBytes, (ssize_t)totalBytes - (ssize_t)recvOffset);
230 if (recvLength > 0) {
231 out->recv(rank, slot, recvOffset, recvLength);
232 numRecv++;
233 }
234 }
235 for (size_t i = 0; i < numRecv; i++) {
236 out->waitRecv(opts.timeout);
237 }
238 } else {
239 size_t sendOffset = context->rank * numSegmentsPerRank * segmentBytes;
240 ssize_t sendLength = std::min(
241 (ssize_t)chunkBytes, (ssize_t)totalBytes - (ssize_t)sendOffset);
242 if (sendLength > 0) {
243 out->send(opts.root, slot, sendOffset, sendLength);
244 out->waitSend(opts.timeout);
245 }
246 }
247}
248
249} // namespace gloo
250