1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/context.h"
12#include "gloo/transport/unbound_buffer.h"
13
14namespace gloo {
15
16class GatherOptions {
17 public:
18 explicit GatherOptions(const std::shared_ptr<Context>& context)
19 : context(context), timeout(context->getTimeout()) {}
20
21 template <typename T>
22 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
23 elementSize = sizeof(T);
24 in = std::move(buf);
25 }
26
27 template <typename T>
28 void setInput(T* ptr, size_t elements) {
29 elementSize = sizeof(T);
30 in = context->createUnboundBuffer(ptr, elements * sizeof(T));
31 }
32
33 template <typename T>
34 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
35 elementSize = sizeof(T);
36 out = std::move(buf);
37 }
38
39 template <typename T>
40 void setOutput(T* ptr, size_t elements) {
41 elementSize = sizeof(T);
42 out = context->createUnboundBuffer(ptr, elements * sizeof(T));
43 }
44
45 void setRoot(int root) {
46 this->root = root;
47 }
48
49 void setTag(uint32_t tag) {
50 this->tag = tag;
51 }
52
53 void setTimeout(std::chrono::milliseconds timeout) {
54 this->timeout = timeout;
55 }
56
57 protected:
58 std::shared_ptr<Context> context;
59 std::unique_ptr<transport::UnboundBuffer> in;
60 std::unique_ptr<transport::UnboundBuffer> out;
61
62 // Number of bytes per element.
63 size_t elementSize = 0;
64
65 // Rank of receiving process.
66 int root = -1;
67
68 // Tag for this operation.
69 // Must be unique across operations executing in parallel.
70 uint32_t tag = 0;
71
72 // End-to-end timeout for this operation.
73 std::chrono::milliseconds timeout;
74
75 friend void gather(GatherOptions&);
76};
77
78void gather(GatherOptions& opts);
79
80} // namespace gloo
81