1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/cuda/ATenCUDAGeneral.h> |
5 | #include <ATen/cuda/CUDAContext.h> |
6 | #include <c10/util/Optional.h> |
7 | #include <torch/csrc/Export.h> |
8 | |
9 | #include <cstddef> |
10 | #include <vector> |
11 | |
12 | namespace torch { |
13 | namespace cuda { |
14 | |
15 | using tensor_list2d = std::vector<std::vector<at::Tensor>>; |
16 | |
17 | TORCH_CUDA_CU_API std::vector<at::Tensor>& broadcast_out( |
18 | const at::Tensor& tensor, |
19 | std::vector<at::Tensor>& out_tensors); |
20 | TORCH_CUDA_CU_API std::vector<at::Tensor> broadcast( |
21 | const at::Tensor& tensor, |
22 | at::IntArrayRef devices); |
23 | TORCH_CUDA_CU_API tensor_list2d broadcast_coalesced( |
24 | at::TensorList tensors, |
25 | at::IntArrayRef devices, |
26 | size_t buffer_size); |
27 | |
28 | TORCH_CUDA_CU_API std::vector<at::Tensor>& scatter_out( |
29 | const at::Tensor& tensor, |
30 | std::vector<at::Tensor>& out_tensors, |
31 | int64_t dim = 0, |
32 | const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& |
33 | streams = c10::nullopt); |
34 | |
35 | TORCH_CUDA_CU_API std::vector<at::Tensor> scatter( |
36 | const at::Tensor& tensor, |
37 | at::IntArrayRef devices, |
38 | const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt, |
39 | int64_t dim = 0, |
40 | const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& |
41 | streams = c10::nullopt); |
42 | |
43 | TORCH_CUDA_CU_API at::Tensor& gather_out( |
44 | at::TensorList tensors, |
45 | at::Tensor& out_tensor, |
46 | int64_t dim); |
47 | |
48 | TORCH_CUDA_CU_API at::Tensor gather( |
49 | at::TensorList tensors, |
50 | int64_t dim, |
51 | c10::optional<int32_t> destination_index); |
52 | } // namespace cuda |
53 | } // namespace torch |
54 | |