1#pragma once
2
3#include <ATen/ATen.h>
4#include <ATen/cuda/ATenCUDAGeneral.h>
5#include <ATen/cuda/CUDAContext.h>
6#include <c10/util/Optional.h>
7#include <torch/csrc/Export.h>
8
9#include <cstddef>
10#include <vector>
11
12namespace torch {
13namespace cuda {
14
15using tensor_list2d = std::vector<std::vector<at::Tensor>>;
16
17TORCH_CUDA_CU_API std::vector<at::Tensor>& broadcast_out(
18 const at::Tensor& tensor,
19 std::vector<at::Tensor>& out_tensors);
20TORCH_CUDA_CU_API std::vector<at::Tensor> broadcast(
21 const at::Tensor& tensor,
22 at::IntArrayRef devices);
23TORCH_CUDA_CU_API tensor_list2d broadcast_coalesced(
24 at::TensorList tensors,
25 at::IntArrayRef devices,
26 size_t buffer_size);
27
28TORCH_CUDA_CU_API std::vector<at::Tensor>& scatter_out(
29 const at::Tensor& tensor,
30 std::vector<at::Tensor>& out_tensors,
31 int64_t dim = 0,
32 const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>&
33 streams = c10::nullopt);
34
35TORCH_CUDA_CU_API std::vector<at::Tensor> scatter(
36 const at::Tensor& tensor,
37 at::IntArrayRef devices,
38 const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
39 int64_t dim = 0,
40 const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>&
41 streams = c10::nullopt);
42
43TORCH_CUDA_CU_API at::Tensor& gather_out(
44 at::TensorList tensors,
45 at::Tensor& out_tensor,
46 int64_t dim);
47
48TORCH_CUDA_CU_API at::Tensor gather(
49 at::TensorList tensors,
50 int64_t dim,
51 c10::optional<int32_t> destination_index);
52} // namespace cuda
53} // namespace torch
54