1 | #include <torch/csrc/autograd/functions/comm.h> |
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <torch/csrc/autograd/function.h> |
5 | #include <torch/csrc/autograd/functions/utils.h> |
6 | #include <torch/csrc/autograd/variable.h> |
7 | #include <torch/csrc/cuda/comm.h> |
8 | |
9 | #include <ATen/ATen.h> |
10 | #include <ATen/cuda/CUDAContext.h> |
11 | #include <c10/util/Optional.h> |
12 | |
13 | #include <cstddef> |
14 | #include <memory> |
15 | #include <vector> |
16 | |
17 | namespace torch { |
18 | namespace autograd { |
19 | Scatter::Scatter( |
20 | std::vector<at::Device> devices, |
21 | c10::optional<std::vector<int64_t>> chunk_sizes, |
22 | int64_t dim, |
23 | c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>> streams, |
24 | bool unsqueeze_scalars) |
25 | : devices_(std::move(devices)), |
26 | chunk_sizes_(std::move(chunk_sizes)), |
27 | dim_(dim), |
28 | streams_(std::move(streams)), |
29 | unsqueeze_scalars_(unsqueeze_scalars) {} |
30 | |
31 | Scatter::~Scatter() = default; |
32 | |
33 | variable_list Scatter::apply(variable_list&& inputs) { |
34 | AT_ASSERT(inputs.size() == 1); |
35 | auto& input = inputs.front(); |
36 | |
37 | std::shared_ptr<Node> grad_fn; |
38 | if (compute_requires_grad(input)) { |
39 | grad_fn = |
40 | std::make_shared<Gather>(/*destination_device=*/input.device(), dim_); |
41 | grad_fn->set_next_edges(collect_next_edges(input)); |
42 | } |
43 | |
44 | auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t { |
45 | return device.index(); |
46 | }); |
47 | auto tensors = torch::cuda::scatter( |
48 | std::move(input), device_indices, chunk_sizes_, dim_, streams_); |
49 | |
50 | std::vector<Variable> variables; |
51 | variables.reserve(tensors.size()); |
52 | for (auto& tensor : tensors) { |
53 | AT_ASSERT(tensor.defined()); |
54 | if (unsqueeze_scalars_) { |
55 | AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1); |
56 | variables.push_back(tensor[0]); |
57 | } else { |
58 | variables.push_back(std::move(tensor)); |
59 | } |
60 | } |
61 | |
62 | if (grad_fn) { |
63 | set_history(variables, grad_fn); |
64 | } |
65 | |
66 | return variables; |
67 | } |
68 | |
69 | Gather::Gather(const at::Device& destination_device, int64_t dim) |
70 | : destination_device_(destination_device), dim_(dim) {} |
71 | |
72 | Gather::~Gather() = default; |
73 | |
74 | variable_list Gather::apply(variable_list&& inputs) { |
75 | bool all_are_zero_dim = true; |
76 | for (const auto& input : inputs) { |
77 | TORCH_CHECK( |
78 | input.is_cuda(), |
79 | "All inputs to Gather must be CUDA tensors, got " , |
80 | input.toString()); |
81 | if (input.dim() > 0) { |
82 | all_are_zero_dim = false; |
83 | } |
84 | } |
85 | |
86 | const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0; |
87 | if (unsqueeze_scalars) { |
88 | TORCH_WARN( |
89 | "Was asked to gather along dimension 0, but all " |
90 | "input tensors were scalars; will instead unsqueeze " |
91 | "and return a vector." ); |
92 | } |
93 | |
94 | std::shared_ptr<Node> grad_fn; |
95 | // compute this before moving variables from `inputs` |
96 | if (compute_requires_grad(inputs)) { |
97 | std::vector<at::Device> source_devices; |
98 | source_devices.reserve(inputs.size()); |
99 | std::vector<int64_t> input_sizes; |
100 | input_sizes.reserve(inputs.size()); |
101 | for (auto& input : inputs) { |
102 | source_devices.push_back(input.device()); |
103 | input_sizes.push_back(input.size(dim_)); |
104 | } |
105 | grad_fn = std::make_shared<Scatter>( |
106 | std::move(source_devices), |
107 | std::move(input_sizes), |
108 | dim_, |
109 | /*streams=*/c10::nullopt, |
110 | /*unsqueeze_scalars=*/unsqueeze_scalars); |
111 | grad_fn->set_next_edges(collect_next_edges(inputs)); |
112 | } |
113 | |
114 | std::vector<at::Tensor> tensors; |
115 | tensors.reserve(inputs.size()); |
116 | for (auto& variable : inputs) { |
117 | if (unsqueeze_scalars) { |
118 | tensors.push_back(variable.view(1)); |
119 | } else { |
120 | tensors.push_back(std::move(variable)); |
121 | } |
122 | } |
123 | |
124 | // Disable the autograd during the actual computation |
125 | // torch::cuda::gather does not return a view or change things inplace |
126 | // so no need for extra logic here |
127 | at::Tensor variable; |
128 | { |
129 | at::AutoDispatchBelowAutograd mode; |
130 | // This is special logic for torch::cuda::gather! |
131 | const auto destination_index = |
132 | destination_device_.is_cpu() ? -1 : destination_device_.index(); |
133 | variable = torch::cuda::gather(tensors, dim_, destination_index); |
134 | } |
135 | if (grad_fn) { |
136 | set_history(variable, grad_fn); |
137 | } |
138 | return {variable}; |
139 | } |
140 | |
141 | } // namespace autograd |
142 | } // namespace torch |
143 | |