1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/context.h"
12#include "gloo/transport/unbound_buffer.h"
13
14namespace gloo {
15
16class BroadcastOptions {
17 public:
18 explicit BroadcastOptions(const std::shared_ptr<Context>& context)
19 : context(context), timeout(context->getTimeout()) {}
20
21 template <typename T>
22 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
23 this->elements = buf->size / sizeof(T);
24 this->elementSize = sizeof(T);
25 this->in = std::move(buf);
26 }
27
28 template <typename T>
29 void setInput(T* ptr, size_t elements) {
30 this->elements = elements;
31 this->elementSize = sizeof(T);
32 this->in = context->createUnboundBuffer(ptr, elements * sizeof(T));
33 }
34
35 template <typename T>
36 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
37 this->elements = buf->size / sizeof(T);
38 this->elementSize = sizeof(T);
39 this->out = std::move(buf);
40 }
41
42 template <typename T>
43 void setOutput(T* ptr, size_t elements) {
44 this->elements = elements;
45 this->elementSize = sizeof(T);
46 this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
47 }
48
49 void setRoot(int root) {
50 this->root = root;
51 }
52
53 void setTag(uint32_t tag) {
54 this->tag = tag;
55 }
56
57 void setTimeout(std::chrono::milliseconds timeout) {
58 this->timeout = timeout;
59 }
60
61 protected:
62 std::shared_ptr<Context> context;
63
64 // Broadcast has an optional input buffer for the root.
65 std::unique_ptr<transport::UnboundBuffer> in;
66
67 // Broadcast has a mandatory output buffer for all ranks.
68 std::unique_ptr<transport::UnboundBuffer> out;
69
70 // Number of elements.
71 size_t elements = 0;
72
73 // Number of bytes per element.
74 size_t elementSize = 0;
75
76 // Rank of process to broadcast from.
77 int root = -1;
78
79 // Tag for this operation.
80 // Must be unique across operations executing in parallel.
81 uint32_t tag = 0;
82
83 // End-to-end timeout for this operation.
84 std::chrono::milliseconds timeout;
85
86 friend void broadcast(BroadcastOptions&);
87};
88
89void broadcast(BroadcastOptions& opts);
90
91} // namespace gloo
92