1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include <functional>
12#include <memory>
13#include <vector>
14
15#include "gloo/context.h"
16#include "gloo/transport/unbound_buffer.h"
17
18namespace gloo {
19
20namespace detail {
21
22struct AllreduceOptionsImpl {
23 // This type describes the function to use for element wise reduction.
24 //
25 // Its arguments are:
26 // 1. non-const output pointer
27 // 2. const input pointer 1 (may be equal to 1)
28 // 3. const input pointer 2 (may be equal to 1)
29 // 4. number of elements to reduce.
30 //
31 // Note that this function is not strictly typed and takes void pointers.
32 // This is specifically done to avoid the need for a templated options class
33 // and templated algorithm implementations. We found this adds very little
34 // value for the increase in compilation time and code size.
35 //
36 using Func = std::function<void(void*, const void*, const void*, size_t)>;
37
38 enum Algorithm {
39 UNSPECIFIED = 0,
40 RING = 1,
41 BCUBE = 2,
42 };
43
44 explicit AllreduceOptionsImpl(const std::shared_ptr<Context>& context)
45 : context(context),
46 timeout(context->getTimeout()),
47 algorithm(UNSPECIFIED) {}
48
49 std::shared_ptr<Context> context;
50
51 // End-to-end timeout for this operation.
52 std::chrono::milliseconds timeout;
53
54 // Algorithm selection.
55 Algorithm algorithm;
56
57 // Input and output buffers.
58 // The output is used as input if input is not specified.
59 std::vector<std::unique_ptr<transport::UnboundBuffer>> in;
60 std::vector<std::unique_ptr<transport::UnboundBuffer>> out;
61
62 // Number of elements.
63 size_t elements = 0;
64
65 // Number of bytes per element.
66 size_t elementSize = 0;
67
68 // Reduction function.
69 Func reduce;
70
71 // Tag for this operation.
72 // Must be unique across operations executing in parallel.
73 uint32_t tag = 0;
74
75 // This is the maximum size of each I/O operation (send/recv) of which
76 // two are in flight at all times. A smaller value leads to more
77 // overhead and a larger value leads to poor cache behavior.
78 static constexpr size_t kMaxSegmentSize = 1024 * 1024;
79
80 // Internal use only. This is used to exercise code paths where we
81 // have more than 2 segments per rank without making the tests slow
82 // (because they would require millions of elements if the default
83 // were not configurable).
84 size_t maxSegmentSize = kMaxSegmentSize;
85};
86
87} // namespace detail
88
89class AllreduceOptions {
90 public:
91 using Func = detail::AllreduceOptionsImpl::Func;
92 using Algorithm = detail::AllreduceOptionsImpl::Algorithm;
93
94 explicit AllreduceOptions(const std::shared_ptr<Context>& context)
95 : impl_(context) {}
96
97 void setAlgorithm(Algorithm algorithm) {
98 impl_.algorithm = algorithm;
99 }
100
101 template <typename T>
102 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
103 std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1);
104 bufs[0] = std::move(buf);
105 setInputs<T>(std::move(bufs));
106 }
107
108 template <typename T>
109 void setInputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
110 impl_.elements = bufs[0]->size / sizeof(T);
111 impl_.elementSize = sizeof(T);
112 impl_.in = std::move(bufs);
113 }
114
115 template <typename T>
116 void setInput(T* ptr, size_t elements) {
117 setInputs(&ptr, 1, elements);
118 }
119
120 template <typename T>
121 void setInputs(std::vector<T*> ptrs, size_t elements) {
122 setInputs(ptrs.data(), ptrs.size(), elements);
123 }
124
125 template <typename T>
126 void setInputs(T** ptrs, size_t len, size_t elements) {
127 impl_.elements = elements;
128 impl_.elementSize = sizeof(T);
129 impl_.in.reserve(len);
130 for (size_t i = 0; i < len; i++) {
131 impl_.in.push_back(
132 impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
133 }
134 }
135
136 template <typename T>
137 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
138 std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1);
139 bufs[0] = std::move(buf);
140 setOutputs<T>(std::move(bufs));
141 }
142
143 template <typename T>
144 void setOutputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
145 impl_.elements = bufs[0]->size / sizeof(T);
146 impl_.elementSize = sizeof(T);
147 impl_.out = std::move(bufs);
148 }
149
150 template <typename T>
151 void setOutput(T* ptr, size_t elements) {
152 setOutputs(&ptr, 1, elements);
153 }
154
155 template <typename T>
156 void setOutputs(std::vector<T*> ptrs, size_t elements) {
157 setOutputs(ptrs.data(), ptrs.size(), elements);
158 }
159
160 template <typename T>
161 void setOutputs(T** ptrs, size_t len, size_t elements) {
162 impl_.elements = elements;
163 impl_.elementSize = sizeof(T);
164 impl_.out.reserve(len);
165 for (size_t i = 0; i < len; i++) {
166 impl_.out.push_back(
167 impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
168 }
169 }
170
171 void setReduceFunction(Func fn) {
172 impl_.reduce = fn;
173 }
174
175 void setTag(uint32_t tag) {
176 impl_.tag = tag;
177 }
178
179 void setMaxSegmentSize(size_t maxSegmentSize) {
180 impl_.maxSegmentSize = maxSegmentSize;
181 }
182
183 void setTimeout(std::chrono::milliseconds timeout) {
184 impl_.timeout = timeout;
185 }
186
187 protected:
188 detail::AllreduceOptionsImpl impl_;
189
190 friend void allreduce(const AllreduceOptions&);
191};
192
193void allreduce(const AllreduceOptions& opts);
194
195} // namespace gloo
196