1 | /** |
2 | * Copyright (c) 2019-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include "gloo/context.h" |
12 | #include "gloo/transport/unbound_buffer.h" |
13 | |
14 | namespace gloo { |
15 | |
16 | class GathervOptions { |
17 | public: |
18 | explicit GathervOptions(const std::shared_ptr<Context>& context) |
19 | : context(context), timeout(context->getTimeout()) {} |
20 | |
21 | template <typename T> |
22 | void setInput(std::unique_ptr<transport::UnboundBuffer> buf) { |
23 | setInput(std::move(buf), sizeof(T)); |
24 | } |
25 | |
26 | template <typename T> |
27 | void setInput(T* ptr, size_t elements) { |
28 | setInput(static_cast<void*>(ptr), elements, sizeof(T)); |
29 | } |
30 | |
31 | template <typename T> |
32 | void setOutput( |
33 | std::unique_ptr<transport::UnboundBuffer> buf, |
34 | std::vector<size_t> elementsPerRank) { |
35 | setOutput(std::move(buf), std::move(elementsPerRank), sizeof(T)); |
36 | } |
37 | |
38 | template <typename T> |
39 | void setOutput(T* ptr, std::vector<size_t> elementsPerRank) { |
40 | setOutput(static_cast<void*>(ptr), std::move(elementsPerRank), sizeof(T)); |
41 | } |
42 | |
43 | void setRoot(int root) { |
44 | this->root = root; |
45 | } |
46 | |
47 | void setTag(uint32_t tag) { |
48 | this->tag = tag; |
49 | } |
50 | |
51 | void setTimeout(std::chrono::milliseconds timeout) { |
52 | this->timeout = timeout; |
53 | } |
54 | |
55 | protected: |
56 | std::shared_ptr<Context> context; |
57 | std::unique_ptr<transport::UnboundBuffer> in; |
58 | std::unique_ptr<transport::UnboundBuffer> out; |
59 | |
60 | // Number of elements per rank in the output. |
61 | std::vector<size_t> elementsPerRank; |
62 | |
63 | // Number of bytes per element. |
64 | size_t elementSize = 0; |
65 | |
66 | // Rank of receiving process. |
67 | int root = -1; |
68 | |
69 | // Tag for this operation. |
70 | // Must be unique across operations executing in parallel. |
71 | uint32_t tag = 0; |
72 | |
73 | // End-to-end timeout for this operation. |
74 | std::chrono::milliseconds timeout; |
75 | |
76 | // Set element size, or check the argument is equal to the current value. |
77 | void setElementSize(size_t elementSize); |
78 | |
79 | // Untemplated implementation of setInput on unbound buffer. |
80 | void setInput( |
81 | std::unique_ptr<transport::UnboundBuffer> buf, |
82 | size_t elementSize); |
83 | |
84 | // Untemplated implementation of setInput on opaque pointer. |
85 | void setInput(void* ptr, size_t elements, size_t elementSize); |
86 | |
87 | // Untemplated implementation of setOutput on unbound buffer. |
88 | void setOutput( |
89 | std::unique_ptr<transport::UnboundBuffer> buf, |
90 | std::vector<size_t> elements, |
91 | size_t elementSize); |
92 | |
93 | // Untemplated implementation of setOutput on opaque pointer. |
94 | void setOutput(void* ptr, std::vector<size_t> elements, size_t elementSize); |
95 | |
96 | friend void gatherv(GathervOptions&); |
97 | }; |
98 | |
99 | void gatherv(GathervOptions& opts); |
100 | |
101 | } // namespace gloo |
102 | |