1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/cuda_broadcast_one_to_all.h"
10
11#include "gloo/cuda_collectives_device.h"
12#include "gloo/cuda_collectives_host.h"
13#include "gloo/cuda_private.h"
14
15namespace gloo {
16
17template <typename T, typename W>
18CudaBroadcastOneToAll<T, W>::CudaBroadcastOneToAll(
19 const std::shared_ptr<Context>& context,
20 const std::vector<T*>& ptrs,
21 int count,
22 int rootRank,
23 int rootPointerRank,
24 const std::vector<cudaStream_t>& streams)
25 : Algorithm(context),
26 count_(count),
27 bytes_(count * sizeof(T)),
28 rootRank_(rootRank),
29 rootPointerRank_(rootPointerRank),
30 synchronizeDeviceOutputs_(streams.size() == 0) {
31 GLOO_ENFORCE_GE(rootRank_, 0);
32 GLOO_ENFORCE_LT(rootRank_, contextSize_);
33
34 auto newStream = true;
35 if (streams.size() > 0) {
36 GLOO_ENFORCE_EQ(streams.size(), ptrs.size());
37 newStream = false;
38 }
39
40 for (auto i = 0; i < ptrs.size(); i++) {
41 auto ptr = CudaDevicePointer<T>::create(ptrs[i], count_);
42 if (newStream) {
43 streams_.push_back(CudaStream(ptr.getDeviceID()));
44 } else {
45 streams_.push_back(CudaStream(ptr.getDeviceID(), streams[i]));
46 }
47 devicePtrs_.push_back(std::move(ptr));
48 }
49
50 // Workspace specific initialization (see below)
51 init();
52
53 // Setup pairs/buffers for sender/receivers
54 if (contextSize_ > 1) {
55 auto slot = context_->nextSlot();
56 if (contextRank_ == rootRank_) {
57 sender_.resize(contextSize_);
58 for (auto i = 0; i < contextSize_; i++) {
59 if (i == contextRank_) {
60 continue;
61 }
62
63 sender_[i] = make_unique<forSender>();
64 auto& pair = context_->getPair(i);
65 sender_[i]->clearToSendBuffer = pair->createRecvBuffer(
66 slot, &sender_[i]->dummy, sizeof(sender_[i]->dummy));
67 sender_[i]->sendBuffer = pair->createSendBuffer(
68 slot, *scratch_, bytes_);
69 }
70 } else {
71 receiver_ = make_unique<forReceiver>();
72 auto& rootPair = context_->getPair(rootRank_);
73 receiver_->clearToSendBuffer = rootPair->createSendBuffer(
74 slot, &receiver_->dummy, sizeof(receiver_->dummy));
75 receiver_->recvBuffer = rootPair->createRecvBuffer(
76 slot, *scratch_, bytes_);
77 }
78 }
79
80 // Setup local broadcast if needed
81 if (devicePtrs_.size() > 1) {
82 localBroadcastOp_ =
83 cudaDeviceBroadcast(
84 streams_,
85 devicePtrs_,
86 devicePtrs_[rootPointerRank],
87 0,
88 count_);
89 }
90}
91
92template <typename T, typename W>
93void CudaBroadcastOneToAll<T, W>::run() {
94 if (contextSize_ == 1) {
95 if (localBroadcastOp_) {
96 localBroadcastOp_->runAsync();
97 if (synchronizeDeviceOutputs_) {
98 localBroadcastOp_->wait();
99 }
100 }
101 return;
102 }
103
104 if (contextRank_ == rootRank_) {
105 CudaStream& stream = streams_[rootPointerRank_];
106
107 // Copy device buffer to host
108 stream.copyAsync(scratch_, devicePtrs_[rootPointerRank_]);
109 stream.wait();
110
111 // Fire off send operations after receiving clear to send
112 for (auto i = 0; i < contextSize_; i++) {
113 if (i == contextRank_) {
114 continue;
115 }
116 sender_[i]->clearToSendBuffer->waitRecv();
117 sender_[i]->sendBuffer->send();
118 }
119
120 // Broadcast locally while sends are happening
121 if (localBroadcastOp_) {
122 localBroadcastOp_->runAsync();
123 if (synchronizeDeviceOutputs_) {
124 localBroadcastOp_->wait();
125 }
126 }
127
128 // Wait for all send operations to complete
129 for (auto i = 0; i < contextSize_; i++) {
130 if (i == contextRank_) {
131 continue;
132 }
133 sender_[i]->sendBuffer->waitSend();
134 }
135 } else {
136 CudaStream& stream = streams_[rootPointerRank_];
137 // Ensure previous H2D copy is complete before notifying the sender
138 // NOTE: this only waits for last copyAsync, not for the whole stream
139 stream.wait();
140
141 receiver_->clearToSendBuffer->send();
142 receiver_->recvBuffer->waitRecv();
143
144 // Copy host buffer to device
145 stream.copyAsync(devicePtrs_[rootPointerRank_], scratch_);
146
147 // Broadcast locally after receiving from root
148 if (localBroadcastOp_) {
149 // Since broadcast synchronizes on root pointer, there is no
150 // need to explicity wait for the memcpy to complete.
151 localBroadcastOp_->runAsync();
152 if (synchronizeDeviceOutputs_) {
153 localBroadcastOp_->wait();
154 }
155 } else {
156 // Wait for memcpy to complete
157 if (synchronizeDeviceOutputs_) {
158 stream.wait();
159 }
160 }
161 }
162}
163
164template <typename T, typename W>
165template <typename U>
166void CudaBroadcastOneToAll<T, W>::init(
167 typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value,
168 typename U::Pointer>::type*) {
169 // Allocate host side buffer if we need to communicate
170 if (contextSize_ > 1) {
171 // Since broadcast transmits from/to a buffer in system memory, the
172 // scratch space is a new host side buffer.
173 scratch_ = W::Pointer::alloc(count_);
174 }
175}
176
177template <typename T, typename W>
178template <typename U>
179void CudaBroadcastOneToAll<T, W>::init(
180 typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value,
181 typename U::Pointer>::type*) {
182 if (contextSize_ > 1) {
183 // For GPUDirect, an additional buffer allocation is unnecessary.
184 // Instead, use the provided input buffer itself as the scratch space.
185 // The caller is the owner_
186 scratch_ = CudaDevicePointer<T>::create(devicePtrs_[0]);
187 }
188}
189
190// Instantiate templates
191#define INSTANTIATE_TEMPLATE(T) \
192template class CudaBroadcastOneToAll<T, CudaHostWorkspace<T> >; \
193template class CudaBroadcastOneToAll<T, CudaDeviceWorkspace<T> >;
194
195
196INSTANTIATE_TEMPLATE(int8_t);
197INSTANTIATE_TEMPLATE(uint8_t);
198INSTANTIATE_TEMPLATE(int32_t);
199INSTANTIATE_TEMPLATE(int64_t);
200INSTANTIATE_TEMPLATE(uint64_t);
201INSTANTIATE_TEMPLATE(float);
202INSTANTIATE_TEMPLATE(double);
203INSTANTIATE_TEMPLATE(float16);
204
205} // namespace gloo
206