1 | #include <torch/csrc/autograd/autograd.h> |
2 | #include <torch/csrc/autograd/variable.h> |
3 | |
4 | #ifndef AT_PER_OPERATOR_HEADERS |
5 | #include <ATen/Functions.h> |
6 | #else |
7 | #include <ATen/ops/ones_like.h> |
8 | #endif |
9 | |
10 | #include <torch/csrc/autograd/edge.h> |
11 | #include <torch/csrc/autograd/engine.h> |
12 | #include <torch/csrc/autograd/function.h> |
13 | #include <torch/csrc/autograd/functions/basic_ops.h> |
14 | |
15 | #include <c10/util/irange.h> |
16 | |
17 | namespace torch { |
18 | namespace autograd { |
19 | |
20 | // NB: This code duplicates existing logic at torch/autograd/__init__.py and |
21 | // torch._C._EngineBase.run_backward in torch/csrc/autograd/python_engine.cpp |
22 | // This is a purely C++ API for Autograd without any dependencies on python |
23 | // it can be exposed in PyTorch C++ API and TorchScript. We will need to |
24 | // maintain the logic equality of this file and the python file together if one |
25 | // changes. |
26 | // TODO: Make the Python API above to just call this C++ API. |
27 | variable_list _make_grads( |
28 | const variable_list& outputs, |
29 | const variable_list& grad_outputs) { |
30 | size_t num_tensors = outputs.size(); |
31 | size_t num_gradients = grad_outputs.size(); |
32 | variable_list new_grads; |
33 | new_grads.reserve(num_tensors); |
34 | if (grad_outputs.empty()) { |
35 | for (const Variable& output : outputs) { |
36 | if (output.requires_grad()) { |
37 | TORCH_CHECK( |
38 | output.numel() == 1, |
39 | "grad can be implicitly created only for scalar outputs" ); |
40 | new_grads.emplace_back( |
41 | at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); |
42 | } |
43 | } |
44 | } else { |
45 | TORCH_CHECK( |
46 | num_tensors == num_gradients, |
47 | "got " , |
48 | num_tensors, |
49 | " tensors and " , |
50 | num_gradients, |
51 | " gradients" ); |
52 | for (const auto i : c10::irange(outputs.size())) { |
53 | const Variable& output = outputs[i]; |
54 | const Variable& grad_output = grad_outputs[i]; |
55 | if (!grad_output.defined()) { |
56 | if (output.requires_grad()) { |
57 | TORCH_CHECK( |
58 | output.numel() == 1, |
59 | "grad can be implicitly created only for scalar outputs" ); |
60 | new_grads.emplace_back( |
61 | at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); |
62 | } |
63 | } else { |
64 | TORCH_CHECK( |
65 | grad_output.is_complex() == output.is_complex(), |
66 | "For complex Tensors, both grad_output and output are required " , |
67 | "to have the same dtype. Mismatch in dtype: grad_output[" , |
68 | grad_output, |
69 | "] has a dtype of " , |
70 | grad_output.scalar_type(), |
71 | " and output[" , |
72 | output, |
73 | "] has a dtype of " , |
74 | output.scalar_type(), |
75 | "." ); |
76 | // grad output is defined, just append to the new_grads |
77 | new_grads.emplace_back(grad_output); |
78 | } |
79 | } |
80 | } |
81 | return new_grads; |
82 | } |
83 | variable_list run_backward( |
84 | const variable_list& outputs, |
85 | const variable_list& grad_outputs, |
86 | bool keep_graph, |
87 | bool create_graph, |
88 | const variable_list& inputs, |
89 | bool allow_unused, |
90 | bool accumulate_grad) { |
91 | size_t num_tensors = outputs.size(); |
92 | edge_list roots; |
93 | roots.reserve(num_tensors); |
94 | for (const auto i : c10::irange(num_tensors)) { |
95 | const Variable& output = outputs[i]; |
96 | auto gradient_edge = impl::gradient_edge(output); |
97 | TORCH_CHECK( |
98 | gradient_edge.function, |
99 | "element " , |
100 | i, |
101 | " of tensors does not require grad and does not have a grad_fn" ); |
102 | roots.push_back(std::move(gradient_edge)); |
103 | } |
104 | |
105 | edge_list output_edges; |
106 | if (!inputs.empty()) { |
107 | size_t num_inputs = inputs.size(); |
108 | output_edges.reserve(num_inputs); |
109 | for (const auto i : c10::irange(num_inputs)) { |
110 | const Variable& input = inputs[i]; |
111 | const auto output_nr = input.output_nr(); |
112 | auto grad_fn = input.grad_fn(); |
113 | if (!grad_fn) { |
114 | grad_fn = impl::try_get_grad_accumulator(input); |
115 | } |
116 | if (accumulate_grad) { |
117 | input.retain_grad(); |
118 | } |
119 | TORCH_CHECK( |
120 | input.requires_grad(), |
121 | "One of the differentiated Tensors does not require grad" ); |
122 | if (!grad_fn) { |
123 | // See NOTE [ Autograd Unreachable Input ] for details |
124 | output_edges.emplace_back(std::make_shared<Identity>(), 0); |
125 | } else { |
126 | output_edges.emplace_back(grad_fn, output_nr); |
127 | } |
128 | } |
129 | } |
130 | |
131 | variable_list grad_inputs = Engine::get_default_engine().execute( |
132 | roots, |
133 | grad_outputs, |
134 | keep_graph, |
135 | create_graph, |
136 | accumulate_grad, |
137 | output_edges); |
138 | // check if grad_inputs contains None or not base on the allow_unused flag |
139 | if (!inputs.empty() && !allow_unused) { |
140 | size_t num_inputs = inputs.size(); |
141 | for (const auto i : c10::irange(num_inputs)) { |
142 | TORCH_CHECK( |
143 | grad_inputs[i].defined(), |
144 | "One of the " |
145 | "differentiated Tensors appears to not have been used " |
146 | "in the graph. Set allow_unused=True if this is the " |
147 | "desired behavior." ); |
148 | } |
149 | } |
150 | return grad_inputs; |
151 | } |
152 | |
153 | void backward( |
154 | const variable_list& tensors, |
155 | const variable_list& grad_tensors, |
156 | c10::optional<bool> retain_graph, |
157 | bool create_graph, |
158 | const variable_list& inputs) { |
159 | variable_list gradients = _make_grads(tensors, grad_tensors); |
160 | if (!retain_graph) { |
161 | retain_graph = create_graph; |
162 | } |
163 | run_backward( |
164 | tensors, |
165 | gradients, |
166 | retain_graph.value(), |
167 | create_graph, |
168 | inputs, |
169 | /*allow_unused=*/true, |
170 | /*accumulate_grad=*/true); |
171 | } |
172 | |
173 | variable_list grad( |
174 | const variable_list& outputs, |
175 | const variable_list& inputs, |
176 | const variable_list& grad_outputs, |
177 | c10::optional<bool> retain_graph, |
178 | bool create_graph, |
179 | bool allow_unused) { |
180 | variable_list gradients = _make_grads(outputs, grad_outputs); |
181 | if (!retain_graph) { |
182 | retain_graph = create_graph; |
183 | } |
184 | return run_backward( |
185 | outputs, |
186 | gradients, |
187 | retain_graph.value(), |
188 | create_graph, |
189 | inputs, |
190 | allow_unused, |
191 | /*accumulate_grad=*/false); |
192 | } |
193 | |
194 | namespace forward_ad { |
195 | |
196 | uint64_t enter_dual_level() { |
197 | return ForwardADLevel::get_next_idx(); |
198 | } |
199 | |
200 | void exit_dual_level(uint64_t level) { |
201 | ForwardADLevel::release_idx(level); |
202 | } |
203 | |
204 | } // namespace forward_ad |
205 | |
206 | } // namespace autograd |
207 | } // namespace torch |
208 | |