1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include "gloo/algorithm.h" |
12 | #include "gloo/cuda.h" |
13 | |
14 | namespace gloo { |
15 | |
16 | template <typename T> |
17 | class CudaAllreduceLocal : public Algorithm { |
18 | public: |
19 | CudaAllreduceLocal( |
20 | const std::shared_ptr<Context>& context, |
21 | const std::vector<T*>& ptrs, |
22 | const int count, |
23 | const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>()); |
24 | |
25 | virtual ~CudaAllreduceLocal() = default; |
26 | |
27 | virtual void run() override; |
28 | |
29 | protected: |
30 | std::vector<CudaDevicePointer<T> > devicePtrs_; |
31 | std::vector<CudaStream> streams_; |
32 | const int count_; |
33 | const int bytes_; |
34 | const CudaReductionFunction<T>* fn_; |
35 | const bool synchronizeDeviceOutputs_; |
36 | |
37 | std::unique_ptr<LocalOp<T> > localReduceOp_; |
38 | std::unique_ptr<LocalOp<T> > localBroadcastOp_; |
39 | }; |
40 | |
41 | } // namespace gloo |
42 | |