1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include <array> |
12 | |
13 | #include "gloo/algorithm.h" |
14 | #include "gloo/cuda.h" |
15 | #include "gloo/cuda_workspace.h" |
16 | |
17 | namespace gloo { |
18 | |
19 | template <typename T, typename W = CudaHostWorkspace<T> > |
20 | class CudaAllreduceRingChunked : public Algorithm { |
21 | public: |
22 | CudaAllreduceRingChunked( |
23 | const std::shared_ptr<Context>& context, |
24 | const std::vector<T*>& ptrs, |
25 | const int count, |
26 | const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>()); |
27 | |
28 | virtual ~CudaAllreduceRingChunked(); |
29 | |
30 | virtual void run() override; |
31 | |
32 | protected: |
33 | int getChunkOffset(int round); |
34 | void copyChunkAtOffset(int chunkOffset); |
35 | |
36 | // Both workspace types have their own initialization function. |
37 | template <typename U = W> |
38 | void init( |
39 | typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value, |
40 | typename U::Pointer>::type* = 0); |
41 | |
42 | template <typename U = W> |
43 | void init( |
44 | typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value, |
45 | typename U::Pointer>::type* = 0); |
46 | |
47 | std::vector<CudaDevicePointer<T> > devicePtrs_; |
48 | std::vector<CudaStream> streams_; |
49 | typename W::Pointer scratch_; |
50 | CudaStream* scratchStream_; |
51 | |
52 | const int count_; |
53 | const int bytes_; |
54 | const bool synchronizeDeviceOutputs_; |
55 | const CudaReductionFunction<T>* fn_; |
56 | |
57 | size_t chunks_; |
58 | size_t chunkSize_; |
59 | size_t chunkBytes_; |
60 | |
61 | struct ChunkContext; |
62 | std::vector<ChunkContext> chunkContext_; |
63 | |
64 | std::array<typename W::Pointer, 2> inbox_; |
65 | std::array<std::unique_ptr<transport::Buffer>, 2> sendDataBuf_; |
66 | std::array<std::unique_ptr<transport::Buffer>, 2> recvDataBuf_; |
67 | |
68 | int dummy_; |
69 | std::unique_ptr<transport::Buffer> sendNotificationBuf_; |
70 | std::unique_ptr<transport::Buffer> recvNotificationBuf_; |
71 | }; |
72 | |
73 | } // namespace gloo |
74 | |