1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/cuda_allreduce_local.h" |
10 | |
11 | #include "gloo/common/logging.h" |
12 | #include "gloo/cuda_collectives_device.h" |
13 | #include "gloo/cuda_private.h" |
14 | |
15 | namespace gloo { |
16 | |
17 | template <typename T> |
18 | CudaAllreduceLocal<T>::CudaAllreduceLocal( |
19 | const std::shared_ptr<Context>& context, |
20 | const std::vector<T*>& ptrs, |
21 | const int count, |
22 | const std::vector<cudaStream_t>& streams) |
23 | : Algorithm(context), |
24 | count_(count), |
25 | bytes_(count_ * sizeof(T)), |
26 | fn_(CudaReductionFunction<T>::sum), |
27 | synchronizeDeviceOutputs_(streams.size() == 0) { |
28 | auto newStream = true; |
29 | if (streams.size() > 0) { |
30 | GLOO_ENFORCE_EQ(streams.size(), ptrs.size()); |
31 | newStream = false; |
32 | } |
33 | |
34 | for (auto i = 0; i < ptrs.size(); i++) { |
35 | auto ptr = CudaDevicePointer<T>::create(ptrs[i], count_); |
36 | if (newStream) { |
37 | streams_.push_back(CudaStream(ptr.getDeviceID())); |
38 | } else { |
39 | streams_.push_back(CudaStream(ptr.getDeviceID(), streams[i])); |
40 | } |
41 | devicePtrs_.push_back(std::move(ptr)); |
42 | } |
43 | |
44 | // Initialize local reduce / local broadcast |
45 | // TODO(pietern): Optimize this to use real direct allreduce if possible |
46 | if (devicePtrs_.size() > 1) { |
47 | localReduceOp_ = |
48 | cudaDeviceReduce(streams_, devicePtrs_, devicePtrs_[0], fn_, 0, count_); |
49 | localBroadcastOp_ = |
50 | cudaDeviceBroadcast(streams_, devicePtrs_, devicePtrs_[0], 0, count_); |
51 | } |
52 | } |
53 | |
54 | template <typename T> |
55 | void CudaAllreduceLocal<T>::run() { |
56 | CudaDeviceGuard guard; |
57 | |
58 | if (devicePtrs_.size() > 1) { |
59 | localReduceOp_->runAsync(); |
60 | localBroadcastOp_->runAsync(); |
61 | if (synchronizeDeviceOutputs_) { |
62 | localBroadcastOp_->wait(); |
63 | } |
64 | } |
65 | } |
66 | |
67 | // Instantiate templates |
68 | #define INSTANTIATE_TEMPLATE(T) template class CudaAllreduceLocal<T>; |
69 | |
70 | INSTANTIATE_TEMPLATE(int8_t); |
71 | INSTANTIATE_TEMPLATE(uint8_t); |
72 | INSTANTIATE_TEMPLATE(int32_t); |
73 | INSTANTIATE_TEMPLATE(int64_t); |
74 | INSTANTIATE_TEMPLATE(uint64_t); |
75 | INSTANTIATE_TEMPLATE(float); |
76 | INSTANTIATE_TEMPLATE(double); |
77 | INSTANTIATE_TEMPLATE(float16); |
78 | |
79 | } // namespace gloo |
80 | |