1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/common/common.h"
12#include "gloo/common/logging.h"
13#include "gloo/nccl/nccl.h"
14
15namespace gloo {
16
17template <typename T>
18std::vector<nccl::NCCLElement<T> > toDeviceElements(
19 std::vector<CudaStream>& streams,
20 const std::vector<CudaDevicePointer<T> >& ptrs,
21 size_t offset,
22 size_t count) {
23 std::vector<nccl::NCCLElement<T> > elements;
24 elements.reserve(ptrs.size());
25 for (auto i = 0; i < ptrs.size(); i++) {
26 elements.emplace_back(
27 ptrs[i].range(offset, count),
28 streams[i],
29 ptrs[i].range(offset, count),
30 streams[i]);
31 }
32 return elements;
33}
34
35// Forward declaration
36template <typename T, typename Dst>
37class CudaLocalNCCLReduce;
38
39// Partial specialization for device pointer target
40template <typename T>
41class CudaLocalNCCLReduce<T, CudaDevicePointer<T> > : public LocalOp<T> {
42 public:
43 CudaLocalNCCLReduce(
44 std::vector<CudaStream>& streams,
45 std::vector<CudaDevicePointer<T> >& devicePtrs,
46 CudaDevicePointer<T>& targetPtr,
47 const CudaReductionFunction<T>* fn,
48 size_t offset,
49 size_t count) {
50 // The targetPtr must be one of devicePtrs.
51 auto root = -1;
52 for (auto i = 0; i < devicePtrs.size(); i++) {
53 if (devicePtrs[i] == targetPtr) {
54 root = i;
55 break;
56 }
57 }
58 GLOO_ENFORCE_GE(root, 0, "targetPtr must be one of devicePtrs");
59
60 // Only if we have multiple device pointers does this
61 // operation need to execute.
62 if (devicePtrs.size() > 1) {
63 reduceOp_ = make_unique<nccl::ReduceOp<T> >(
64 toDeviceElements(streams, devicePtrs, offset, count),
65 fn,
66 root);
67 }
68 }
69
70 virtual ~CudaLocalNCCLReduce() {}
71
72 virtual void runAsync() {
73 if (reduceOp_) {
74 reduceOp_->runAsync();
75 }
76 }
77
78 virtual void wait() {
79 if (reduceOp_) {
80 reduceOp_->wait();
81 }
82 }
83
84 protected:
85 std::unique_ptr<nccl::ReduceOp<T> > reduceOp_;
86};
87
88// Partial specialization for host pointer target
89template <typename T>
90class CudaLocalNCCLReduce<T, CudaHostPointer<T> > : public LocalOp<T> {
91 public:
92 CudaLocalNCCLReduce(
93 std::vector<CudaStream>& streams,
94 std::vector<CudaDevicePointer<T> >& devicePtrs,
95 CudaHostPointer<T>& targetPtr,
96 const CudaReductionFunction<T>* fn,
97 size_t offset,
98 size_t count)
99 : root_(0),
100 stream_(streams[root_]),
101 devicePtr_(devicePtrs[root_].range(offset, count)),
102 hostPtr_(targetPtr.range(offset, count)) {
103 if (devicePtrs.size() > 1) {
104 reduceOp_ = make_unique<nccl::ReduceOp<T> >(
105 toDeviceElements(streams, devicePtrs, offset, count),
106 fn,
107 root_);
108 }
109 }
110
111 virtual ~CudaLocalNCCLReduce() {}
112
113 virtual void runAsync() {
114 if (reduceOp_) {
115 reduceOp_->runAsync();
116 }
117
118 // The stream for operations on devicePtrs_[0] now includes an
119 // asynchronous wait for completion of the reduce operation, if it
120 // was executed. This means we can sequence an asynchronous memory
121 // copy and wait on completion of that to signal completion of
122 // both operations.
123 stream_.copyAsync(hostPtr_, devicePtr_);
124 }
125
126 virtual void wait() {
127 stream_.wait();
128 }
129
130 protected:
131 const int root_;
132 CudaStream& stream_;
133 CudaDevicePointer<T> devicePtr_;
134 CudaHostPointer<T> hostPtr_;
135 std::unique_ptr<nccl::ReduceOp<T> > reduceOp_;
136};
137
138// Forward declaration
139template <typename T, typename Src>
140class CudaLocalNCCLBroadcast;
141
142// Specialization for device pointer source
143template <typename T>
144class CudaLocalNCCLBroadcast<T, CudaDevicePointer<T> > : public LocalOp<T> {
145 public:
146 CudaLocalNCCLBroadcast(
147 std::vector<CudaStream>& streams,
148 std::vector<CudaDevicePointer<T> >& devicePtrs,
149 CudaDevicePointer<T>& sourcePtr,
150 size_t offset,
151 size_t count) {
152 // The sourcePtr must be one of devicePtrs.
153 auto root = -1;
154 for (auto i = 0; i < devicePtrs.size(); i++) {
155 if (devicePtrs[i] == sourcePtr) {
156 root = i;
157 break;
158 }
159 }
160 GLOO_ENFORCE_GE(root, 0, "sourcePtr must be one of devicePtrs");
161
162 // Only if we have multiple device pointers does this
163 // operation need to execute.
164 if (devicePtrs.size() > 1) {
165 broadcastOp_ = make_unique<nccl::BroadcastOp<T> >(
166 toDeviceElements(streams, devicePtrs, offset, count),
167 root);
168 }
169 }
170
171 virtual ~CudaLocalNCCLBroadcast() {}
172
173 virtual void runAsync() {
174 if (broadcastOp_) {
175 broadcastOp_->runAsync();
176 }
177 }
178
179 virtual void wait() {
180 if (broadcastOp_) {
181 broadcastOp_->wait();
182 }
183 }
184
185 protected:
186 std::unique_ptr<nccl::BroadcastOp<T> > broadcastOp_;
187};
188
189// Specialization for host pointer source
190template <typename T>
191class CudaLocalNCCLBroadcast<T, CudaHostPointer<T> > : public LocalOp<T> {
192 public:
193 CudaLocalNCCLBroadcast(
194 std::vector<CudaStream>& streams,
195 std::vector<CudaDevicePointer<T> >& devicePtrs,
196 CudaHostPointer<T>& sourcePtr,
197 size_t offset,
198 size_t count)
199 : root_(0),
200 stream_(streams[root_]),
201 devicePtr_(devicePtrs[root_].range(offset, count)),
202 sourcePtr_(sourcePtr.range(offset, count)) {
203 if (devicePtrs.size() > 1) {
204 broadcastOp_ = make_unique<nccl::BroadcastOp<T> >(
205 toDeviceElements(streams, devicePtrs, offset, count),
206 root_);
207 }
208 }
209
210 virtual ~CudaLocalNCCLBroadcast() {}
211
212 virtual void runAsync() {
213 // Since we run an asynchronous memcpy to devicePtr_ which is
214 // executed on the stream associated with that device pointer, the
215 // broadcast operation will only start after the memcpy completes.
216 stream_.copyAsync(devicePtr_, sourcePtr_);
217 if (broadcastOp_) {
218 broadcastOp_->runAsync();
219 }
220 }
221
222 virtual void wait() {
223 stream_.wait();
224 if (broadcastOp_) {
225 broadcastOp_->wait();
226 }
227 }
228
229 protected:
230 const int root_;
231 CudaStream& stream_;
232 CudaDevicePointer<T> devicePtr_;
233 CudaHostPointer<T> sourcePtr_;
234 std::unique_ptr<nccl::BroadcastOp<T> > broadcastOp_;
235};
236
237template <typename T, typename Dst>
238std::unique_ptr<LocalOp<T> > cudaNCCLReduce(
239 std::vector<CudaStream>& streams,
240 std::vector<CudaDevicePointer<T> >& devicePtrs,
241 Dst& targetPtr,
242 const CudaReductionFunction<T>* fn,
243 size_t offset,
244 size_t count) {
245 GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size());
246 return make_unique<CudaLocalNCCLReduce<T, Dst> >(
247 streams, devicePtrs, targetPtr, fn, offset, count);
248}
249
250template <typename T, typename Src>
251std::unique_ptr<LocalOp<T> > cudaNCCLBroadcast(
252 std::vector<CudaStream>& streams,
253 std::vector<CudaDevicePointer<T> >& devicePtrs,
254 Src& sourcePtr,
255 size_t offset,
256 size_t count) {
257 GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size());
258 return make_unique<CudaLocalNCCLBroadcast<T, Src> >(
259 streams, devicePtrs, sourcePtr, offset, count);
260}
261
262} // namespace gloo
263