1 | /** |
2 | * Copyright (c) 2018-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/alltoall.h" |
10 | |
11 | #include <cstring> |
12 | |
13 | #include "gloo/common/logging.h" |
14 | #include "gloo/types.h" |
15 | |
16 | namespace gloo { |
17 | |
18 | void alltoall(AlltoallOptions& opts) { |
19 | const auto& context = opts.context; |
20 | transport::UnboundBuffer* in = opts.in.get(); |
21 | transport::UnboundBuffer* out = opts.out.get(); |
22 | const auto slot = Slot::build(kAlltoallSlotPrefix, opts.tag); |
23 | |
24 | // Sanity checks. |
25 | // Number of elements should be evenly split in input and output buffers. |
26 | GLOO_ENFORCE(opts.elementSize > 0); |
27 | GLOO_ENFORCE(in != nullptr); |
28 | GLOO_ENFORCE(out != nullptr); |
29 | GLOO_ENFORCE(in->size % context->size == 0); |
30 | GLOO_ENFORCE(in->size == out->size); |
31 | |
32 | size_t chunkSize = in->size / context->size; |
33 | int myRank = context->rank; |
34 | int worldSize = context->size; |
35 | |
36 | // Local copy. |
37 | memcpy( |
38 | static_cast<char*>(out->ptr) + myRank * chunkSize, |
39 | static_cast<char*>(in->ptr) + myRank * chunkSize, |
40 | chunkSize); |
41 | |
42 | // Remote copy. |
43 | for (int i = 1; i < worldSize; i++) { |
44 | int sendRank = (myRank + i) % worldSize; |
45 | int recvRank = (myRank + worldSize - i) % worldSize; |
46 | in->send(sendRank, slot, sendRank * chunkSize, chunkSize); |
47 | out->recv(recvRank, slot, recvRank * chunkSize, chunkSize); |
48 | } |
49 | |
50 | for (int i = 1; i < worldSize; i++) { |
51 | in->waitSend(opts.timeout); |
52 | out->waitRecv(opts.timeout); |
53 | } |
54 | } |
55 | |
56 | } // namespace gloo |
57 | |