1 | /** |
---|---|
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/scatter.h" |
10 | |
11 | #include <algorithm> |
12 | #include <cstring> |
13 | |
14 | #include "gloo/common/logging.h" |
15 | #include "gloo/types.h" |
16 | |
17 | namespace gloo { |
18 | |
19 | void scatter(ScatterOptions& opts) { |
20 | const auto& context = opts.context; |
21 | std::vector<std::unique_ptr<transport::UnboundBuffer>>& in = opts.in; |
22 | std::unique_ptr<transport::UnboundBuffer>& out = opts.out; |
23 | const auto slot = Slot::build(kScatterSlotPrefix, opts.tag); |
24 | |
25 | // Sanity checks |
26 | GLOO_ENFORCE(opts.elementSize > 0); |
27 | GLOO_ENFORCE(opts.root >= 0 && opts.root < context->size); |
28 | GLOO_ENFORCE(out); |
29 | if (context->rank == opts.root) { |
30 | // Assert there are as many inputs as ranks to send to. |
31 | GLOO_ENFORCE_EQ(in.size(), context->size); |
32 | // Assert the size of all inputs is identical to the output. |
33 | for (size_t i = 0; i < in.size(); i++) { |
34 | GLOO_ENFORCE_EQ(in[i]->size, out->size); |
35 | } |
36 | } |
37 | |
38 | if (context->rank == opts.root) { |
39 | // Post send operations to peers. |
40 | for (size_t i = 0; i < context->size; i++) { |
41 | if (i == context->rank) { |
42 | continue; |
43 | } |
44 | in[i]->send(i, slot); |
45 | } |
46 | |
47 | // Copy local input to output |
48 | memcpy(out->ptr, in[context->rank]->ptr, out->size); |
49 | |
50 | // Wait for send operations to complete |
51 | for (size_t i = 0; i < context->size; i++) { |
52 | if (i == context->rank) { |
53 | continue; |
54 | } |
55 | in[i]->waitSend(opts.timeout); |
56 | } |
57 | } else { |
58 | out->recv(opts.root, slot); |
59 | out->waitRecv(opts.timeout); |
60 | } |
61 | } |
62 | |
63 | } // namespace gloo |
64 |