1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/algorithm.h"
12#include "gloo/cuda.h"
13#include "gloo/cuda_workspace.h"
14
15namespace gloo {
16
17template <typename T, typename W = CudaHostWorkspace<T> >
18class CudaAllreduceRing : public Algorithm {
19 public:
20 CudaAllreduceRing(
21 const std::shared_ptr<Context>& context,
22 const std::vector<T*>& ptrs,
23 const int count,
24 const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>());
25
26 virtual ~CudaAllreduceRing() = default;
27
28 virtual void run() override;
29
30 protected:
31 // Both workspace types have their own initialization function.
32 template <typename U = W>
33 void init(
34 typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value,
35 typename U::Pointer>::type* = 0);
36
37 template <typename U = W>
38 void init(
39 typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value,
40 typename U::Pointer>::type* = 0);
41
42 std::vector<CudaDevicePointer<T> > devicePtrs_;
43 std::vector<CudaStream> streams_;
44 typename W::Pointer scratch_;
45 CudaStream* scratchStream_;
46
47 const int count_;
48 const int bytes_;
49 const bool synchronizeDeviceOutputs_;
50 const CudaReductionFunction<T>* fn_;
51
52 std::unique_ptr<LocalOp<T> > localReduceOp_;
53 std::unique_ptr<LocalOp<T> > localBroadcastOp_;
54
55 typename W::Pointer inbox_;
56 typename W::Pointer outbox_;
57 std::unique_ptr<transport::Buffer> sendDataBuf_;
58 std::unique_ptr<transport::Buffer> recvDataBuf_;
59
60 int dummy_;
61 std::unique_ptr<transport::Buffer> sendNotificationBuf_;
62 std::unique_ptr<transport::Buffer> recvNotificationBuf_;
63};
64
65} // namespace gloo
66