1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/c10d/Store.hpp> |
4 | |
5 | #include <chrono> |
6 | #include <cstdint> |
7 | |
8 | #include <ATen/core/ivalue.h> |
9 | #include <ATen/core/Tensor.h> |
10 | |
11 | #include <c10/macros/Macros.h> |
12 | #include <c10/util/intrusive_ptr.h> |
13 | |
14 | namespace c10d { |
15 | |
16 | // Base class for supplementary data potentially needed by ReduceOps |
17 | struct TORCH_API _SupplementBase : torch::CustomClassHolder { |
18 | ~_SupplementBase() override = default; |
19 | }; |
20 | |
21 | // Supplementary data specific to NCCL PREMUL_SUM |
22 | // The point of use in ProcessGroupNCCL knows how to unpack it. |
23 | struct NCCLPreMulSumSupplement : _SupplementBase { |
24 | double double_factor{0.0}; |
25 | at::Tensor tensor_factor; |
26 | NCCLPreMulSumSupplement(double f) : double_factor{f} {} |
27 | NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} { |
28 | TORCH_CHECK_EQ(tensor_factor.numel(), 1); |
29 | } |
30 | }; |
31 | |
32 | // Other ReduceOps that need different supplementary data can also |
33 | // derive from _SupplementBase. |
34 | struct TORCH_API ReduceOp : torch::CustomClassHolder { |
35 | // note(crcrpar): RedOpType could be defined outside of `ReduceOp` |
36 | enum RedOpType : uint8_t { |
37 | SUM = 0, |
38 | AVG = 1, |
39 | PRODUCT = 2, |
40 | MIN = 3, |
41 | MAX = 4, |
42 | BAND = 5, // Bitwise AND |
43 | BOR = 6, // Bitwise OR |
44 | BXOR = 7, // Bitwise XOR |
45 | PREMUL_SUM = 8, // Multiply by a user-supplied constant before summing. |
46 | UNUSED = 9 |
47 | }; |
48 | |
49 | ReduceOp() = default; |
50 | |
51 | ReduceOp(RedOpType op) : op_(op) { |
52 | TORCH_INTERNAL_ASSERT( |
53 | op_ != PREMUL_SUM, |
54 | "Use `torch.distributed._make_nccl_premul_sum` to create an instance of ReduceOp with PREMUL_SUM" |
55 | ); |
56 | } |
57 | |
58 | ReduceOp(RedOpType op, c10::intrusive_ptr<_SupplementBase> optional_supplement) { |
59 | if (optional_supplement.get()) { |
60 | op_ = op; |
61 | } else { |
62 | supplement_ = optional_supplement; |
63 | } |
64 | } |
65 | |
66 | // The heap resource supplement_, if it exists, is managed by a c10::intrusive_ptr, |
67 | // so constructors and operator= can be simple |
68 | ReduceOp(const ReduceOp& other) : |
69 | op_(other.op_), supplement_(other.supplement_) {} |
70 | |
71 | const ReduceOp& operator=(const ReduceOp& other) { |
72 | op_ = other.op_; |
73 | supplement_ = other.supplement_; |
74 | return *this; |
75 | } |
76 | |
77 | operator RedOpType() const { return op_; } |
78 | |
79 | bool operator==(const std::uint8_t other) { |
80 | TORCH_INTERNAL_ASSERT(other < 9, "Invalid other op value" ); |
81 | return other == op_; |
82 | } |
83 | |
84 | bool operator==(const ReduceOp::RedOpType other) { |
85 | return *this == static_cast<std::uint8_t>(other); |
86 | } |
87 | |
88 | // todo(crcrpar): Handle `RedOpType::PREMUL_SUM` with its scaling factor. |
89 | bool operator==(const ReduceOp& other) { |
90 | return *this == other.op_; |
91 | } |
92 | |
93 | RedOpType op_ = SUM; |
94 | // supplement_ is "type-erased" storage for optional supplementary |
95 | // data the op might need. |
96 | // The point of use will know the derived type supplement_ really is, |
97 | // and downcast its pointer to extract the data as the needed type(s). |
98 | // Right now, only PREMUL_SUM needs supplementary data, but the same |
99 | // mechanism could extend to support other nontrivial reduce ops with |
100 | // different supplementary payloads. |
101 | c10::intrusive_ptr<_SupplementBase> supplement_; |
102 | }; |
103 | |
104 | template<typename T> ReduceOp makeNCCLPreMulSum(const T& factor) { |
105 | ReduceOp rop; |
106 | rop.op_ = ReduceOp::PREMUL_SUM; |
107 | rop.supplement_ = c10::make_intrusive<NCCLPreMulSumSupplement>(factor); |
108 | return rop; |
109 | } |
110 | |
111 | constexpr auto kUnsetTimeout = std::chrono::milliseconds(-1); |
112 | |
113 | struct BroadcastOptions { |
114 | int64_t rootRank = 0; |
115 | int64_t rootTensor = 0; |
116 | std::chrono::milliseconds timeout = kUnsetTimeout; |
117 | }; |
118 | |
119 | struct AllreduceOptions { |
120 | ReduceOp reduceOp = ReduceOp::SUM; |
121 | std::chrono::milliseconds timeout = kUnsetTimeout; |
122 | }; |
123 | |
124 | struct AllreduceCoalescedOptions : AllreduceOptions {}; |
125 | |
126 | struct ReduceOptions { |
127 | ReduceOp reduceOp = ReduceOp::SUM; |
128 | int64_t rootRank = 0; |
129 | int64_t rootTensor = 0; |
130 | std::chrono::milliseconds timeout = kUnsetTimeout; |
131 | }; |
132 | |
133 | struct AllgatherOptions { |
134 | std::chrono::milliseconds timeout = kUnsetTimeout; |
135 | }; |
136 | |
137 | struct GatherOptions { |
138 | int64_t rootRank = 0; |
139 | std::chrono::milliseconds timeout = kUnsetTimeout; |
140 | }; |
141 | |
142 | struct ScatterOptions { |
143 | int64_t rootRank = 0; |
144 | std::chrono::milliseconds timeout = kUnsetTimeout; |
145 | }; |
146 | |
147 | struct ReduceScatterOptions { |
148 | ReduceOp reduceOp = ReduceOp::SUM; |
149 | std::chrono::milliseconds timeout = kUnsetTimeout; |
150 | }; |
151 | |
152 | struct AllToAllOptions { |
153 | std::chrono::milliseconds timeout = kUnsetTimeout; |
154 | }; |
155 | |
156 | struct BarrierOptions { |
157 | std::vector<int64_t> device_ids; |
158 | std::chrono::milliseconds timeout = kUnsetTimeout; |
159 | }; |
160 | |
161 | struct DistributedBackendOptions { |
162 | c10::intrusive_ptr<::c10d::Store> store; |
163 | int group_rank; |
164 | int group_size; |
165 | std::chrono::duration<float> timeout; |
166 | std::string group_id; |
167 | std::vector<int64_t> global_ranks_in_group; |
168 | }; |
169 | |
170 | } // namespace c10d |
171 | |