1#pragma once
2
3#include <ATen/ATen.h>
4#include <ATen/cuda/CUDAContext.h>
5#include <c10/util/Optional.h>
6
7#include <cstddef>
8#include <vector>
9
10// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for
11// HIP 3.1+
12#if defined(__CUDA_BF16_TYPES_EXIST__)
13#define HAS_NCCL_BF16_DATATYPE \
14 ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
15#elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301)
16#define HAS_NCCL_BF16_DATATYPE 1
17#else
18#define HAS_NCCL_BF16_DATATYPE 0
19#endif
20
21namespace torch {
22namespace cuda {
23namespace nccl {
24
25/* The following are copied from <nccl.h> and redefined in torch::cuda::nccl
26 * namespace */
27/* pytorch should only use the following definition within pytorch scope */
28
29/* Opaque handle to communicator to ncclComm*, this will reinterpret as ncclComm
30 * in nccl.cpp */
31typedef void* ncclComm_t;
32
33/** redefine nccl unique ID in torch scope. this should be identical to native
34 * nccl impp. */
35#define NCCL_UNIQUE_ID_BYTES 128
36// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
37typedef struct {
38 char internal[NCCL_UNIQUE_ID_BYTES];
39} ncclUniqueId;
40
41/* Error type */
42enum class ncclResult {
43 Success = 0,
44 UnhandledCudaError = 1,
45 SystemError = 2,
46 InternalError = 3,
47 InvalidArgument = 4,
48 InvalidUsage = 5,
49 NumResults = 6
50};
51
52/* Reduction operation selector */
53enum class ncclRedOp { Sum = 0, Prod = 1, Max = 2, Min = 3, NumOps = 4 };
54
55/* Data types */
56enum class ncclDataType {
57 Int8 = 0,
58 Char = 0,
59 Uint8 = 1,
60 Int32 = 2,
61 Int = 2,
62 Uint32 = 3,
63 Int64 = 4,
64 Uint64 = 5,
65 Float16 = 6,
66 Half = 6,
67 Float32 = 7,
68 Float = 7,
69 Float64 = 8,
70 Double = 8,
71 Bfloat16 = 9,
72 NumTypes = 10
73};
74
75// RAII helper class to manage NCCL group API and CUDA free mutex.
76// The destructor is allowed to throw since this helper class only
77// manages group and lock lifetimes.
78struct AutoNcclGroup {
79 AutoNcclGroup();
80 ~AutoNcclGroup() noexcept(false);
81};
82
83// NOTE: this is exposed only so that python_nccl.cpp can some of these helpers.
84// Don't use them outside of these files.
85namespace detail {
86
87TORCH_CUDA_CPP_API void throw_nccl_error(ncclResult status);
88
89static inline void NCCL_CHECK(ncclResult status) {
90 if (status != ncclResult::Success) {
91 throw_nccl_error(status);
92 }
93}
94
95TORCH_CUDA_CPP_API at::ArrayRef<ncclComm_t> get_communicators(
96 at::TensorList inputs);
97TORCH_CUDA_CPP_API void check_inputs(
98 at::TensorList inputs,
99 at::TensorList outputs,
100 int input_multiplier,
101 int output_multiplier);
102TORCH_CUDA_CPP_API void check_inputs(
103 at::TensorList inputs,
104 const at::Tensor& output,
105 int root,
106 int input_multiplier,
107 int output_multiplier);
108
109} // namespace detail
110
111using comm_list = std::vector<ncclComm_t>;
112using stream_list = std::vector<c10::optional<at::cuda::CUDAStream>>;
113
114TORCH_CUDA_CPP_API std::uint64_t version();
115
116bool is_available(at::TensorList tensors);
117
118TORCH_CUDA_CPP_API void get_unique_id(ncclUniqueId& id);
119TORCH_CUDA_CPP_API ncclComm_t
120comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank);
121TORCH_CUDA_CPP_API void comm_destroy(ncclComm_t comm);
122
123TORCH_CUDA_CPP_API void broadcast(
124 at::TensorList tensors,
125 const stream_list& streams = {},
126 const comm_list& user_comms = {});
127
128size_t get_max_count();
129
130TORCH_CUDA_CPP_API void reduce(
131 const std::vector<at::Tensor>& inputs,
132 at::Tensor& output,
133 int32_t root = 0,
134 int32_t op = static_cast<int>(ncclRedOp::Sum),
135 const stream_list& streams = {},
136 const comm_list& user_comms = {});
137
138TORCH_CUDA_CPP_API void reduce(
139 std::vector<at::Tensor>& inputs,
140 int32_t root = 0,
141 int32_t op = static_cast<int>(ncclRedOp::Sum),
142 const stream_list& streams = {},
143 const comm_list& user_comms = {});
144
145TORCH_CUDA_CPP_API void all_reduce(
146 const std::vector<at::Tensor>& inputs,
147 std::vector<at::Tensor>& outputs,
148 int32_t op = static_cast<int>(ncclRedOp::Sum),
149 const stream_list& streams = {},
150 const comm_list& user_comms = {});
151
152TORCH_CUDA_CPP_API void reduce_scatter(
153 const std::vector<at::Tensor>& inputs,
154 std::vector<at::Tensor>& outputs,
155 int32_t op = static_cast<int>(ncclRedOp::Sum),
156 const stream_list& streams = {},
157 const comm_list& user_comms = {});
158
159TORCH_CUDA_CPP_API void scatter(
160 const std::vector<at::Tensor>& inputs,
161 at::Tensor& outputs,
162 ncclComm_t comm,
163 at::cuda::CUDAStream& stream,
164 int32_t root = 0);
165
166TORCH_CUDA_CPP_API void all_gather(
167 const std::vector<at::Tensor>& inputs,
168 std::vector<at::Tensor>& outputs,
169 const stream_list& streams = {},
170 const comm_list& user_comms = {});
171
172TORCH_CUDA_CPP_API void gather(
173 const at::Tensor& inputs,
174 std::vector<at::Tensor>& outputs,
175 ncclComm_t comm,
176 at::cuda::CUDAStream& stream,
177 int32_t root = 0);
178
179TORCH_CUDA_CPP_API void all2all_single_equal_split(
180 at::Tensor& input,
181 at::Tensor& output,
182 int size,
183 ncclComm_t comm,
184 at::cuda::CUDAStream& stream);
185
186TORCH_CUDA_CPP_API void all2all_single_unequal_split(
187 void* sendbuff,
188 const size_t* sendcounts,
189 const size_t* senddispls,
190 void* recvbuff,
191 const size_t* recvcounts,
192 const size_t* recvdispls,
193 size_t size,
194 c10::ScalarType type,
195 ncclComm_t comm,
196 at::cuda::CUDAStream& stream);
197
198TORCH_CUDA_CPP_API void all2all(
199 std::vector<at::Tensor>& outputTensors,
200 std::vector<at::Tensor>& inputTensors,
201 ncclComm_t _comm,
202 at::cuda::CUDAStream& stream);
203
204TORCH_CUDA_CPP_API void send(
205 const at::Tensor& input,
206 ncclComm_t comm,
207 at::cuda::CUDAStream stream,
208 int dst);
209
210TORCH_CUDA_CPP_API void recv(
211 at::Tensor& output,
212 ncclComm_t comm,
213 at::cuda::CUDAStream stream,
214 int src);
215} // namespace nccl
216} // namespace cuda
217} // namespace torch
218