1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/cuda_allreduce_local.h"
10
11#include "gloo/common/logging.h"
12#include "gloo/cuda_collectives_device.h"
13#include "gloo/cuda_private.h"
14
15namespace gloo {
16
17template <typename T>
18CudaAllreduceLocal<T>::CudaAllreduceLocal(
19 const std::shared_ptr<Context>& context,
20 const std::vector<T*>& ptrs,
21 const int count,
22 const std::vector<cudaStream_t>& streams)
23 : Algorithm(context),
24 count_(count),
25 bytes_(count_ * sizeof(T)),
26 fn_(CudaReductionFunction<T>::sum),
27 synchronizeDeviceOutputs_(streams.size() == 0) {
28 auto newStream = true;
29 if (streams.size() > 0) {
30 GLOO_ENFORCE_EQ(streams.size(), ptrs.size());
31 newStream = false;
32 }
33
34 for (auto i = 0; i < ptrs.size(); i++) {
35 auto ptr = CudaDevicePointer<T>::create(ptrs[i], count_);
36 if (newStream) {
37 streams_.push_back(CudaStream(ptr.getDeviceID()));
38 } else {
39 streams_.push_back(CudaStream(ptr.getDeviceID(), streams[i]));
40 }
41 devicePtrs_.push_back(std::move(ptr));
42 }
43
44 // Initialize local reduce / local broadcast
45 // TODO(pietern): Optimize this to use real direct allreduce if possible
46 if (devicePtrs_.size() > 1) {
47 localReduceOp_ =
48 cudaDeviceReduce(streams_, devicePtrs_, devicePtrs_[0], fn_, 0, count_);
49 localBroadcastOp_ =
50 cudaDeviceBroadcast(streams_, devicePtrs_, devicePtrs_[0], 0, count_);
51 }
52}
53
54template <typename T>
55void CudaAllreduceLocal<T>::run() {
56 CudaDeviceGuard guard;
57
58 if (devicePtrs_.size() > 1) {
59 localReduceOp_->runAsync();
60 localBroadcastOp_->runAsync();
61 if (synchronizeDeviceOutputs_) {
62 localBroadcastOp_->wait();
63 }
64 }
65}
66
67// Instantiate templates
68#define INSTANTIATE_TEMPLATE(T) template class CudaAllreduceLocal<T>;
69
70INSTANTIATE_TEMPLATE(int8_t);
71INSTANTIATE_TEMPLATE(uint8_t);
72INSTANTIATE_TEMPLATE(int32_t);
73INSTANTIATE_TEMPLATE(int64_t);
74INSTANTIATE_TEMPLATE(uint64_t);
75INSTANTIATE_TEMPLATE(float);
76INSTANTIATE_TEMPLATE(double);
77INSTANTIATE_TEMPLATE(float16);
78
79} // namespace gloo
80