1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/common/logging.h"
12#include "gloo/context.h"
13#include "gloo/transport/unbound_buffer.h"
14
15namespace gloo {
16
17class AlltoallvOptions {
18 public:
19 explicit AlltoallvOptions(const std::shared_ptr<Context>& context)
20 : context(context), timeout(context->getTimeout()) {}
21
22 template <typename T>
23 void setInput(
24 std::unique_ptr<transport::UnboundBuffer> buf,
25 std::vector<int64_t> elementsPerRank) {
26 setInput(std::move(buf), std::move(elementsPerRank), sizeof(T));
27 }
28
29 template <typename T>
30 void setInput(T* ptr, std::vector<int64_t> elementsPerRank) {
31 setInput(static_cast<void*>(ptr), std::move(elementsPerRank), sizeof(T));
32 }
33
34 template <typename T>
35 void setOutput(
36 std::unique_ptr<transport::UnboundBuffer> buf,
37 std::vector<int64_t> elementsPerRank) {
38 setOutput(std::move(buf), std::move(elementsPerRank), sizeof(T));
39 }
40
41 template <typename T>
42 void setOutput(T* ptr, std::vector<int64_t> elementsPerRank) {
43 setOutput(static_cast<void*>(ptr), std::move(elementsPerRank), sizeof(T));
44 }
45
46 void setTag(uint32_t tag) {
47 this->tag = tag;
48 }
49
50 void setTimeout(std::chrono::milliseconds timeout) {
51 GLOO_ENFORCE(timeout.count() > 0);
52 this->timeout = timeout;
53 }
54
55 protected:
56 std::shared_ptr<Context> context;
57 std::unique_ptr<transport::UnboundBuffer> in;
58 std::unique_ptr<transport::UnboundBuffer> out;
59 std::vector<size_t> inOffsetPerRank;
60 std::vector<size_t> inLengthPerRank;
61 std::vector<size_t> outOffsetPerRank;
62 std::vector<size_t> outLengthPerRank;
63
64 // Number of bytes per element.
65 size_t elementSize = 0;
66
67 // Tag for this operation.
68 // Must be unique across operations executing in parallel.
69 uint32_t tag = 0;
70
71 // Set element size, or check the argument is equal to the current value.
72 void setElementSize(size_t elementSize);
73
74 // Untemplated implementation of setInput on unbound buffer.
75 void setInput(
76 std::unique_ptr<transport::UnboundBuffer> buf,
77 std::vector<int64_t> elementsPerRank,
78 size_t elementSize);
79
80 // Untemplated implementation of setInput on opaque pointer.
81 void
82 setInput(void* ptr, std::vector<int64_t> elementsPerRank, size_t elementSize);
83
84 // Untemplated implementation of setOutput on unbound buffer.
85 void setOutput(
86 std::unique_ptr<transport::UnboundBuffer> buf,
87 std::vector<int64_t> elementsPerRank,
88 size_t elementSize);
89
90 // Untemplated implementation of setOutput on opaque pointer.
91 void
92 setOutput(void* ptr, std::vector<int64_t> elementsPerRank, size_t elementSize);
93
94 // End-to-end timeout for this operation.
95 std::chrono::milliseconds timeout;
96
97 friend void alltoallv(AlltoallvOptions&);
98};
99
100void alltoallv(AlltoallvOptions& opts);
101
102} // namespace gloo
103