1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include "gloo/allreduce_local.h"
10
11#include <string.h>
12
13namespace gloo {
14
15template <typename T>
16AllreduceLocal<T>::AllreduceLocal(
17 const std::shared_ptr<Context>& context,
18 const std::vector<T*>& ptrs,
19 const int count,
20 const ReductionFunction<T>* fn)
21 : Algorithm(context),
22 ptrs_(ptrs),
23 count_(count),
24 bytes_(count_ * sizeof(T)),
25 fn_(fn) {
26}
27
28template <typename T>
29void AllreduceLocal<T>::run() {
30 // Reduce specified pointers into ptrs_[0]
31 for (int i = 1; i < ptrs_.size(); i++) {
32 fn_->call(ptrs_[0], ptrs_[i], count_);
33 }
34 // Broadcast ptrs_[0]
35 for (int i = 1; i < ptrs_.size(); i++) {
36 memcpy(ptrs_[i], ptrs_[0], bytes_);
37 }
38}
39
40// Instantiate templates
41#define INSTANTIATE_TEMPLATE(T) template class AllreduceLocal<T>;
42
43INSTANTIATE_TEMPLATE(int8_t);
44INSTANTIATE_TEMPLATE(uint8_t);
45INSTANTIATE_TEMPLATE(int32_t);
46INSTANTIATE_TEMPLATE(int64_t);
47INSTANTIATE_TEMPLATE(uint64_t);
48INSTANTIATE_TEMPLATE(float);
49INSTANTIATE_TEMPLATE(double);
50INSTANTIATE_TEMPLATE(float16);
51// Needed for benchmark (main.cc) to build, should not get used
52INSTANTIATE_TEMPLATE(char);
53
54} // namespace gloo
55