1 | #include <gtest/gtest.h> |
2 | |
3 | #include "test/cpp/jit/test_utils.h" |
4 | #include "torch/csrc/jit/frontend/tracer.h" |
5 | #include "torch/csrc/jit/passes/common_subexpression_elimination.h" |
6 | #include "torch/csrc/jit/passes/constant_propagation.h" |
7 | #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" |
8 | #include "torch/csrc/jit/passes/dead_code_elimination.h" |
9 | #include "torch/csrc/jit/passes/graph_fuser.h" |
10 | #include "torch/csrc/jit/passes/lower_grad_of.h" |
11 | #include "torch/csrc/jit/passes/requires_grad_analysis.h" |
12 | #include "torch/csrc/jit/passes/shape_analysis.h" |
13 | #include "torch/csrc/jit/passes/utils/subgraph_utils.h" |
14 | #include "torch/csrc/jit/runtime/argument_spec.h" |
15 | #include "torch/csrc/jit/runtime/autodiff.h" |
16 | #include "torch/csrc/jit/runtime/graph_iterator.h" |
17 | #include "torch/csrc/jit/runtime/profiling_graph_executor_impl.h" |
18 | #include "torch/torch.h" |
19 | |
20 | #include <ATen/ATen.h> |
21 | #include "torch/csrc/autograd/engine.h" |
22 | #include "torch/csrc/autograd/generated/variable_factories.h" |
23 | #include "torch/csrc/autograd/variable.h" |
24 | |
25 | namespace torch { |
26 | namespace jit { |
27 | |
28 | using namespace torch::autograd; |
29 | |
30 | using var_meta_type = std::vector<int64_t>; |
31 | using var_meta_list = std::vector<var_meta_type>; |
32 | using test_fn_type = std::function<variable_list(const variable_list&)>; |
33 | |
34 | struct ADTestSpec { |
35 | ADTestSpec( |
36 | const char* name, |
37 | // NOLINTNEXTLINE(modernize-pass-by-value) |
38 | var_meta_list input_meta, |
39 | // NOLINTNEXTLINE(modernize-pass-by-value) |
40 | test_fn_type test_fn, |
41 | float clampMax = -1.0f) |
42 | : name(name), |
43 | input_meta(input_meta), |
44 | test_fn(test_fn), |
45 | clampMax(clampMax) {} |
46 | |
47 | variable_list operator()(const variable_list& inputs) const { |
48 | return test_fn(inputs); |
49 | }; |
50 | |
51 | std::vector<Variable> make_vars() const { |
52 | std::vector<Variable> out; |
53 | for (const auto& m : input_meta) { |
54 | if (clampMax > 0.0f) { |
55 | out.push_back(torch::randn(m, at::requires_grad(true)) |
56 | .clamp(-clampMax, clampMax)); |
57 | continue; |
58 | } |
59 | out.push_back(torch::randn(m, at::requires_grad(true))); |
60 | } |
61 | return out; |
62 | } |
63 | |
64 | const char* name; |
65 | var_meta_list input_meta; |
66 | test_fn_type test_fn; |
67 | float clampMax; |
68 | }; |
69 | |
70 | variable_list get_grad_outputs(const variable_list& vars) { |
71 | return fmap(vars, [](const Variable& v) -> Variable { |
72 | return at::randn(v.sizes(), v.options()); |
73 | }); |
74 | } |
75 | |
76 | variable_list grad( |
77 | const variable_list& outputs, |
78 | const variable_list& inputs, |
79 | const variable_list& grad_outputs) { |
80 | const auto get_edge = [](const Variable& v) { |
81 | return torch::autograd::impl::gradient_edge(v); |
82 | }; |
83 | auto& engine = torch::autograd::Engine::get_default_engine(); |
84 | return engine.execute( |
85 | fmap(outputs, get_edge), |
86 | grad_outputs, |
87 | true, |
88 | false, |
89 | false, |
90 | fmap(inputs, get_edge)); |
91 | } |
92 | |
93 | TEST(AutodiffTest, ADFormulas) { |
94 | const auto cast = [](const Variable& v) { |
95 | return static_cast<at::Tensor>(v); |
96 | }; |
97 | |
98 | using VL = variable_list; |
99 | const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}}; |
100 | const var_meta_list unary_pointwise = {{2, 3, 4, 5}}; |
101 | const var_meta_list unary_pointwise_2d = {{2, 3}}; |
102 | const std::vector<ADTestSpec> ad_tests = { |
103 | {"add" , |
104 | binary_pointwise, |
105 | [](const VL& v) -> VL { return {v[0] + v[1]}; }}, |
106 | {"sub" , |
107 | binary_pointwise, |
108 | [](const VL& v) -> VL { return {v[0] - v[1]}; }}, |
109 | {"mul" , |
110 | binary_pointwise, |
111 | [](const VL& v) -> VL { return {v[0] * v[1]}; }}, |
112 | {"sigmoid" , |
113 | unary_pointwise, |
114 | [](const VL& v) -> VL { return {v[0].sigmoid()}; }}, |
115 | // Clamp tanh input tensor values to [-3, 3] |
116 | // to set a minimum on gradient absolute values |
117 | {"tanh" , |
118 | unary_pointwise, |
119 | [](const VL& v) -> VL { return {v[0].tanh()}; }, |
120 | 3.0f}, |
121 | {"t" , unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }}, |
122 | {"view" , |
123 | unary_pointwise_2d, |
124 | [](const VL& v) -> VL { |
125 | return {v[0].view({3, 2})}; |
126 | }}, |
127 | {"expand" , |
128 | {{2, 1}}, |
129 | [](const VL& v) -> VL { |
130 | return {v[0].expand({2, 3})}; |
131 | }}, |
132 | {"mm" , |
133 | {{10, 12}, {12, 15}}, |
134 | [](const VL& v) -> VL { return {v[0].mm(v[1])}; }}, |
135 | // TODO: enable once we'll be able to capture lists across |
136 | // forward-backward |
137 | //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return |
138 | // fmap<Variable>(v[0].chunk(4, 1)); }}, |
139 | //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return |
140 | // fmap<Variable>(v[0].chunk(3, 2)); }}, |
141 | //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return |
142 | // fmap<Variable>(v[0].split(4, 1)); }}, |
143 | //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return |
144 | // fmap<Variable>(v[0].split(3, 2)); }}, |
145 | }; |
146 | |
147 | for (const auto& test : ad_tests) { |
148 | // Get reference values form autograd |
149 | auto vars_in = test.make_vars(); |
150 | auto vars_out = test(vars_in); |
151 | auto var_grads_in = get_grad_outputs(vars_out); |
152 | auto var_grads_out = grad(vars_out, vars_in, var_grads_in); |
153 | |
154 | // Trace and differentiate the op |
155 | auto graph = tracer::trace( |
156 | fmap<IValue>(vars_in), |
157 | [&test](Stack in) -> Stack { |
158 | auto ivalue_inps = fmap(in, [](const IValue& v) { |
159 | return Variable(v.toTensor()); |
160 | }); |
161 | return fmap<IValue>(test(ivalue_inps)); |
162 | }, |
163 | [](const Variable& var) { return "" ; }) |
164 | .first->graph; |
165 | EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick |
166 | ConstantPropagation(graph); |
167 | auto grad_spec = differentiate(graph); |
168 | LowerGradOf(*grad_spec.df); |
169 | // Get outputs from the interpreter |
170 | auto tensors_in = fmap(vars_in, cast); |
171 | auto tensor_grads_in = fmap(var_grads_in, cast); |
172 | tensor_list tensors_out, tensor_grads_out; |
173 | std::tie(tensors_out, tensor_grads_out) = |
174 | runGradient(grad_spec, tensors_in, tensor_grads_in); |
175 | |
176 | // Compare results |
177 | auto expected_tensors_out = fmap(vars_out, cast); |
178 | auto expected_tensor_grads_out = fmap(var_grads_out, cast); |
179 | assertAllClose(tensors_out, expected_tensors_out); |
180 | assertAllClose(tensor_grads_out, expected_tensor_grads_out); |
181 | } |
182 | } |
183 | |
184 | TEST(AutodiffTest, Differentiate) { |
185 | // Note: can't use IRParser for this test due to issue #23989 |
186 | auto graph = std::make_shared<Graph>(); |
187 | std::vector<int64_t> sizes{2, 3, 4}; |
188 | std::vector<int64_t> strides{12, 4, 1}; |
189 | const auto type = TensorType::create( |
190 | at::ScalarType::Float, |
191 | at::kCPU, |
192 | c10::VaryingShape<int64_t>{sizes}, |
193 | c10::VaryingShape<int64_t>{strides}, |
194 | true); |
195 | |
196 | // Builds graph a * b * a + b |
197 | auto* a = graph->addInput()->setType(type); |
198 | auto* b = graph->addInput()->setType(type); |
199 | auto* cOne = graph->insertConstant(1); |
200 | |
201 | auto* ab = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1)); |
202 | ab->addInput(a); |
203 | ab->addInput(b); |
204 | |
205 | auto* aba = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1)); |
206 | aba->addInput(ab->output()); |
207 | aba->addInput(a); |
208 | |
209 | auto* abaplusb = |
210 | graph->insertNode(graph->create(aten::add, /*num_outputs =*/1)); |
211 | abaplusb->addInput(aba->output()); |
212 | abaplusb->addInput(b); |
213 | abaplusb->addInput(cOne); |
214 | |
215 | graph->registerOutput(abaplusb->output()); |
216 | |
217 | auto grad_spec = differentiate(graph); |
218 | std::vector<size_t> expected_captured_inputs = {0, 1}; |
219 | std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7}; |
220 | std::vector<size_t> expected_input_vjps = {0, 1}; |
221 | std::vector<size_t> expected_output_vjps = {0, 1}; |
222 | ASSERT_EQ(grad_spec.f_real_outputs, 1); |
223 | ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs); |
224 | ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs); |
225 | ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); |
226 | ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); |
227 | testing::FileCheck() |
228 | .check_count("aten::mul" , 2) |
229 | ->check("aten::size" ) |
230 | ->check("aten::add" ) |
231 | ->run(*grad_spec.f); |
232 | testing::FileCheck() |
233 | .check("prim::GradOf[name=\"aten::add\"]" ) |
234 | ->check_count("prim::GradOf[name=\"aten::mul\"]" , 2) |
235 | ->check_count("AutogradAdd" , 2) |
236 | ->run(*grad_spec.df); |
237 | } |
238 | |
239 | TEST(AutodiffTest, DifferentiateWithRequiresGrad) { |
240 | const auto graph_string = R"IR( |
241 | graph(%0 : Tensor, |
242 | %1 : Tensor): |
243 | %2 : int = prim::Constant[value=1]() |
244 | %3 : Tensor = aten::mul(%1, %1) |
245 | %4 : Tensor = aten::add(%3, %1, %2) |
246 | %5 : Tensor = aten::add(%4, %0, %2) |
247 | %6 : Tensor = aten::mul(%5, %0) |
248 | %7 : Tensor = aten::add(%6, %1, %2) |
249 | return (%4, %7))IR" ; |
250 | auto g = std::make_shared<Graph>(); |
251 | torch::jit::parseIR(graph_string, g.get()); |
252 | |
253 | auto a_var = autograd::make_variable( |
254 | at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true); |
255 | auto b_var = autograd::make_variable( |
256 | at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false); |
257 | |
258 | ArgumentSpecCreator asc(*g); |
259 | asc.specializeTypes(*g, asc.create(true, {a_var, b_var})); |
260 | |
261 | PropagateInputShapes(g); |
262 | PropagateRequiresGrad(g); |
263 | |
264 | auto grad_spec = differentiate(g); |
265 | std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a) |
266 | std::vector<size_t> expected_output_vjps = {0}; // only a requires grad |
267 | ASSERT_EQ(grad_spec.f_real_outputs, 2); |
268 | ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0})); |
269 | ASSERT_EQ( |
270 | grad_spec.df_input_captured_outputs, |
271 | std::vector<size_t>({2, 3, 4, 5, 6})); |
272 | ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); |
273 | ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); |
274 | testing::FileCheck() |
275 | .check("aten::mul" ) |
276 | ->check_count("aten::add" , 2) |
277 | ->check("aten::mul" ) |
278 | ->check("aten::size" ) |
279 | ->check("aten::add" ) |
280 | ->run(*grad_spec.f); |
281 | |
282 | testing::FileCheck() |
283 | .check_count("prim::GradOf[name=\"aten::mul\"]" , 1, /*exactly*/ true) |
284 | ->run(*grad_spec.df); |
285 | } |
286 | |
287 | class AutodiffRemoveUnusedGradientsTest : public ::testing::Test { |
288 | protected: |
289 | void SetUp() override { |
290 | prev_exec = getExecutorMode(); |
291 | getExecutorMode() = true; |
292 | prev_inline_autodiff = getAutodiffSubgraphInlining(); |
293 | debugSetAutodiffSubgraphInlining(false); |
294 | } |
295 | void TearDown() override { |
296 | getExecutorMode() = prev_exec; |
297 | debugSetAutodiffSubgraphInlining(prev_inline_autodiff); |
298 | } |
299 | |
300 | bool prev_exec; |
301 | bool prev_profiling; |
302 | bool prev_inline_autodiff; |
303 | }; |
304 | |
305 | TEST_F(AutodiffRemoveUnusedGradientsTest, Linear) { |
306 | auto graph = std::make_shared<Graph>(); |
307 | const std::string input = |
308 | R"IR( |
309 | graph(%inp.1 : Tensor, |
310 | %weight.1 : Tensor, |
311 | %bias.1 : Tensor): |
312 | %6 : Tensor = aten::linear(%inp.1, %weight.1, %bias.1) |
313 | return (%6))IR" ; |
314 | parseIR(input, graph.get()); |
315 | |
316 | auto inp = torch::randn({10, 10}).requires_grad_(false); |
317 | auto weight = torch::randn({10, 10}).requires_grad_(true); |
318 | auto bias = torch::randn({1, 10}).requires_grad_(true); |
319 | auto stack = createStack({inp, weight, bias}); |
320 | |
321 | ProfilingGraphExecutorImpl executor(graph, "linear" ); |
322 | |
323 | // initial run to profile requires_grad information |
324 | auto plan = executor.getPlanFor(stack, 20); |
325 | InterpreterState is{plan.code}; |
326 | is.run(stack); |
327 | |
328 | auto optimized_plan = executor.getPlanFor(stack, 20); |
329 | DepthFirstGraphNodeIterator it(optimized_plan.graph); |
330 | Node* diff_graph_node = nullptr; |
331 | |
332 | while ((diff_graph_node = it.next()) != nullptr) { |
333 | if (diff_graph_node->kind() == prim::DifferentiableGraph) { |
334 | break; |
335 | } |
336 | } |
337 | ASSERT_NE(nullptr, diff_graph_node); |
338 | |
339 | auto backward_graph = diff_graph_node->g(attr::ReverseSubgraph); |
340 | |
341 | // we expect to compute grad_weight (which requires a matmul) but we don't |
342 | // expect to compute grad_input. So, we expect exactly 1 matmul. |
343 | // Note: this could change, e.g. if mm is used instead |
344 | testing::FileCheck().check_count("matmul" , 1, true)->run(*backward_graph); |
345 | } |
346 | |
347 | } // namespace jit |
348 | } // namespace torch |
349 | |