1#include <gtest/gtest.h>
2
3#include "test/cpp/jit/test_utils.h"
4#include "torch/csrc/jit/frontend/tracer.h"
5#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
6#include "torch/csrc/jit/passes/constant_propagation.h"
7#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
8#include "torch/csrc/jit/passes/dead_code_elimination.h"
9#include "torch/csrc/jit/passes/graph_fuser.h"
10#include "torch/csrc/jit/passes/lower_grad_of.h"
11#include "torch/csrc/jit/passes/requires_grad_analysis.h"
12#include "torch/csrc/jit/passes/shape_analysis.h"
13#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
14#include "torch/csrc/jit/runtime/argument_spec.h"
15#include "torch/csrc/jit/runtime/autodiff.h"
16#include "torch/csrc/jit/runtime/graph_iterator.h"
17#include "torch/csrc/jit/runtime/profiling_graph_executor_impl.h"
18#include "torch/torch.h"
19
20#include <ATen/ATen.h>
21#include "torch/csrc/autograd/engine.h"
22#include "torch/csrc/autograd/generated/variable_factories.h"
23#include "torch/csrc/autograd/variable.h"
24
25namespace torch {
26namespace jit {
27
28using namespace torch::autograd;
29
30using var_meta_type = std::vector<int64_t>;
31using var_meta_list = std::vector<var_meta_type>;
32using test_fn_type = std::function<variable_list(const variable_list&)>;
33
34struct ADTestSpec {
35 ADTestSpec(
36 const char* name,
37 // NOLINTNEXTLINE(modernize-pass-by-value)
38 var_meta_list input_meta,
39 // NOLINTNEXTLINE(modernize-pass-by-value)
40 test_fn_type test_fn,
41 float clampMax = -1.0f)
42 : name(name),
43 input_meta(input_meta),
44 test_fn(test_fn),
45 clampMax(clampMax) {}
46
47 variable_list operator()(const variable_list& inputs) const {
48 return test_fn(inputs);
49 };
50
51 std::vector<Variable> make_vars() const {
52 std::vector<Variable> out;
53 for (const auto& m : input_meta) {
54 if (clampMax > 0.0f) {
55 out.push_back(torch::randn(m, at::requires_grad(true))
56 .clamp(-clampMax, clampMax));
57 continue;
58 }
59 out.push_back(torch::randn(m, at::requires_grad(true)));
60 }
61 return out;
62 }
63
64 const char* name;
65 var_meta_list input_meta;
66 test_fn_type test_fn;
67 float clampMax;
68};
69
70variable_list get_grad_outputs(const variable_list& vars) {
71 return fmap(vars, [](const Variable& v) -> Variable {
72 return at::randn(v.sizes(), v.options());
73 });
74}
75
76variable_list grad(
77 const variable_list& outputs,
78 const variable_list& inputs,
79 const variable_list& grad_outputs) {
80 const auto get_edge = [](const Variable& v) {
81 return torch::autograd::impl::gradient_edge(v);
82 };
83 auto& engine = torch::autograd::Engine::get_default_engine();
84 return engine.execute(
85 fmap(outputs, get_edge),
86 grad_outputs,
87 true,
88 false,
89 false,
90 fmap(inputs, get_edge));
91}
92
93TEST(AutodiffTest, ADFormulas) {
94 const auto cast = [](const Variable& v) {
95 return static_cast<at::Tensor>(v);
96 };
97
98 using VL = variable_list;
99 const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}};
100 const var_meta_list unary_pointwise = {{2, 3, 4, 5}};
101 const var_meta_list unary_pointwise_2d = {{2, 3}};
102 const std::vector<ADTestSpec> ad_tests = {
103 {"add",
104 binary_pointwise,
105 [](const VL& v) -> VL { return {v[0] + v[1]}; }},
106 {"sub",
107 binary_pointwise,
108 [](const VL& v) -> VL { return {v[0] - v[1]}; }},
109 {"mul",
110 binary_pointwise,
111 [](const VL& v) -> VL { return {v[0] * v[1]}; }},
112 {"sigmoid",
113 unary_pointwise,
114 [](const VL& v) -> VL { return {v[0].sigmoid()}; }},
115 // Clamp tanh input tensor values to [-3, 3]
116 // to set a minimum on gradient absolute values
117 {"tanh",
118 unary_pointwise,
119 [](const VL& v) -> VL { return {v[0].tanh()}; },
120 3.0f},
121 {"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }},
122 {"view",
123 unary_pointwise_2d,
124 [](const VL& v) -> VL {
125 return {v[0].view({3, 2})};
126 }},
127 {"expand",
128 {{2, 1}},
129 [](const VL& v) -> VL {
130 return {v[0].expand({2, 3})};
131 }},
132 {"mm",
133 {{10, 12}, {12, 15}},
134 [](const VL& v) -> VL { return {v[0].mm(v[1])}; }},
135 // TODO: enable once we'll be able to capture lists across
136 // forward-backward
137 //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return
138 // fmap<Variable>(v[0].chunk(4, 1)); }},
139 //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return
140 // fmap<Variable>(v[0].chunk(3, 2)); }},
141 //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return
142 // fmap<Variable>(v[0].split(4, 1)); }},
143 //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return
144 // fmap<Variable>(v[0].split(3, 2)); }},
145 };
146
147 for (const auto& test : ad_tests) {
148 // Get reference values form autograd
149 auto vars_in = test.make_vars();
150 auto vars_out = test(vars_in);
151 auto var_grads_in = get_grad_outputs(vars_out);
152 auto var_grads_out = grad(vars_out, vars_in, var_grads_in);
153
154 // Trace and differentiate the op
155 auto graph = tracer::trace(
156 fmap<IValue>(vars_in),
157 [&test](Stack in) -> Stack {
158 auto ivalue_inps = fmap(in, [](const IValue& v) {
159 return Variable(v.toTensor());
160 });
161 return fmap<IValue>(test(ivalue_inps));
162 },
163 [](const Variable& var) { return ""; })
164 .first->graph;
165 EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
166 ConstantPropagation(graph);
167 auto grad_spec = differentiate(graph);
168 LowerGradOf(*grad_spec.df);
169 // Get outputs from the interpreter
170 auto tensors_in = fmap(vars_in, cast);
171 auto tensor_grads_in = fmap(var_grads_in, cast);
172 tensor_list tensors_out, tensor_grads_out;
173 std::tie(tensors_out, tensor_grads_out) =
174 runGradient(grad_spec, tensors_in, tensor_grads_in);
175
176 // Compare results
177 auto expected_tensors_out = fmap(vars_out, cast);
178 auto expected_tensor_grads_out = fmap(var_grads_out, cast);
179 assertAllClose(tensors_out, expected_tensors_out);
180 assertAllClose(tensor_grads_out, expected_tensor_grads_out);
181 }
182}
183
184TEST(AutodiffTest, Differentiate) {
185 // Note: can't use IRParser for this test due to issue #23989
186 auto graph = std::make_shared<Graph>();
187 std::vector<int64_t> sizes{2, 3, 4};
188 std::vector<int64_t> strides{12, 4, 1};
189 const auto type = TensorType::create(
190 at::ScalarType::Float,
191 at::kCPU,
192 c10::VaryingShape<int64_t>{sizes},
193 c10::VaryingShape<int64_t>{strides},
194 true);
195
196 // Builds graph a * b * a + b
197 auto* a = graph->addInput()->setType(type);
198 auto* b = graph->addInput()->setType(type);
199 auto* cOne = graph->insertConstant(1);
200
201 auto* ab = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1));
202 ab->addInput(a);
203 ab->addInput(b);
204
205 auto* aba = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1));
206 aba->addInput(ab->output());
207 aba->addInput(a);
208
209 auto* abaplusb =
210 graph->insertNode(graph->create(aten::add, /*num_outputs =*/1));
211 abaplusb->addInput(aba->output());
212 abaplusb->addInput(b);
213 abaplusb->addInput(cOne);
214
215 graph->registerOutput(abaplusb->output());
216
217 auto grad_spec = differentiate(graph);
218 std::vector<size_t> expected_captured_inputs = {0, 1};
219 std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
220 std::vector<size_t> expected_input_vjps = {0, 1};
221 std::vector<size_t> expected_output_vjps = {0, 1};
222 ASSERT_EQ(grad_spec.f_real_outputs, 1);
223 ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
224 ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs);
225 ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
226 ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
227 testing::FileCheck()
228 .check_count("aten::mul", 2)
229 ->check("aten::size")
230 ->check("aten::add")
231 ->run(*grad_spec.f);
232 testing::FileCheck()
233 .check("prim::GradOf[name=\"aten::add\"]")
234 ->check_count("prim::GradOf[name=\"aten::mul\"]", 2)
235 ->check_count("AutogradAdd", 2)
236 ->run(*grad_spec.df);
237}
238
239TEST(AutodiffTest, DifferentiateWithRequiresGrad) {
240 const auto graph_string = R"IR(
241 graph(%0 : Tensor,
242 %1 : Tensor):
243 %2 : int = prim::Constant[value=1]()
244 %3 : Tensor = aten::mul(%1, %1)
245 %4 : Tensor = aten::add(%3, %1, %2)
246 %5 : Tensor = aten::add(%4, %0, %2)
247 %6 : Tensor = aten::mul(%5, %0)
248 %7 : Tensor = aten::add(%6, %1, %2)
249 return (%4, %7))IR";
250 auto g = std::make_shared<Graph>();
251 torch::jit::parseIR(graph_string, g.get());
252
253 auto a_var = autograd::make_variable(
254 at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
255 auto b_var = autograd::make_variable(
256 at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
257
258 ArgumentSpecCreator asc(*g);
259 asc.specializeTypes(*g, asc.create(true, {a_var, b_var}));
260
261 PropagateInputShapes(g);
262 PropagateRequiresGrad(g);
263
264 auto grad_spec = differentiate(g);
265 std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
266 std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
267 ASSERT_EQ(grad_spec.f_real_outputs, 2);
268 ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
269 ASSERT_EQ(
270 grad_spec.df_input_captured_outputs,
271 std::vector<size_t>({2, 3, 4, 5, 6}));
272 ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
273 ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
274 testing::FileCheck()
275 .check("aten::mul")
276 ->check_count("aten::add", 2)
277 ->check("aten::mul")
278 ->check("aten::size")
279 ->check("aten::add")
280 ->run(*grad_spec.f);
281
282 testing::FileCheck()
283 .check_count("prim::GradOf[name=\"aten::mul\"]", 1, /*exactly*/ true)
284 ->run(*grad_spec.df);
285}
286
287class AutodiffRemoveUnusedGradientsTest : public ::testing::Test {
288 protected:
289 void SetUp() override {
290 prev_exec = getExecutorMode();
291 getExecutorMode() = true;
292 prev_inline_autodiff = getAutodiffSubgraphInlining();
293 debugSetAutodiffSubgraphInlining(false);
294 }
295 void TearDown() override {
296 getExecutorMode() = prev_exec;
297 debugSetAutodiffSubgraphInlining(prev_inline_autodiff);
298 }
299
300 bool prev_exec;
301 bool prev_profiling;
302 bool prev_inline_autodiff;
303};
304
305TEST_F(AutodiffRemoveUnusedGradientsTest, Linear) {
306 auto graph = std::make_shared<Graph>();
307 const std::string input =
308 R"IR(
309graph(%inp.1 : Tensor,
310 %weight.1 : Tensor,
311 %bias.1 : Tensor):
312 %6 : Tensor = aten::linear(%inp.1, %weight.1, %bias.1)
313 return (%6))IR";
314 parseIR(input, graph.get());
315
316 auto inp = torch::randn({10, 10}).requires_grad_(false);
317 auto weight = torch::randn({10, 10}).requires_grad_(true);
318 auto bias = torch::randn({1, 10}).requires_grad_(true);
319 auto stack = createStack({inp, weight, bias});
320
321 ProfilingGraphExecutorImpl executor(graph, "linear");
322
323 // initial run to profile requires_grad information
324 auto plan = executor.getPlanFor(stack, 20);
325 InterpreterState is{plan.code};
326 is.run(stack);
327
328 auto optimized_plan = executor.getPlanFor(stack, 20);
329 DepthFirstGraphNodeIterator it(optimized_plan.graph);
330 Node* diff_graph_node = nullptr;
331
332 while ((diff_graph_node = it.next()) != nullptr) {
333 if (diff_graph_node->kind() == prim::DifferentiableGraph) {
334 break;
335 }
336 }
337 ASSERT_NE(nullptr, diff_graph_node);
338
339 auto backward_graph = diff_graph_node->g(attr::ReverseSubgraph);
340
341 // we expect to compute grad_weight (which requires a matmul) but we don't
342 // expect to compute grad_input. So, we expect exactly 1 matmul.
343 // Note: this could change, e.g. if mm is used instead
344 testing::FileCheck().check_count("matmul", 1, true)->run(*backward_graph);
345}
346
347} // namespace jit
348} // namespace torch
349