1 | #include <torch/csrc/cuda/comm.h> |
2 | |
3 | #include <torch/csrc/cuda/device_set.h> |
4 | #include <torch/csrc/utils/tensor_flatten.h> |
5 | |
6 | #ifdef USE_NCCL |
7 | #include <torch/csrc/cuda/nccl.h> |
8 | #endif |
9 | |
10 | #include <ATen/ATen.h> |
11 | #include <ATen/WrapDimUtils.h> |
12 | #include <ATen/cuda/CUDAContext.h> |
13 | #include <c10/cuda/CUDAGuard.h> |
14 | #include <c10/util/Optional.h> |
15 | #include <c10/util/irange.h> |
16 | #include <torch/csrc/autograd/variable.h> |
17 | |
18 | #include <cstddef> |
19 | #include <vector> |
20 | |
21 | namespace torch { |
22 | namespace cuda { |
23 | using namespace at; |
24 | using namespace torch::autograd; |
25 | |
26 | // Some operations can be performed more efficiently if we're handling tensors |
27 | // of a single type only. Adding this logic directly in the loop makes it a bit |
28 | // ugly, so here's a helper for it. |
29 | struct unique_type_checker { |
30 | void show(size_t type_id) { |
31 | if (!unique) { |
32 | return; |
33 | } |
34 | if (!type_id_) { |
35 | type_id_ = type_id; |
36 | } |
37 | |
38 | unique = type_id_.value() == type_id; |
39 | } |
40 | |
41 | c10::optional<size_t> type_id_; |
42 | bool unique = true; |
43 | }; |
44 | |
45 | // ***************** Broadcast ******************* |
46 | // |
47 | // Broadcast a source tensor (CPU or CUDA) to a list of CUDA devices, or CUDA |
48 | // tensors on one or more devices. |
49 | |
50 | // no checks |
51 | static inline std::vector<Tensor>& _broadcast_out_impl( |
52 | const Tensor& tensor, |
53 | std::vector<Tensor>& out_tensors) { |
54 | #ifdef USE_NCCL |
55 | std::vector<Tensor> nccl_list; |
56 | nccl_list.reserve(out_tensors.size() + 1); |
57 | nccl_list.emplace_back(tensor); |
58 | for (auto& out_tensor : out_tensors) { |
59 | nccl_list.emplace_back(out_tensor); |
60 | } |
61 | if (nccl::is_available(nccl_list)) { |
62 | nccl::broadcast(nccl_list); |
63 | } else { |
64 | #else |
65 | { |
66 | #endif |
67 | for (auto& out_tensor : out_tensors) { |
68 | out_tensor.copy_(tensor, /*non_blocking=*/true); |
69 | } |
70 | } |
71 | return out_tensors; |
72 | } |
73 | |
74 | std::vector<Tensor>& broadcast_out( |
75 | const Tensor& tensor, |
76 | std::vector<Tensor>& out_tensors) { |
77 | for (const auto i : c10::irange(out_tensors.size())) { |
78 | TORCH_CHECK( |
79 | out_tensors[i].is_cuda(), |
80 | "Expected all output tensors to be CUDA tensors, but output tensor at index " , |
81 | i, |
82 | " has device '" , |
83 | out_tensors[i].device(), |
84 | "'" ); |
85 | TORCH_CHECK( |
86 | out_tensors[i].sizes() == tensor.sizes(), |
87 | "Expected all output tensors to have same shape as the source tensor " , |
88 | tensor.sizes(), |
89 | ", but output tensor at index " , |
90 | i, |
91 | " has shape " , |
92 | out_tensors[i].sizes()); |
93 | } |
94 | return _broadcast_out_impl(tensor, out_tensors); |
95 | } |
96 | |
97 | std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) { |
98 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
99 | std::vector<Tensor> diff_device_dst_tensors; |
100 | diff_device_dst_tensors.reserve(devices.size()); |
101 | for (auto device : devices) { |
102 | TORCH_CHECK( |
103 | device >= 0, "Expected non-negative device index, but got " , device); |
104 | if (device != tensor.get_device()) { |
105 | diff_device_dst_tensors.emplace_back(at::empty( |
106 | tensor.sizes(), |
107 | tensor.options().device( |
108 | at::Device(DeviceType::CUDA, device)))); // preserve memory format |
109 | } |
110 | } |
111 | _broadcast_out_impl(tensor, diff_device_dst_tensors); |
112 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
113 | std::vector<Tensor> dst_tensors; |
114 | dst_tensors.reserve(devices.size()); |
115 | auto it = diff_device_dst_tensors.begin(); |
116 | for (auto device : devices) { |
117 | // NOLINTNEXTLINE(bugprone-branch-clone) |
118 | if (device != tensor.get_device()) { |
119 | dst_tensors.emplace_back(*it++); |
120 | } else { |
121 | dst_tensors.emplace_back(tensor); |
122 | } |
123 | } |
124 | TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end()); |
125 | return dst_tensors; |
126 | } |
127 | |
128 | // NOTE [ Version Counter in comm.*_coalesced ] |
129 | // |
130 | // broadcast_coalesced |
131 | // ~~~~~~~~~~~~~~~~~~~ |
132 | // |
133 | // In broadcast_coalesced, multiple variables may be coalesced into a single |
134 | // large one, broadcast to other devices, and the get split according to the |
135 | // original shapes. |
136 | // |
137 | // When splitting, the view operations will make all Variables broadcast |
138 | // together to share a single version counter, because they are all views of the |
139 | // large Variable. However, that large Variable is immediately discarded and all |
140 | // these Variables do not share storage at all. |
141 | // |
142 | // For example, when two buffers are broadcast together in `DataParallel` and |
143 | // one of them is modified in-place during `forward` but the other is needed in |
144 | // backward, autograd engine will complain. |
145 | // |
146 | // We thus re-wrap these Variables after broadcasting (i.e., effectively doing |
147 | // what is equivalent to .data in Python), and give them individual version |
148 | // counters. |
149 | // |
150 | // NB: Just calling detach() on the variables is not sufficient |
151 | // |
152 | // NB: For `device[0]` in broadcast_coalesced, the input Variables are always |
153 | // returned as-is, so **do not** re-wrap them. |
154 | // |
155 | // reduce_add_coalesced |
156 | // ~~~~~~~~~~~~~~~~~~~~ |
157 | // |
158 | // Similarly for reduce_add_coalesced, when the output are newly created |
159 | // Variables. |
160 | tensor_list2d broadcast_coalesced( |
161 | TensorList tensors, |
162 | IntArrayRef devices, |
163 | size_t buffer_size) { |
164 | TORCH_CHECK( |
165 | std::all_of( |
166 | tensors.begin(), |
167 | tensors.end(), |
168 | [&](const at::Tensor& t) { return t.get_device() == devices[0]; }), |
169 | "All tensors must be on devices[0]: " , |
170 | devices[0]); |
171 | #ifdef USE_NCCL |
172 | buffer_size = std::min(torch::cuda::nccl::get_max_count(), buffer_size); |
173 | #endif |
174 | |
175 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
176 | tensor_list2d outputs(devices.size()); |
177 | outputs[0] = tensors.vec(); |
178 | for (auto& o : outputs) |
179 | o.reserve(tensors.size()); |
180 | |
181 | unique_type_checker type_checker; |
182 | at::cuda::CUDAGuard device_guard(devices[0]); |
183 | for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) { |
184 | auto type_id = chunk.type_id(); |
185 | type_checker.show(type_id); |
186 | std::vector<at::Tensor> results; |
187 | if (chunk.options().is_sparse()) { |
188 | auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors); |
189 | auto broadcast_indices = broadcast(flat_tuple.first, devices); |
190 | auto broadcast_values = broadcast(flat_tuple.second, devices); |
191 | results.reserve(devices.size()); |
192 | for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { |
193 | device_guard.set_index(devices[i]); |
194 | auto& device_outputs = outputs[i]; |
195 | auto& inds = broadcast_indices[i]; |
196 | auto& vals = broadcast_values[i]; |
197 | for (const auto& var : torch::utils::unflatten_sparse_tensors( |
198 | inds, vals, chunk.tensors)) { |
199 | // See NOTE [ Version Counter in comm.*_coalesced ] |
200 | device_outputs.emplace_back(make_variable(var.tensor_data(), false)); |
201 | } |
202 | } |
203 | } else { |
204 | auto results = broadcast( |
205 | torch::utils::flatten_dense_tensors(chunk.tensors), devices); |
206 | for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { |
207 | device_guard.set_index(devices[i]); |
208 | auto& device_outputs = outputs[i]; |
209 | for (auto& var : |
210 | torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) { |
211 | // See NOTE [ Version Counter in comm.*_coalesced ] |
212 | device_outputs.emplace_back(make_variable(var.tensor_data(), false)); |
213 | } |
214 | } |
215 | } |
216 | } |
217 | |
218 | // If we only saw a single tensor type, then we can skip expensive reordering |
219 | if (!type_checker.unique) { |
220 | for (auto& o : outputs) |
221 | torch::utils::reorder_tensors_like(o, tensors); |
222 | } |
223 | return outputs; |
224 | } |
225 | |
226 | // ***************** Scatter ******************* |
227 | // |
228 | // Scatter a source tensor (CPU or CUDA) to a list of CUDA tensors on one or |
229 | // more devices. |
230 | |
231 | std::vector<at::Tensor>& scatter_out( |
232 | const at::Tensor& tensor, |
233 | std::vector<at::Tensor>& out_tensors, |
234 | int64_t dim, |
235 | const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& |
236 | streams) { |
237 | TORCH_CHECK( |
238 | !out_tensors.empty(), |
239 | "Expected at least one output tensor to scatter to" ); |
240 | dim = at::maybe_wrap_dim(dim, tensor); |
241 | int64_t total_size = 0; |
242 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
243 | std::vector<int64_t> chunk_sizes; |
244 | chunk_sizes.reserve(out_tensors.size()); |
245 | for (const auto i : c10::irange(out_tensors.size())) { |
246 | TORCH_CHECK( |
247 | out_tensors[i].is_cuda(), |
248 | "Expected all output tensors to be CUDA tensors, but output tensor at index " , |
249 | i, |
250 | " has device '" , |
251 | out_tensors[i].device(), |
252 | "'" ); |
253 | auto out_sizes = out_tensors[i].sizes().vec(); |
254 | // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
255 | bool same_ndim = out_sizes.size() == tensor.dim(); |
256 | if (same_ndim) { |
257 | total_size += out_sizes[dim]; |
258 | chunk_sizes.emplace_back(out_sizes[dim]); |
259 | out_sizes[dim] = tensor.size(dim); |
260 | } |
261 | TORCH_CHECK( |
262 | same_ndim && out_sizes == tensor.sizes(), |
263 | "Output tensor at index " , |
264 | i, |
265 | " has incorrect shape: " , |
266 | out_tensors[i].sizes(), |
267 | ". Expected same " |
268 | "shape except for scatter dim " , |
269 | dim, |
270 | " as the source tensor: " , |
271 | at::IntArrayRef(tensor.sizes())); |
272 | } |
273 | TORCH_CHECK( |
274 | total_size == tensor.size(dim), |
275 | "Total size for output tensors along scatter dim " , |
276 | dim, |
277 | " does not match " |
278 | "the source tensor size at dim " , |
279 | dim, |
280 | ". Expected " , |
281 | tensor.size(dim), |
282 | ", but got total size " , |
283 | total_size); |
284 | |
285 | auto chunks = |
286 | tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim); |
287 | at::cuda::OptionalCUDAStreamGuard cuda_guard; |
288 | for (const auto i : c10::irange(chunks.size())) { |
289 | if (i < (streams ? streams->size() : 0U) && (*streams)[i]) { |
290 | const auto device_index = |
291 | static_cast<int16_t>(out_tensors[i].get_device()); |
292 | TORCH_CHECK( |
293 | (*streams)[i]->device_index() == device_index, |
294 | "Expected the device associated with the stream at index " , |
295 | i, |
296 | " (was " , |
297 | (*streams)[i]->device_index(), |
298 | ") " , |
299 | "to match the device supplied at that index " , |
300 | "(expected " , |
301 | device_index, |
302 | ")" ); |
303 | cuda_guard.reset_stream(*(*streams)[i]); |
304 | } |
305 | // NB: We don't detect the case where `out_tensor` is already the correct |
306 | // view of `tensor` since that would be nontrivial and involve checking |
307 | // ptr, offset, and strides. So `scatter_out(src, src.chunk(...))` does |
308 | // more copying than `scatter(src)`. |
309 | out_tensors[i].copy_(chunks[i], /*non_blocking=*/true); |
310 | } |
311 | return out_tensors; |
312 | } |
313 | |
314 | std::vector<at::Tensor> scatter( |
315 | const at::Tensor& tensor, |
316 | at::IntArrayRef devices, |
317 | const c10::optional<std::vector<int64_t>>& chunk_sizes, |
318 | int64_t dim, |
319 | const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& |
320 | streams) { |
321 | TORCH_CHECK(!devices.empty(), "Expected at least one device to scatter to" ); |
322 | if (chunk_sizes.has_value()) { |
323 | TORCH_CHECK( |
324 | chunk_sizes->size() == devices.size(), |
325 | "Expected devices and chunk_sizes to be of same length, but got " |
326 | "len(devices) = " , |
327 | devices.size(), |
328 | " and len(chunk_sizes) = " , |
329 | chunk_sizes->size()); |
330 | } |
331 | dim = at::maybe_wrap_dim(dim, tensor); |
332 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
333 | std::vector<at::Tensor> chunks = chunk_sizes |
334 | ? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim) |
335 | : tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim); |
336 | at::cuda::OptionalCUDAStreamGuard cuda_guard; |
337 | for (const auto i : c10::irange(chunks.size())) { |
338 | const auto device_index = static_cast<int16_t>(devices[i]); |
339 | if (device_index != tensor.get_device()) { |
340 | if (i < (streams ? streams->size() : 0U) && (*streams)[i]) { |
341 | TORCH_CHECK( |
342 | (*streams)[i]->device_index() == device_index, |
343 | "Expected the device associated with the stream at index " , |
344 | i, |
345 | " (was " , |
346 | (*streams)[i]->device_index(), |
347 | ") " , |
348 | "to match the device supplied at that index " , |
349 | "(expected " , |
350 | device_index, |
351 | ")" ); |
352 | cuda_guard.reset_stream(*(*streams)[i]); |
353 | } |
354 | TORCH_CHECK( |
355 | device_index >= 0, |
356 | "Expected non-negative device index, but got " , |
357 | device_index); |
358 | chunks[i] = chunks[i].to( |
359 | {DeviceType::CUDA, device_index}, |
360 | /*non_blocking=*/true, |
361 | /*copy=*/false, |
362 | /*memory_format=*/at::MemoryFormat::Preserve); |
363 | } |
364 | } |
365 | return chunks; |
366 | } |
367 | |
368 | // ***************** Gather ******************* |
369 | // |
370 | // Gather a list of CUDA tensors on one or more devices to a target tensor or |
371 | // device, either CPU or CUDA. |
372 | |
373 | // no checks |
374 | static inline at::Tensor& _gather_out_impl( |
375 | at::TensorList tensors, |
376 | at::Tensor& out_tensor, |
377 | int64_t dim) { |
378 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
379 | std::vector<int64_t> chunk_sizes; |
380 | chunk_sizes.reserve(tensors.size()); |
381 | for (auto& tensor : tensors) { |
382 | chunk_sizes.emplace_back(tensor.size(dim)); |
383 | } |
384 | auto chunks = |
385 | out_tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim); |
386 | for (const auto i : c10::irange(tensors.size())) { |
387 | chunks[i].copy_(tensors[i], /*non_blocking=*/out_tensor.is_cuda()); |
388 | } |
389 | return out_tensor; |
390 | } |
391 | |
392 | at::Tensor& gather_out( |
393 | at::TensorList tensors, |
394 | at::Tensor& out_tensor, |
395 | int64_t dim) { |
396 | TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from" ); |
397 | int64_t total_size = 0; |
398 | auto& first = tensors.front(); |
399 | const auto first_size = first.sizes(); |
400 | dim = at::maybe_wrap_dim(dim, first); |
401 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
402 | std::vector<int64_t> expected_size(first_size.begin(), first_size.end()); |
403 | for (const auto i : c10::irange(tensors.size())) { |
404 | const auto& tensor = tensors[i]; |
405 | TORCH_CHECK( |
406 | tensor.is_cuda(), |
407 | "Expected all input tensors to be CUDA tensors, but " |
408 | "tensor at index " , |
409 | i, |
410 | " has device '" , |
411 | tensor.device(), |
412 | "'" ); |
413 | TORCH_CHECK( |
414 | tensor.ndimension() == static_cast<int64_t>(expected_size.size()), |
415 | "Expected all input tensors to have the same number of dimensions, but " , |
416 | "tensor at index " , |
417 | i, |
418 | "has " , |
419 | tensor.ndimension(), |
420 | " dimensions, (expected " , |
421 | expected_size.size(), |
422 | ")" ); |
423 | expected_size[dim] = tensor.size(dim); |
424 | for (const auto dimension : c10::irange(expected_size.size())) { |
425 | TORCH_CHECK( |
426 | expected_size[dimension] == tensor.size(dimension), |
427 | "Input tensor at index " , |
428 | i, |
429 | " has invalid shape " , |
430 | tensor.sizes(), |
431 | ", but expected " , |
432 | at::IntArrayRef(expected_size)); |
433 | } |
434 | total_size += tensor.size(dim); |
435 | } |
436 | expected_size[dim] = total_size; |
437 | TORCH_CHECK( |
438 | out_tensor.sizes() == expected_size, |
439 | "Expected out tensor to have shape " , |
440 | at::IntArrayRef(expected_size), |
441 | ", but got " , |
442 | out_tensor.sizes()) |
443 | |
444 | return _gather_out_impl(tensors, out_tensor, dim); |
445 | } |
446 | |
447 | at::Tensor gather( |
448 | at::TensorList tensors, |
449 | int64_t dim, |
450 | c10::optional<int32_t> destination_index) { |
451 | TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from" ); |
452 | int64_t total_size = 0; |
453 | auto& first = tensors.front(); |
454 | const auto first_size = first.sizes(); |
455 | dim = at::maybe_wrap_dim(dim, first); |
456 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
457 | std::vector<int64_t> expected_size(first_size.begin(), first_size.end()); |
458 | auto memory_format = first.suggest_memory_format(); |
459 | for (const auto i : c10::irange(tensors.size())) { |
460 | const auto& tensor = tensors[i]; |
461 | TORCH_CHECK( |
462 | tensor.is_cuda(), |
463 | "Expected all input tensors to be CUDA tensors, but " |
464 | "tensor at index " , |
465 | i, |
466 | " has device " , |
467 | tensor.device()); |
468 | TORCH_CHECK( |
469 | tensor.ndimension() == static_cast<int64_t>(expected_size.size()), |
470 | "Expected all input tensors to have the same number of dimensions, but " , |
471 | "tensor at index " , |
472 | i, |
473 | "has " , |
474 | tensor.ndimension(), |
475 | " dimensions, (expected " , |
476 | expected_size.size(), |
477 | ")" ); |
478 | expected_size[dim] = tensor.size(dim); |
479 | for (const auto dimension : c10::irange(expected_size.size())) { |
480 | TORCH_CHECK( |
481 | expected_size[dimension] == tensor.size(dimension), |
482 | "Input tensor at index " , |
483 | i, |
484 | " has invalid shape " , |
485 | tensor.sizes(), |
486 | ", but expected " , |
487 | at::IntArrayRef(expected_size)); |
488 | } |
489 | total_size += tensor.size(dim); |
490 | if (memory_format != MemoryFormat::Contiguous && |
491 | tensor.suggest_memory_format() != memory_format) { |
492 | memory_format = MemoryFormat::Contiguous; |
493 | } |
494 | } |
495 | expected_size[dim] = total_size; |
496 | at::Device device(DeviceType::CPU); |
497 | if (!destination_index || *destination_index != -1) { |
498 | device = at::Device( |
499 | DeviceType::CUDA, destination_index ? *destination_index : -1); |
500 | } |
501 | |
502 | at::Tensor result = |
503 | at::empty(expected_size, first.options().device(device), memory_format); |
504 | return _gather_out_impl(tensors, result, dim); |
505 | } |
506 | |
507 | } // namespace cuda |
508 | } // namespace torch |
509 | |