1 | /** |
2 | * Copyright (c) 2019-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/gatherv.h" |
10 | |
11 | #include <cstring> |
12 | #include <numeric> |
13 | |
14 | #include "gloo/common/logging.h" |
15 | #include "gloo/types.h" |
16 | |
17 | namespace gloo { |
18 | |
19 | void GathervOptions::setElementSize(size_t elementSize) { |
20 | if (this->elementSize == 0) { |
21 | this->elementSize = elementSize; |
22 | } else { |
23 | GLOO_ENFORCE_EQ( |
24 | elementSize, |
25 | this->elementSize, |
26 | "Element size does not match existing value. " , |
27 | "Please double check that the input and output types match." ); |
28 | } |
29 | } |
30 | |
31 | void GathervOptions::setInput( |
32 | std::unique_ptr<transport::UnboundBuffer> buf, |
33 | size_t elementSize) { |
34 | this->setElementSize(elementSize); |
35 | this->in = std::move(buf); |
36 | } |
37 | |
38 | void GathervOptions::setInput( |
39 | void* ptr, |
40 | size_t elements, |
41 | size_t elementSize) { |
42 | this->setElementSize(elementSize); |
43 | this->in = context->createUnboundBuffer(ptr, elements * elementSize); |
44 | } |
45 | |
46 | void GathervOptions::setOutput( |
47 | std::unique_ptr<transport::UnboundBuffer> buf, |
48 | std::vector<size_t> elementsPerRank, |
49 | size_t elementSize) { |
50 | const auto totalElements = |
51 | std::accumulate( |
52 | elementsPerRank.begin(), elementsPerRank.end(), size_t(0)); |
53 | this->setElementSize(elementSize); |
54 | GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size); |
55 | this->elementsPerRank = std::move(elementsPerRank); |
56 | GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size); |
57 | this->out = std::move(buf); |
58 | } |
59 | |
60 | void GathervOptions::setOutput( |
61 | void* ptr, |
62 | std::vector<size_t> elementsPerRank, |
63 | size_t elementSize) { |
64 | const auto totalElements = |
65 | std::accumulate( |
66 | elementsPerRank.begin(), elementsPerRank.end(), size_t(0)); |
67 | this->setElementSize(elementSize); |
68 | GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size); |
69 | this->elementsPerRank = std::move(elementsPerRank); |
70 | this->out = context->createUnboundBuffer(ptr, totalElements * elementSize); |
71 | } |
72 | |
73 | void gatherv(GathervOptions& opts) { |
74 | const auto& context = opts.context; |
75 | transport::UnboundBuffer* in = opts.in.get(); |
76 | transport::UnboundBuffer* out = opts.out.get(); |
77 | const auto slot = Slot::build(kGatherSlotPrefix, opts.tag); |
78 | |
79 | // Sanity checks |
80 | GLOO_ENFORCE(opts.elementSize > 0); |
81 | GLOO_ENFORCE(in != nullptr); |
82 | |
83 | if (context->rank == opts.root) { |
84 | size_t offset = 0; |
85 | for (int i = 0; i < context->size; i++) { |
86 | size_t copyLength = opts.elementSize * opts.elementsPerRank[i]; |
87 | if (i != context->rank) { |
88 | // Remote memory copy |
89 | out->recv(i, slot, offset, copyLength); |
90 | } else { |
91 | // Local memory copy |
92 | GLOO_ENFORCE_EQ(copyLength, in->size); |
93 | if (copyLength > 0) { |
94 | memcpy( |
95 | static_cast<char*>(out->ptr) + offset, |
96 | in->ptr, |
97 | in->size); |
98 | } |
99 | } |
100 | offset += copyLength; |
101 | } |
102 | // Wait for receive operations to complete |
103 | for (int i = 0; i < context->size - 1; i++) { |
104 | out->waitRecv(opts.timeout); |
105 | } |
106 | } else { |
107 | size_t sendLength = opts.elementSize * opts.elementsPerRank[context->rank]; |
108 | GLOO_ENFORCE_GE(in->size, sendLength); |
109 | in->send(opts.root, slot, 0, sendLength); |
110 | in->waitSend(opts.timeout); |
111 | } |
112 | } |
113 | |
114 | } // namespace gloo |
115 | |