1 | /** |
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include <functional> |
12 | #include <memory> |
13 | |
14 | #include "gloo/context.h" |
15 | #include "gloo/transport/unbound_buffer.h" |
16 | |
17 | namespace gloo { |
18 | |
19 | class ReduceOptions { |
20 | public: |
21 | using Func = std::function<void(void*, const void*, const void*, size_t)>; |
22 | |
23 | explicit ReduceOptions(const std::shared_ptr<Context>& context) |
24 | : context(context), timeout(context->getTimeout()) {} |
25 | |
26 | template <typename T> |
27 | void setInput(std::unique_ptr<transport::UnboundBuffer> buf) { |
28 | this->elements = buf->size / sizeof(T); |
29 | this->elementSize = sizeof(T); |
30 | this->in = std::move(buf); |
31 | } |
32 | |
33 | template <typename T> |
34 | void setInput(T* ptr, size_t elements) { |
35 | this->elements = elements; |
36 | this->elementSize = sizeof(T); |
37 | this->in = context->createUnboundBuffer(ptr, elements * sizeof(T)); |
38 | } |
39 | |
40 | template <typename T> |
41 | void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) { |
42 | this->elements = buf->size / sizeof(T); |
43 | this->elementSize = sizeof(T); |
44 | this->out = std::move(buf); |
45 | } |
46 | |
47 | template <typename T> |
48 | void setOutput(T* ptr, size_t elements) { |
49 | this->elements = elements; |
50 | this->elementSize = sizeof(T); |
51 | this->out = context->createUnboundBuffer(ptr, elements * sizeof(T)); |
52 | } |
53 | |
54 | void setRoot(int root) { |
55 | this->root = root; |
56 | } |
57 | |
58 | void setReduceFunction(Func fn) { |
59 | this->reduce = fn; |
60 | } |
61 | |
62 | void setTag(uint32_t tag) { |
63 | this->tag = tag; |
64 | } |
65 | |
66 | void setMaxSegmentSize(size_t maxSegmentSize) { |
67 | this->maxSegmentSize = maxSegmentSize; |
68 | } |
69 | |
70 | void setTimeout(std::chrono::milliseconds timeout) { |
71 | this->timeout = timeout; |
72 | } |
73 | |
74 | protected: |
75 | std::shared_ptr<Context> context; |
76 | std::unique_ptr<transport::UnboundBuffer> in; |
77 | std::unique_ptr<transport::UnboundBuffer> out; |
78 | |
79 | // Number of elements. |
80 | size_t elements = 0; |
81 | |
82 | // Number of bytes per element. |
83 | size_t elementSize = 0; |
84 | |
85 | // Rank of process to reduce to. |
86 | int root = -1; |
87 | |
88 | // Reduction function. |
89 | Func reduce; |
90 | |
91 | // Tag for this operation. |
92 | // Must be unique across operations executing in parallel. |
93 | uint32_t tag = 0; |
94 | |
95 | // This is the maximum size of each I/O operation (send/recv) of which |
96 | // two are in flight at all times. A smaller value leads to more |
97 | // overhead and a larger value leads to poor cache behavior. |
98 | static constexpr size_t kMaxSegmentSize = 1024 * 1024; |
99 | |
100 | // Internal use only. This is used to exercise code paths where we |
101 | // have more than 2 segments per rank without making the tests slow |
102 | // (because they would require millions of elements if the default |
103 | // were not configurable). |
104 | size_t maxSegmentSize = kMaxSegmentSize; |
105 | |
106 | // End-to-end timeout for this operation. |
107 | std::chrono::milliseconds timeout; |
108 | |
109 | friend void reduce(ReduceOptions&); |
110 | }; |
111 | |
112 | void reduce(ReduceOptions& opts); |
113 | |
114 | } // namespace gloo |
115 | |