1 | /** |
---|---|
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/gather.h" |
10 | |
11 | #include <cstring> |
12 | |
13 | #include "gloo/common/logging.h" |
14 | #include "gloo/types.h" |
15 | |
16 | namespace gloo { |
17 | |
18 | void gather(GatherOptions& opts) { |
19 | const auto& context = opts.context; |
20 | transport::UnboundBuffer* in = opts.in.get(); |
21 | transport::UnboundBuffer* out = opts.out.get(); |
22 | const auto slot = Slot::build(kGatherSlotPrefix, opts.tag); |
23 | |
24 | // Sanity checks |
25 | GLOO_ENFORCE(opts.elementSize > 0); |
26 | GLOO_ENFORCE(in != nullptr); |
27 | |
28 | if (context->rank == opts.root) { |
29 | const size_t chunkSize = in->size; |
30 | |
31 | // Ensure the output buffer has the right size. |
32 | GLOO_ENFORCE(out != nullptr); |
33 | GLOO_ENFORCE(in->size * context->size == out->size); |
34 | |
35 | // Post receive operations from peers into out buffer |
36 | for (size_t i = 0; i < context->size; i++) { |
37 | if (i == context->rank) { |
38 | continue; |
39 | } |
40 | out->recv(i, slot, i * chunkSize, chunkSize); |
41 | } |
42 | |
43 | // Copy local input to output |
44 | memcpy( |
45 | static_cast<char*>(out->ptr) + (context->rank * chunkSize), |
46 | in->ptr, |
47 | chunkSize); |
48 | |
49 | // Wait for receive operations to complete |
50 | for (size_t i = 0; i < context->size; i++) { |
51 | if (i == context->rank) { |
52 | continue; |
53 | } |
54 | out->waitRecv(opts.timeout); |
55 | } |
56 | } else { |
57 | in->send(opts.root, slot); |
58 | in->waitSend(opts.timeout); |
59 | } |
60 | } |
61 | |
62 | } // namespace gloo |
63 |