1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #pragma once |
10 | |
11 | #include <memory> |
12 | |
13 | #include "gloo/context.h" |
14 | #include "gloo/math.h" |
15 | |
16 | namespace gloo { |
17 | |
18 | extern const size_t kOnDeviceThreshold; |
19 | |
20 | class Algorithm { |
21 | public: |
22 | explicit Algorithm(const std::shared_ptr<Context>&); |
23 | virtual ~Algorithm() noexcept(false) = 0; |
24 | |
25 | virtual void run() = 0; |
26 | |
27 | protected: |
28 | std::shared_ptr<Context> context_; |
29 | |
30 | const int ; |
31 | const int contextSize_; |
32 | |
33 | std::unique_ptr<transport::Pair>& getPair(int i); |
34 | |
35 | // Helpers for ring algorithms |
36 | std::unique_ptr<transport::Pair>& getLeftPair(); |
37 | std::unique_ptr<transport::Pair>& getRightPair(); |
38 | }; |
39 | |
40 | // Type of reduction function. |
41 | // |
42 | // If the reduction type is one of the built-ins, algorithm |
43 | // implementations may use accelerated versions if available. |
44 | // |
45 | // For example, if a ReductionFunction with ReductionType equal |
46 | // SUM is passed to CUDA aware Allreduce, it knows it can |
47 | // use a NCCL implementation instead of the specified function. |
48 | // |
49 | enum ReductionType { |
50 | SUM = 1, |
51 | PRODUCT = 2, |
52 | MAX = 3, |
53 | MIN = 4, |
54 | |
55 | // Use larger number so we have plenty of room to add built-ins |
56 | CUSTOM = 1000, |
57 | }; |
58 | |
59 | template <typename T> |
60 | class ReductionFunction { |
61 | public: |
62 | using Function = void(T*, const T*, size_t n); |
63 | |
64 | static const ReductionFunction<T>* sum; |
65 | static const ReductionFunction<T>* product; |
66 | static const ReductionFunction<T>* min; |
67 | static const ReductionFunction<T>* max; |
68 | |
69 | ReductionFunction(ReductionType type, Function* fn) |
70 | : type_(type), fn_(fn) {} |
71 | |
72 | ReductionType type() const { |
73 | return type_; |
74 | } |
75 | |
76 | void call(T* x, const T* y, size_t n) const { |
77 | fn_(x, y, n); |
78 | } |
79 | |
80 | protected: |
81 | ReductionType type_; |
82 | Function* fn_; |
83 | }; |
84 | |
85 | template <typename T> |
86 | const ReductionFunction<T>* ReductionFunction<T>::sum = |
87 | new ReductionFunction<T>(SUM, &::gloo::sum<T>); |
88 | template <typename T> |
89 | const ReductionFunction<T>* ReductionFunction<T>::product = |
90 | new ReductionFunction<T>(PRODUCT, &::gloo::product<T>); |
91 | template <typename T> |
92 | const ReductionFunction<T>* ReductionFunction<T>::min = |
93 | new ReductionFunction<T>(MIN, &::gloo::min<T>); |
94 | template <typename T> |
95 | const ReductionFunction<T>* ReductionFunction<T>::max = |
96 | new ReductionFunction<T>(MAX, &::gloo::max<T>); |
97 | |
98 | // Local operation. |
99 | // If an algorithm uses multiple local pointers, local operations |
100 | // can be used for local reduction, broadcast, gathering, etc. |
101 | template <typename T> |
102 | class LocalOp { |
103 | public: |
104 | virtual ~LocalOp() noexcept(false) {} |
105 | virtual void runAsync() = 0; |
106 | virtual void wait() = 0; |
107 | |
108 | // Synchronous run is equal to asynchronous run and wait. |
109 | inline void run() { |
110 | runAsync(); |
111 | wait(); |
112 | } |
113 | }; |
114 | |
115 | } // namespace gloo |
116 | |