1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include "gloo/allreduce_local.h" |
10 | |
11 | #include <string.h> |
12 | |
13 | namespace gloo { |
14 | |
15 | template <typename T> |
16 | AllreduceLocal<T>::AllreduceLocal( |
17 | const std::shared_ptr<Context>& context, |
18 | const std::vector<T*>& ptrs, |
19 | const int count, |
20 | const ReductionFunction<T>* fn) |
21 | : Algorithm(context), |
22 | ptrs_(ptrs), |
23 | count_(count), |
24 | bytes_(count_ * sizeof(T)), |
25 | fn_(fn) { |
26 | } |
27 | |
28 | template <typename T> |
29 | void AllreduceLocal<T>::run() { |
30 | // Reduce specified pointers into ptrs_[0] |
31 | for (int i = 1; i < ptrs_.size(); i++) { |
32 | fn_->call(ptrs_[0], ptrs_[i], count_); |
33 | } |
34 | // Broadcast ptrs_[0] |
35 | for (int i = 1; i < ptrs_.size(); i++) { |
36 | memcpy(ptrs_[i], ptrs_[0], bytes_); |
37 | } |
38 | } |
39 | |
40 | // Instantiate templates |
41 | #define INSTANTIATE_TEMPLATE(T) template class AllreduceLocal<T>; |
42 | |
43 | INSTANTIATE_TEMPLATE(int8_t); |
44 | INSTANTIATE_TEMPLATE(uint8_t); |
45 | INSTANTIATE_TEMPLATE(int32_t); |
46 | INSTANTIATE_TEMPLATE(int64_t); |
47 | INSTANTIATE_TEMPLATE(uint64_t); |
48 | INSTANTIATE_TEMPLATE(float); |
49 | INSTANTIATE_TEMPLATE(double); |
50 | INSTANTIATE_TEMPLATE(float16); |
51 | // Needed for benchmark (main.cc) to build, should not get used |
52 | INSTANTIATE_TEMPLATE(char); |
53 | |
54 | } // namespace gloo |
55 | |