1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <c10/util/flat_hash_map.h> |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/autograd/function.h> |
7 | #include <torch/csrc/autograd/variable.h> |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace autograd { |
12 | |
13 | using optional_variable_list = std::vector<c10::optional<Variable>>; |
14 | using _jvp_fn_t = std::function<variable_list(variable_list, variable_list)>; |
15 | |
16 | TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs( |
17 | const variable_list& input_vars, |
18 | const std::unordered_set<at::TensorImpl*>& non_differentiable, |
19 | const std::unordered_set<at::TensorImpl*>& dirty_inputs, |
20 | const at::ArrayRef<c10::optional<Variable>> raw_outputs, |
21 | const std::shared_ptr<Node>& cdata, |
22 | _jvp_fn_t jvp_user_function); |
23 | |
24 | TORCH_API void check_variable_result( |
25 | const at::TensorBase& original, |
26 | const at::TensorBase& result, |
27 | std::string hook_name); |
28 | |
29 | // Get the return type of the forward function of the custom Function class X |
30 | template <typename X, typename... Args> |
31 | using forward_t = decltype(X::forward(nullptr, std::declval<Args>()...)); |
32 | |
33 | /// To use custom autograd operations, implement a Function subclass with |
34 | /// static forward and backward functions: |
35 | /// |
36 | /// `forward` can take as many arguments as you want and should return either a |
37 | /// variable list or a Variable. Use of any direct Variable arguments will be |
38 | /// registered in the graph but no vectors/sets or any other data structures |
39 | /// will be traversed. You can use c10::optional<Tensor> as one of the arguments |
40 | /// and it will be registered as a variable in the graph if the argument has a |
41 | /// value. It should take a pointer to `torch::autograd::AutogradContext` as the |
42 | /// first argument. Variables can be saved in the `ctx` using |
43 | /// `ctx->save_for_backward` |
44 | /// (see `torch::autograd::AutogradContext::save_for_backward`) and other data |
45 | /// can be saved in the `ctx->saved_data` map |
46 | /// (see `torch::autograd::AutogradContext::saved_data`) |
47 | /// in the form of `<std::string, at::IValue>` pairs. |
48 | /// |
49 | /// `backward` should take a pointer to `torch::autograd::AutogradContext` |
50 | /// and a variable list containing as many Variables as there were outputs from |
51 | /// `forward` as arguments. It should return as many Variables as there were |
52 | /// inputs with each of them containing the gradient w.r.t. its corresponding |
53 | /// input. Variables saved in `forward` can be accessed with |
54 | /// `ctx->get_saved_variables` (see |
55 | /// `torch::autograd::AutogradContext::get_saved_variables`) and other saved |
56 | /// data can be accessed from `ctx->saved_data`. |
57 | /// |
58 | /// For example: |
59 | /// ``` |
60 | /// class MyFunction : public Function<MyFunction> { |
61 | /// public: |
62 | /// static variable_list forward(AutogradContext *ctx, int n, Variable var) { |
63 | /// // Save data for backward in context |
64 | /// ctx->saved_data["n"] = n; |
65 | /// var.mul_(2); |
66 | /// // Mark var as modified by inplace operation |
67 | /// ctx->mark_dirty({var}); |
68 | /// return {var}; |
69 | /// } |
70 | /// |
71 | /// static variable_list backward(AutogradContext *ctx, variable_list |
72 | /// grad_output) { |
73 | /// // Use data saved in forward |
74 | /// auto n = ctx->saved_data["n"].toInt(); |
75 | /// return {grad_output[0]*n}; |
76 | /// } |
77 | /// }; |
78 | /// ``` |
79 | /// |
80 | /// To use `MyFunction`: |
81 | /// ``` |
82 | /// Variable x; |
83 | /// auto y = MyFunction::apply(6, x); |
84 | /// // Example backward call |
85 | /// y[0].sum().backward(); |
86 | /// ``` |
87 | template <class T> |
88 | struct TORCH_API Function { |
89 | // We need to use a different template parameter than T here because T will |
90 | // inherit from Function, and when Function<T> is instantiated, T::forward |
91 | // is not declared yet. |
92 | // The enable_if check is to ensure that the user doesn't explicitly provide |
93 | // the parameter X. |
94 | template <typename X = T, typename... Args> |
95 | static auto apply(Args&&... args) |
96 | -> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>>; |
97 | }; |
98 | |
99 | /// Context to save information during `forward` that can be accessed in |
100 | /// `backward` in custom autograd operations (see `torch::autograd::Function` |
101 | /// for details). |
102 | struct TORCH_API AutogradContext { |
103 | AutogradContext() = default; |
104 | AutogradContext(const AutogradContext& other) = delete; |
105 | AutogradContext& operator=(const AutogradContext& other) = delete; |
106 | |
107 | /// Can be used to save non-variable data for `backward`. |
108 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
109 | ska::flat_hash_map<std::string, at::IValue> saved_data; |
110 | |
111 | /// Saves the list of variables for a future call to `backward`. This |
112 | /// should be called at most once from inside of `forward`. |
113 | void save_for_backward(variable_list to_save); |
114 | /// Marks variables in the list as modified in an in-place operation. This |
115 | /// should be called at most once from inside of `forward` and all arguments |
116 | /// should be inputs. |
117 | void mark_dirty(const variable_list& inputs); |
118 | /// Marks outputs in the list as not requiring gradients. This should be |
119 | /// called at most once from inside of `forward` and all arguments should be |
120 | /// outputs. |
121 | void mark_non_differentiable(const variable_list& outputs); |
122 | // Sets whether undefined output grad tensors should be expanded to tensors |
123 | // full of zeros before calling backward function. Default value is true. |
124 | void set_materialize_grads(bool value); |
125 | |
126 | /// Get the list of variables that were saved in `forward` using |
127 | /// `save_for_backward()`. Before returning them to the user, a check is made |
128 | /// to ensure that they were not modified by any in-place operations. |
129 | variable_list get_saved_variables() const; |
130 | const std::unordered_set<at::TensorImpl*>& get_and_bump_dirty() const; |
131 | const std::unordered_set<at::TensorImpl*>& get_non_differentiable() const; |
132 | |
133 | /// Expose the Node's `task_should_compute_output` method to the cpp |
134 | /// custom autograd Function as `needs_input_grad`. |
135 | bool needs_input_grad(size_t output_edge_index) const; |
136 | bool needs_input_grad(std::initializer_list<IndexRange> idxs) const; |
137 | |
138 | private: |
139 | std::unordered_set<at::TensorImpl*> non_differentiable_; |
140 | std::unordered_set<at::TensorImpl*> dirty_inputs_; |
141 | std::vector<torch::autograd::SavedVariable> saved_variables_; |
142 | variable_list to_save_; |
143 | bool materialize_grads_{true}; |
144 | |
145 | // The CppNode in the autograd graph that owns this AutogradContext. We need a |
146 | // weak_ptr to avoid a refcycle. Since grad_fn_ owns this AutogradContext, it |
147 | // will always be alive when we want to use it. |
148 | std::weak_ptr<Node> grad_fn_; |
149 | bool has_freed_buffers_{false}; |
150 | |
151 | void save_variables(); |
152 | |
153 | template <class T> |
154 | friend struct CppNode; |
155 | }; |
156 | |
157 | struct TORCH_API VariableInfo { |
158 | explicit VariableInfo(); |
159 | explicit VariableInfo(const Variable& var); |
160 | |
161 | Variable zeros(at::OptionalDeviceGuard& device_guard) const; |
162 | |
163 | at::Layout layout = at::Layout::Strided; |
164 | at::Device device = at::kCPU; |
165 | at::ScalarType scalar_type = at::kFloat; |
166 | std::vector<int64_t> size; |
167 | bool requires_grad; |
168 | bool is_empty; |
169 | }; |
170 | |
171 | // CppNode<T> is the Node in the autograd graph that represents the user defined |
172 | // backward function for Function<T>. Calls to CppNode::apply are forward to |
173 | // T::backward(). |
174 | template <class T> |
175 | struct CppNode : public Node { |
176 | variable_list apply(variable_list&& inputs) override; |
177 | AutogradContext ctx_; |
178 | std::vector<bool> is_variable_input_; |
179 | std::vector<VariableInfo> input_info_; |
180 | std::vector<VariableInfo> output_info_; |
181 | |
182 | void release_variables() override; |
183 | |
184 | void set_ctx_grad_fn(const std::shared_ptr<Node>& node); |
185 | void save_variables_to_ctx(); |
186 | }; |
187 | |
188 | struct : IterArgs<ExtractVariables> { |
189 | std::vector<bool>& ; |
190 | variable_list& ; |
191 | (std::vector<bool>& is_var, variable_list& list) |
192 | : is_var_(is_var), list_(list) {} |
193 | void (const c10::optional<at::Tensor>& x) { |
194 | // NOLINTNEXTLINE(bugprone-branch-clone) |
195 | if (x.has_value() && x.value().defined()) { |
196 | is_var_.push_back(true); |
197 | list_.emplace_back(x.value()); |
198 | } else { |
199 | is_var_.push_back(false); |
200 | } |
201 | } |
202 | void (const at::Tensor& x) { |
203 | is_var_.push_back(true); |
204 | list_.emplace_back(x); |
205 | } |
206 | void (const at::TensorList& list) { |
207 | for (const at::Tensor& x : list) { |
208 | is_var_.push_back(true); |
209 | list_.emplace_back(x); |
210 | } |
211 | } |
212 | template <typename T> |
213 | void (const T& x) { |
214 | is_var_.push_back(false); |
215 | } |
216 | }; |
217 | |
218 | template <typename... Args> |
219 | inline void ( |
220 | std::vector<bool>& is_var, |
221 | variable_list& list, |
222 | Args&&... args) { |
223 | ExtractVariables(is_var, list).apply(std::forward<Args>(args)...); |
224 | } |
225 | |
226 | template <typename T> |
227 | typename std::enable_if<std::is_same<T, variable_list>::value, T>::type |
228 | to_output_type(std::vector<c10::optional<Variable>>& output_list) { |
229 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
230 | variable_list result; |
231 | std::transform( |
232 | output_list.begin(), |
233 | output_list.end(), |
234 | std::back_inserter(result), |
235 | [](const c10::optional<Variable>& var) { return *var; }); |
236 | return result; |
237 | } |
238 | |
239 | template <typename T> |
240 | typename std::enable_if<std::is_same<T, Variable>::value, T>::type |
241 | to_output_type(std::vector<c10::optional<Variable>>& output_list) { |
242 | return *output_list[0]; |
243 | } |
244 | |
245 | inline std::vector<c10::optional<Variable>> to_optional(Variable& output) { |
246 | return std::vector<c10::optional<Variable>>{output}; |
247 | } |
248 | |
249 | inline std::vector<c10::optional<Variable>> to_optional(variable_list& output) { |
250 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
251 | std::vector<c10::optional<Variable>> result; |
252 | std::transform( |
253 | output.begin(), |
254 | output.end(), |
255 | std::back_inserter(result), |
256 | [](const Variable& var) { return var; }); |
257 | return result; |
258 | } |
259 | |
260 | template <class T> |
261 | template <typename X, typename... Args> |
262 | auto Function<T>::apply(Args&&... args) |
263 | -> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>> { |
264 | std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode); |
265 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
266 | variable_list input_vars; |
267 | |
268 | const size_t num_inputs = sizeof...(Args); |
269 | input_vars.reserve(num_inputs); |
270 | node->is_variable_input_.reserve(num_inputs); |
271 | // TODO Add tracing here |
272 | extract_vars(node->is_variable_input_, input_vars, args...); |
273 | |
274 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
275 | bool is_executable = |
276 | GradMode::is_enabled() && any_variable_requires_grad(input_vars); |
277 | auto next_edges = |
278 | (is_executable ? collect_next_edges(input_vars) : edge_list()); |
279 | node->set_ctx_grad_fn(node); |
280 | node->set_next_edges(std::move(next_edges)); |
281 | node->clear_input_metadata(); |
282 | |
283 | node->input_info_.reserve(input_vars.size()); |
284 | for (auto& var : input_vars) { |
285 | node->input_info_.emplace_back(var); |
286 | } |
287 | |
288 | using forward_return_t = forward_t<X, Args...>; |
289 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
290 | forward_return_t outputs; |
291 | { |
292 | AutoGradMode grad_mode(false); |
293 | outputs = T::forward(&node->ctx_, std::forward<Args>(args)...); |
294 | } |
295 | |
296 | _jvp_fn_t jvp_fn = [](variable_list inputs, |
297 | variable_list gI) -> variable_list { |
298 | TORCH_CHECK( |
299 | false, |
300 | "jvp is not implemented for the c++ API of custom Function yet." , |
301 | "Please open a feature request on GitHub if you need this." ); |
302 | }; |
303 | |
304 | auto wrapped_outputs = _wrap_outputs( |
305 | input_vars, |
306 | node->ctx_.get_non_differentiable(), |
307 | node->ctx_.get_and_bump_dirty(), |
308 | to_optional(outputs), |
309 | is_executable ? node : nullptr, |
310 | jvp_fn); |
311 | |
312 | node->output_info_.reserve(wrapped_outputs.size()); |
313 | for (auto& output : wrapped_outputs) { |
314 | if (is_executable && output.has_value()) { |
315 | node->output_info_.emplace_back(output.value()); |
316 | } else if (is_executable) { |
317 | node->output_info_.emplace_back(); |
318 | } |
319 | } |
320 | |
321 | if (is_executable) { |
322 | node->save_variables_to_ctx(); |
323 | } |
324 | |
325 | // wrapped_outputs will be a variable_list so, convert it to the correct |
326 | // return type. Only Variable and variable_list are accepted as return types. |
327 | return to_output_type<forward_return_t>(wrapped_outputs); |
328 | } |
329 | |
330 | // The logic here is the same as PyNode::apply, so changes to it should be done |
331 | // in both the places |
332 | template <class T> |
333 | variable_list CppNode<T>::apply(variable_list&& inputs) { |
334 | at::OptionalDeviceGuard _device_guard; |
335 | |
336 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
337 | int num_inputs = inputs.size(); |
338 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
339 | variable_list backward_inputs; |
340 | backward_inputs.reserve(num_inputs); |
341 | for (const auto i : c10::irange(num_inputs)) { |
342 | if (inputs[i].defined() || !ctx_.materialize_grads_) { |
343 | backward_inputs.emplace_back(inputs[i]); |
344 | } else { |
345 | backward_inputs.emplace_back(output_info_[i].zeros(_device_guard)); |
346 | } |
347 | } |
348 | |
349 | // Acquire lock to here protect thread safety on custom C++ Autograd Node |
350 | // This is needed for the custom Autograd Node since we don't know if the |
351 | // user defined Node will write to the shared data during backward. |
352 | // see Note [Thread Safety on Autograd Node] |
353 | std::lock_guard<std::mutex> lock(mutex_); |
354 | |
355 | auto outputs = T::backward(&ctx_, backward_inputs); |
356 | |
357 | const auto num_forward_inputs = |
358 | static_cast<int64_t>(is_variable_input_.size()); |
359 | auto num_outputs = static_cast<int64_t>(outputs.size()); |
360 | // Returning too many results is ok, but only as long as they're all |
361 | // undefined. Truncate the result vector in that case. |
362 | if (num_outputs > num_forward_inputs) { |
363 | bool all_undef = true; |
364 | for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { |
365 | all_undef &= (!outputs[i].defined()); |
366 | } |
367 | if (all_undef) { |
368 | outputs.resize(num_forward_inputs); |
369 | num_outputs = num_forward_inputs; |
370 | } |
371 | } |
372 | |
373 | if (num_outputs != num_forward_inputs) { |
374 | std::string msg("function " ); |
375 | msg += name() + " returned an incorrect number of gradients (expected " ; |
376 | msg += c10::to_string(num_forward_inputs) + ", got " ; |
377 | msg += c10::to_string(num_outputs) + ")" ; |
378 | throw std::runtime_error(msg); |
379 | } |
380 | |
381 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
382 | variable_list results; |
383 | results.reserve(num_outputs); |
384 | for (const auto i : c10::irange(num_outputs)) { |
385 | if (!is_variable_input_[i]) { |
386 | if (outputs[i].defined()) { |
387 | std::string msg("function " ); |
388 | msg += name() + |
389 | " returned a gradient different that is defined at position " ; |
390 | msg += c10::to_string(i + 1) + |
391 | ", but the corresponding forward input was not a Variable" ; |
392 | throw std::runtime_error(msg); |
393 | } |
394 | continue; |
395 | } |
396 | results.emplace_back(outputs[i]); |
397 | } |
398 | return results; |
399 | } |
400 | |
401 | template <class T> |
402 | void CppNode<T>::release_variables() { |
403 | // lock to ensure thread safety, see [Thread Safety on Autograd Node] |
404 | std::lock_guard<std::mutex> lock(mutex_); |
405 | ctx_.saved_variables_.clear(); |
406 | ctx_.has_freed_buffers_ = true; |
407 | } |
408 | |
409 | template <class T> |
410 | void CppNode<T>::save_variables_to_ctx() { |
411 | ctx_.save_variables(); |
412 | } |
413 | |
414 | template <class T> |
415 | void CppNode<T>::set_ctx_grad_fn(const std::shared_ptr<Node>& node) { |
416 | ctx_.grad_fn_ = node; |
417 | } |
418 | |
419 | } // namespace autograd |
420 | } // namespace torch |
421 | |