1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include <memory>
12#include <vector>
13
14#include "gloo/context.h"
15#include "gloo/transport/unbound_buffer.h"
16
17namespace gloo {
18
19class ScatterOptions {
20 public:
21 explicit ScatterOptions(const std::shared_ptr<Context>& context)
22 : context(context), timeout(context->getTimeout()) {}
23
24 template <typename T>
25 void setInputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
26 this->elementSize = sizeof(T);
27 this->in = std::move(bufs);
28 }
29
30 template <typename T>
31 void setInputs(std::vector<T*> ptrs, size_t elements) {
32 setInputs(ptrs.data(), ptrs.size(), elements);
33 }
34
35 template <typename T>
36 void setInputs(T** ptrs, size_t len, size_t elements) {
37 this->elementSize = sizeof(T);
38 this->in.reserve(len);
39 for (size_t i = 0; i < len; i++) {
40 this->in.push_back(
41 context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
42 }
43 }
44
45 template <typename T>
46 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
47 this->elementSize = sizeof(T);
48 this->out = std::move(buf);
49 }
50
51 template <typename T>
52 void setOutput(T* ptr, size_t elements) {
53 this->elementSize = sizeof(T);
54 this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
55 }
56
57 void setRoot(int root) {
58 this->root = root;
59 }
60
61 void setTag(uint32_t tag) {
62 this->tag = tag;
63 }
64
65 void setTimeout(std::chrono::milliseconds timeout) {
66 this->timeout = timeout;
67 }
68
69 protected:
70 std::shared_ptr<Context> context;
71
72 // Scatter has N input buffers where each one in its
73 // entirety gets sent to a rank. The input(s) only need to
74 // be set on the root process.
75 std::vector<std::unique_ptr<transport::UnboundBuffer>> in;
76
77 // Scatter only has a single output buffer per rank.
78 std::unique_ptr<transport::UnboundBuffer> out;
79
80 // Number of bytes per element.
81 size_t elementSize = 0;
82
83 // Rank of process to scatter from.
84 int root = -1;
85
86 // Tag for this operation.
87 // Must be unique across operations executing in parallel.
88 uint32_t tag = 0;
89
90 // End-to-end timeout for this operation.
91 std::chrono::milliseconds timeout;
92
93 friend void scatter(ScatterOptions&);
94};
95
96void scatter(ScatterOptions& opts);
97
98} // namespace gloo
99