1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/common/logging.h"
12#include "gloo/context.h"
13#include "gloo/transport/unbound_buffer.h"
14
15namespace gloo {
16
17class AlltoallOptions {
18 public:
19 explicit AlltoallOptions(const std::shared_ptr<Context>& context)
20 : context(context), timeout(context->getTimeout()) {}
21
22 template <typename T>
23 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
24 elementSize = sizeof(T);
25 in = std::move(buf);
26 }
27
28 template <typename T>
29 void setInput(T* ptr, size_t elements) {
30 elementSize = sizeof(T);
31 in = context->createUnboundBuffer(ptr, elements * sizeof(T));
32 }
33
34 template <typename T>
35 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
36 elementSize = sizeof(T);
37 out = std::move(buf);
38 }
39
40 template <typename T>
41 void setOutput(T* ptr, size_t elements) {
42 elementSize = sizeof(T);
43 out = context->createUnboundBuffer(ptr, elements * sizeof(T));
44 }
45
46 void setTag(uint32_t tag) {
47 this->tag = tag;
48 }
49
50 void setTimeout(std::chrono::milliseconds timeout) {
51 GLOO_ENFORCE(timeout.count() > 0);
52 this->timeout = timeout;
53 }
54
55 protected:
56 std::shared_ptr<Context> context;
57 std::unique_ptr<transport::UnboundBuffer> in;
58 std::unique_ptr<transport::UnboundBuffer> out;
59
60 // Number of bytes per element.
61 size_t elementSize = 0;
62
63 // Tag for this operation.
64 // Must be unique across operations executing in parallel.
65 uint32_t tag = 0;
66
67 // End-to-end timeout for this operation.
68 std::chrono::milliseconds timeout;
69
70 friend void alltoall(AlltoallOptions&);
71};
72
73void alltoall(AlltoallOptions& opts);
74
75} // namespace gloo
76