1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include "gloo/algorithm.h" |
12 | #include "gloo/cuda.h" |
13 | #include "gloo/cuda_workspace.h" |
14 | |
15 | namespace gloo { |
16 | |
17 | template <typename T, typename W = CudaHostWorkspace<T>> |
18 | |
19 | class CudaBroadcastOneToAll : public Algorithm { |
20 | public: |
21 | CudaBroadcastOneToAll( |
22 | const std::shared_ptr<Context>& context, |
23 | const std::vector<T*>& ptrs, |
24 | int count, |
25 | int rootRank = 0, |
26 | int rootPointerRank = 0, |
27 | const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>()); |
28 | |
29 | virtual ~CudaBroadcastOneToAll() = default; |
30 | |
31 | virtual void run() override; |
32 | |
33 | protected: |
34 | // Both workspace types have their own initialization function. |
35 | template <typename U = W> |
36 | void init( |
37 | typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value, |
38 | typename U::Pointer>::type* = 0); |
39 | |
40 | template <typename U = W> |
41 | void init( |
42 | typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value, |
43 | typename U::Pointer>::type* = 0); |
44 | |
45 | std::vector<CudaDevicePointer<T> > devicePtrs_; |
46 | std::vector<CudaStream> streams_; |
47 | typename W::Pointer scratch_; |
48 | const int count_; |
49 | const int bytes_; |
50 | const int rootRank_; |
51 | const int rootPointerRank_; |
52 | const bool synchronizeDeviceOutputs_; |
53 | |
54 | // For the sender (root) |
55 | struct forSender { |
56 | int dummy; |
57 | std::unique_ptr<transport::Buffer> clearToSendBuffer; |
58 | std::unique_ptr<transport::Buffer> sendBuffer; |
59 | }; |
60 | |
61 | std::vector<std::unique_ptr<forSender>> sender_; |
62 | |
63 | // For all receivers |
64 | struct forReceiver { |
65 | int dummy; |
66 | std::unique_ptr<transport::Buffer> clearToSendBuffer; |
67 | std::unique_ptr<transport::Buffer> recvBuffer; |
68 | }; |
69 | |
70 | std::unique_ptr<forReceiver> receiver_; |
71 | |
72 | // For local broadcast |
73 | std::unique_ptr<LocalOp<T> > localBroadcastOp_; |
74 | }; |
75 | |
76 | } // namespace gloo |
77 | |