1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include "gloo/algorithm.h" |
12 | #include "gloo/cuda.h" |
13 | #include "gloo/cuda_workspace.h" |
14 | |
15 | namespace gloo { |
16 | |
17 | template <typename T, typename W = CudaHostWorkspace<T> > |
18 | class CudaAllreduceRing : public Algorithm { |
19 | public: |
20 | CudaAllreduceRing( |
21 | const std::shared_ptr<Context>& context, |
22 | const std::vector<T*>& ptrs, |
23 | const int count, |
24 | const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>()); |
25 | |
26 | virtual ~CudaAllreduceRing() = default; |
27 | |
28 | virtual void run() override; |
29 | |
30 | protected: |
31 | // Both workspace types have their own initialization function. |
32 | template <typename U = W> |
33 | void init( |
34 | typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value, |
35 | typename U::Pointer>::type* = 0); |
36 | |
37 | template <typename U = W> |
38 | void init( |
39 | typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value, |
40 | typename U::Pointer>::type* = 0); |
41 | |
42 | std::vector<CudaDevicePointer<T> > devicePtrs_; |
43 | std::vector<CudaStream> streams_; |
44 | typename W::Pointer scratch_; |
45 | CudaStream* scratchStream_; |
46 | |
47 | const int count_; |
48 | const int bytes_; |
49 | const bool synchronizeDeviceOutputs_; |
50 | const CudaReductionFunction<T>* fn_; |
51 | |
52 | std::unique_ptr<LocalOp<T> > localReduceOp_; |
53 | std::unique_ptr<LocalOp<T> > localBroadcastOp_; |
54 | |
55 | typename W::Pointer inbox_; |
56 | typename W::Pointer outbox_; |
57 | std::unique_ptr<transport::Buffer> sendDataBuf_; |
58 | std::unique_ptr<transport::Buffer> recvDataBuf_; |
59 | |
60 | int dummy_; |
61 | std::unique_ptr<transport::Buffer> sendNotificationBuf_; |
62 | std::unique_ptr<transport::Buffer> recvNotificationBuf_; |
63 | }; |
64 | |
65 | } // namespace gloo |
66 | |