1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include <algorithm>
12#include <cmath>
13
14#include "gloo/common/common.h"
15#include "gloo/common/logging.h"
16#include "gloo/config.h"
17#include "gloo/cuda.h"
18#include "gloo/cuda_private.h"
19
20#if GLOO_USE_NCCL
21#include "gloo/cuda_collectives_nccl.h"
22#else
23#include "gloo/cuda_collectives_native.h"
24#endif
25
26namespace gloo {
27
28template <typename T, typename Dst>
29std::unique_ptr<LocalOp<T> > cudaDeviceReduce(
30 std::vector<CudaStream>& streams,
31 std::vector<CudaDevicePointer<T> >& devicePtrs,
32 Dst& targetPtr,
33 const CudaReductionFunction<T>* fn,
34 size_t offset,
35 size_t count) {
36 GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size());
37
38 // Simple copy operation if there is only a single device pointer.
39 if (devicePtrs.size() == 1) {
40 return make_unique<
41 CudaLocalMemcpy<T, CudaDevicePointer<T>, Dst> >(
42 streams[0],
43 devicePtrs[0],
44 targetPtr,
45 offset,
46 count);
47 }
48
49#if GLOO_USE_NCCL
50 return make_unique<CudaLocalNCCLReduce<T, Dst> >(
51 streams, devicePtrs, targetPtr, fn, offset, count);
52#else
53 return make_unique<CudaLocalNativeReduce<T, Dst> >(
54 streams, devicePtrs, targetPtr, fn, offset, count);
55#endif
56}
57
58template <typename T, typename Src>
59std::unique_ptr<LocalOp<T> > cudaDeviceBroadcast(
60 std::vector<CudaStream>& streams,
61 std::vector<CudaDevicePointer<T> >& devicePtrs,
62 Src& sourcePtr,
63 size_t offset,
64 size_t count) {
65 GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size());
66
67 // Simple copy operation if there is only a single device pointer.
68 if (devicePtrs.size() == 1) {
69 return make_unique<
70 CudaLocalMemcpy<T, Src, CudaDevicePointer<T> > >(
71 streams[0],
72 sourcePtr,
73 devicePtrs[0],
74 offset,
75 count);
76 }
77
78#if GLOO_USE_NCCL
79 return make_unique<CudaLocalNCCLBroadcast<T, Src> >(
80 streams, devicePtrs, sourcePtr, offset, count);
81#else
82 return make_unique<CudaLocalNativeBroadcast<T, Src> >(
83 streams, devicePtrs, sourcePtr, offset, count);
84#endif
85}
86
87} // namespace gloo
88