1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/scatter.h"
10
11#include <algorithm>
12#include <cstring>
13
14#include "gloo/common/logging.h"
15#include "gloo/types.h"
16
17namespace gloo {
18
19void scatter(ScatterOptions& opts) {
20 const auto& context = opts.context;
21 std::vector<std::unique_ptr<transport::UnboundBuffer>>& in = opts.in;
22 std::unique_ptr<transport::UnboundBuffer>& out = opts.out;
23 const auto slot = Slot::build(kScatterSlotPrefix, opts.tag);
24
25 // Sanity checks
26 GLOO_ENFORCE(opts.elementSize > 0);
27 GLOO_ENFORCE(opts.root >= 0 && opts.root < context->size);
28 GLOO_ENFORCE(out);
29 if (context->rank == opts.root) {
30 // Assert there are as many inputs as ranks to send to.
31 GLOO_ENFORCE_EQ(in.size(), context->size);
32 // Assert the size of all inputs is identical to the output.
33 for (size_t i = 0; i < in.size(); i++) {
34 GLOO_ENFORCE_EQ(in[i]->size, out->size);
35 }
36 }
37
38 if (context->rank == opts.root) {
39 // Post send operations to peers.
40 for (size_t i = 0; i < context->size; i++) {
41 if (i == context->rank) {
42 continue;
43 }
44 in[i]->send(i, slot);
45 }
46
47 // Copy local input to output
48 memcpy(out->ptr, in[context->rank]->ptr, out->size);
49
50 // Wait for send operations to complete
51 for (size_t i = 0; i < context->size; i++) {
52 if (i == context->rank) {
53 continue;
54 }
55 in[i]->waitSend(opts.timeout);
56 }
57 } else {
58 out->recv(opts.root, slot);
59 out->waitRecv(opts.timeout);
60 }
61}
62
63} // namespace gloo
64