1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/cuda/CUDAContext.h> |
5 | #include <c10/util/Optional.h> |
6 | |
7 | #include <cstddef> |
8 | #include <vector> |
9 | |
10 | // NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for |
11 | // HIP 3.1+ |
12 | #if defined(__CUDA_BF16_TYPES_EXIST__) |
13 | #define HAS_NCCL_BF16_DATATYPE \ |
14 | ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) |
15 | #elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301) |
16 | #define HAS_NCCL_BF16_DATATYPE 1 |
17 | #else |
18 | #define HAS_NCCL_BF16_DATATYPE 0 |
19 | #endif |
20 | |
21 | namespace torch { |
22 | namespace cuda { |
23 | namespace nccl { |
24 | |
25 | /* The following are copied from <nccl.h> and redefined in torch::cuda::nccl |
26 | * namespace */ |
27 | /* pytorch should only use the following definition within pytorch scope */ |
28 | |
29 | /* Opaque handle to communicator to ncclComm*, this will reinterpret as ncclComm |
30 | * in nccl.cpp */ |
31 | typedef void* ncclComm_t; |
32 | |
33 | /** redefine nccl unique ID in torch scope. this should be identical to native |
34 | * nccl impp. */ |
35 | #define NCCL_UNIQUE_ID_BYTES 128 |
36 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
37 | typedef struct { |
38 | char internal[NCCL_UNIQUE_ID_BYTES]; |
39 | } ncclUniqueId; |
40 | |
41 | /* Error type */ |
42 | enum class ncclResult { |
43 | Success = 0, |
44 | UnhandledCudaError = 1, |
45 | SystemError = 2, |
46 | InternalError = 3, |
47 | InvalidArgument = 4, |
48 | InvalidUsage = 5, |
49 | NumResults = 6 |
50 | }; |
51 | |
52 | /* Reduction operation selector */ |
53 | enum class ncclRedOp { Sum = 0, Prod = 1, Max = 2, Min = 3, NumOps = 4 }; |
54 | |
55 | /* Data types */ |
56 | enum class ncclDataType { |
57 | Int8 = 0, |
58 | Char = 0, |
59 | Uint8 = 1, |
60 | Int32 = 2, |
61 | Int = 2, |
62 | Uint32 = 3, |
63 | Int64 = 4, |
64 | Uint64 = 5, |
65 | Float16 = 6, |
66 | Half = 6, |
67 | Float32 = 7, |
68 | Float = 7, |
69 | Float64 = 8, |
70 | Double = 8, |
71 | Bfloat16 = 9, |
72 | NumTypes = 10 |
73 | }; |
74 | |
75 | // RAII helper class to manage NCCL group API and CUDA free mutex. |
76 | // The destructor is allowed to throw since this helper class only |
77 | // manages group and lock lifetimes. |
78 | struct AutoNcclGroup { |
79 | AutoNcclGroup(); |
80 | ~AutoNcclGroup() noexcept(false); |
81 | }; |
82 | |
83 | // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers. |
84 | // Don't use them outside of these files. |
85 | namespace detail { |
86 | |
87 | TORCH_CUDA_CPP_API void throw_nccl_error(ncclResult status); |
88 | |
89 | static inline void NCCL_CHECK(ncclResult status) { |
90 | if (status != ncclResult::Success) { |
91 | throw_nccl_error(status); |
92 | } |
93 | } |
94 | |
95 | TORCH_CUDA_CPP_API at::ArrayRef<ncclComm_t> get_communicators( |
96 | at::TensorList inputs); |
97 | TORCH_CUDA_CPP_API void check_inputs( |
98 | at::TensorList inputs, |
99 | at::TensorList outputs, |
100 | int input_multiplier, |
101 | int output_multiplier); |
102 | TORCH_CUDA_CPP_API void check_inputs( |
103 | at::TensorList inputs, |
104 | const at::Tensor& output, |
105 | int root, |
106 | int input_multiplier, |
107 | int output_multiplier); |
108 | |
109 | } // namespace detail |
110 | |
111 | using comm_list = std::vector<ncclComm_t>; |
112 | using stream_list = std::vector<c10::optional<at::cuda::CUDAStream>>; |
113 | |
114 | TORCH_CUDA_CPP_API std::uint64_t version(); |
115 | |
116 | bool is_available(at::TensorList tensors); |
117 | |
118 | TORCH_CUDA_CPP_API void get_unique_id(ncclUniqueId& id); |
119 | TORCH_CUDA_CPP_API ncclComm_t |
120 | comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank); |
121 | TORCH_CUDA_CPP_API void comm_destroy(ncclComm_t comm); |
122 | |
123 | TORCH_CUDA_CPP_API void broadcast( |
124 | at::TensorList tensors, |
125 | const stream_list& streams = {}, |
126 | const comm_list& user_comms = {}); |
127 | |
128 | size_t get_max_count(); |
129 | |
130 | TORCH_CUDA_CPP_API void reduce( |
131 | const std::vector<at::Tensor>& inputs, |
132 | at::Tensor& output, |
133 | int32_t root = 0, |
134 | int32_t op = static_cast<int>(ncclRedOp::Sum), |
135 | const stream_list& streams = {}, |
136 | const comm_list& user_comms = {}); |
137 | |
138 | TORCH_CUDA_CPP_API void reduce( |
139 | std::vector<at::Tensor>& inputs, |
140 | int32_t root = 0, |
141 | int32_t op = static_cast<int>(ncclRedOp::Sum), |
142 | const stream_list& streams = {}, |
143 | const comm_list& user_comms = {}); |
144 | |
145 | TORCH_CUDA_CPP_API void all_reduce( |
146 | const std::vector<at::Tensor>& inputs, |
147 | std::vector<at::Tensor>& outputs, |
148 | int32_t op = static_cast<int>(ncclRedOp::Sum), |
149 | const stream_list& streams = {}, |
150 | const comm_list& user_comms = {}); |
151 | |
152 | TORCH_CUDA_CPP_API void reduce_scatter( |
153 | const std::vector<at::Tensor>& inputs, |
154 | std::vector<at::Tensor>& outputs, |
155 | int32_t op = static_cast<int>(ncclRedOp::Sum), |
156 | const stream_list& streams = {}, |
157 | const comm_list& user_comms = {}); |
158 | |
159 | TORCH_CUDA_CPP_API void scatter( |
160 | const std::vector<at::Tensor>& inputs, |
161 | at::Tensor& outputs, |
162 | ncclComm_t comm, |
163 | at::cuda::CUDAStream& stream, |
164 | int32_t root = 0); |
165 | |
166 | TORCH_CUDA_CPP_API void all_gather( |
167 | const std::vector<at::Tensor>& inputs, |
168 | std::vector<at::Tensor>& outputs, |
169 | const stream_list& streams = {}, |
170 | const comm_list& user_comms = {}); |
171 | |
172 | TORCH_CUDA_CPP_API void gather( |
173 | const at::Tensor& inputs, |
174 | std::vector<at::Tensor>& outputs, |
175 | ncclComm_t comm, |
176 | at::cuda::CUDAStream& stream, |
177 | int32_t root = 0); |
178 | |
179 | TORCH_CUDA_CPP_API void all2all_single_equal_split( |
180 | at::Tensor& input, |
181 | at::Tensor& output, |
182 | int size, |
183 | ncclComm_t comm, |
184 | at::cuda::CUDAStream& stream); |
185 | |
186 | TORCH_CUDA_CPP_API void all2all_single_unequal_split( |
187 | void* sendbuff, |
188 | const size_t* sendcounts, |
189 | const size_t* senddispls, |
190 | void* recvbuff, |
191 | const size_t* recvcounts, |
192 | const size_t* recvdispls, |
193 | size_t size, |
194 | c10::ScalarType type, |
195 | ncclComm_t comm, |
196 | at::cuda::CUDAStream& stream); |
197 | |
198 | TORCH_CUDA_CPP_API void all2all( |
199 | std::vector<at::Tensor>& outputTensors, |
200 | std::vector<at::Tensor>& inputTensors, |
201 | ncclComm_t _comm, |
202 | at::cuda::CUDAStream& stream); |
203 | |
204 | TORCH_CUDA_CPP_API void send( |
205 | const at::Tensor& input, |
206 | ncclComm_t comm, |
207 | at::cuda::CUDAStream stream, |
208 | int dst); |
209 | |
210 | TORCH_CUDA_CPP_API void recv( |
211 | at::Tensor& output, |
212 | ncclComm_t comm, |
213 | at::cuda::CUDAStream stream, |
214 | int src); |
215 | } // namespace nccl |
216 | } // namespace cuda |
217 | } // namespace torch |
218 | |