1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/cuda_broadcast_one_to_all.h" |
10 | |
11 | #include "gloo/cuda_collectives_device.h" |
12 | #include "gloo/cuda_collectives_host.h" |
13 | #include "gloo/cuda_private.h" |
14 | |
15 | namespace gloo { |
16 | |
17 | template <typename T, typename W> |
18 | CudaBroadcastOneToAll<T, W>::CudaBroadcastOneToAll( |
19 | const std::shared_ptr<Context>& context, |
20 | const std::vector<T*>& ptrs, |
21 | int count, |
22 | int rootRank, |
23 | int rootPointerRank, |
24 | const std::vector<cudaStream_t>& streams) |
25 | : Algorithm(context), |
26 | count_(count), |
27 | bytes_(count * sizeof(T)), |
28 | rootRank_(rootRank), |
29 | rootPointerRank_(rootPointerRank), |
30 | synchronizeDeviceOutputs_(streams.size() == 0) { |
31 | GLOO_ENFORCE_GE(rootRank_, 0); |
32 | GLOO_ENFORCE_LT(rootRank_, contextSize_); |
33 | |
34 | auto newStream = true; |
35 | if (streams.size() > 0) { |
36 | GLOO_ENFORCE_EQ(streams.size(), ptrs.size()); |
37 | newStream = false; |
38 | } |
39 | |
40 | for (auto i = 0; i < ptrs.size(); i++) { |
41 | auto ptr = CudaDevicePointer<T>::create(ptrs[i], count_); |
42 | if (newStream) { |
43 | streams_.push_back(CudaStream(ptr.getDeviceID())); |
44 | } else { |
45 | streams_.push_back(CudaStream(ptr.getDeviceID(), streams[i])); |
46 | } |
47 | devicePtrs_.push_back(std::move(ptr)); |
48 | } |
49 | |
50 | // Workspace specific initialization (see below) |
51 | init(); |
52 | |
53 | // Setup pairs/buffers for sender/receivers |
54 | if (contextSize_ > 1) { |
55 | auto slot = context_->nextSlot(); |
56 | if (contextRank_ == rootRank_) { |
57 | sender_.resize(contextSize_); |
58 | for (auto i = 0; i < contextSize_; i++) { |
59 | if (i == contextRank_) { |
60 | continue; |
61 | } |
62 | |
63 | sender_[i] = make_unique<forSender>(); |
64 | auto& pair = context_->getPair(i); |
65 | sender_[i]->clearToSendBuffer = pair->createRecvBuffer( |
66 | slot, &sender_[i]->dummy, sizeof(sender_[i]->dummy)); |
67 | sender_[i]->sendBuffer = pair->createSendBuffer( |
68 | slot, *scratch_, bytes_); |
69 | } |
70 | } else { |
71 | receiver_ = make_unique<forReceiver>(); |
72 | auto& rootPair = context_->getPair(rootRank_); |
73 | receiver_->clearToSendBuffer = rootPair->createSendBuffer( |
74 | slot, &receiver_->dummy, sizeof(receiver_->dummy)); |
75 | receiver_->recvBuffer = rootPair->createRecvBuffer( |
76 | slot, *scratch_, bytes_); |
77 | } |
78 | } |
79 | |
80 | // Setup local broadcast if needed |
81 | if (devicePtrs_.size() > 1) { |
82 | localBroadcastOp_ = |
83 | cudaDeviceBroadcast( |
84 | streams_, |
85 | devicePtrs_, |
86 | devicePtrs_[rootPointerRank], |
87 | 0, |
88 | count_); |
89 | } |
90 | } |
91 | |
92 | template <typename T, typename W> |
93 | void CudaBroadcastOneToAll<T, W>::run() { |
94 | if (contextSize_ == 1) { |
95 | if (localBroadcastOp_) { |
96 | localBroadcastOp_->runAsync(); |
97 | if (synchronizeDeviceOutputs_) { |
98 | localBroadcastOp_->wait(); |
99 | } |
100 | } |
101 | return; |
102 | } |
103 | |
104 | if (contextRank_ == rootRank_) { |
105 | CudaStream& stream = streams_[rootPointerRank_]; |
106 | |
107 | // Copy device buffer to host |
108 | stream.copyAsync(scratch_, devicePtrs_[rootPointerRank_]); |
109 | stream.wait(); |
110 | |
111 | // Fire off send operations after receiving clear to send |
112 | for (auto i = 0; i < contextSize_; i++) { |
113 | if (i == contextRank_) { |
114 | continue; |
115 | } |
116 | sender_[i]->clearToSendBuffer->waitRecv(); |
117 | sender_[i]->sendBuffer->send(); |
118 | } |
119 | |
120 | // Broadcast locally while sends are happening |
121 | if (localBroadcastOp_) { |
122 | localBroadcastOp_->runAsync(); |
123 | if (synchronizeDeviceOutputs_) { |
124 | localBroadcastOp_->wait(); |
125 | } |
126 | } |
127 | |
128 | // Wait for all send operations to complete |
129 | for (auto i = 0; i < contextSize_; i++) { |
130 | if (i == contextRank_) { |
131 | continue; |
132 | } |
133 | sender_[i]->sendBuffer->waitSend(); |
134 | } |
135 | } else { |
136 | CudaStream& stream = streams_[rootPointerRank_]; |
137 | // Ensure previous H2D copy is complete before notifying the sender |
138 | // NOTE: this only waits for last copyAsync, not for the whole stream |
139 | stream.wait(); |
140 | |
141 | receiver_->clearToSendBuffer->send(); |
142 | receiver_->recvBuffer->waitRecv(); |
143 | |
144 | // Copy host buffer to device |
145 | stream.copyAsync(devicePtrs_[rootPointerRank_], scratch_); |
146 | |
147 | // Broadcast locally after receiving from root |
148 | if (localBroadcastOp_) { |
149 | // Since broadcast synchronizes on root pointer, there is no |
150 | // need to explicity wait for the memcpy to complete. |
151 | localBroadcastOp_->runAsync(); |
152 | if (synchronizeDeviceOutputs_) { |
153 | localBroadcastOp_->wait(); |
154 | } |
155 | } else { |
156 | // Wait for memcpy to complete |
157 | if (synchronizeDeviceOutputs_) { |
158 | stream.wait(); |
159 | } |
160 | } |
161 | } |
162 | } |
163 | |
164 | template <typename T, typename W> |
165 | template <typename U> |
166 | void CudaBroadcastOneToAll<T, W>::init( |
167 | typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value, |
168 | typename U::Pointer>::type*) { |
169 | // Allocate host side buffer if we need to communicate |
170 | if (contextSize_ > 1) { |
171 | // Since broadcast transmits from/to a buffer in system memory, the |
172 | // scratch space is a new host side buffer. |
173 | scratch_ = W::Pointer::alloc(count_); |
174 | } |
175 | } |
176 | |
177 | template <typename T, typename W> |
178 | template <typename U> |
179 | void CudaBroadcastOneToAll<T, W>::init( |
180 | typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value, |
181 | typename U::Pointer>::type*) { |
182 | if (contextSize_ > 1) { |
183 | // For GPUDirect, an additional buffer allocation is unnecessary. |
184 | // Instead, use the provided input buffer itself as the scratch space. |
185 | // The caller is the owner_ |
186 | scratch_ = CudaDevicePointer<T>::create(devicePtrs_[0]); |
187 | } |
188 | } |
189 | |
190 | // Instantiate templates |
191 | #define INSTANTIATE_TEMPLATE(T) \ |
192 | template class CudaBroadcastOneToAll<T, CudaHostWorkspace<T> >; \ |
193 | template class CudaBroadcastOneToAll<T, CudaDeviceWorkspace<T> >; |
194 | |
195 | |
196 | INSTANTIATE_TEMPLATE(int8_t); |
197 | INSTANTIATE_TEMPLATE(uint8_t); |
198 | INSTANTIATE_TEMPLATE(int32_t); |
199 | INSTANTIATE_TEMPLATE(int64_t); |
200 | INSTANTIATE_TEMPLATE(uint64_t); |
201 | INSTANTIATE_TEMPLATE(float); |
202 | INSTANTIATE_TEMPLATE(double); |
203 | INSTANTIATE_TEMPLATE(float16); |
204 | |
205 | } // namespace gloo |
206 | |