1 | /** |
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include "gloo/common/logging.h" |
12 | #include "gloo/context.h" |
13 | #include "gloo/transport/unbound_buffer.h" |
14 | |
15 | namespace gloo { |
16 | |
17 | class AlltoallvOptions { |
18 | public: |
19 | explicit AlltoallvOptions(const std::shared_ptr<Context>& context) |
20 | : context(context), timeout(context->getTimeout()) {} |
21 | |
22 | template <typename T> |
23 | void setInput( |
24 | std::unique_ptr<transport::UnboundBuffer> buf, |
25 | std::vector<int64_t> elementsPerRank) { |
26 | setInput(std::move(buf), std::move(elementsPerRank), sizeof(T)); |
27 | } |
28 | |
29 | template <typename T> |
30 | void setInput(T* ptr, std::vector<int64_t> elementsPerRank) { |
31 | setInput(static_cast<void*>(ptr), std::move(elementsPerRank), sizeof(T)); |
32 | } |
33 | |
34 | template <typename T> |
35 | void setOutput( |
36 | std::unique_ptr<transport::UnboundBuffer> buf, |
37 | std::vector<int64_t> elementsPerRank) { |
38 | setOutput(std::move(buf), std::move(elementsPerRank), sizeof(T)); |
39 | } |
40 | |
41 | template <typename T> |
42 | void setOutput(T* ptr, std::vector<int64_t> elementsPerRank) { |
43 | setOutput(static_cast<void*>(ptr), std::move(elementsPerRank), sizeof(T)); |
44 | } |
45 | |
46 | void setTag(uint32_t tag) { |
47 | this->tag = tag; |
48 | } |
49 | |
50 | void setTimeout(std::chrono::milliseconds timeout) { |
51 | GLOO_ENFORCE(timeout.count() > 0); |
52 | this->timeout = timeout; |
53 | } |
54 | |
55 | protected: |
56 | std::shared_ptr<Context> context; |
57 | std::unique_ptr<transport::UnboundBuffer> in; |
58 | std::unique_ptr<transport::UnboundBuffer> out; |
59 | std::vector<size_t> inOffsetPerRank; |
60 | std::vector<size_t> inLengthPerRank; |
61 | std::vector<size_t> outOffsetPerRank; |
62 | std::vector<size_t> outLengthPerRank; |
63 | |
64 | // Number of bytes per element. |
65 | size_t elementSize = 0; |
66 | |
67 | // Tag for this operation. |
68 | // Must be unique across operations executing in parallel. |
69 | uint32_t tag = 0; |
70 | |
71 | // Set element size, or check the argument is equal to the current value. |
72 | void setElementSize(size_t elementSize); |
73 | |
74 | // Untemplated implementation of setInput on unbound buffer. |
75 | void setInput( |
76 | std::unique_ptr<transport::UnboundBuffer> buf, |
77 | std::vector<int64_t> elementsPerRank, |
78 | size_t elementSize); |
79 | |
80 | // Untemplated implementation of setInput on opaque pointer. |
81 | void |
82 | setInput(void* ptr, std::vector<int64_t> elementsPerRank, size_t elementSize); |
83 | |
84 | // Untemplated implementation of setOutput on unbound buffer. |
85 | void setOutput( |
86 | std::unique_ptr<transport::UnboundBuffer> buf, |
87 | std::vector<int64_t> elementsPerRank, |
88 | size_t elementSize); |
89 | |
90 | // Untemplated implementation of setOutput on opaque pointer. |
91 | void |
92 | setOutput(void* ptr, std::vector<int64_t> elementsPerRank, size_t elementSize); |
93 | |
94 | // End-to-end timeout for this operation. |
95 | std::chrono::milliseconds timeout; |
96 | |
97 | friend void alltoallv(AlltoallvOptions&); |
98 | }; |
99 | |
100 | void alltoallv(AlltoallvOptions& opts); |
101 | |
102 | } // namespace gloo |
103 | |