1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <c10/util/flat_hash_map.h>
5#include <c10/util/irange.h>
6#include <torch/csrc/autograd/function.h>
7#include <torch/csrc/autograd/variable.h>
8#include <vector>
9
10namespace torch {
11namespace autograd {
12
13using optional_variable_list = std::vector<c10::optional<Variable>>;
14using _jvp_fn_t = std::function<variable_list(variable_list, variable_list)>;
15
16TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
17 const variable_list& input_vars,
18 const std::unordered_set<at::TensorImpl*>& non_differentiable,
19 const std::unordered_set<at::TensorImpl*>& dirty_inputs,
20 const at::ArrayRef<c10::optional<Variable>> raw_outputs,
21 const std::shared_ptr<Node>& cdata,
22 _jvp_fn_t jvp_user_function);
23
24TORCH_API void check_variable_result(
25 const at::TensorBase& original,
26 const at::TensorBase& result,
27 std::string hook_name);
28
29// Get the return type of the forward function of the custom Function class X
30template <typename X, typename... Args>
31using forward_t = decltype(X::forward(nullptr, std::declval<Args>()...));
32
33/// To use custom autograd operations, implement a Function subclass with
34/// static forward and backward functions:
35///
36/// `forward` can take as many arguments as you want and should return either a
37/// variable list or a Variable. Use of any direct Variable arguments will be
38/// registered in the graph but no vectors/sets or any other data structures
39/// will be traversed. You can use c10::optional<Tensor> as one of the arguments
40/// and it will be registered as a variable in the graph if the argument has a
41/// value. It should take a pointer to `torch::autograd::AutogradContext` as the
42/// first argument. Variables can be saved in the `ctx` using
43/// `ctx->save_for_backward`
44/// (see `torch::autograd::AutogradContext::save_for_backward`) and other data
45/// can be saved in the `ctx->saved_data` map
46/// (see `torch::autograd::AutogradContext::saved_data`)
47/// in the form of `<std::string, at::IValue>` pairs.
48///
49/// `backward` should take a pointer to `torch::autograd::AutogradContext`
50/// and a variable list containing as many Variables as there were outputs from
51/// `forward` as arguments. It should return as many Variables as there were
52/// inputs with each of them containing the gradient w.r.t. its corresponding
53/// input. Variables saved in `forward` can be accessed with
54/// `ctx->get_saved_variables` (see
55/// `torch::autograd::AutogradContext::get_saved_variables`) and other saved
56/// data can be accessed from `ctx->saved_data`.
57///
58/// For example:
59/// ```
60/// class MyFunction : public Function<MyFunction> {
61/// public:
62/// static variable_list forward(AutogradContext *ctx, int n, Variable var) {
63/// // Save data for backward in context
64/// ctx->saved_data["n"] = n;
65/// var.mul_(2);
66/// // Mark var as modified by inplace operation
67/// ctx->mark_dirty({var});
68/// return {var};
69/// }
70///
71/// static variable_list backward(AutogradContext *ctx, variable_list
72/// grad_output) {
73/// // Use data saved in forward
74/// auto n = ctx->saved_data["n"].toInt();
75/// return {grad_output[0]*n};
76/// }
77/// };
78/// ```
79///
80/// To use `MyFunction`:
81/// ```
82/// Variable x;
83/// auto y = MyFunction::apply(6, x);
84/// // Example backward call
85/// y[0].sum().backward();
86/// ```
87template <class T>
88struct TORCH_API Function {
89 // We need to use a different template parameter than T here because T will
90 // inherit from Function, and when Function<T> is instantiated, T::forward
91 // is not declared yet.
92 // The enable_if check is to ensure that the user doesn't explicitly provide
93 // the parameter X.
94 template <typename X = T, typename... Args>
95 static auto apply(Args&&... args)
96 -> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>>;
97};
98
99/// Context to save information during `forward` that can be accessed in
100/// `backward` in custom autograd operations (see `torch::autograd::Function`
101/// for details).
102struct TORCH_API AutogradContext {
103 AutogradContext() = default;
104 AutogradContext(const AutogradContext& other) = delete;
105 AutogradContext& operator=(const AutogradContext& other) = delete;
106
107 /// Can be used to save non-variable data for `backward`.
108 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
109 ska::flat_hash_map<std::string, at::IValue> saved_data;
110
111 /// Saves the list of variables for a future call to `backward`. This
112 /// should be called at most once from inside of `forward`.
113 void save_for_backward(variable_list to_save);
114 /// Marks variables in the list as modified in an in-place operation. This
115 /// should be called at most once from inside of `forward` and all arguments
116 /// should be inputs.
117 void mark_dirty(const variable_list& inputs);
118 /// Marks outputs in the list as not requiring gradients. This should be
119 /// called at most once from inside of `forward` and all arguments should be
120 /// outputs.
121 void mark_non_differentiable(const variable_list& outputs);
122 // Sets whether undefined output grad tensors should be expanded to tensors
123 // full of zeros before calling backward function. Default value is true.
124 void set_materialize_grads(bool value);
125
126 /// Get the list of variables that were saved in `forward` using
127 /// `save_for_backward()`. Before returning them to the user, a check is made
128 /// to ensure that they were not modified by any in-place operations.
129 variable_list get_saved_variables() const;
130 const std::unordered_set<at::TensorImpl*>& get_and_bump_dirty() const;
131 const std::unordered_set<at::TensorImpl*>& get_non_differentiable() const;
132
133 /// Expose the Node's `task_should_compute_output` method to the cpp
134 /// custom autograd Function as `needs_input_grad`.
135 bool needs_input_grad(size_t output_edge_index) const;
136 bool needs_input_grad(std::initializer_list<IndexRange> idxs) const;
137
138 private:
139 std::unordered_set<at::TensorImpl*> non_differentiable_;
140 std::unordered_set<at::TensorImpl*> dirty_inputs_;
141 std::vector<torch::autograd::SavedVariable> saved_variables_;
142 variable_list to_save_;
143 bool materialize_grads_{true};
144
145 // The CppNode in the autograd graph that owns this AutogradContext. We need a
146 // weak_ptr to avoid a refcycle. Since grad_fn_ owns this AutogradContext, it
147 // will always be alive when we want to use it.
148 std::weak_ptr<Node> grad_fn_;
149 bool has_freed_buffers_{false};
150
151 void save_variables();
152
153 template <class T>
154 friend struct CppNode;
155};
156
157struct TORCH_API VariableInfo {
158 explicit VariableInfo();
159 explicit VariableInfo(const Variable& var);
160
161 Variable zeros(at::OptionalDeviceGuard& device_guard) const;
162
163 at::Layout layout = at::Layout::Strided;
164 at::Device device = at::kCPU;
165 at::ScalarType scalar_type = at::kFloat;
166 std::vector<int64_t> size;
167 bool requires_grad;
168 bool is_empty;
169};
170
171// CppNode<T> is the Node in the autograd graph that represents the user defined
172// backward function for Function<T>. Calls to CppNode::apply are forward to
173// T::backward().
174template <class T>
175struct CppNode : public Node {
176 variable_list apply(variable_list&& inputs) override;
177 AutogradContext ctx_;
178 std::vector<bool> is_variable_input_;
179 std::vector<VariableInfo> input_info_;
180 std::vector<VariableInfo> output_info_;
181
182 void release_variables() override;
183
184 void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
185 void save_variables_to_ctx();
186};
187
188struct ExtractVariables : IterArgs<ExtractVariables> {
189 std::vector<bool>& is_var_;
190 variable_list& list_;
191 ExtractVariables(std::vector<bool>& is_var, variable_list& list)
192 : is_var_(is_var), list_(list) {}
193 void operator()(const c10::optional<at::Tensor>& x) {
194 // NOLINTNEXTLINE(bugprone-branch-clone)
195 if (x.has_value() && x.value().defined()) {
196 is_var_.push_back(true);
197 list_.emplace_back(x.value());
198 } else {
199 is_var_.push_back(false);
200 }
201 }
202 void operator()(const at::Tensor& x) {
203 is_var_.push_back(true);
204 list_.emplace_back(x);
205 }
206 void operator()(const at::TensorList& list) {
207 for (const at::Tensor& x : list) {
208 is_var_.push_back(true);
209 list_.emplace_back(x);
210 }
211 }
212 template <typename T>
213 void operator()(const T& x) {
214 is_var_.push_back(false);
215 }
216};
217
218template <typename... Args>
219inline void extract_vars(
220 std::vector<bool>& is_var,
221 variable_list& list,
222 Args&&... args) {
223 ExtractVariables(is_var, list).apply(std::forward<Args>(args)...);
224}
225
226template <typename T>
227typename std::enable_if<std::is_same<T, variable_list>::value, T>::type
228to_output_type(std::vector<c10::optional<Variable>>& output_list) {
229 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
230 variable_list result;
231 std::transform(
232 output_list.begin(),
233 output_list.end(),
234 std::back_inserter(result),
235 [](const c10::optional<Variable>& var) { return *var; });
236 return result;
237}
238
239template <typename T>
240typename std::enable_if<std::is_same<T, Variable>::value, T>::type
241to_output_type(std::vector<c10::optional<Variable>>& output_list) {
242 return *output_list[0];
243}
244
245inline std::vector<c10::optional<Variable>> to_optional(Variable& output) {
246 return std::vector<c10::optional<Variable>>{output};
247}
248
249inline std::vector<c10::optional<Variable>> to_optional(variable_list& output) {
250 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
251 std::vector<c10::optional<Variable>> result;
252 std::transform(
253 output.begin(),
254 output.end(),
255 std::back_inserter(result),
256 [](const Variable& var) { return var; });
257 return result;
258}
259
260template <class T>
261template <typename X, typename... Args>
262auto Function<T>::apply(Args&&... args)
263 -> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>> {
264 std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);
265 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
266 variable_list input_vars;
267
268 const size_t num_inputs = sizeof...(Args);
269 input_vars.reserve(num_inputs);
270 node->is_variable_input_.reserve(num_inputs);
271 // TODO Add tracing here
272 extract_vars(node->is_variable_input_, input_vars, args...);
273
274 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
275 bool is_executable =
276 GradMode::is_enabled() && any_variable_requires_grad(input_vars);
277 auto next_edges =
278 (is_executable ? collect_next_edges(input_vars) : edge_list());
279 node->set_ctx_grad_fn(node);
280 node->set_next_edges(std::move(next_edges));
281 node->clear_input_metadata();
282
283 node->input_info_.reserve(input_vars.size());
284 for (auto& var : input_vars) {
285 node->input_info_.emplace_back(var);
286 }
287
288 using forward_return_t = forward_t<X, Args...>;
289 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
290 forward_return_t outputs;
291 {
292 AutoGradMode grad_mode(false);
293 outputs = T::forward(&node->ctx_, std::forward<Args>(args)...);
294 }
295
296 _jvp_fn_t jvp_fn = [](variable_list inputs,
297 variable_list gI) -> variable_list {
298 TORCH_CHECK(
299 false,
300 "jvp is not implemented for the c++ API of custom Function yet.",
301 "Please open a feature request on GitHub if you need this.");
302 };
303
304 auto wrapped_outputs = _wrap_outputs(
305 input_vars,
306 node->ctx_.get_non_differentiable(),
307 node->ctx_.get_and_bump_dirty(),
308 to_optional(outputs),
309 is_executable ? node : nullptr,
310 jvp_fn);
311
312 node->output_info_.reserve(wrapped_outputs.size());
313 for (auto& output : wrapped_outputs) {
314 if (is_executable && output.has_value()) {
315 node->output_info_.emplace_back(output.value());
316 } else if (is_executable) {
317 node->output_info_.emplace_back();
318 }
319 }
320
321 if (is_executable) {
322 node->save_variables_to_ctx();
323 }
324
325 // wrapped_outputs will be a variable_list so, convert it to the correct
326 // return type. Only Variable and variable_list are accepted as return types.
327 return to_output_type<forward_return_t>(wrapped_outputs);
328}
329
330// The logic here is the same as PyNode::apply, so changes to it should be done
331// in both the places
332template <class T>
333variable_list CppNode<T>::apply(variable_list&& inputs) {
334 at::OptionalDeviceGuard _device_guard;
335
336 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
337 int num_inputs = inputs.size();
338 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
339 variable_list backward_inputs;
340 backward_inputs.reserve(num_inputs);
341 for (const auto i : c10::irange(num_inputs)) {
342 if (inputs[i].defined() || !ctx_.materialize_grads_) {
343 backward_inputs.emplace_back(inputs[i]);
344 } else {
345 backward_inputs.emplace_back(output_info_[i].zeros(_device_guard));
346 }
347 }
348
349 // Acquire lock to here protect thread safety on custom C++ Autograd Node
350 // This is needed for the custom Autograd Node since we don't know if the
351 // user defined Node will write to the shared data during backward.
352 // see Note [Thread Safety on Autograd Node]
353 std::lock_guard<std::mutex> lock(mutex_);
354
355 auto outputs = T::backward(&ctx_, backward_inputs);
356
357 const auto num_forward_inputs =
358 static_cast<int64_t>(is_variable_input_.size());
359 auto num_outputs = static_cast<int64_t>(outputs.size());
360 // Returning too many results is ok, but only as long as they're all
361 // undefined. Truncate the result vector in that case.
362 if (num_outputs > num_forward_inputs) {
363 bool all_undef = true;
364 for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
365 all_undef &= (!outputs[i].defined());
366 }
367 if (all_undef) {
368 outputs.resize(num_forward_inputs);
369 num_outputs = num_forward_inputs;
370 }
371 }
372
373 if (num_outputs != num_forward_inputs) {
374 std::string msg("function ");
375 msg += name() + " returned an incorrect number of gradients (expected ";
376 msg += c10::to_string(num_forward_inputs) + ", got ";
377 msg += c10::to_string(num_outputs) + ")";
378 throw std::runtime_error(msg);
379 }
380
381 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
382 variable_list results;
383 results.reserve(num_outputs);
384 for (const auto i : c10::irange(num_outputs)) {
385 if (!is_variable_input_[i]) {
386 if (outputs[i].defined()) {
387 std::string msg("function ");
388 msg += name() +
389 " returned a gradient different that is defined at position ";
390 msg += c10::to_string(i + 1) +
391 ", but the corresponding forward input was not a Variable";
392 throw std::runtime_error(msg);
393 }
394 continue;
395 }
396 results.emplace_back(outputs[i]);
397 }
398 return results;
399}
400
401template <class T>
402void CppNode<T>::release_variables() {
403 // lock to ensure thread safety, see [Thread Safety on Autograd Node]
404 std::lock_guard<std::mutex> lock(mutex_);
405 ctx_.saved_variables_.clear();
406 ctx_.has_freed_buffers_ = true;
407}
408
409template <class T>
410void CppNode<T>::save_variables_to_ctx() {
411 ctx_.save_variables();
412}
413
414template <class T>
415void CppNode<T>::set_ctx_grad_fn(const std::shared_ptr<Node>& node) {
416 ctx_.grad_fn_ = node;
417}
418
419} // namespace autograd
420} // namespace torch
421