1#pragma once
2
3#include <torch/csrc/distributed/c10d/Store.hpp>
4
5#include <chrono>
6#include <cstdint>
7
8#include <ATen/core/ivalue.h>
9#include <ATen/core/Tensor.h>
10
11#include <c10/macros/Macros.h>
12#include <c10/util/intrusive_ptr.h>
13
14namespace c10d {
15
16// Base class for supplementary data potentially needed by ReduceOps
17struct TORCH_API _SupplementBase : torch::CustomClassHolder {
18 ~_SupplementBase() override = default;
19};
20
21// Supplementary data specific to NCCL PREMUL_SUM
22// The point of use in ProcessGroupNCCL knows how to unpack it.
23struct NCCLPreMulSumSupplement : _SupplementBase {
24 double double_factor{0.0};
25 at::Tensor tensor_factor;
26 NCCLPreMulSumSupplement(double f) : double_factor{f} {}
27 NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {
28 TORCH_CHECK_EQ(tensor_factor.numel(), 1);
29 }
30};
31
32// Other ReduceOps that need different supplementary data can also
33// derive from _SupplementBase.
34struct TORCH_API ReduceOp : torch::CustomClassHolder {
35 // note(crcrpar): RedOpType could be defined outside of `ReduceOp`
36 enum RedOpType : uint8_t {
37 SUM = 0,
38 AVG = 1,
39 PRODUCT = 2,
40 MIN = 3,
41 MAX = 4,
42 BAND = 5, // Bitwise AND
43 BOR = 6, // Bitwise OR
44 BXOR = 7, // Bitwise XOR
45 PREMUL_SUM = 8, // Multiply by a user-supplied constant before summing.
46 UNUSED = 9
47 };
48
49 ReduceOp() = default;
50
51 ReduceOp(RedOpType op) : op_(op) {
52 TORCH_INTERNAL_ASSERT(
53 op_ != PREMUL_SUM,
54 "Use `torch.distributed._make_nccl_premul_sum` to create an instance of ReduceOp with PREMUL_SUM"
55 );
56 }
57
58 ReduceOp(RedOpType op, c10::intrusive_ptr<_SupplementBase> optional_supplement) {
59 if (optional_supplement.get()) {
60 op_ = op;
61 } else {
62 supplement_ = optional_supplement;
63 }
64 }
65
66 // The heap resource supplement_, if it exists, is managed by a c10::intrusive_ptr,
67 // so constructors and operator= can be simple
68 ReduceOp(const ReduceOp& other) :
69 op_(other.op_), supplement_(other.supplement_) {}
70
71 const ReduceOp& operator=(const ReduceOp& other) {
72 op_ = other.op_;
73 supplement_ = other.supplement_;
74 return *this;
75 }
76
77 operator RedOpType() const { return op_; }
78
79 bool operator==(const std::uint8_t other) {
80 TORCH_INTERNAL_ASSERT(other < 9, "Invalid other op value");
81 return other == op_;
82 }
83
84 bool operator==(const ReduceOp::RedOpType other) {
85 return *this == static_cast<std::uint8_t>(other);
86 }
87
88 // todo(crcrpar): Handle `RedOpType::PREMUL_SUM` with its scaling factor.
89 bool operator==(const ReduceOp& other) {
90 return *this == other.op_;
91 }
92
93 RedOpType op_ = SUM;
94 // supplement_ is "type-erased" storage for optional supplementary
95 // data the op might need.
96 // The point of use will know the derived type supplement_ really is,
97 // and downcast its pointer to extract the data as the needed type(s).
98 // Right now, only PREMUL_SUM needs supplementary data, but the same
99 // mechanism could extend to support other nontrivial reduce ops with
100 // different supplementary payloads.
101 c10::intrusive_ptr<_SupplementBase> supplement_;
102};
103
104template<typename T> ReduceOp makeNCCLPreMulSum(const T& factor) {
105 ReduceOp rop;
106 rop.op_ = ReduceOp::PREMUL_SUM;
107 rop.supplement_ = c10::make_intrusive<NCCLPreMulSumSupplement>(factor);
108 return rop;
109}
110
111constexpr auto kUnsetTimeout = std::chrono::milliseconds(-1);
112
113struct BroadcastOptions {
114 int64_t rootRank = 0;
115 int64_t rootTensor = 0;
116 std::chrono::milliseconds timeout = kUnsetTimeout;
117};
118
119struct AllreduceOptions {
120 ReduceOp reduceOp = ReduceOp::SUM;
121 std::chrono::milliseconds timeout = kUnsetTimeout;
122};
123
124struct AllreduceCoalescedOptions : AllreduceOptions {};
125
126struct ReduceOptions {
127 ReduceOp reduceOp = ReduceOp::SUM;
128 int64_t rootRank = 0;
129 int64_t rootTensor = 0;
130 std::chrono::milliseconds timeout = kUnsetTimeout;
131};
132
133struct AllgatherOptions {
134 std::chrono::milliseconds timeout = kUnsetTimeout;
135};
136
137struct GatherOptions {
138 int64_t rootRank = 0;
139 std::chrono::milliseconds timeout = kUnsetTimeout;
140};
141
142struct ScatterOptions {
143 int64_t rootRank = 0;
144 std::chrono::milliseconds timeout = kUnsetTimeout;
145};
146
147struct ReduceScatterOptions {
148 ReduceOp reduceOp = ReduceOp::SUM;
149 std::chrono::milliseconds timeout = kUnsetTimeout;
150};
151
152struct AllToAllOptions {
153 std::chrono::milliseconds timeout = kUnsetTimeout;
154};
155
156struct BarrierOptions {
157 std::vector<int64_t> device_ids;
158 std::chrono::milliseconds timeout = kUnsetTimeout;
159};
160
161struct DistributedBackendOptions {
162 c10::intrusive_ptr<::c10d::Store> store;
163 int group_rank;
164 int group_size;
165 std::chrono::duration<float> timeout;
166 std::string group_id;
167 std::vector<int64_t> global_ranks_in_group;
168};
169
170} // namespace c10d
171