1 | /** |
---|---|
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include "gloo/context.h" |
12 | #include "gloo/transport/unbound_buffer.h" |
13 | |
14 | namespace gloo { |
15 | |
16 | class BroadcastOptions { |
17 | public: |
18 | explicit BroadcastOptions(const std::shared_ptr<Context>& context) |
19 | : context(context), timeout(context->getTimeout()) {} |
20 | |
21 | template <typename T> |
22 | void setInput(std::unique_ptr<transport::UnboundBuffer> buf) { |
23 | this->elements = buf->size / sizeof(T); |
24 | this->elementSize = sizeof(T); |
25 | this->in = std::move(buf); |
26 | } |
27 | |
28 | template <typename T> |
29 | void setInput(T* ptr, size_t elements) { |
30 | this->elements = elements; |
31 | this->elementSize = sizeof(T); |
32 | this->in = context->createUnboundBuffer(ptr, elements * sizeof(T)); |
33 | } |
34 | |
35 | template <typename T> |
36 | void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) { |
37 | this->elements = buf->size / sizeof(T); |
38 | this->elementSize = sizeof(T); |
39 | this->out = std::move(buf); |
40 | } |
41 | |
42 | template <typename T> |
43 | void setOutput(T* ptr, size_t elements) { |
44 | this->elements = elements; |
45 | this->elementSize = sizeof(T); |
46 | this->out = context->createUnboundBuffer(ptr, elements * sizeof(T)); |
47 | } |
48 | |
49 | void setRoot(int root) { |
50 | this->root = root; |
51 | } |
52 | |
53 | void setTag(uint32_t tag) { |
54 | this->tag = tag; |
55 | } |
56 | |
57 | void setTimeout(std::chrono::milliseconds timeout) { |
58 | this->timeout = timeout; |
59 | } |
60 | |
61 | protected: |
62 | std::shared_ptr<Context> context; |
63 | |
64 | // Broadcast has an optional input buffer for the root. |
65 | std::unique_ptr<transport::UnboundBuffer> in; |
66 | |
67 | // Broadcast has a mandatory output buffer for all ranks. |
68 | std::unique_ptr<transport::UnboundBuffer> out; |
69 | |
70 | // Number of elements. |
71 | size_t elements = 0; |
72 | |
73 | // Number of bytes per element. |
74 | size_t elementSize = 0; |
75 | |
76 | // Rank of process to broadcast from. |
77 | int root = -1; |
78 | |
79 | // Tag for this operation. |
80 | // Must be unique across operations executing in parallel. |
81 | uint32_t tag = 0; |
82 | |
83 | // End-to-end timeout for this operation. |
84 | std::chrono::milliseconds timeout; |
85 | |
86 | friend void broadcast(BroadcastOptions&); |
87 | }; |
88 | |
89 | void broadcast(BroadcastOptions& opts); |
90 | |
91 | } // namespace gloo |
92 |