1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/context.h"
12#include "gloo/transport/unbound_buffer.h"
13
14namespace gloo {
15
16class AllgatherOptions {
17 public:
18 explicit AllgatherOptions(const std::shared_ptr<Context>& context)
19 : context(context), timeout(context->getTimeout()) {}
20
21 template <typename T>
22 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
23 elementSize = sizeof(T);
24 in = std::move(buf);
25 }
26
27 template <typename T>
28 void setInput(T* ptr, size_t elements) {
29 elementSize = sizeof(T);
30 in = context->createUnboundBuffer(ptr, elements * sizeof(T));
31 }
32
33 template <typename T>
34 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
35 elementSize = sizeof(T);
36 out = std::move(buf);
37 }
38
39 template <typename T>
40 void setOutput(T* ptr, size_t elements) {
41 elementSize = sizeof(T);
42 out = context->createUnboundBuffer(ptr, elements * sizeof(T));
43 }
44
45 void setTag(uint32_t tag) {
46 this->tag = tag;
47 }
48
49 void setTimeout(std::chrono::milliseconds timeout) {
50 this->timeout = timeout;
51 }
52
53 protected:
54 std::shared_ptr<Context> context;
55 std::unique_ptr<transport::UnboundBuffer> in;
56 std::unique_ptr<transport::UnboundBuffer> out;
57
58 // Number of bytes per element.
59 size_t elementSize = 0;
60
61 // Tag for this operation.
62 // Must be unique across operations executing in parallel.
63 uint32_t tag = 0;
64
65 // End-to-end timeout for this operation.
66 std::chrono::milliseconds timeout;
67
68 friend void allgather(AllgatherOptions&);
69};
70
71void allgather(AllgatherOptions& opts);
72
73} // namespace gloo
74