1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#pragma once
10
11#include <memory>
12
13#include "gloo/context.h"
14#include "gloo/math.h"
15
16namespace gloo {
17
18extern const size_t kOnDeviceThreshold;
19
20class Algorithm {
21 public:
22 explicit Algorithm(const std::shared_ptr<Context>&);
23 virtual ~Algorithm() noexcept(false) = 0;
24
25 virtual void run() = 0;
26
27 protected:
28 std::shared_ptr<Context> context_;
29
30 const int contextRank_;
31 const int contextSize_;
32
33 std::unique_ptr<transport::Pair>& getPair(int i);
34
35 // Helpers for ring algorithms
36 std::unique_ptr<transport::Pair>& getLeftPair();
37 std::unique_ptr<transport::Pair>& getRightPair();
38};
39
40// Type of reduction function.
41//
42// If the reduction type is one of the built-ins, algorithm
43// implementations may use accelerated versions if available.
44//
45// For example, if a ReductionFunction with ReductionType equal
46// SUM is passed to CUDA aware Allreduce, it knows it can
47// use a NCCL implementation instead of the specified function.
48//
49enum ReductionType {
50 SUM = 1,
51 PRODUCT = 2,
52 MAX = 3,
53 MIN = 4,
54
55 // Use larger number so we have plenty of room to add built-ins
56 CUSTOM = 1000,
57};
58
59template <typename T>
60class ReductionFunction {
61 public:
62 using Function = void(T*, const T*, size_t n);
63
64 static const ReductionFunction<T>* sum;
65 static const ReductionFunction<T>* product;
66 static const ReductionFunction<T>* min;
67 static const ReductionFunction<T>* max;
68
69 ReductionFunction(ReductionType type, Function* fn)
70 : type_(type), fn_(fn) {}
71
72 ReductionType type() const {
73 return type_;
74 }
75
76 void call(T* x, const T* y, size_t n) const {
77 fn_(x, y, n);
78 }
79
80 protected:
81 ReductionType type_;
82 Function* fn_;
83};
84
85template <typename T>
86const ReductionFunction<T>* ReductionFunction<T>::sum =
87 new ReductionFunction<T>(SUM, &::gloo::sum<T>);
88template <typename T>
89const ReductionFunction<T>* ReductionFunction<T>::product =
90 new ReductionFunction<T>(PRODUCT, &::gloo::product<T>);
91template <typename T>
92const ReductionFunction<T>* ReductionFunction<T>::min =
93 new ReductionFunction<T>(MIN, &::gloo::min<T>);
94template <typename T>
95const ReductionFunction<T>* ReductionFunction<T>::max =
96 new ReductionFunction<T>(MAX, &::gloo::max<T>);
97
98// Local operation.
99// If an algorithm uses multiple local pointers, local operations
100// can be used for local reduction, broadcast, gathering, etc.
101template <typename T>
102class LocalOp {
103 public:
104 virtual ~LocalOp() noexcept(false) {}
105 virtual void runAsync() = 0;
106 virtual void wait() = 0;
107
108 // Synchronous run is equal to asynchronous run and wait.
109 inline void run() {
110 runAsync();
111 wait();
112 }
113};
114
115} // namespace gloo
116