1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include "gloo/common/common.h" |
12 | #include "gloo/common/logging.h" |
13 | #include "gloo/nccl/nccl.h" |
14 | |
15 | namespace gloo { |
16 | |
17 | template <typename T> |
18 | std::vector<nccl::NCCLElement<T> > toDeviceElements( |
19 | std::vector<CudaStream>& streams, |
20 | const std::vector<CudaDevicePointer<T> >& ptrs, |
21 | size_t offset, |
22 | size_t count) { |
23 | std::vector<nccl::NCCLElement<T> > elements; |
24 | elements.reserve(ptrs.size()); |
25 | for (auto i = 0; i < ptrs.size(); i++) { |
26 | elements.emplace_back( |
27 | ptrs[i].range(offset, count), |
28 | streams[i], |
29 | ptrs[i].range(offset, count), |
30 | streams[i]); |
31 | } |
32 | return elements; |
33 | } |
34 | |
35 | // Forward declaration |
36 | template <typename T, typename Dst> |
37 | class CudaLocalNCCLReduce; |
38 | |
39 | // Partial specialization for device pointer target |
40 | template <typename T> |
41 | class CudaLocalNCCLReduce<T, CudaDevicePointer<T> > : public LocalOp<T> { |
42 | public: |
43 | CudaLocalNCCLReduce( |
44 | std::vector<CudaStream>& streams, |
45 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
46 | CudaDevicePointer<T>& targetPtr, |
47 | const CudaReductionFunction<T>* fn, |
48 | size_t offset, |
49 | size_t count) { |
50 | // The targetPtr must be one of devicePtrs. |
51 | auto root = -1; |
52 | for (auto i = 0; i < devicePtrs.size(); i++) { |
53 | if (devicePtrs[i] == targetPtr) { |
54 | root = i; |
55 | break; |
56 | } |
57 | } |
58 | GLOO_ENFORCE_GE(root, 0, "targetPtr must be one of devicePtrs" ); |
59 | |
60 | // Only if we have multiple device pointers does this |
61 | // operation need to execute. |
62 | if (devicePtrs.size() > 1) { |
63 | reduceOp_ = make_unique<nccl::ReduceOp<T> >( |
64 | toDeviceElements(streams, devicePtrs, offset, count), |
65 | fn, |
66 | root); |
67 | } |
68 | } |
69 | |
70 | virtual ~CudaLocalNCCLReduce() {} |
71 | |
72 | virtual void runAsync() { |
73 | if (reduceOp_) { |
74 | reduceOp_->runAsync(); |
75 | } |
76 | } |
77 | |
78 | virtual void wait() { |
79 | if (reduceOp_) { |
80 | reduceOp_->wait(); |
81 | } |
82 | } |
83 | |
84 | protected: |
85 | std::unique_ptr<nccl::ReduceOp<T> > reduceOp_; |
86 | }; |
87 | |
88 | // Partial specialization for host pointer target |
89 | template <typename T> |
90 | class CudaLocalNCCLReduce<T, CudaHostPointer<T> > : public LocalOp<T> { |
91 | public: |
92 | CudaLocalNCCLReduce( |
93 | std::vector<CudaStream>& streams, |
94 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
95 | CudaHostPointer<T>& targetPtr, |
96 | const CudaReductionFunction<T>* fn, |
97 | size_t offset, |
98 | size_t count) |
99 | : root_(0), |
100 | stream_(streams[root_]), |
101 | devicePtr_(devicePtrs[root_].range(offset, count)), |
102 | hostPtr_(targetPtr.range(offset, count)) { |
103 | if (devicePtrs.size() > 1) { |
104 | reduceOp_ = make_unique<nccl::ReduceOp<T> >( |
105 | toDeviceElements(streams, devicePtrs, offset, count), |
106 | fn, |
107 | root_); |
108 | } |
109 | } |
110 | |
111 | virtual ~CudaLocalNCCLReduce() {} |
112 | |
113 | virtual void runAsync() { |
114 | if (reduceOp_) { |
115 | reduceOp_->runAsync(); |
116 | } |
117 | |
118 | // The stream for operations on devicePtrs_[0] now includes an |
119 | // asynchronous wait for completion of the reduce operation, if it |
120 | // was executed. This means we can sequence an asynchronous memory |
121 | // copy and wait on completion of that to signal completion of |
122 | // both operations. |
123 | stream_.copyAsync(hostPtr_, devicePtr_); |
124 | } |
125 | |
126 | virtual void wait() { |
127 | stream_.wait(); |
128 | } |
129 | |
130 | protected: |
131 | const int root_; |
132 | CudaStream& stream_; |
133 | CudaDevicePointer<T> devicePtr_; |
134 | CudaHostPointer<T> hostPtr_; |
135 | std::unique_ptr<nccl::ReduceOp<T> > reduceOp_; |
136 | }; |
137 | |
138 | // Forward declaration |
139 | template <typename T, typename Src> |
140 | class CudaLocalNCCLBroadcast; |
141 | |
142 | // Specialization for device pointer source |
143 | template <typename T> |
144 | class CudaLocalNCCLBroadcast<T, CudaDevicePointer<T> > : public LocalOp<T> { |
145 | public: |
146 | CudaLocalNCCLBroadcast( |
147 | std::vector<CudaStream>& streams, |
148 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
149 | CudaDevicePointer<T>& sourcePtr, |
150 | size_t offset, |
151 | size_t count) { |
152 | // The sourcePtr must be one of devicePtrs. |
153 | auto root = -1; |
154 | for (auto i = 0; i < devicePtrs.size(); i++) { |
155 | if (devicePtrs[i] == sourcePtr) { |
156 | root = i; |
157 | break; |
158 | } |
159 | } |
160 | GLOO_ENFORCE_GE(root, 0, "sourcePtr must be one of devicePtrs" ); |
161 | |
162 | // Only if we have multiple device pointers does this |
163 | // operation need to execute. |
164 | if (devicePtrs.size() > 1) { |
165 | broadcastOp_ = make_unique<nccl::BroadcastOp<T> >( |
166 | toDeviceElements(streams, devicePtrs, offset, count), |
167 | root); |
168 | } |
169 | } |
170 | |
171 | virtual ~CudaLocalNCCLBroadcast() {} |
172 | |
173 | virtual void runAsync() { |
174 | if (broadcastOp_) { |
175 | broadcastOp_->runAsync(); |
176 | } |
177 | } |
178 | |
179 | virtual void wait() { |
180 | if (broadcastOp_) { |
181 | broadcastOp_->wait(); |
182 | } |
183 | } |
184 | |
185 | protected: |
186 | std::unique_ptr<nccl::BroadcastOp<T> > broadcastOp_; |
187 | }; |
188 | |
189 | // Specialization for host pointer source |
190 | template <typename T> |
191 | class CudaLocalNCCLBroadcast<T, CudaHostPointer<T> > : public LocalOp<T> { |
192 | public: |
193 | CudaLocalNCCLBroadcast( |
194 | std::vector<CudaStream>& streams, |
195 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
196 | CudaHostPointer<T>& sourcePtr, |
197 | size_t offset, |
198 | size_t count) |
199 | : root_(0), |
200 | stream_(streams[root_]), |
201 | devicePtr_(devicePtrs[root_].range(offset, count)), |
202 | sourcePtr_(sourcePtr.range(offset, count)) { |
203 | if (devicePtrs.size() > 1) { |
204 | broadcastOp_ = make_unique<nccl::BroadcastOp<T> >( |
205 | toDeviceElements(streams, devicePtrs, offset, count), |
206 | root_); |
207 | } |
208 | } |
209 | |
210 | virtual ~CudaLocalNCCLBroadcast() {} |
211 | |
212 | virtual void runAsync() { |
213 | // Since we run an asynchronous memcpy to devicePtr_ which is |
214 | // executed on the stream associated with that device pointer, the |
215 | // broadcast operation will only start after the memcpy completes. |
216 | stream_.copyAsync(devicePtr_, sourcePtr_); |
217 | if (broadcastOp_) { |
218 | broadcastOp_->runAsync(); |
219 | } |
220 | } |
221 | |
222 | virtual void wait() { |
223 | stream_.wait(); |
224 | if (broadcastOp_) { |
225 | broadcastOp_->wait(); |
226 | } |
227 | } |
228 | |
229 | protected: |
230 | const int root_; |
231 | CudaStream& stream_; |
232 | CudaDevicePointer<T> devicePtr_; |
233 | CudaHostPointer<T> sourcePtr_; |
234 | std::unique_ptr<nccl::BroadcastOp<T> > broadcastOp_; |
235 | }; |
236 | |
237 | template <typename T, typename Dst> |
238 | std::unique_ptr<LocalOp<T> > cudaNCCLReduce( |
239 | std::vector<CudaStream>& streams, |
240 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
241 | Dst& targetPtr, |
242 | const CudaReductionFunction<T>* fn, |
243 | size_t offset, |
244 | size_t count) { |
245 | GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size()); |
246 | return make_unique<CudaLocalNCCLReduce<T, Dst> >( |
247 | streams, devicePtrs, targetPtr, fn, offset, count); |
248 | } |
249 | |
250 | template <typename T, typename Src> |
251 | std::unique_ptr<LocalOp<T> > cudaNCCLBroadcast( |
252 | std::vector<CudaStream>& streams, |
253 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
254 | Src& sourcePtr, |
255 | size_t offset, |
256 | size_t count) { |
257 | GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size()); |
258 | return make_unique<CudaLocalNCCLBroadcast<T, Src> >( |
259 | streams, devicePtrs, sourcePtr, offset, count); |
260 | } |
261 | |
262 | } // namespace gloo |
263 | |