1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include <algorithm> |
12 | #include <cmath> |
13 | |
14 | #include "gloo/common/common.h" |
15 | #include "gloo/common/logging.h" |
16 | #include "gloo/config.h" |
17 | #include "gloo/cuda.h" |
18 | #include "gloo/cuda_private.h" |
19 | |
20 | #if GLOO_USE_NCCL |
21 | #include "gloo/cuda_collectives_nccl.h" |
22 | #else |
23 | #include "gloo/cuda_collectives_native.h" |
24 | #endif |
25 | |
26 | namespace gloo { |
27 | |
28 | template <typename T, typename Dst> |
29 | std::unique_ptr<LocalOp<T> > cudaDeviceReduce( |
30 | std::vector<CudaStream>& streams, |
31 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
32 | Dst& targetPtr, |
33 | const CudaReductionFunction<T>* fn, |
34 | size_t offset, |
35 | size_t count) { |
36 | GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size()); |
37 | |
38 | // Simple copy operation if there is only a single device pointer. |
39 | if (devicePtrs.size() == 1) { |
40 | return make_unique< |
41 | CudaLocalMemcpy<T, CudaDevicePointer<T>, Dst> >( |
42 | streams[0], |
43 | devicePtrs[0], |
44 | targetPtr, |
45 | offset, |
46 | count); |
47 | } |
48 | |
49 | #if GLOO_USE_NCCL |
50 | return make_unique<CudaLocalNCCLReduce<T, Dst> >( |
51 | streams, devicePtrs, targetPtr, fn, offset, count); |
52 | #else |
53 | return make_unique<CudaLocalNativeReduce<T, Dst> >( |
54 | streams, devicePtrs, targetPtr, fn, offset, count); |
55 | #endif |
56 | } |
57 | |
58 | template <typename T, typename Src> |
59 | std::unique_ptr<LocalOp<T> > cudaDeviceBroadcast( |
60 | std::vector<CudaStream>& streams, |
61 | std::vector<CudaDevicePointer<T> >& devicePtrs, |
62 | Src& sourcePtr, |
63 | size_t offset, |
64 | size_t count) { |
65 | GLOO_ENFORCE_EQ(streams.size(), devicePtrs.size()); |
66 | |
67 | // Simple copy operation if there is only a single device pointer. |
68 | if (devicePtrs.size() == 1) { |
69 | return make_unique< |
70 | CudaLocalMemcpy<T, Src, CudaDevicePointer<T> > >( |
71 | streams[0], |
72 | sourcePtr, |
73 | devicePtrs[0], |
74 | offset, |
75 | count); |
76 | } |
77 | |
78 | #if GLOO_USE_NCCL |
79 | return make_unique<CudaLocalNCCLBroadcast<T, Src> >( |
80 | streams, devicePtrs, sourcePtr, offset, count); |
81 | #else |
82 | return make_unique<CudaLocalNativeBroadcast<T, Src> >( |
83 | streams, devicePtrs, sourcePtr, offset, count); |
84 | #endif |
85 | } |
86 | |
87 | } // namespace gloo |
88 | |