1 | #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | void UpdateDifferentiableGraphRequiresGrad( |
10 | Block* block, |
11 | c10::optional<bool> new_requires_grad) { |
12 | for (Node* n : block->nodes()) { |
13 | for (Value* v : n->inputs()) { |
14 | auto ty = v->type()->cast<TensorType>(); |
15 | if (ty) { |
16 | v->setType(ty->withRequiresGrad(new_requires_grad)); |
17 | } |
18 | } |
19 | if (n->kind() == prim::profile) { |
20 | n->ty_( |
21 | attr::profiled_type, |
22 | n->ty(attr::profiled_type) |
23 | ->expectRef<TensorType>() |
24 | .withRequiresGrad(new_requires_grad)); |
25 | } |
26 | for (Block* b : n->blocks()) { |
27 | UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad); |
28 | } |
29 | } |
30 | } |
31 | |
32 | void UpdateDifferentiableGraphRequiresGrad( |
33 | std::shared_ptr<Graph>& diff_forward_graph, |
34 | c10::optional<bool> new_requires_grad) { |
35 | UpdateDifferentiableGraphRequiresGrad( |
36 | diff_forward_graph->block(), new_requires_grad); |
37 | } |
38 | |
39 | } // namespace jit |
40 | } // namespace torch |
41 |