1#include <torch/csrc/jit/passes/requires_grad_analysis.h>
2
3#include <ATen/core/jit_type.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/autograd/autograd.h>
6#include <torch/csrc/jit/ir/constants.h>
7#include <torch/csrc/jit/ir/ir.h>
8#include <torch/csrc/jit/runtime/operator.h>
9
10#include <vector>
11
12namespace torch {
13namespace jit {
14
15namespace {
16
17bool getRequiresGrad(Value* value) {
18 return value->requires_grad();
19}
20
21void setRequiresGrad(Value* value, bool req_value) {
22 if (auto type = value->type()->cast<TensorType>()) {
23 value->setType(type->withRequiresGrad(req_value));
24 }
25}
26
27void setRequiresGrad(
28 at::ArrayRef<Value*> outputs,
29 const std::vector<bool>& values) {
30 AT_ASSERT(outputs.size() == values.size());
31 for (const auto i : c10::irange(values.size())) {
32 setRequiresGrad(outputs[i], values[i]);
33 }
34}
35
36void setRequiresGrad(Node* node, const std::vector<bool>& values) {
37 setRequiresGrad(node->outputs(), values);
38}
39
40std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
41 AT_ASSERT(a.size() == b.size());
42 for (const auto i : c10::irange(a.size())) {
43 a[i] = a[i] || b[i];
44 }
45 return a;
46}
47
48void PropagateRequiresGradSimpleNode(Node* node) {
49 static const OperatorSet comparison_ops = {
50 "aten::lt(Tensor self, Tensor other) -> Tensor",
51 "aten::le(Tensor self, Tensor other) -> Tensor",
52 "aten::gt(Tensor self, Tensor other) -> Tensor",
53 "aten::ge(Tensor self, Tensor other) -> Tensor",
54 "aten::eq(Tensor self, Tensor other) -> Tensor",
55 "aten::ne(Tensor self, Tensor other) -> Tensor",
56 "aten::lt(Tensor self, Scalar other) -> Tensor",
57 "aten::le(Tensor self, Scalar other) -> Tensor",
58 "aten::gt(Tensor self, Scalar other) -> Tensor",
59 "aten::ge(Tensor self, Scalar other) -> Tensor",
60 "aten::eq(Tensor self, Scalar other) -> Tensor",
61 "aten::ne(Tensor self, Scalar other) -> Tensor",
62 };
63
64 // NOLINTNEXTLINE(bugprone-branch-clone)
65 if (node->isMemberOf(comparison_ops)) {
66 return setRequiresGrad(node->output(), false);
67 } else if (node->matches(
68 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
69 return setRequiresGrad(node->output(), node->input(0)->requires_grad());
70 } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
71 return setRequiresGrad(node->output(), false);
72 } else if (node->kind() == aten::tensor) {
73 if (auto grad_index =
74 node->schema().argumentIndexWithName("requires_grad")) {
75 if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) {
76 return setRequiresGrad(node->output(), *const_arg);
77 }
78 }
79 if (auto type = node->output()->type()->cast<TensorType>()) {
80 if (type->scalarType()) {
81 setRequiresGrad(
82 node->output(),
83 autograd::isDifferentiableType(*type->scalarType()));
84 }
85 }
86 return;
87 }
88
89 auto inputs = node->inputs();
90 auto outputs = node->outputs();
91 bool should_require =
92 std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
93 for (Value* output : outputs) {
94 if (auto type = output->type()->cast<TensorType>()) {
95 if (type->scalarType()) {
96 setRequiresGrad(
97 output,
98 should_require &&
99 autograd::isDifferentiableType(*type->scalarType()));
100 }
101 }
102 }
103}
104
105void PropagateRequiresGrad(Block* block);
106
107void PropagateRequiresGrad(Node* node) {
108 if (node->kind() == prim::If) {
109 auto blocks = node->blocks();
110 auto true_block = blocks.at(0);
111 auto false_block = blocks.at(1);
112
113 PropagateRequiresGrad(true_block);
114 PropagateRequiresGrad(false_block);
115
116 auto outputs_require = bitwiseOr(
117 fmap(true_block->outputs(), getRequiresGrad),
118 fmap(false_block->outputs(), getRequiresGrad));
119 setRequiresGrad(node, outputs_require);
120 } else if (node->kind() == prim::Loop) {
121 auto body = node->blocks().at(0);
122 std::vector<bool> loop_inputs_require =
123 fmap(node->inputs().slice(2), getRequiresGrad);
124 std::vector<bool> body_inputs_require = loop_inputs_require;
125 std::vector<bool> body_outputs_require(node->outputs().size(), false);
126
127 std::vector<bool> new_body_inputs_require = body_inputs_require;
128 std::vector<bool> new_body_outputs_require = body_outputs_require;
129
130 // continue iterating until the results have converged
131 do {
132 body_inputs_require = new_body_inputs_require;
133 body_outputs_require = new_body_outputs_require;
134
135 new_body_inputs_require =
136 bitwiseOr(body_inputs_require, body_outputs_require);
137 setRequiresGrad(
138 body->param_node()->outputs().slice(1), new_body_inputs_require);
139 PropagateRequiresGrad(body);
140 new_body_outputs_require =
141 fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
142 } while (new_body_inputs_require != body_inputs_require ||
143 new_body_outputs_require != body_outputs_require);
144
145 setRequiresGrad(node, bitwiseOr(body_outputs_require, loop_inputs_require));
146 } else {
147 PropagateRequiresGradSimpleNode(node);
148 }
149}
150
151void PropagateRequiresGrad(Block* block) {
152 for (Node* node : block->nodes()) {
153 PropagateRequiresGrad(node);
154 }
155}
156} // anonymous namespace
157
158void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
159 PropagateRequiresGrad(graph->block());
160}
161} // namespace jit
162} // namespace torch
163