1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/algorithm.h"
12#include "gloo/cuda.h"
13#include "gloo/cuda_workspace.h"
14
15namespace gloo {
16
17template <typename T, typename W = CudaHostWorkspace<T>>
18
19class CudaBroadcastOneToAll : public Algorithm {
20 public:
21 CudaBroadcastOneToAll(
22 const std::shared_ptr<Context>& context,
23 const std::vector<T*>& ptrs,
24 int count,
25 int rootRank = 0,
26 int rootPointerRank = 0,
27 const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>());
28
29 virtual ~CudaBroadcastOneToAll() = default;
30
31 virtual void run() override;
32
33 protected:
34 // Both workspace types have their own initialization function.
35 template <typename U = W>
36 void init(
37 typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value,
38 typename U::Pointer>::type* = 0);
39
40 template <typename U = W>
41 void init(
42 typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value,
43 typename U::Pointer>::type* = 0);
44
45 std::vector<CudaDevicePointer<T> > devicePtrs_;
46 std::vector<CudaStream> streams_;
47 typename W::Pointer scratch_;
48 const int count_;
49 const int bytes_;
50 const int rootRank_;
51 const int rootPointerRank_;
52 const bool synchronizeDeviceOutputs_;
53
54 // For the sender (root)
55 struct forSender {
56 int dummy;
57 std::unique_ptr<transport::Buffer> clearToSendBuffer;
58 std::unique_ptr<transport::Buffer> sendBuffer;
59 };
60
61 std::vector<std::unique_ptr<forSender>> sender_;
62
63 // For all receivers
64 struct forReceiver {
65 int dummy;
66 std::unique_ptr<transport::Buffer> clearToSendBuffer;
67 std::unique_ptr<transport::Buffer> recvBuffer;
68 };
69
70 std::unique_ptr<forReceiver> receiver_;
71
72 // For local broadcast
73 std::unique_ptr<LocalOp<T> > localBroadcastOp_;
74};
75
76} // namespace gloo
77