1 | /** |
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include <memory> |
12 | #include <vector> |
13 | |
14 | #include "gloo/context.h" |
15 | #include "gloo/transport/unbound_buffer.h" |
16 | |
17 | namespace gloo { |
18 | |
19 | class ScatterOptions { |
20 | public: |
21 | explicit ScatterOptions(const std::shared_ptr<Context>& context) |
22 | : context(context), timeout(context->getTimeout()) {} |
23 | |
24 | template <typename T> |
25 | void setInputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) { |
26 | this->elementSize = sizeof(T); |
27 | this->in = std::move(bufs); |
28 | } |
29 | |
30 | template <typename T> |
31 | void setInputs(std::vector<T*> ptrs, size_t elements) { |
32 | setInputs(ptrs.data(), ptrs.size(), elements); |
33 | } |
34 | |
35 | template <typename T> |
36 | void setInputs(T** ptrs, size_t len, size_t elements) { |
37 | this->elementSize = sizeof(T); |
38 | this->in.reserve(len); |
39 | for (size_t i = 0; i < len; i++) { |
40 | this->in.push_back( |
41 | context->createUnboundBuffer(ptrs[i], elements * sizeof(T))); |
42 | } |
43 | } |
44 | |
45 | template <typename T> |
46 | void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) { |
47 | this->elementSize = sizeof(T); |
48 | this->out = std::move(buf); |
49 | } |
50 | |
51 | template <typename T> |
52 | void setOutput(T* ptr, size_t elements) { |
53 | this->elementSize = sizeof(T); |
54 | this->out = context->createUnboundBuffer(ptr, elements * sizeof(T)); |
55 | } |
56 | |
57 | void setRoot(int root) { |
58 | this->root = root; |
59 | } |
60 | |
61 | void setTag(uint32_t tag) { |
62 | this->tag = tag; |
63 | } |
64 | |
65 | void setTimeout(std::chrono::milliseconds timeout) { |
66 | this->timeout = timeout; |
67 | } |
68 | |
69 | protected: |
70 | std::shared_ptr<Context> context; |
71 | |
72 | // Scatter has N input buffers where each one in its |
73 | // entirety gets sent to a rank. The input(s) only need to |
74 | // be set on the root process. |
75 | std::vector<std::unique_ptr<transport::UnboundBuffer>> in; |
76 | |
77 | // Scatter only has a single output buffer per rank. |
78 | std::unique_ptr<transport::UnboundBuffer> out; |
79 | |
80 | // Number of bytes per element. |
81 | size_t elementSize = 0; |
82 | |
83 | // Rank of process to scatter from. |
84 | int root = -1; |
85 | |
86 | // Tag for this operation. |
87 | // Must be unique across operations executing in parallel. |
88 | uint32_t tag = 0; |
89 | |
90 | // End-to-end timeout for this operation. |
91 | std::chrono::milliseconds timeout; |
92 | |
93 | friend void scatter(ScatterOptions&); |
94 | }; |
95 | |
96 | void scatter(ScatterOptions& opts); |
97 | |
98 | } // namespace gloo |
99 | |