1#include <torch/csrc/cuda/comm.h>
2
3#include <torch/csrc/cuda/device_set.h>
4#include <torch/csrc/utils/tensor_flatten.h>
5
6#ifdef USE_NCCL
7#include <torch/csrc/cuda/nccl.h>
8#endif
9
10#include <ATen/ATen.h>
11#include <ATen/WrapDimUtils.h>
12#include <ATen/cuda/CUDAContext.h>
13#include <c10/cuda/CUDAGuard.h>
14#include <c10/util/Optional.h>
15#include <c10/util/irange.h>
16#include <torch/csrc/autograd/variable.h>
17
18#include <cstddef>
19#include <vector>
20
21namespace torch {
22namespace cuda {
23using namespace at;
24using namespace torch::autograd;
25
26// Some operations can be performed more efficiently if we're handling tensors
27// of a single type only. Adding this logic directly in the loop makes it a bit
28// ugly, so here's a helper for it.
29struct unique_type_checker {
30 void show(size_t type_id) {
31 if (!unique) {
32 return;
33 }
34 if (!type_id_) {
35 type_id_ = type_id;
36 }
37
38 unique = type_id_.value() == type_id;
39 }
40
41 c10::optional<size_t> type_id_;
42 bool unique = true;
43};
44
45// ***************** Broadcast *******************
46//
47// Broadcast a source tensor (CPU or CUDA) to a list of CUDA devices, or CUDA
48// tensors on one or more devices.
49
50// no checks
51static inline std::vector<Tensor>& _broadcast_out_impl(
52 const Tensor& tensor,
53 std::vector<Tensor>& out_tensors) {
54#ifdef USE_NCCL
55 std::vector<Tensor> nccl_list;
56 nccl_list.reserve(out_tensors.size() + 1);
57 nccl_list.emplace_back(tensor);
58 for (auto& out_tensor : out_tensors) {
59 nccl_list.emplace_back(out_tensor);
60 }
61 if (nccl::is_available(nccl_list)) {
62 nccl::broadcast(nccl_list);
63 } else {
64#else
65 {
66#endif
67 for (auto& out_tensor : out_tensors) {
68 out_tensor.copy_(tensor, /*non_blocking=*/true);
69 }
70 }
71 return out_tensors;
72}
73
74std::vector<Tensor>& broadcast_out(
75 const Tensor& tensor,
76 std::vector<Tensor>& out_tensors) {
77 for (const auto i : c10::irange(out_tensors.size())) {
78 TORCH_CHECK(
79 out_tensors[i].is_cuda(),
80 "Expected all output tensors to be CUDA tensors, but output tensor at index ",
81 i,
82 " has device '",
83 out_tensors[i].device(),
84 "'");
85 TORCH_CHECK(
86 out_tensors[i].sizes() == tensor.sizes(),
87 "Expected all output tensors to have same shape as the source tensor ",
88 tensor.sizes(),
89 ", but output tensor at index ",
90 i,
91 " has shape ",
92 out_tensors[i].sizes());
93 }
94 return _broadcast_out_impl(tensor, out_tensors);
95}
96
97std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
98 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
99 std::vector<Tensor> diff_device_dst_tensors;
100 diff_device_dst_tensors.reserve(devices.size());
101 for (auto device : devices) {
102 TORCH_CHECK(
103 device >= 0, "Expected non-negative device index, but got ", device);
104 if (device != tensor.get_device()) {
105 diff_device_dst_tensors.emplace_back(at::empty(
106 tensor.sizes(),
107 tensor.options().device(
108 at::Device(DeviceType::CUDA, device)))); // preserve memory format
109 }
110 }
111 _broadcast_out_impl(tensor, diff_device_dst_tensors);
112 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
113 std::vector<Tensor> dst_tensors;
114 dst_tensors.reserve(devices.size());
115 auto it = diff_device_dst_tensors.begin();
116 for (auto device : devices) {
117 // NOLINTNEXTLINE(bugprone-branch-clone)
118 if (device != tensor.get_device()) {
119 dst_tensors.emplace_back(*it++);
120 } else {
121 dst_tensors.emplace_back(tensor);
122 }
123 }
124 TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end());
125 return dst_tensors;
126}
127
128// NOTE [ Version Counter in comm.*_coalesced ]
129//
130// broadcast_coalesced
131// ~~~~~~~~~~~~~~~~~~~
132//
133// In broadcast_coalesced, multiple variables may be coalesced into a single
134// large one, broadcast to other devices, and the get split according to the
135// original shapes.
136//
137// When splitting, the view operations will make all Variables broadcast
138// together to share a single version counter, because they are all views of the
139// large Variable. However, that large Variable is immediately discarded and all
140// these Variables do not share storage at all.
141//
142// For example, when two buffers are broadcast together in `DataParallel` and
143// one of them is modified in-place during `forward` but the other is needed in
144// backward, autograd engine will complain.
145//
146// We thus re-wrap these Variables after broadcasting (i.e., effectively doing
147// what is equivalent to .data in Python), and give them individual version
148// counters.
149//
150// NB: Just calling detach() on the variables is not sufficient
151//
152// NB: For `device[0]` in broadcast_coalesced, the input Variables are always
153// returned as-is, so **do not** re-wrap them.
154//
155// reduce_add_coalesced
156// ~~~~~~~~~~~~~~~~~~~~
157//
158// Similarly for reduce_add_coalesced, when the output are newly created
159// Variables.
160tensor_list2d broadcast_coalesced(
161 TensorList tensors,
162 IntArrayRef devices,
163 size_t buffer_size) {
164 TORCH_CHECK(
165 std::all_of(
166 tensors.begin(),
167 tensors.end(),
168 [&](const at::Tensor& t) { return t.get_device() == devices[0]; }),
169 "All tensors must be on devices[0]: ",
170 devices[0]);
171#ifdef USE_NCCL
172 buffer_size = std::min(torch::cuda::nccl::get_max_count(), buffer_size);
173#endif
174
175 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
176 tensor_list2d outputs(devices.size());
177 outputs[0] = tensors.vec();
178 for (auto& o : outputs)
179 o.reserve(tensors.size());
180
181 unique_type_checker type_checker;
182 at::cuda::CUDAGuard device_guard(devices[0]);
183 for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
184 auto type_id = chunk.type_id();
185 type_checker.show(type_id);
186 std::vector<at::Tensor> results;
187 if (chunk.options().is_sparse()) {
188 auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors);
189 auto broadcast_indices = broadcast(flat_tuple.first, devices);
190 auto broadcast_values = broadcast(flat_tuple.second, devices);
191 results.reserve(devices.size());
192 for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
193 device_guard.set_index(devices[i]);
194 auto& device_outputs = outputs[i];
195 auto& inds = broadcast_indices[i];
196 auto& vals = broadcast_values[i];
197 for (const auto& var : torch::utils::unflatten_sparse_tensors(
198 inds, vals, chunk.tensors)) {
199 // See NOTE [ Version Counter in comm.*_coalesced ]
200 device_outputs.emplace_back(make_variable(var.tensor_data(), false));
201 }
202 }
203 } else {
204 auto results = broadcast(
205 torch::utils::flatten_dense_tensors(chunk.tensors), devices);
206 for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
207 device_guard.set_index(devices[i]);
208 auto& device_outputs = outputs[i];
209 for (auto& var :
210 torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
211 // See NOTE [ Version Counter in comm.*_coalesced ]
212 device_outputs.emplace_back(make_variable(var.tensor_data(), false));
213 }
214 }
215 }
216 }
217
218 // If we only saw a single tensor type, then we can skip expensive reordering
219 if (!type_checker.unique) {
220 for (auto& o : outputs)
221 torch::utils::reorder_tensors_like(o, tensors);
222 }
223 return outputs;
224}
225
226// ***************** Scatter *******************
227//
228// Scatter a source tensor (CPU or CUDA) to a list of CUDA tensors on one or
229// more devices.
230
231std::vector<at::Tensor>& scatter_out(
232 const at::Tensor& tensor,
233 std::vector<at::Tensor>& out_tensors,
234 int64_t dim,
235 const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>&
236 streams) {
237 TORCH_CHECK(
238 !out_tensors.empty(),
239 "Expected at least one output tensor to scatter to");
240 dim = at::maybe_wrap_dim(dim, tensor);
241 int64_t total_size = 0;
242 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
243 std::vector<int64_t> chunk_sizes;
244 chunk_sizes.reserve(out_tensors.size());
245 for (const auto i : c10::irange(out_tensors.size())) {
246 TORCH_CHECK(
247 out_tensors[i].is_cuda(),
248 "Expected all output tensors to be CUDA tensors, but output tensor at index ",
249 i,
250 " has device '",
251 out_tensors[i].device(),
252 "'");
253 auto out_sizes = out_tensors[i].sizes().vec();
254 // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
255 bool same_ndim = out_sizes.size() == tensor.dim();
256 if (same_ndim) {
257 total_size += out_sizes[dim];
258 chunk_sizes.emplace_back(out_sizes[dim]);
259 out_sizes[dim] = tensor.size(dim);
260 }
261 TORCH_CHECK(
262 same_ndim && out_sizes == tensor.sizes(),
263 "Output tensor at index ",
264 i,
265 " has incorrect shape: ",
266 out_tensors[i].sizes(),
267 ". Expected same "
268 "shape except for scatter dim ",
269 dim,
270 " as the source tensor: ",
271 at::IntArrayRef(tensor.sizes()));
272 }
273 TORCH_CHECK(
274 total_size == tensor.size(dim),
275 "Total size for output tensors along scatter dim ",
276 dim,
277 " does not match "
278 "the source tensor size at dim ",
279 dim,
280 ". Expected ",
281 tensor.size(dim),
282 ", but got total size ",
283 total_size);
284
285 auto chunks =
286 tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim);
287 at::cuda::OptionalCUDAStreamGuard cuda_guard;
288 for (const auto i : c10::irange(chunks.size())) {
289 if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
290 const auto device_index =
291 static_cast<int16_t>(out_tensors[i].get_device());
292 TORCH_CHECK(
293 (*streams)[i]->device_index() == device_index,
294 "Expected the device associated with the stream at index ",
295 i,
296 " (was ",
297 (*streams)[i]->device_index(),
298 ") ",
299 "to match the device supplied at that index ",
300 "(expected ",
301 device_index,
302 ")");
303 cuda_guard.reset_stream(*(*streams)[i]);
304 }
305 // NB: We don't detect the case where `out_tensor` is already the correct
306 // view of `tensor` since that would be nontrivial and involve checking
307 // ptr, offset, and strides. So `scatter_out(src, src.chunk(...))` does
308 // more copying than `scatter(src)`.
309 out_tensors[i].copy_(chunks[i], /*non_blocking=*/true);
310 }
311 return out_tensors;
312}
313
314std::vector<at::Tensor> scatter(
315 const at::Tensor& tensor,
316 at::IntArrayRef devices,
317 const c10::optional<std::vector<int64_t>>& chunk_sizes,
318 int64_t dim,
319 const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>&
320 streams) {
321 TORCH_CHECK(!devices.empty(), "Expected at least one device to scatter to");
322 if (chunk_sizes.has_value()) {
323 TORCH_CHECK(
324 chunk_sizes->size() == devices.size(),
325 "Expected devices and chunk_sizes to be of same length, but got "
326 "len(devices) = ",
327 devices.size(),
328 " and len(chunk_sizes) = ",
329 chunk_sizes->size());
330 }
331 dim = at::maybe_wrap_dim(dim, tensor);
332 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
333 std::vector<at::Tensor> chunks = chunk_sizes
334 ? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim)
335 : tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim);
336 at::cuda::OptionalCUDAStreamGuard cuda_guard;
337 for (const auto i : c10::irange(chunks.size())) {
338 const auto device_index = static_cast<int16_t>(devices[i]);
339 if (device_index != tensor.get_device()) {
340 if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
341 TORCH_CHECK(
342 (*streams)[i]->device_index() == device_index,
343 "Expected the device associated with the stream at index ",
344 i,
345 " (was ",
346 (*streams)[i]->device_index(),
347 ") ",
348 "to match the device supplied at that index ",
349 "(expected ",
350 device_index,
351 ")");
352 cuda_guard.reset_stream(*(*streams)[i]);
353 }
354 TORCH_CHECK(
355 device_index >= 0,
356 "Expected non-negative device index, but got ",
357 device_index);
358 chunks[i] = chunks[i].to(
359 {DeviceType::CUDA, device_index},
360 /*non_blocking=*/true,
361 /*copy=*/false,
362 /*memory_format=*/at::MemoryFormat::Preserve);
363 }
364 }
365 return chunks;
366}
367
368// ***************** Gather *******************
369//
370// Gather a list of CUDA tensors on one or more devices to a target tensor or
371// device, either CPU or CUDA.
372
373// no checks
374static inline at::Tensor& _gather_out_impl(
375 at::TensorList tensors,
376 at::Tensor& out_tensor,
377 int64_t dim) {
378 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
379 std::vector<int64_t> chunk_sizes;
380 chunk_sizes.reserve(tensors.size());
381 for (auto& tensor : tensors) {
382 chunk_sizes.emplace_back(tensor.size(dim));
383 }
384 auto chunks =
385 out_tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim);
386 for (const auto i : c10::irange(tensors.size())) {
387 chunks[i].copy_(tensors[i], /*non_blocking=*/out_tensor.is_cuda());
388 }
389 return out_tensor;
390}
391
392at::Tensor& gather_out(
393 at::TensorList tensors,
394 at::Tensor& out_tensor,
395 int64_t dim) {
396 TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from");
397 int64_t total_size = 0;
398 auto& first = tensors.front();
399 const auto first_size = first.sizes();
400 dim = at::maybe_wrap_dim(dim, first);
401 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
402 std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
403 for (const auto i : c10::irange(tensors.size())) {
404 const auto& tensor = tensors[i];
405 TORCH_CHECK(
406 tensor.is_cuda(),
407 "Expected all input tensors to be CUDA tensors, but "
408 "tensor at index ",
409 i,
410 " has device '",
411 tensor.device(),
412 "'");
413 TORCH_CHECK(
414 tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
415 "Expected all input tensors to have the same number of dimensions, but ",
416 "tensor at index ",
417 i,
418 "has ",
419 tensor.ndimension(),
420 " dimensions, (expected ",
421 expected_size.size(),
422 ")");
423 expected_size[dim] = tensor.size(dim);
424 for (const auto dimension : c10::irange(expected_size.size())) {
425 TORCH_CHECK(
426 expected_size[dimension] == tensor.size(dimension),
427 "Input tensor at index ",
428 i,
429 " has invalid shape ",
430 tensor.sizes(),
431 ", but expected ",
432 at::IntArrayRef(expected_size));
433 }
434 total_size += tensor.size(dim);
435 }
436 expected_size[dim] = total_size;
437 TORCH_CHECK(
438 out_tensor.sizes() == expected_size,
439 "Expected out tensor to have shape ",
440 at::IntArrayRef(expected_size),
441 ", but got ",
442 out_tensor.sizes())
443
444 return _gather_out_impl(tensors, out_tensor, dim);
445}
446
447at::Tensor gather(
448 at::TensorList tensors,
449 int64_t dim,
450 c10::optional<int32_t> destination_index) {
451 TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from");
452 int64_t total_size = 0;
453 auto& first = tensors.front();
454 const auto first_size = first.sizes();
455 dim = at::maybe_wrap_dim(dim, first);
456 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
457 std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
458 auto memory_format = first.suggest_memory_format();
459 for (const auto i : c10::irange(tensors.size())) {
460 const auto& tensor = tensors[i];
461 TORCH_CHECK(
462 tensor.is_cuda(),
463 "Expected all input tensors to be CUDA tensors, but "
464 "tensor at index ",
465 i,
466 " has device ",
467 tensor.device());
468 TORCH_CHECK(
469 tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
470 "Expected all input tensors to have the same number of dimensions, but ",
471 "tensor at index ",
472 i,
473 "has ",
474 tensor.ndimension(),
475 " dimensions, (expected ",
476 expected_size.size(),
477 ")");
478 expected_size[dim] = tensor.size(dim);
479 for (const auto dimension : c10::irange(expected_size.size())) {
480 TORCH_CHECK(
481 expected_size[dimension] == tensor.size(dimension),
482 "Input tensor at index ",
483 i,
484 " has invalid shape ",
485 tensor.sizes(),
486 ", but expected ",
487 at::IntArrayRef(expected_size));
488 }
489 total_size += tensor.size(dim);
490 if (memory_format != MemoryFormat::Contiguous &&
491 tensor.suggest_memory_format() != memory_format) {
492 memory_format = MemoryFormat::Contiguous;
493 }
494 }
495 expected_size[dim] = total_size;
496 at::Device device(DeviceType::CPU);
497 if (!destination_index || *destination_index != -1) {
498 device = at::Device(
499 DeviceType::CUDA, destination_index ? *destination_index : -1);
500 }
501
502 at::Tensor result =
503 at::empty(expected_size, first.options().device(device), memory_format);
504 return _gather_out_impl(tensors, result, dim);
505}
506
507} // namespace cuda
508} // namespace torch
509