1#include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
5
6namespace torch {
7namespace jit {
8
9void UpdateDifferentiableGraphRequiresGrad(
10 Block* block,
11 c10::optional<bool> new_requires_grad) {
12 for (Node* n : block->nodes()) {
13 for (Value* v : n->inputs()) {
14 auto ty = v->type()->cast<TensorType>();
15 if (ty) {
16 v->setType(ty->withRequiresGrad(new_requires_grad));
17 }
18 }
19 if (n->kind() == prim::profile) {
20 n->ty_(
21 attr::profiled_type,
22 n->ty(attr::profiled_type)
23 ->expectRef<TensorType>()
24 .withRequiresGrad(new_requires_grad));
25 }
26 for (Block* b : n->blocks()) {
27 UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad);
28 }
29 }
30}
31
32void UpdateDifferentiableGraphRequiresGrad(
33 std::shared_ptr<Graph>& diff_forward_graph,
34 c10::optional<bool> new_requires_grad) {
35 UpdateDifferentiableGraphRequiresGrad(
36 diff_forward_graph->block(), new_requires_grad);
37}
38
39} // namespace jit
40} // namespace torch
41