1/**
2 * Copyright (c) 2019-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/gatherv.h"
10
11#include <cstring>
12#include <numeric>
13
14#include "gloo/common/logging.h"
15#include "gloo/types.h"
16
17namespace gloo {
18
19void GathervOptions::setElementSize(size_t elementSize) {
20 if (this->elementSize == 0) {
21 this->elementSize = elementSize;
22 } else {
23 GLOO_ENFORCE_EQ(
24 elementSize,
25 this->elementSize,
26 "Element size does not match existing value. ",
27 "Please double check that the input and output types match.");
28 }
29}
30
31void GathervOptions::setInput(
32 std::unique_ptr<transport::UnboundBuffer> buf,
33 size_t elementSize) {
34 this->setElementSize(elementSize);
35 this->in = std::move(buf);
36}
37
38void GathervOptions::setInput(
39 void* ptr,
40 size_t elements,
41 size_t elementSize) {
42 this->setElementSize(elementSize);
43 this->in = context->createUnboundBuffer(ptr, elements * elementSize);
44}
45
46void GathervOptions::setOutput(
47 std::unique_ptr<transport::UnboundBuffer> buf,
48 std::vector<size_t> elementsPerRank,
49 size_t elementSize) {
50 const auto totalElements =
51 std::accumulate(
52 elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
53 this->setElementSize(elementSize);
54 GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
55 this->elementsPerRank = std::move(elementsPerRank);
56 GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size);
57 this->out = std::move(buf);
58}
59
60void GathervOptions::setOutput(
61 void* ptr,
62 std::vector<size_t> elementsPerRank,
63 size_t elementSize) {
64 const auto totalElements =
65 std::accumulate(
66 elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
67 this->setElementSize(elementSize);
68 GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
69 this->elementsPerRank = std::move(elementsPerRank);
70 this->out = context->createUnboundBuffer(ptr, totalElements * elementSize);
71}
72
73void gatherv(GathervOptions& opts) {
74 const auto& context = opts.context;
75 transport::UnboundBuffer* in = opts.in.get();
76 transport::UnboundBuffer* out = opts.out.get();
77 const auto slot = Slot::build(kGatherSlotPrefix, opts.tag);
78
79 // Sanity checks
80 GLOO_ENFORCE(opts.elementSize > 0);
81 GLOO_ENFORCE(in != nullptr);
82
83 if (context->rank == opts.root) {
84 size_t offset = 0;
85 for (int i = 0; i < context->size; i++) {
86 size_t copyLength = opts.elementSize * opts.elementsPerRank[i];
87 if (i != context->rank) {
88 // Remote memory copy
89 out->recv(i, slot, offset, copyLength);
90 } else {
91 // Local memory copy
92 GLOO_ENFORCE_EQ(copyLength, in->size);
93 if (copyLength > 0) {
94 memcpy(
95 static_cast<char*>(out->ptr) + offset,
96 in->ptr,
97 in->size);
98 }
99 }
100 offset += copyLength;
101 }
102 // Wait for receive operations to complete
103 for (int i = 0; i < context->size - 1; i++) {
104 out->waitRecv(opts.timeout);
105 }
106 } else {
107 size_t sendLength = opts.elementSize * opts.elementsPerRank[context->rank];
108 GLOO_ENFORCE_GE(in->size, sendLength);
109 in->send(opts.root, slot, 0, sendLength);
110 in->waitSend(opts.timeout);
111 }
112}
113
114} // namespace gloo
115