1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include "gloo/algorithm.h"
12#include "gloo/cuda.h"
13
14namespace gloo {
15
16template <typename T>
17class CudaAllreduceLocal : public Algorithm {
18 public:
19 CudaAllreduceLocal(
20 const std::shared_ptr<Context>& context,
21 const std::vector<T*>& ptrs,
22 const int count,
23 const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>());
24
25 virtual ~CudaAllreduceLocal() = default;
26
27 virtual void run() override;
28
29 protected:
30 std::vector<CudaDevicePointer<T> > devicePtrs_;
31 std::vector<CudaStream> streams_;
32 const int count_;
33 const int bytes_;
34 const CudaReductionFunction<T>* fn_;
35 const bool synchronizeDeviceOutputs_;
36
37 std::unique_ptr<LocalOp<T> > localReduceOp_;
38 std::unique_ptr<LocalOp<T> > localBroadcastOp_;
39};
40
41} // namespace gloo
42