1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include <array>
12
13#include "gloo/algorithm.h"
14#include "gloo/cuda.h"
15#include "gloo/cuda_workspace.h"
16
17namespace gloo {
18
19template <typename T, typename W = CudaHostWorkspace<T> >
20class CudaAllreduceRingChunked : public Algorithm {
21 public:
22 CudaAllreduceRingChunked(
23 const std::shared_ptr<Context>& context,
24 const std::vector<T*>& ptrs,
25 const int count,
26 const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>());
27
28 virtual ~CudaAllreduceRingChunked();
29
30 virtual void run() override;
31
32 protected:
33 int getChunkOffset(int round);
34 void copyChunkAtOffset(int chunkOffset);
35
36 // Both workspace types have their own initialization function.
37 template <typename U = W>
38 void init(
39 typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value,
40 typename U::Pointer>::type* = 0);
41
42 template <typename U = W>
43 void init(
44 typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value,
45 typename U::Pointer>::type* = 0);
46
47 std::vector<CudaDevicePointer<T> > devicePtrs_;
48 std::vector<CudaStream> streams_;
49 typename W::Pointer scratch_;
50 CudaStream* scratchStream_;
51
52 const int count_;
53 const int bytes_;
54 const bool synchronizeDeviceOutputs_;
55 const CudaReductionFunction<T>* fn_;
56
57 size_t chunks_;
58 size_t chunkSize_;
59 size_t chunkBytes_;
60
61 struct ChunkContext;
62 std::vector<ChunkContext> chunkContext_;
63
64 std::array<typename W::Pointer, 2> inbox_;
65 std::array<std::unique_ptr<transport::Buffer>, 2> sendDataBuf_;
66 std::array<std::unique_ptr<transport::Buffer>, 2> recvDataBuf_;
67
68 int dummy_;
69 std::unique_ptr<transport::Buffer> sendNotificationBuf_;
70 std::unique_ptr<transport::Buffer> recvNotificationBuf_;
71};
72
73} // namespace gloo
74