1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4
5namespace torch {
6namespace jit {
7
8// Because differentiable graphs detach the gradients of input Tensors,
9// creating and inlining differentiable graphs changes the requires_grad
10// property of tensors in the graph. This pass updates prim::profiles
11// requires_grad to keep profiled properties up to date, it does not update
12// grad properties of other nodes like graph inputs bc the only downstream
13// user of the grad property is the profiling executor, which just uses
14// the types of prim::profiles
15TORCH_API void UpdateDifferentiableGraphRequiresGrad(
16 std::shared_ptr<Graph>& diff_forward_graph,
17 c10::optional<bool> new_requires_grad);
18
19} // namespace jit
20} // namespace torch
21