1#include <torch/csrc/autograd/functions/comm.h>
2
3#include <ATen/core/functional.h>
4#include <torch/csrc/autograd/function.h>
5#include <torch/csrc/autograd/functions/utils.h>
6#include <torch/csrc/autograd/variable.h>
7#include <torch/csrc/cuda/comm.h>
8
9#include <ATen/ATen.h>
10#include <ATen/cuda/CUDAContext.h>
11#include <c10/util/Optional.h>
12
13#include <cstddef>
14#include <memory>
15#include <vector>
16
17namespace torch {
18namespace autograd {
19Scatter::Scatter(
20 std::vector<at::Device> devices,
21 c10::optional<std::vector<int64_t>> chunk_sizes,
22 int64_t dim,
23 c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>> streams,
24 bool unsqueeze_scalars)
25 : devices_(std::move(devices)),
26 chunk_sizes_(std::move(chunk_sizes)),
27 dim_(dim),
28 streams_(std::move(streams)),
29 unsqueeze_scalars_(unsqueeze_scalars) {}
30
31Scatter::~Scatter() = default;
32
33variable_list Scatter::apply(variable_list&& inputs) {
34 AT_ASSERT(inputs.size() == 1);
35 auto& input = inputs.front();
36
37 std::shared_ptr<Node> grad_fn;
38 if (compute_requires_grad(input)) {
39 grad_fn =
40 std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
41 grad_fn->set_next_edges(collect_next_edges(input));
42 }
43
44 auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
45 return device.index();
46 });
47 auto tensors = torch::cuda::scatter(
48 std::move(input), device_indices, chunk_sizes_, dim_, streams_);
49
50 std::vector<Variable> variables;
51 variables.reserve(tensors.size());
52 for (auto& tensor : tensors) {
53 AT_ASSERT(tensor.defined());
54 if (unsqueeze_scalars_) {
55 AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
56 variables.push_back(tensor[0]);
57 } else {
58 variables.push_back(std::move(tensor));
59 }
60 }
61
62 if (grad_fn) {
63 set_history(variables, grad_fn);
64 }
65
66 return variables;
67}
68
69Gather::Gather(const at::Device& destination_device, int64_t dim)
70 : destination_device_(destination_device), dim_(dim) {}
71
72Gather::~Gather() = default;
73
74variable_list Gather::apply(variable_list&& inputs) {
75 bool all_are_zero_dim = true;
76 for (const auto& input : inputs) {
77 TORCH_CHECK(
78 input.is_cuda(),
79 "All inputs to Gather must be CUDA tensors, got ",
80 input.toString());
81 if (input.dim() > 0) {
82 all_are_zero_dim = false;
83 }
84 }
85
86 const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
87 if (unsqueeze_scalars) {
88 TORCH_WARN(
89 "Was asked to gather along dimension 0, but all "
90 "input tensors were scalars; will instead unsqueeze "
91 "and return a vector.");
92 }
93
94 std::shared_ptr<Node> grad_fn;
95 // compute this before moving variables from `inputs`
96 if (compute_requires_grad(inputs)) {
97 std::vector<at::Device> source_devices;
98 source_devices.reserve(inputs.size());
99 std::vector<int64_t> input_sizes;
100 input_sizes.reserve(inputs.size());
101 for (auto& input : inputs) {
102 source_devices.push_back(input.device());
103 input_sizes.push_back(input.size(dim_));
104 }
105 grad_fn = std::make_shared<Scatter>(
106 std::move(source_devices),
107 std::move(input_sizes),
108 dim_,
109 /*streams=*/c10::nullopt,
110 /*unsqueeze_scalars=*/unsqueeze_scalars);
111 grad_fn->set_next_edges(collect_next_edges(inputs));
112 }
113
114 std::vector<at::Tensor> tensors;
115 tensors.reserve(inputs.size());
116 for (auto& variable : inputs) {
117 if (unsqueeze_scalars) {
118 tensors.push_back(variable.view(1));
119 } else {
120 tensors.push_back(std::move(variable));
121 }
122 }
123
124 // Disable the autograd during the actual computation
125 // torch::cuda::gather does not return a view or change things inplace
126 // so no need for extra logic here
127 at::Tensor variable;
128 {
129 at::AutoDispatchBelowAutograd mode;
130 // This is special logic for torch::cuda::gather!
131 const auto destination_index =
132 destination_device_.is_cpu() ? -1 : destination_device_.index();
133 variable = torch::cuda::gather(tensors, dim_, destination_index);
134 }
135 if (grad_fn) {
136 set_history(variable, grad_fn);
137 }
138 return {variable};
139}
140
141} // namespace autograd
142} // namespace torch
143