1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include <functional>
12#include <memory>
13
14#include "gloo/context.h"
15#include "gloo/transport/unbound_buffer.h"
16
17namespace gloo {
18
19class ReduceOptions {
20 public:
21 using Func = std::function<void(void*, const void*, const void*, size_t)>;
22
23 explicit ReduceOptions(const std::shared_ptr<Context>& context)
24 : context(context), timeout(context->getTimeout()) {}
25
26 template <typename T>
27 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
28 this->elements = buf->size / sizeof(T);
29 this->elementSize = sizeof(T);
30 this->in = std::move(buf);
31 }
32
33 template <typename T>
34 void setInput(T* ptr, size_t elements) {
35 this->elements = elements;
36 this->elementSize = sizeof(T);
37 this->in = context->createUnboundBuffer(ptr, elements * sizeof(T));
38 }
39
40 template <typename T>
41 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
42 this->elements = buf->size / sizeof(T);
43 this->elementSize = sizeof(T);
44 this->out = std::move(buf);
45 }
46
47 template <typename T>
48 void setOutput(T* ptr, size_t elements) {
49 this->elements = elements;
50 this->elementSize = sizeof(T);
51 this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
52 }
53
54 void setRoot(int root) {
55 this->root = root;
56 }
57
58 void setReduceFunction(Func fn) {
59 this->reduce = fn;
60 }
61
62 void setTag(uint32_t tag) {
63 this->tag = tag;
64 }
65
66 void setMaxSegmentSize(size_t maxSegmentSize) {
67 this->maxSegmentSize = maxSegmentSize;
68 }
69
70 void setTimeout(std::chrono::milliseconds timeout) {
71 this->timeout = timeout;
72 }
73
74 protected:
75 std::shared_ptr<Context> context;
76 std::unique_ptr<transport::UnboundBuffer> in;
77 std::unique_ptr<transport::UnboundBuffer> out;
78
79 // Number of elements.
80 size_t elements = 0;
81
82 // Number of bytes per element.
83 size_t elementSize = 0;
84
85 // Rank of process to reduce to.
86 int root = -1;
87
88 // Reduction function.
89 Func reduce;
90
91 // Tag for this operation.
92 // Must be unique across operations executing in parallel.
93 uint32_t tag = 0;
94
95 // This is the maximum size of each I/O operation (send/recv) of which
96 // two are in flight at all times. A smaller value leads to more
97 // overhead and a larger value leads to poor cache behavior.
98 static constexpr size_t kMaxSegmentSize = 1024 * 1024;
99
100 // Internal use only. This is used to exercise code paths where we
101 // have more than 2 segments per rank without making the tests slow
102 // (because they would require millions of elements if the default
103 // were not configurable).
104 size_t maxSegmentSize = kMaxSegmentSize;
105
106 // End-to-end timeout for this operation.
107 std::chrono::milliseconds timeout;
108
109 friend void reduce(ReduceOptions&);
110};
111
112void reduce(ReduceOptions& opts);
113
114} // namespace gloo
115