1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/cuda_allreduce_ring.h"
10
11#include "gloo/cuda_collectives_device.h"
12#include "gloo/cuda_collectives_host.h"
13#include "gloo/cuda_private.h"
14
15namespace gloo {
16
17template <typename T, typename W>
18CudaAllreduceRing<T, W>::CudaAllreduceRing(
19 const std::shared_ptr<Context>& context,
20 const std::vector<T*>& ptrs,
21 const int count,
22 const std::vector<cudaStream_t>& streams)
23 : Algorithm(context),
24 count_(count),
25 bytes_(count_ * sizeof(T)),
26 synchronizeDeviceOutputs_(streams.size() == 0),
27 fn_(CudaReductionFunction<T>::sum) {
28 auto newStream = true;
29 if (streams.size() > 0) {
30 GLOO_ENFORCE_EQ(streams.size(), ptrs.size());
31 newStream = false;
32 }
33
34 for (auto i = 0; i < ptrs.size(); i++) {
35 auto ptr = CudaDevicePointer<T>::create(ptrs[i], count_);
36 if (newStream) {
37 streams_.push_back(CudaStream(ptr.getDeviceID()));
38 } else {
39 streams_.push_back(CudaStream(ptr.getDeviceID(), streams[i]));
40 }
41 devicePtrs_.push_back(std::move(ptr));
42 }
43
44 // Workspace specific initialization (see below)
45 init();
46
47 if (this->contextSize_ == 1) {
48 return;
49 }
50
51 auto& leftPair = this->getLeftPair();
52 auto& rightPair = this->getRightPair();
53 auto slot = this->context_->nextSlot();
54
55 // Buffer to send to (rank+1).
56 sendDataBuf_ = rightPair->createSendBuffer(slot, *outbox_, bytes_);
57
58 // Buffer that (rank-1) writes to.
59 recvDataBuf_ = leftPair->createRecvBuffer(slot, *inbox_, bytes_);
60
61 // Dummy buffers for localized barrier.
62 // Before sending to the right, we only need to know that the node
63 // on the right is done using the inbox that's about to be written
64 // into. No need for a global barrier.
65 auto notificationSlot = this->context_->nextSlot();
66 sendNotificationBuf_ =
67 leftPair->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
68 recvNotificationBuf_ =
69 rightPair->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
70}
71
72template <typename T, typename W>
73void CudaAllreduceRing<T, W>::run() {
74 CudaDeviceGuard guard;
75 CudaStream& stream = *scratchStream_;
76
77 if (localReduceOp_) {
78 localReduceOp_->run();
79 }
80
81 // Initialize outbox with locally reduced values
82 stream.copyAsync(outbox_, scratch_);
83 stream.wait();
84
85 int numRounds = this->contextSize_ - 1;
86 for (int round = 0; round < numRounds; round++) {
87 // Initiate write to inbox of node on the right
88 sendDataBuf_->send();
89
90 // Wait for inbox write from node on the left
91 recvDataBuf_->waitRecv();
92
93 // Reduce
94 fn_->call(scratch_, inbox_, count_, stream);
95 stream.wait();
96
97 // Wait for outbox write to complete
98 sendDataBuf_->waitSend();
99
100 // Prepare for next round if necessary
101 if (round < (numRounds - 1)) {
102 stream.copyAsync(outbox_, inbox_);
103 stream.wait();
104 }
105
106 // Send notification to node on the left that
107 // this node is ready for an inbox write.
108 sendNotificationBuf_->send();
109
110 // Wait for notification from node on the right
111 recvNotificationBuf_->waitRecv();
112 }
113
114 // Asynchronously copy result buffer to all device buffers
115 if (localBroadcastOp_) {
116 localBroadcastOp_->runAsync();
117 if (synchronizeDeviceOutputs_) {
118 localBroadcastOp_->wait();
119 }
120 }
121}
122
123template <typename T, typename W>
124template <typename U>
125void CudaAllreduceRing<T, W>::init(
126 typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value,
127 typename U::Pointer>::type*) {
128 // Since reduction is executed on the CPU, the scratch space
129 // where they are accumulated is a new host side buffer.
130 scratch_ = W::Pointer::alloc(count_);
131 scratchStream_ = &streams_[0];
132
133 // Execute local reduction and broadcast from host.
134 // If devicePtrs_.size() == 1 these functions construct an op that
135 // executes a memcpy such that scratch_ always holds the result.
136 localReduceOp_ =
137 cudaHostReduce(streams_, devicePtrs_, scratch_, fn_, 0, count_);
138 localBroadcastOp_ =
139 cudaHostBroadcast(streams_, devicePtrs_, scratch_, 0, count_);
140
141 inbox_ = W::Pointer::alloc(count_);
142 outbox_ = W::Pointer::alloc(count_);
143}
144
145template <typename T, typename W>
146template <typename U>
147void CudaAllreduceRing<T, W>::init(
148 typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value,
149 typename U::Pointer>::type*) {
150 // The networking adapter does DMA to/from GPU memory, so we should reduce
151 // onto the device that's closest to the networking adapter bound
152 // to our context. This uses PCI distance to find closest GPU.
153 auto index = findCudaDevicePointerClosestToDevice(
154 devicePtrs_, this->context_->getDevice());
155 scratch_ = CudaDevicePointer<T>::create(devicePtrs_[index]);
156 scratchStream_ = &streams_[index];
157
158 // Run local reduction and broadcast on device.
159 // When running with a device workspace we intend to never leave the device.
160 if (devicePtrs_.size() > 1) {
161 localReduceOp_ =
162 cudaDeviceReduce(streams_, devicePtrs_, scratch_, fn_, 0, count_);
163 localBroadcastOp_ =
164 cudaDeviceBroadcast(streams_, devicePtrs_, scratch_, 0, count_);
165 }
166
167 // Inbox/outbox must be colocated with scratch buffer to avoid
168 // cross device copies while accumulating the reduction.
169 {
170 CudaDeviceScope scope(scratch_.getDeviceID());
171 inbox_ = W::Pointer::alloc(count_);
172 outbox_ = W::Pointer::alloc(count_);
173 }
174}
175
176// Instantiate templates
177#define INSTANTIATE_TEMPLATE(T) \
178template class CudaAllreduceRing<T, CudaHostWorkspace<T> >; \
179template class CudaAllreduceRing<T, CudaDeviceWorkspace<T> >;
180
181INSTANTIATE_TEMPLATE(int8_t);
182INSTANTIATE_TEMPLATE(uint8_t);
183INSTANTIATE_TEMPLATE(int32_t);
184INSTANTIATE_TEMPLATE(int64_t);
185INSTANTIATE_TEMPLATE(uint64_t);
186INSTANTIATE_TEMPLATE(float);
187INSTANTIATE_TEMPLATE(double);
188INSTANTIATE_TEMPLATE(float16);
189
190} // namespace gloo
191