1/**
2* Copyright (c) 2017-present, Facebook, Inc.
3* All rights reserved.
4*
5* This source code is licensed under the BSD-style license found in the
6* LICENSE file in the root directory of this source tree.
7*/
8
9#include "gloo/cuda_allreduce_halving_doubling.h"
10
11#include "gloo/cuda_collectives_device.h"
12#include "gloo/cuda_collectives_host.h"
13#include "gloo/cuda_private.h"
14
15namespace gloo {
16
17namespace {
18// returns the last n bits of ctr reversed
19uint32_t reverseLastNBits(uint32_t ctr, uint32_t n) {
20 uint32_t bitMask = 1;
21 uint32_t reversed = 0;
22 while (bitMask < (static_cast<uint32_t>(1) << n)) {
23 reversed <<= 1;
24 if (ctr & bitMask) {
25 reversed |= 1;
26 }
27 bitMask <<= 1;
28 }
29 return reversed;
30}
31}
32
33template <typename T, typename W>
34void CudaAllreduceHalvingDoubling<T, W>::initBinaryBlocks() {
35 uint32_t offset = this->contextSize_;
36 uint32_t blockSize = 1;
37 uint32_t currentBlockSize = 0;
38 uint32_t prevBlockSize = 0;
39 do {
40 if (this->contextSize_ & blockSize) {
41 prevBlockSize = currentBlockSize;
42 currentBlockSize = blockSize;
43 offset -= blockSize;
44 if (myBinaryBlockSize_ != 0) {
45 nextLargerBlockSize_ = currentBlockSize;
46 break;
47 }
48 if (offset <= this->context_->rank) {
49 offsetToMyBinaryBlock_ = offset;
50 myBinaryBlockSize_ = currentBlockSize;
51 nextSmallerBlockSize_ = prevBlockSize;
52 }
53 }
54 blockSize <<= 1;
55 } while (offset != 0);
56
57 stepsWithinBlock_ = log2(myBinaryBlockSize_);
58 rankInBinaryBlock_ = this->context_->rank % myBinaryBlockSize_;
59}
60
61template <typename T, typename W>
62CudaAllreduceHalvingDoubling<T, W>::CudaAllreduceHalvingDoubling(
63 const std::shared_ptr<Context>& context,
64 const std::vector<T*>& ptrs,
65 const int count,
66 const std::vector<cudaStream_t>& streams,
67 bool pipelineBroadcastAndReduce)
68 : Algorithm(context),
69 count_(count),
70 bytes_(count_ * sizeof(T)),
71 steps_(log2(this->contextSize_)),
72 chunks_(1 << steps_),
73 chunkSize_((count_ + chunks_ - 1) / chunks_),
74 chunkBytes_(chunkSize_ * sizeof(T)),
75 fn_(CudaReductionFunction<T>::sum),
76 sendOffsets_(steps_),
77 recvOffsets_(steps_),
78 sendCounts_(steps_, 0),
79 recvCounts_(steps_, 0),
80 sendCountToLargerBlock_(0),
81 devicePtrsForBroadcast_(steps_),
82 pipelined_(pipelineBroadcastAndReduce),
83 offsetToMyBinaryBlock_(0),
84 myBinaryBlockSize_(0),
85 stepsWithinBlock_(0),
86 rankInBinaryBlock_(0),
87 nextSmallerBlockSize_(0),
88 nextLargerBlockSize_(0) {
89 initBinaryBlocks();
90 sendDataBufs_.reserve(stepsWithinBlock_);
91 recvDataBufs_.reserve(stepsWithinBlock_);
92 auto newStream = true;
93 if (streams.size() > 0) {
94 GLOO_ENFORCE_EQ(streams.size(), ptrs.size());
95 newStream = false;
96 }
97
98 for (auto i = 0; i < ptrs.size(); i++) {
99 auto ptr = CudaDevicePointer<T>::create(ptrs[i], count_);
100 if (newStream) {
101 streams_.push_back(CudaStream(ptr.getDeviceID()));
102 } else {
103 streams_.push_back(CudaStream(ptr.getDeviceID(), streams[i]));
104 }
105 devicePtrs_.push_back(std::move(ptr));
106 }
107
108 // Workspace-specific initialization
109 init();
110
111 if (this->contextSize_ == 1) {
112 return;
113 }
114
115 // Reserve max needed number of context slots. Up to 2 slots per process
116 // pair are needed (one for regular sends and one for notifications). For
117 // simplicity, the same mapping is used on all processes so that the slots
118 // trivially match across processes
119 slotOffset_ = this->context_->nextSlot(
120 2 * this->contextSize_ * (this->contextSize_ - 1));
121
122 size_t bitmask = 1;
123 size_t stepChunkSize = chunkSize_ << (steps_ - 1);
124 size_t stepChunkBytes = stepChunkSize * sizeof(T);
125 size_t sendOffset = 0;
126 size_t recvOffset = 0;
127 size_t bufferOffset = 0; // offset into recvBuf_
128 for (int i = 0; i < stepsWithinBlock_; i++) {
129 const int destRank = static_cast<int>((this->context_->rank) ^ bitmask);
130 auto& pair = this->context_->getPair(destRank);
131 const auto myRank = this->context_->rank;
132 auto slot = slotOffset_ +
133 2 * (std::min(myRank, destRank) * this->contextSize_ +
134 std::max(myRank, destRank));
135 sendOffsets_[i] = sendOffset + ((destRank & bitmask) ? stepChunkSize : 0);
136 recvOffsets_[i] =
137 recvOffset + ((this->context_->rank & bitmask) ? stepChunkSize : 0);
138 if (sendOffsets_[i] < count_) {
139 // specifies number of elements of scratch_ buffer to send in each step
140 if (sendOffsets_[i] + stepChunkSize > count_) {
141 sendCounts_[i] = count_ - sendOffsets_[i];
142 } else {
143 sendCounts_[i] = stepChunkSize;
144 }
145 }
146 sendDataBufs_.push_back(pair->createSendBuffer(slot, *scratch_, bytes_));
147 if (recvOffsets_[i] < count_) {
148 // specifies number of elements received in each step
149 if (recvOffsets_[i] + stepChunkSize > count_) {
150 recvCounts_[i] = count_ - recvOffsets_[i];
151 } else {
152 recvCounts_[i] = stepChunkSize;
153 }
154 }
155 recvDataBufs_.push_back(
156 pair->createRecvBuffer(
157 slot, &recvBuf_[bufferOffset], stepChunkBytes));
158 bufferOffset += stepChunkSize;
159 if (this->context_->rank & bitmask) {
160 sendOffset += stepChunkSize;
161 recvOffset += stepChunkSize;
162 }
163 bitmask <<= 1;
164 stepChunkSize >>= 1;
165 stepChunkBytes >>= 1;
166
167 ++slot;
168 sendNotificationBufs_.push_back(
169 pair->createSendBuffer(slot, &dummy_, sizeof(dummy_)));
170 recvNotificationBufs_.push_back(
171 pair->createRecvBuffer(slot, &dummy_, sizeof(dummy_)));
172 }
173
174 if (nextSmallerBlockSize_ != 0) {
175 const auto offsetToSmallerBlock =
176 offsetToMyBinaryBlock_ + myBinaryBlockSize_;
177 const int destRank = static_cast<int>(
178 offsetToSmallerBlock + rankInBinaryBlock_ % nextSmallerBlockSize_);
179 auto& destPair = this->context_->getPair(destRank);
180 const auto myRank = this->context_->rank;
181 const auto slot = slotOffset_ +
182 2 * (std::min(myRank, destRank) * this->contextSize_ +
183 std::max(myRank, destRank));
184 smallerBlockSendDataBuf_ = destPair->createSendBuffer(
185 slot, *scratch_, bytes_);
186 const auto itemCount = recvCounts_[stepsWithinBlock_ - 1];
187 if (itemCount > 0) {
188 smallerBlockRecvDataBuf_ = destPair->createRecvBuffer(
189 slot, &recvBuf_[bufferOffset], itemCount * sizeof(T));
190 }
191 }
192 if (nextLargerBlockSize_ != 0) {
193 // Due to the design decision of sending large messages to nearby ranks,
194 // after the reduce-scatter the reduced chunks end up in an order
195 // according to the reversed bit pattern of each proc's rank within the
196 // block. So, instead of ranks 0, 1, 2, ... 7 having blocks A, B, C, D, E,
197 // F, G, H etc. what you get is A, E, C, G, B, F, D, H. Taking this
198 // example further, if there is also a smaller binary block of size 2
199 // (with the reduced blocks A - D, E - H), rank 0 within the smaller block
200 // will need to send chunks of its buffer to ranks 0, 4, 2, 6 within the
201 // larger block (in that order) and rank 1 will send to 1, 5, 3, 7. Within
202 // the reversed bit patterns, this communication is actually 0 to [0, 1,
203 // 2, 3] and 1 to [4, 5, 6, 7].
204 const auto offsetToLargerBlock =
205 offsetToMyBinaryBlock_ - nextLargerBlockSize_;
206 const auto numSendsAndReceivesToLargerBlock =
207 nextLargerBlockSize_ / myBinaryBlockSize_;
208 const auto totalItemsToSend =
209 stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
210 sendCountToLargerBlock_ = stepChunkSize >>
211 (static_cast<size_t>(log2(numSendsAndReceivesToLargerBlock)) - 1);
212 auto srcOrdinal =
213 reverseLastNBits(rankInBinaryBlock_, log2(myBinaryBlockSize_));
214 auto destOrdinal = srcOrdinal * numSendsAndReceivesToLargerBlock;
215 for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
216 const int destRank = offsetToLargerBlock +
217 reverseLastNBits(destOrdinal, log2(nextLargerBlockSize_));
218 auto& destPair = this->context_->getPair(destRank);
219 const auto myRank = this->context_->rank;
220 const auto slot = slotOffset_ +
221 2 * (std::min(myRank, destRank) * this->contextSize_ +
222 std::max(myRank, destRank));
223 largerBlockSendDataBufs_.push_back(
224 destPair->createSendBuffer(slot, *scratch_, bytes_));
225 if (sendCountToLargerBlock_ * i < totalItemsToSend) {
226 const auto toSend = std::min(
227 sendCountToLargerBlock_,
228 totalItemsToSend - sendCountToLargerBlock_ * i);
229 largerBlockRecvDataBufs_.push_back(
230 destPair->createRecvBuffer(
231 slot, &recvBuf_[bufferOffset], toSend * sizeof(T)));
232 bufferOffset += toSend;
233 }
234 destOrdinal++;
235 }
236 }
237
238 if (pipelined_) {
239 devicePointerInit();
240 // Workspace-specific initialization for pipelined reductions/broadcasts
241 initReductionsAndBroadcasts();
242 }
243}
244
245template <typename T, typename W>
246void CudaAllreduceHalvingDoubling<T, W>::run() {
247 CudaDeviceGuard guard;
248 CudaStream& stream = *scratchStream_;
249 size_t bufferOffset = 0;
250 size_t numItems = stepsWithinBlock_ > 0 ? chunkSize_ << (steps_ - 1) : count_;
251
252 if (pipelined_ && reduceBeforeFirstSend_) {
253 reduceBeforeFirstSend_->run();
254 } else if (localReduceOp_) {
255 localReduceOp_->run();
256 }
257
258 if (this->contextSize_ == 1) {
259 GLOO_ENFORCE(localBroadcastOp_,
260 "localBroadcastOp must be initialized for single machine");
261 localBroadcastOp_->run();
262 return;
263 }
264
265 // Reduce-scatter
266 for (int i = 0; i < stepsWithinBlock_; i++) {
267 if (sendOffsets_[i] < count_) {
268 sendDataBufs_[i]->send(
269 sendOffsets_[i] * sizeof(T), sendCounts_[i] * sizeof(T));
270 }
271 if (recvOffsets_[i] < count_) {
272 if (pipelined_ && i == 0 && reduceBeforeFirstRecv_) {
273 reduceBeforeFirstRecv_->runAsync();
274 }
275 recvDataBufs_[i]->waitRecv();
276 if (pipelined_ && i == 0 && reduceBeforeFirstRecv_) {
277 reduceBeforeFirstRecv_->wait();
278 }
279 auto recvBufAtOffset = recvBuf_.range(bufferOffset, recvCounts_[i]);
280 auto scratchAtOffset = scratch_.range(recvOffsets_[i], recvCounts_[i]);
281 fn_->call(scratchAtOffset, recvBufAtOffset, recvCounts_[i], stream);
282 stream.wait();
283 }
284 sendNotificationBufs_[i]->send();
285 bufferOffset += numItems;
286 if (i != stepsWithinBlock_ - 1) {
287 numItems >>= 1;
288 }
289 }
290
291 // Communication across binary blocks for non-power-of-two number of
292 // processes
293
294 // receive from smaller block
295 // data sizes same as in the last step of intrablock reduce-scatter above
296 if (nextSmallerBlockSize_ != 0 && smallerBlockRecvDataBuf_ != nullptr) {
297 smallerBlockRecvDataBuf_->waitRecv();
298 auto recvBufAtOffset =
299 recvBuf_.range(bufferOffset, recvCounts_[stepsWithinBlock_ - 1]);
300 auto scratchAtOffset = scratch_.range(
301 recvOffsets_[stepsWithinBlock_ - 1],
302 recvCounts_[stepsWithinBlock_ - 1]);
303 fn_->call(
304 scratchAtOffset,
305 recvBufAtOffset,
306 recvCounts_[stepsWithinBlock_ - 1],
307 stream);
308 stream.wait();
309 }
310
311 const auto totalItemsToSend =
312 stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
313 if (nextLargerBlockSize_ != 0 && totalItemsToSend != 0) {
314 // scatter to larger block
315 const auto offset =
316 stepsWithinBlock_ > 0 ? recvOffsets_[stepsWithinBlock_ - 1] : 0;
317 const auto numSendsAndReceivesToLargerBlock =
318 nextLargerBlockSize_ / myBinaryBlockSize_;
319 for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
320 if (sendCountToLargerBlock_ * i < totalItemsToSend) {
321 largerBlockSendDataBufs_[i]->send(
322 (offset + i * sendCountToLargerBlock_) * sizeof(T),
323 std::min(
324 sendCountToLargerBlock_,
325 totalItemsToSend - sendCountToLargerBlock_ * i) *
326 sizeof(T));
327 }
328 }
329 // no notification is needed because the forward and backward messages
330 // across blocks are serialized in relation to each other
331
332 // receive from larger blocks
333 for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
334 if (sendCountToLargerBlock_ * i < totalItemsToSend) {
335 largerBlockRecvDataBufs_[i]->waitRecv();
336 }
337 }
338 auto recvBufAtOffset = recvBuf_.range(bufferOffset, totalItemsToSend);
339 auto scratchAtOffset = scratch_.range(offset, totalItemsToSend);
340 // msg from larger block is the final result, no reduce needed
341 stream.copyAsync(scratchAtOffset, recvBufAtOffset);
342 stream.wait();
343 }
344
345 // Send to smaller block (technically the beginning of allgather)
346 bool sentToSmallerBlock = false;
347 if (nextSmallerBlockSize_ != 0) {
348 if (recvOffsets_[stepsWithinBlock_ - 1] < count_) {
349 sentToSmallerBlock = true;
350 smallerBlockSendDataBuf_->send(
351 recvOffsets_[stepsWithinBlock_ - 1] * sizeof(T),
352 recvCounts_[stepsWithinBlock_ - 1] * sizeof(T));
353 }
354 }
355
356 // Allgather
357 numItems = chunkSize_ << (steps_ - stepsWithinBlock_);
358 for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
359 // verify that destination rank has received and processed this rank's
360 // message during the reduce-scatter phase
361 recvNotificationBufs_[i]->waitRecv();
362 if (recvOffsets_[i] < count_) {
363 sendDataBufs_[i]->send(
364 recvOffsets_[i] * sizeof(T), recvCounts_[i] * sizeof(T));
365 }
366 bufferOffset -= numItems;
367 if (sendOffsets_[i] < count_) {
368 recvDataBufs_[i]->waitRecv();
369 auto recvBufAtOffset = recvBuf_.range(bufferOffset, sendCounts_[i]);
370 auto scratchAtOffset = scratch_.range(sendOffsets_[i], sendCounts_[i]);
371 stream.copyAsync(scratchAtOffset, recvBufAtOffset);
372 stream.wait();
373 }
374 if (pipelined_ && broadcastOps_[i]) {
375 broadcastOps_[i]->runAsync();
376 }
377 numItems <<= 1;
378
379 // Send notification to the pair we just received from that
380 // we're done dealing with the receive buffer.
381 sendNotificationBufs_[i]->send();
382 }
383
384 if (pipelined_ && stepsWithinBlock_ > 0) {
385 for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
386 if (broadcastOps_[i]) {
387 broadcastOps_[i]->wait();
388 }
389 }
390 } else if (localBroadcastOp_) {
391 localBroadcastOp_->runAsync();
392 localBroadcastOp_->wait();
393 }
394
395 // Wait for notifications from our peers within the block to make
396 // sure we can send data immediately without risking overwriting
397 // data in its receive buffer before it consumed that data.
398 for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
399 recvNotificationBufs_[i]->waitRecv();
400 }
401
402 // We have to be sure the send to the smaller block (if any) has
403 // completed before returning. If we don't, the buffer contents may
404 // be modified by our caller.
405 if (sentToSmallerBlock) {
406 smallerBlockSendDataBuf_->waitSend();
407 }
408}
409
410template <typename T, typename W>
411void CudaAllreduceHalvingDoubling<T, W>::devicePointerInit() {
412 size_t offset, numElements;
413
414 for (int i = 0; i < stepsWithinBlock_; i++) {
415 // in the first broadcast (with step 'steps_ - 1'), include both the local
416 // chunk result from reduce-scatter and the first received chunk
417 offset = i == stepsWithinBlock_ - 1
418 ? std::min(recvOffsets_[i], sendOffsets_[i])
419 : sendOffsets_[i];
420 numElements = i == stepsWithinBlock_ - 1 ? recvCounts_[i] + sendCounts_[i]
421 : sendCounts_[i];
422 if (offset > count_) {
423 scratchPtrForBroadcast_.push_back(typename W::Pointer());
424 continue;
425 }
426 if (offset + numElements > count_) {
427 numElements = count_ - offset;
428 }
429
430 scratchPtrForBroadcast_.push_back(scratch_.range(offset, numElements));
431 for (int j = 0; j < devicePtrs_.size(); j++) {
432 devicePtrsForBroadcast_[i].push_back(
433 devicePtrs_[j].range(offset, numElements));
434 }
435 }
436 if (sendOffsets_[0] < count_) {
437 scratchPtrForFirstSend_ = scratch_.range(sendOffsets_[0], sendCounts_[0]);
438 }
439 if (recvOffsets_[0] < count_) {
440 scratchPtrForFirstRecv_ = scratch_.range(recvOffsets_[0], recvCounts_[0]);
441 }
442
443 for (int i = 0; i < devicePtrs_.size(); i++) {
444 if (sendOffsets_[0] < count_) {
445 devicePtrsForFirstSend_.push_back(
446 devicePtrs_[i].range(sendOffsets_[0], sendCounts_[0]));
447 }
448 if (recvOffsets_[0] < count_) {
449 devicePtrsForFirstRecv_.push_back(
450 devicePtrs_[i].range(recvOffsets_[0], recvCounts_[0]));
451 }
452 }
453}
454
455template <typename T, typename W>
456template <typename U>
457void CudaAllreduceHalvingDoubling<T, W>::init(
458 typename std::enable_if<
459 std::is_same<U, CudaHostWorkspace<T>>::value,
460 typename U::Pointer>::type*) {
461 // Since reduction is executed on the CPU, the scratch space
462 // where they are accumulated is a new host side buffer.
463 scratch_ = W::Pointer::alloc(count_);
464 scratchStream_ = &streams_[0];
465
466 // pad receive buffer size to nearest power of 2 to ensure sufficient space
467 recvBuf_ = W::Pointer::alloc(chunkSize_ << steps_);
468
469 // Set up local reduction and broadcast operations on the host.
470 // If devicePtrs_.size() == 1 these functions construct an op that
471 // executes a memcpy such that scratch_ always holds the result.
472
473 // local reduce and broadcast ops are only used in the non-pipelined case and
474 // for blocks of size 1
475 if (pipelined_ && stepsWithinBlock_ > 0) {
476 return;
477 }
478 if (bytes_ < kOnDeviceThreshold) {
479 localReduceOp_ =
480 cudaHostReduce(streams_, devicePtrs_, scratch_, fn_, 0, count_);
481 localBroadcastOp_ =
482 cudaHostBroadcast(streams_, devicePtrs_, scratch_, 0, count_);
483 } else {
484 localReduceOp_ =
485 cudaDeviceReduce(streams_, devicePtrs_, scratch_, fn_, 0, count_);
486 localBroadcastOp_ =
487 cudaDeviceBroadcast(streams_, devicePtrs_, scratch_, 0, count_);
488 }
489}
490
491template <typename T, typename W>
492template <typename U>
493void CudaAllreduceHalvingDoubling<T, W>::init(
494 typename std::enable_if<
495 std::is_same<U, CudaDeviceWorkspace<T>>::value,
496 typename U::Pointer>::type*) {
497 // The networking adapter does DMA to/from GPU memory, so we should reduce
498 // onto the device that's closest to the networking adapter bound
499 // to our context. This uses PCI distance to find closest GPU.
500 auto index = findCudaDevicePointerClosestToDevice(
501 devicePtrs_, this->context_->getDevice());
502 scratch_ = CudaDevicePointer<T>::create(devicePtrs_[index]);
503 scratchStream_ = &streams_[index];
504
505 // Inbox/outbox must be colocated with scratch buffer to avoid
506 // cross device copies while accumulating the reduction.
507 {
508 CudaDeviceScope scope(scratch_.getDeviceID());
509 // pad receive buffer size to nearest power of 2 to ensure sufficient space
510 recvBuf_ = W::Pointer::alloc(chunkSize_ << steps_);
511 }
512
513 // Set up local reduction and broadcast operations on the device.
514 // When running with a device workspace we intend to never leave the device.
515
516 // local reduce and broadcast ops are only used in the non-pipelined case and
517 // for blocks of size 1
518 if (pipelined_ && stepsWithinBlock_ > 0) {
519 return;
520 }
521 if (devicePtrs_.size() > 1) {
522 localReduceOp_ =
523 cudaDeviceReduce(streams_, devicePtrs_, scratch_, fn_, 0, count_);
524 localBroadcastOp_ =
525 cudaDeviceBroadcast(streams_, devicePtrs_, scratch_, 0, count_);
526 }
527}
528
529template <typename T, typename W>
530template <typename U>
531void CudaAllreduceHalvingDoubling<T, W>::initReductionsAndBroadcasts(
532 typename std::enable_if<
533 std::is_same<U, CudaHostWorkspace<T>>::value,
534 typename U::Pointer>::type*) {
535 if (stepsWithinBlock_ == 0) {
536 return;
537 }
538 if (sendCounts_[0] * sizeof(T) < kOnDeviceThreshold) {
539 if (!devicePtrsForFirstSend_.empty()) {
540 reduceBeforeFirstSend_ = cudaHostReduce(
541 streams_,
542 devicePtrsForFirstSend_,
543 scratchPtrForFirstSend_,
544 fn_,
545 0,
546 sendCounts_[0]);
547 }
548 if (!devicePtrsForFirstRecv_.empty()) {
549 reduceBeforeFirstRecv_ = cudaHostReduce(
550 streams_,
551 devicePtrsForFirstRecv_,
552 scratchPtrForFirstRecv_,
553 fn_,
554 0,
555 recvCounts_[0]);
556 }
557 } else {
558 if (!devicePtrsForFirstSend_.empty()) {
559 reduceBeforeFirstSend_ = cudaDeviceReduce(
560 streams_,
561 devicePtrsForFirstSend_,
562 scratchPtrForFirstSend_,
563 fn_,
564 0,
565 sendCounts_[0]);
566 }
567 if (!devicePtrsForFirstRecv_.empty()) {
568 reduceBeforeFirstRecv_ = cudaDeviceReduce(
569 streams_,
570 devicePtrsForFirstRecv_,
571 scratchPtrForFirstRecv_,
572 fn_,
573 0,
574 recvCounts_[0]);
575 }
576 }
577 for (int i = 0; i < stepsWithinBlock_; i++) {
578 if (devicePtrsForBroadcast_[i].empty()) {
579 broadcastOps_.push_back(nullptr);
580 continue;
581 }
582 const size_t numElementsInBcast = i == stepsWithinBlock_ - 1
583 ? sendCounts_[i] + recvCounts_[i]
584 : sendCounts_[i];
585 if (numElementsInBcast * sizeof(T) < kOnDeviceThreshold) {
586 broadcastOps_.push_back(cudaHostBroadcast(
587 streams_,
588 devicePtrsForBroadcast_[i],
589 scratchPtrForBroadcast_[i],
590 0,
591 numElementsInBcast));
592 } else {
593 broadcastOps_.push_back(cudaDeviceBroadcast(
594 streams_,
595 devicePtrsForBroadcast_[i],
596 scratchPtrForBroadcast_[i],
597 0,
598 numElementsInBcast));
599 }
600 }
601}
602
603template <typename T, typename W>
604template <typename U>
605void CudaAllreduceHalvingDoubling<T, W>::initReductionsAndBroadcasts(
606 typename std::enable_if<
607 std::is_same<U, CudaDeviceWorkspace<T>>::value,
608 typename U::Pointer>::type*) {
609 if (stepsWithinBlock_ == 0) {
610 return;
611 }
612 if (!devicePtrsForFirstSend_.empty()) {
613 reduceBeforeFirstSend_ = cudaDeviceReduce(
614 streams_,
615 devicePtrsForFirstSend_,
616 scratchPtrForFirstSend_,
617 fn_,
618 0,
619 sendCounts_[0]);
620 }
621 if (!devicePtrsForFirstRecv_.empty()) {
622 reduceBeforeFirstRecv_ = cudaDeviceReduce(
623 streams_,
624 devicePtrsForFirstRecv_,
625 scratchPtrForFirstRecv_,
626 fn_,
627 0,
628 recvCounts_[0]);
629 }
630 for (int i = 0; i < stepsWithinBlock_; i++) {
631 if (devicePtrsForBroadcast_[i].empty()) {
632 broadcastOps_.push_back(nullptr);
633 continue;
634 }
635 broadcastOps_.push_back(cudaDeviceBroadcast(
636 streams_,
637 devicePtrsForBroadcast_[i],
638 scratchPtrForBroadcast_[i],
639 0,
640 i == stepsWithinBlock_ - 1 ? sendCounts_[i] + recvCounts_[i]
641 : sendCounts_[i]));
642 }
643}
644
645#define INSTANTIATE_TEMPLATE(T) \
646 template class CudaAllreduceHalvingDoubling<T, CudaHostWorkspace<T>>; \
647 template class CudaAllreduceHalvingDoubling<T, CudaDeviceWorkspace<T>>;
648
649INSTANTIATE_TEMPLATE(int8_t);
650INSTANTIATE_TEMPLATE(uint8_t);
651INSTANTIATE_TEMPLATE(int32_t);
652INSTANTIATE_TEMPLATE(int64_t);
653INSTANTIATE_TEMPLATE(uint64_t);
654INSTANTIATE_TEMPLATE(float);
655INSTANTIATE_TEMPLATE(double);
656INSTANTIATE_TEMPLATE(float16);
657
658} // namespace gloo
659