1 | #include <torch/csrc/jit/passes/requires_grad_analysis.h> |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/autograd/autograd.h> |
6 | #include <torch/csrc/jit/ir/constants.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <torch/csrc/jit/runtime/operator.h> |
9 | |
10 | #include <vector> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | namespace { |
16 | |
17 | bool getRequiresGrad(Value* value) { |
18 | return value->requires_grad(); |
19 | } |
20 | |
21 | void setRequiresGrad(Value* value, bool req_value) { |
22 | if (auto type = value->type()->cast<TensorType>()) { |
23 | value->setType(type->withRequiresGrad(req_value)); |
24 | } |
25 | } |
26 | |
27 | void setRequiresGrad( |
28 | at::ArrayRef<Value*> outputs, |
29 | const std::vector<bool>& values) { |
30 | AT_ASSERT(outputs.size() == values.size()); |
31 | for (const auto i : c10::irange(values.size())) { |
32 | setRequiresGrad(outputs[i], values[i]); |
33 | } |
34 | } |
35 | |
36 | void setRequiresGrad(Node* node, const std::vector<bool>& values) { |
37 | setRequiresGrad(node->outputs(), values); |
38 | } |
39 | |
40 | std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) { |
41 | AT_ASSERT(a.size() == b.size()); |
42 | for (const auto i : c10::irange(a.size())) { |
43 | a[i] = a[i] || b[i]; |
44 | } |
45 | return a; |
46 | } |
47 | |
48 | void PropagateRequiresGradSimpleNode(Node* node) { |
49 | static const OperatorSet comparison_ops = { |
50 | "aten::lt(Tensor self, Tensor other) -> Tensor" , |
51 | "aten::le(Tensor self, Tensor other) -> Tensor" , |
52 | "aten::gt(Tensor self, Tensor other) -> Tensor" , |
53 | "aten::ge(Tensor self, Tensor other) -> Tensor" , |
54 | "aten::eq(Tensor self, Tensor other) -> Tensor" , |
55 | "aten::ne(Tensor self, Tensor other) -> Tensor" , |
56 | "aten::lt(Tensor self, Scalar other) -> Tensor" , |
57 | "aten::le(Tensor self, Scalar other) -> Tensor" , |
58 | "aten::gt(Tensor self, Scalar other) -> Tensor" , |
59 | "aten::ge(Tensor self, Scalar other) -> Tensor" , |
60 | "aten::eq(Tensor self, Scalar other) -> Tensor" , |
61 | "aten::ne(Tensor self, Scalar other) -> Tensor" , |
62 | }; |
63 | |
64 | // NOLINTNEXTLINE(bugprone-branch-clone) |
65 | if (node->isMemberOf(comparison_ops)) { |
66 | return setRequiresGrad(node->output(), false); |
67 | } else if (node->matches( |
68 | "aten::type_as(Tensor self, Tensor other) -> Tensor" )) { |
69 | return setRequiresGrad(node->output(), node->input(0)->requires_grad()); |
70 | } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)" )) { |
71 | return setRequiresGrad(node->output(), false); |
72 | } else if (node->kind() == aten::tensor) { |
73 | if (auto grad_index = |
74 | node->schema().argumentIndexWithName("requires_grad" )) { |
75 | if (auto const_arg = constant_as<bool>(node->inputs().at(*grad_index))) { |
76 | return setRequiresGrad(node->output(), *const_arg); |
77 | } |
78 | } |
79 | if (auto type = node->output()->type()->cast<TensorType>()) { |
80 | if (type->scalarType()) { |
81 | setRequiresGrad( |
82 | node->output(), |
83 | autograd::isDifferentiableType(*type->scalarType())); |
84 | } |
85 | } |
86 | return; |
87 | } |
88 | |
89 | auto inputs = node->inputs(); |
90 | auto outputs = node->outputs(); |
91 | bool should_require = |
92 | std::any_of(inputs.begin(), inputs.end(), getRequiresGrad); |
93 | for (Value* output : outputs) { |
94 | if (auto type = output->type()->cast<TensorType>()) { |
95 | if (type->scalarType()) { |
96 | setRequiresGrad( |
97 | output, |
98 | should_require && |
99 | autograd::isDifferentiableType(*type->scalarType())); |
100 | } |
101 | } |
102 | } |
103 | } |
104 | |
105 | void PropagateRequiresGrad(Block* block); |
106 | |
107 | void PropagateRequiresGrad(Node* node) { |
108 | if (node->kind() == prim::If) { |
109 | auto blocks = node->blocks(); |
110 | auto true_block = blocks.at(0); |
111 | auto false_block = blocks.at(1); |
112 | |
113 | PropagateRequiresGrad(true_block); |
114 | PropagateRequiresGrad(false_block); |
115 | |
116 | auto outputs_require = bitwiseOr( |
117 | fmap(true_block->outputs(), getRequiresGrad), |
118 | fmap(false_block->outputs(), getRequiresGrad)); |
119 | setRequiresGrad(node, outputs_require); |
120 | } else if (node->kind() == prim::Loop) { |
121 | auto body = node->blocks().at(0); |
122 | std::vector<bool> loop_inputs_require = |
123 | fmap(node->inputs().slice(2), getRequiresGrad); |
124 | std::vector<bool> body_inputs_require = loop_inputs_require; |
125 | std::vector<bool> body_outputs_require(node->outputs().size(), false); |
126 | |
127 | std::vector<bool> new_body_inputs_require = body_inputs_require; |
128 | std::vector<bool> new_body_outputs_require = body_outputs_require; |
129 | |
130 | // continue iterating until the results have converged |
131 | do { |
132 | body_inputs_require = new_body_inputs_require; |
133 | body_outputs_require = new_body_outputs_require; |
134 | |
135 | new_body_inputs_require = |
136 | bitwiseOr(body_inputs_require, body_outputs_require); |
137 | setRequiresGrad( |
138 | body->param_node()->outputs().slice(1), new_body_inputs_require); |
139 | PropagateRequiresGrad(body); |
140 | new_body_outputs_require = |
141 | fmap(body->return_node()->inputs().slice(1), getRequiresGrad); |
142 | } while (new_body_inputs_require != body_inputs_require || |
143 | new_body_outputs_require != body_outputs_require); |
144 | |
145 | setRequiresGrad(node, bitwiseOr(body_outputs_require, loop_inputs_require)); |
146 | } else { |
147 | PropagateRequiresGradSimpleNode(node); |
148 | } |
149 | } |
150 | |
151 | void PropagateRequiresGrad(Block* block) { |
152 | for (Node* node : block->nodes()) { |
153 | PropagateRequiresGrad(node); |
154 | } |
155 | } |
156 | } // anonymous namespace |
157 | |
158 | void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) { |
159 | PropagateRequiresGrad(graph->block()); |
160 | } |
161 | } // namespace jit |
162 | } // namespace torch |
163 | |