1/**
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/alltoallv.h"
10
11#include <cstring>
12#include <numeric>
13
14#include "gloo/common/logging.h"
15#include "gloo/types.h"
16
17namespace gloo {
18
19static void splitOffsetsAndLengths(
20 const std::vector<int64_t>& elementsPerRank,
21 size_t elementSize,
22 std::vector<size_t>& offsets,
23 std::vector<size_t>& lengths) {
24 size_t offset = 0;
25 for (size_t elements : elementsPerRank) {
26 size_t length = elements * elementSize;
27 offsets.push_back(offset);
28 lengths.push_back(length);
29 offset += length;
30 }
31}
32
33void AlltoallvOptions::setElementSize(size_t elementSize) {
34 if (this->elementSize == 0) {
35 this->elementSize = elementSize;
36 } else {
37 GLOO_ENFORCE_EQ(
38 elementSize,
39 this->elementSize,
40 "Element size does not match existing value. ",
41 "Please double check that the input and output types match.");
42 }
43}
44
45void AlltoallvOptions::setInput(
46 std::unique_ptr<transport::UnboundBuffer> buf,
47 std::vector<int64_t> elementsPerRank,
48 size_t elementSize) {
49 const auto totalElements = std::accumulate(
50 elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
51 this->setElementSize(elementSize);
52 GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
53 this->inOffsetPerRank.reserve(elementsPerRank.size());
54 this->inLengthPerRank.reserve(elementsPerRank.size());
55 splitOffsetsAndLengths(
56 elementsPerRank,
57 elementSize,
58 this->inOffsetPerRank,
59 this->inLengthPerRank);
60 GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size);
61 this->in = std::move(buf);
62}
63
64void AlltoallvOptions::setInput(
65 void* ptr,
66 std::vector<int64_t> elementsPerRank,
67 size_t elementSize) {
68 const auto totalElements = std::accumulate(
69 elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
70 this->setElementSize(elementSize);
71 GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
72 this->inOffsetPerRank.reserve(elementsPerRank.size());
73 this->inLengthPerRank.reserve(elementsPerRank.size());
74 splitOffsetsAndLengths(
75 elementsPerRank,
76 elementSize,
77 this->inOffsetPerRank,
78 this->inLengthPerRank);
79 this->in = context->createUnboundBuffer(ptr, totalElements * elementSize);
80}
81
82void AlltoallvOptions::setOutput(
83 std::unique_ptr<transport::UnboundBuffer> buf,
84 std::vector<int64_t> elementsPerRank,
85 size_t elementSize) {
86 const auto totalElements = std::accumulate(
87 elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
88 this->setElementSize(elementSize);
89 GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
90 this->outOffsetPerRank.reserve(elementsPerRank.size());
91 this->outLengthPerRank.reserve(elementsPerRank.size());
92 splitOffsetsAndLengths(
93 elementsPerRank,
94 elementSize,
95 this->outOffsetPerRank,
96 this->outLengthPerRank);
97 GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size);
98 this->out = std::move(buf);
99}
100
101void AlltoallvOptions::setOutput(
102 void* ptr,
103 std::vector<int64_t> elementsPerRank,
104 size_t elementSize) {
105 const auto totalElements = std::accumulate(
106 elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
107 this->setElementSize(elementSize);
108 GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
109 this->outOffsetPerRank.reserve(elementsPerRank.size());
110 this->outLengthPerRank.reserve(elementsPerRank.size());
111 splitOffsetsAndLengths(
112 elementsPerRank,
113 elementSize,
114 this->outOffsetPerRank,
115 this->outLengthPerRank);
116 this->out = context->createUnboundBuffer(ptr, totalElements * elementSize);
117}
118
119void alltoallv(AlltoallvOptions& opts) {
120 const auto& context = opts.context;
121 transport::UnboundBuffer* in = opts.in.get();
122 transport::UnboundBuffer* out = opts.out.get();
123 std::vector<size_t>& inOffsetPerRank = opts.inOffsetPerRank;
124 std::vector<size_t>& inLengthPerRank = opts.inLengthPerRank;
125 std::vector<size_t>& outOffsetPerRank = opts.outOffsetPerRank;
126 std::vector<size_t>& outLengthPerRank = opts.outLengthPerRank;
127 const auto slot = Slot::build(kAlltoallSlotPrefix, opts.tag);
128
129 // Sanity checks.
130 GLOO_ENFORCE(opts.elementSize > 0);
131 GLOO_ENFORCE(in != nullptr);
132 GLOO_ENFORCE(out != nullptr);
133
134 int myRank = context->rank;
135 int worldSize = context->size;
136
137 // Local copy.
138 GLOO_ENFORCE(inLengthPerRank[myRank] == outLengthPerRank[myRank]);
139 size_t myInOffset = inOffsetPerRank[myRank];
140 size_t myOutOffset = outOffsetPerRank[myRank];
141 size_t myChunkSize = inLengthPerRank[myRank];
142 memcpy(
143 static_cast<char*>(out->ptr) + myOutOffset,
144 static_cast<char*>(in->ptr) + myInOffset,
145 myChunkSize);
146
147 // Remote copy.
148 for (int i = 1; i < worldSize; i++) {
149 int sendRank = (myRank + i) % worldSize;
150 int recvRank = (myRank + worldSize - i) % worldSize;
151 in->send(
152 sendRank, slot, inOffsetPerRank[sendRank], inLengthPerRank[sendRank]);
153 out->recv(
154 recvRank, slot, outOffsetPerRank[recvRank], outLengthPerRank[recvRank]);
155 }
156
157 for (int i = 1; i < worldSize; i++) {
158 in->waitSend(opts.timeout);
159 out->waitRecv(opts.timeout);
160 }
161}
162
163} // namespace gloo
164