1 | #include <ATen/core/boxing/impl/test_helpers.h> |
2 | #include <gtest/gtest.h> |
3 | |
4 | #include <ATen/core/op_registration/op_registration.h> |
5 | #include <torch/torch.h> |
6 | |
7 | #include <torch/csrc/autograd/FunctionsManual.h> |
8 | #include <torch/csrc/autograd/functions/basic_ops.h> |
9 | |
10 | #include <test/cpp/api/support.h> |
11 | |
12 | using namespace torch::autograd; |
13 | using namespace torch::test; |
14 | |
15 | #define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b))) |
16 | #define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b))) |
17 | |
18 | std::string graph_desc(std::shared_ptr<Node> node) { |
19 | if (!node) { |
20 | return "None" ; |
21 | } |
22 | auto result = node->name() + "(" ; |
23 | auto next_edges = node->next_edges(); |
24 | for (auto& edge : next_edges) { |
25 | result += graph_desc(edge.function); |
26 | } |
27 | return result + ")" ; |
28 | } |
29 | |
30 | Variable simple_fn(const Variable& x, const Variable& y) { |
31 | return x + 2 * y + x * y; |
32 | } |
33 | |
34 | TEST(AutogradAPITests, BackwardSimpleTest) { |
35 | Variable x = torch::randn({2, 2}, torch::requires_grad()); |
36 | Variable y = torch::randn({2, 2}, torch::requires_grad()); |
37 | auto res = simple_fn(x, y); |
38 | backward({res.sum()}, {}); |
39 | |
40 | ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2})); |
41 | ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2); |
42 | } |
43 | |
44 | TEST(AutogradAPITests, BackwardTest) { |
45 | Variable x = torch::randn({2, 2}, torch::requires_grad()); |
46 | Variable y = torch::randn({2, 2}, torch::requires_grad()); |
47 | auto res = simple_fn(x, y); |
48 | backward({res}, {torch::ones({2, 2})}, {}, true); |
49 | |
50 | backward({res}, {torch::ones({2, 2})}); |
51 | |
52 | ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2}))); |
53 | ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2)); |
54 | } |
55 | |
56 | TEST(AutogradAPITests, GradSimpleTest) { |
57 | // basic grad |
58 | Variable x = torch::randn({2, 2}, torch::requires_grad()); |
59 | Variable y = torch::randn({2, 2}, torch::requires_grad()); |
60 | auto res = simple_fn(x, y); |
61 | auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})}); |
62 | |
63 | ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2})); |
64 | ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2); |
65 | } |
66 | |
67 | TEST(AutogradAPITests, GradTest) { |
68 | Variable x = torch::randn({2, 2}, torch::requires_grad()); |
69 | Variable y = torch::randn({2, 2}, torch::requires_grad()); |
70 | auto res = simple_fn(x, y); |
71 | res.backward(torch::ones({2, 2}), false, true); |
72 | |
73 | Variable x_grad = y + torch::ones({2, 2}); |
74 | Variable y_grad = x + torch::ones({2, 2}) * 2; |
75 | ASSERT_VARIABLE_EQ(x.grad(), x_grad); |
76 | ASSERT_VARIABLE_EQ(y.grad(), y_grad); |
77 | |
78 | Variable grad_sum = 2 * x.grad() + y.grad(); |
79 | auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true); |
80 | |
81 | ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2})); |
82 | ASSERT_VARIABLE_EQ(x.grad(), x_grad); |
83 | ASSERT_VARIABLE_EQ(y.grad(), y_grad); |
84 | } |
85 | |
86 | TEST(AutogradAPITests, GradNonLeafTest) { |
87 | Variable x_init = torch::randn({2, 2}, torch::requires_grad()); |
88 | Variable x = x_init; |
89 | Variable y = torch::randn({2, 2}, torch::requires_grad()); |
90 | Variable grad_output = torch::ones({2, 2}); |
91 | |
92 | for (int i = 0; i < 5; ++i) { |
93 | auto res = simple_fn(x, y); |
94 | auto input_grads = grad({res}, {x}, {grad_output}, {}, true); |
95 | |
96 | Variable grad_x_expected = y + torch::ones({2, 2}); |
97 | ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected); |
98 | ASSERT_FALSE(x.grad().defined()); |
99 | ASSERT_FALSE(y.grad().defined()); |
100 | x = x + 0.05 * input_grads[0]; |
101 | } |
102 | |
103 | float val_init = simple_fn(x_init, y).sum().item().toFloat(); |
104 | float val_final = simple_fn(x, y).sum().item().toFloat(); |
105 | ASSERT_TRUE(val_final > val_init); |
106 | |
107 | x.backward(grad_output, false, true); |
108 | ASSERT_TRUE(x_init.grad().defined()); |
109 | ASSERT_TRUE(y.grad().defined()); |
110 | } |
111 | |
112 | TEST(AutogradAPITests, GradUnreachableTest) { |
113 | Variable x = torch::ones({1}, torch::requires_grad()); |
114 | Variable y = torch::ones({1}, torch::requires_grad()); |
115 | |
116 | Variable z = x * 2; |
117 | Variable w = y * 2; |
118 | |
119 | auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true); |
120 | ASSERT_VARIABLE_EQ(grad_res[0], x * 2); |
121 | ASSERT_FALSE(grad_res[1].defined()); |
122 | |
123 | // This is slightly different than the case above, because z doesn't even |
124 | // have a grad accumulator allocated. |
125 | z = torch::ones({1}, torch::requires_grad()); |
126 | grad_res = grad({x * 2}, {x, z}, {}, {}, false, true); |
127 | |
128 | ASSERT_VARIABLE_EQ(grad_res[0], x * 2); |
129 | ASSERT_FALSE(grad_res[1].defined()); |
130 | |
131 | // allow_unused=False, but grads contains None inside, should throw |
132 | ASSERT_THROWS_WITH( |
133 | grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True" ); |
134 | } |
135 | |
136 | TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) { |
137 | // Test that certain nodes are not erroneously executed when an input |
138 | // is unreachable. See #39784 |
139 | struct MyFunction : public Function<MyFunction> { |
140 | static Variable forward(AutogradContext* ctx, Variable var) { |
141 | return var; |
142 | } |
143 | |
144 | static variable_list backward( |
145 | AutogradContext* ctx, |
146 | variable_list grad_output) { |
147 | ADD_FAILURE() << "This node should not be executed!" ; |
148 | return grad_output; |
149 | } |
150 | }; |
151 | |
152 | auto x = torch::randn(1, torch::requires_grad()); |
153 | auto x1 = torch::randn(1); |
154 | auto x2 = MyFunction::apply(x + x1); |
155 | |
156 | auto y = torch::randn(1, torch::requires_grad()); |
157 | auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true); |
158 | ASSERT_FALSE(grad_res[0].defined()); |
159 | } |
160 | |
161 | TEST(AutogradAPITests, EmptyInput) { |
162 | Variable x = torch::ones({1}, torch::requires_grad()); |
163 | ASSERT_THROWS_WITH( |
164 | grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs." ); |
165 | } |
166 | |
167 | TEST(AutogradAPITests, RetainGrad) { |
168 | auto input = torch::rand({1, 3}, torch::requires_grad()); |
169 | auto h1 = input * 3; |
170 | auto out = (h1 * h1).sum(); |
171 | |
172 | { |
173 | // Warning when grad is accessed for non-leaf tensor |
174 | WarningCapture warnings; |
175 | ASSERT_FALSE(h1.grad().defined()); |
176 | ASSERT_TRUE(warnings.str().find("is not a leaf" ) != std::string::npos); |
177 | } |
178 | // It should be possible to call retain_grad() multiple times |
179 | h1.retain_grad(); |
180 | h1.retain_grad(); |
181 | { |
182 | // If retain_grad is true for a non-leaf tensor, |
183 | // there should not be any warning when grad is accessed |
184 | WarningCapture warnings; |
185 | ASSERT_FALSE(h1.grad().defined()); |
186 | ASSERT_FALSE(warnings.str().find("is not a leaf" ) != std::string::npos); |
187 | } |
188 | |
189 | // Gradient should be accumulated |
190 | // NOLINTNEXTLINE(bugprone-argument-comment) |
191 | out.backward({}, /*keep_graph=*/true); |
192 | ASSERT_VARIABLE_EQ(h1 * 2, h1.grad()); |
193 | // NOLINTNEXTLINE(bugprone-argument-comment) |
194 | out.backward({}, /*keep_graph=*/true); |
195 | ASSERT_VARIABLE_EQ(h1 * 4, h1.grad()); |
196 | |
197 | { |
198 | torch::NoGradGuard no_grad; |
199 | input.grad().zero_(); |
200 | } |
201 | // It should be a no-op for leaves |
202 | input.retain_grad(); |
203 | input.retain_grad(); |
204 | out.backward(); |
205 | ASSERT_VARIABLE_EQ(input * 18, input.grad()); |
206 | } |
207 | |
208 | TEST(AutogradAPITests, AnomalyMode) { |
209 | // Needs to have backtrace as warning and then throw an error |
210 | torch::autograd::DetectAnomalyGuard detect_anomaly; |
211 | { |
212 | WarningCapture warnings; |
213 | auto x = torch::tensor({5.0}, torch::requires_grad()); |
214 | auto y = x * x; |
215 | auto z = y * y; |
216 | y += 1; |
217 | ASSERT_THROWS_WITH(z.backward(), "inplace" ); |
218 | ASSERT_TRUE( |
219 | warnings.str().find("Traceback of forward" ) != std::string::npos); |
220 | } |
221 | auto double_backward_produce_nan = [](bool should_throw) { |
222 | auto x = torch::tensor({0.0}, torch::requires_grad()); |
223 | auto y = x.pow(1.5); |
224 | auto gr = |
225 | // NOLINTNEXTLINE(bugprone-argument-comment) |
226 | grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); |
227 | if (should_throw) { |
228 | WarningCapture warnings; |
229 | ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})}); |
230 | , "returned nan" ); |
231 | auto msgs = warnings.messages(); |
232 | ASSERT_EQ(msgs.size(), 2); |
233 | ASSERT_TRUE( |
234 | msgs[0].find("Traceback of forward call that caused the error" ) != |
235 | std::string::npos); |
236 | ASSERT_TRUE( |
237 | msgs[1].find( |
238 | "Traceback of forward call that induced the previous calculation" ) != |
239 | std::string::npos); |
240 | } else { |
241 | grad({gr[0]}, {x}, {torch::tensor({0.0})}); |
242 | } |
243 | }; |
244 | |
245 | double_backward_produce_nan(true); |
246 | { |
247 | torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false); |
248 | double_backward_produce_nan(false); |
249 | { |
250 | torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true); |
251 | double_backward_produce_nan(true); |
252 | } |
253 | } |
254 | double_backward_produce_nan(true); |
255 | } |
256 | |
257 | TEST(CustomAutogradTest, CustomFunction) { |
258 | struct MyFunction : public Function<MyFunction> { |
259 | static Variable forward( |
260 | AutogradContext* ctx, |
261 | Variable var1, |
262 | int mul, |
263 | Variable var2) { |
264 | ctx->saved_data["mul" ] = mul; |
265 | ctx->save_for_backward({var1, var2}); |
266 | return var1 + mul * var2 + var1 * var2; |
267 | } |
268 | |
269 | static variable_list backward( |
270 | AutogradContext* ctx, |
271 | variable_list grad_output) { |
272 | int mul = ctx->saved_data["mul" ].toInt(); |
273 | auto saved = ctx->get_saved_variables(); |
274 | auto var1 = saved[0]; |
275 | auto var2 = saved[1]; |
276 | variable_list output = { |
277 | grad_output[0] + grad_output[0] * var2, |
278 | Variable(), |
279 | grad_output[0] * mul + grad_output[0] * var1}; |
280 | return output; |
281 | } |
282 | }; |
283 | |
284 | Variable x = torch::randn({5, 5}, torch::requires_grad()); |
285 | Variable y = torch::randn({5, 5}, torch::requires_grad()); |
286 | auto res = MyFunction::apply(x, 2, y); |
287 | auto go = torch::ones({}, torch::requires_grad()); |
288 | res.sum().backward(go, false, true); |
289 | |
290 | ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5})); |
291 | ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2); |
292 | } |
293 | |
294 | TEST(CustomAutogradTest, CustomFunctionWithTensorList) { |
295 | struct MyFunction : public Function<MyFunction> { |
296 | static Variable forward(AutogradContext* ctx, at::TensorList tensors) { |
297 | torch::autograd::variable_list vars; |
298 | for (const at::Tensor& tensor : tensors) { |
299 | vars.push_back(tensor); |
300 | } |
301 | ctx->save_for_backward(vars); |
302 | return tensors[0] + tensors[1] + tensors[0] * tensors[1]; |
303 | } |
304 | |
305 | static variable_list backward( |
306 | AutogradContext* ctx, |
307 | variable_list grad_output) { |
308 | auto saved = ctx->get_saved_variables(); |
309 | auto var1 = saved[0]; |
310 | auto var2 = saved[1]; |
311 | variable_list output = { |
312 | grad_output[0] + grad_output[0] * var2, |
313 | grad_output[0] + grad_output[0] * var1}; |
314 | return output; |
315 | } |
316 | }; |
317 | |
318 | at::Tensor x = torch::randn({5, 5}, torch::requires_grad()); |
319 | at::Tensor y = torch::randn({5, 5}, torch::requires_grad()); |
320 | torch::autograd::variable_list variables = {x, y}; |
321 | at::TensorList tensors = variables; |
322 | auto res = MyFunction::apply(tensors); |
323 | auto go = torch::ones({}, torch::requires_grad()); |
324 | res.sum().backward(go, false, true); |
325 | |
326 | ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5})); |
327 | ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5})); |
328 | } |
329 | |
330 | TEST(CustomAutogradTest, GraphTaskTrimEdges) { |
331 | struct MyFunction : public Function<MyFunction> { |
332 | static Variable forward( |
333 | AutogradContext* ctx, |
334 | Variable var1, |
335 | Variable var2, |
336 | int mul, |
337 | bool needs_input1_grad, |
338 | bool needs_input2_grad) { |
339 | // setup the expected should and should not compute idx |
340 | ctx->saved_data["needs_input1_grad" ] = needs_input1_grad; |
341 | ctx->saved_data["needs_input2_grad" ] = needs_input2_grad; |
342 | |
343 | ctx->saved_data["mul" ] = mul; |
344 | ctx->save_for_backward({var1, var2}); |
345 | return var1 + mul * var2 + var1 * var2; |
346 | } |
347 | |
348 | static variable_list backward( |
349 | AutogradContext* ctx, |
350 | variable_list grad_output) { |
351 | // Test `needs_input_grad` method is working correctly. |
352 | // We have to test this within the backward function. |
353 | auto needs_input1_grad = ctx->saved_data["needs_input1_grad" ].toBool(); |
354 | auto needs_input2_grad = ctx->saved_data["needs_input2_grad" ].toBool(); |
355 | IndexRange var1_idx = {0, 1}; |
356 | IndexRange var2_idx = {1, 2}; |
357 | EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad); |
358 | EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad); |
359 | EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad); |
360 | EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad); |
361 | EXPECT_EQ( |
362 | ctx->needs_input_grad({var1_idx, var2_idx}), |
363 | needs_input1_grad || needs_input2_grad); |
364 | |
365 | // calculate gradients |
366 | int mul = ctx->saved_data["mul" ].toInt(); |
367 | auto saved = ctx->get_saved_variables(); |
368 | auto var1 = saved[0]; |
369 | auto var2 = saved[1]; |
370 | |
371 | Variable grad_var1, grad_var2; |
372 | if (ctx->needs_input_grad(0)) { |
373 | grad_var1 = grad_output[0] + grad_output[0] * var2; |
374 | } |
375 | if (ctx->needs_input_grad(1)) { |
376 | grad_var2 = grad_output[0] * mul + grad_output[0] * var1; |
377 | } |
378 | variable_list output = { |
379 | grad_var1, |
380 | grad_var2, |
381 | Variable(), |
382 | Variable(), |
383 | Variable(), |
384 | }; |
385 | return output; |
386 | } |
387 | }; |
388 | |
389 | Variable x = torch::randn({5, 5}, torch::requires_grad()); |
390 | Variable y = torch::randn({5, 5}, torch::requires_grad()); |
391 | auto go = torch::ones_like(x); |
392 | Variable out; |
393 | |
394 | // grad_x |
395 | out = MyFunction::apply( |
396 | x, |
397 | y, |
398 | 2, |
399 | /* needs_input1_grad= */ true, |
400 | /* needs_input2_grad= */ false); |
401 | auto grad_x = torch::autograd::grad({out}, {x}, {go})[0]; |
402 | ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5})); |
403 | |
404 | // grad_y |
405 | out = MyFunction::apply( |
406 | x, |
407 | y, |
408 | 2, |
409 | /* needs_input1_grad= */ false, |
410 | /* needs_input2_grad= */ true); |
411 | auto grad_y = torch::autograd::grad({out}, {y}, {go})[0]; |
412 | ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2); |
413 | |
414 | // grad_x and grad_y |
415 | out = MyFunction::apply( |
416 | x, |
417 | y, |
418 | 2, |
419 | /* needs_input1_grad= */ true, |
420 | /* needs_input2_grad= */ true); |
421 | auto grads = torch::autograd::grad({out}, {x, y}, {go}); |
422 | ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5})); |
423 | ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2); |
424 | } |
425 | |
426 | TEST(CustomAutogradTest, FunctionReturnsInput) { |
427 | struct MyFunction : public Function<MyFunction> { |
428 | static Variable forward(AutogradContext* ctx, Variable var1) { |
429 | return var1; |
430 | } |
431 | |
432 | static variable_list backward( |
433 | AutogradContext* ctx, |
434 | variable_list grad_output) { |
435 | return {grad_output[0] * 2}; |
436 | } |
437 | }; |
438 | |
439 | Variable x(torch::ones(1, torch::requires_grad())); |
440 | MyFunction::apply(x).backward(torch::ones(1), true, true); |
441 | ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.)); |
442 | } |
443 | |
444 | TEST(CustomAutogradTest, FunctionReturnsUndefined) { |
445 | struct MyFunction : public Function<MyFunction> { |
446 | static Variable forward(AutogradContext* ctx, Variable var) { |
447 | return var * 2; |
448 | } |
449 | |
450 | static variable_list backward( |
451 | AutogradContext* ctx, |
452 | variable_list grad_output) { |
453 | at::Tensor undefined_tensor; |
454 | return {undefined_tensor}; |
455 | } |
456 | }; |
457 | |
458 | auto x = torch::ones(1, torch::requires_grad()); |
459 | |
460 | MyFunction::apply(x).backward(); |
461 | ASSERT_FALSE(x.grad().defined()); |
462 | |
463 | MyFunction::apply(x.pow(2)).backward(); |
464 | ASSERT_FALSE(x.grad().defined()); |
465 | |
466 | MyFunction::apply(x).sum().backward(); |
467 | ASSERT_FALSE(x.grad().defined()); |
468 | |
469 | ASSERT_FALSE(torch::autograd::grad( |
470 | {MyFunction::apply(x)}, {x}, {}, false, false, true)[0] |
471 | .defined()); |
472 | } |
473 | |
474 | TEST(CustomAutogradTest, MaterializeGrads) { |
475 | struct MyFunction : public Function<MyFunction> { |
476 | static Variable forward(AutogradContext* ctx, Variable var) { |
477 | return var; |
478 | } |
479 | |
480 | static variable_list backward( |
481 | AutogradContext* ctx, |
482 | variable_list grad_output) { |
483 | EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1)); |
484 | return grad_output; |
485 | } |
486 | }; |
487 | |
488 | auto x = torch::ones(1, torch::requires_grad()); |
489 | UndefinedGrad().apply({MyFunction::apply(x)})[0].backward(); |
490 | } |
491 | |
492 | TEST(CustomAutogradTest, DontMaterializeGrads) { |
493 | struct MyFunction : public Function<MyFunction> { |
494 | static Variable forward(AutogradContext* ctx, Variable var) { |
495 | ctx->set_materialize_grads(false); |
496 | return var; |
497 | } |
498 | |
499 | static variable_list backward( |
500 | AutogradContext* ctx, |
501 | variable_list grad_output) { |
502 | EXPECT_FALSE(grad_output[0].defined()); |
503 | return grad_output; |
504 | } |
505 | }; |
506 | |
507 | auto x = torch::ones(1, torch::requires_grad()); |
508 | UndefinedGrad().apply({MyFunction::apply(x)})[0].backward(); |
509 | } |
510 | |
511 | TEST(CustomAutogradTest, NoGradCustomFunction) { |
512 | // Custom Function should respect grad mode |
513 | struct MyOp : public Function<MyOp> { |
514 | static Variable forward(AutogradContext* ctx, Variable x) { |
515 | return x + 1; |
516 | } |
517 | |
518 | static variable_list backward(AutogradContext* ctx, variable_list dy) { |
519 | return dy; |
520 | } |
521 | }; |
522 | |
523 | auto x = torch::ones({5, 5}, torch::requires_grad()); |
524 | { |
525 | at::NoGradGuard no_grad; |
526 | auto y = MyOp::apply(x); |
527 | ASSERT_FALSE(y.requires_grad()); |
528 | } |
529 | } |
530 | |
531 | TEST(CustomAutogradTest, MarkDirty) { |
532 | struct MyFunction : public Function<MyFunction> { |
533 | static Variable forward(AutogradContext* ctx, Variable v) { |
534 | // Change the value inplace |
535 | auto v_data = v.data_ptr<float>(); |
536 | v_data[0] = 2; |
537 | ctx->mark_dirty({v}); |
538 | return v; |
539 | } |
540 | |
541 | static variable_list backward( |
542 | AutogradContext* ctx, |
543 | variable_list grad_output) { |
544 | return {(grad_output[0] * 2.0)}; |
545 | } |
546 | }; |
547 | |
548 | // Clone here because modifying leafs inplace is not allowed |
549 | auto x = torch::randn({5, 5}, torch::requires_grad()).clone(); |
550 | auto version_before = x._version(); |
551 | auto out = MyFunction::apply(x); |
552 | auto version_after = x._version(); |
553 | ASSERT_TRUE(version_after >= (version_before + 1)); |
554 | out.sum().backward(); |
555 | } |
556 | |
557 | TEST(CustomAutogradTest, MarkNonDifferentiable) { |
558 | struct MyFunction : public Function<MyFunction> { |
559 | static Variable forward(AutogradContext* ctx, Variable v) { |
560 | Variable output = v > 0; |
561 | ctx->mark_non_differentiable({output}); |
562 | return output; |
563 | } |
564 | |
565 | static variable_list backward( |
566 | AutogradContext* ctx, |
567 | variable_list grad_output) { |
568 | return {(grad_output[0] * 0.0)}; |
569 | } |
570 | }; |
571 | |
572 | auto x = torch::randn({5, 5}, torch::requires_grad()); |
573 | auto mask = MyFunction::apply(x); |
574 | ASSERT_FALSE(mask.requires_grad()); |
575 | auto y = x.masked_fill(mask, 0); |
576 | y.sum().backward(); |
577 | } |
578 | |
579 | TEST(CustomAutogradTest, MarkNonDifferentiableMixed) { |
580 | struct MyFunction : public Function<MyFunction> { |
581 | static variable_list forward(AutogradContext* ctx, Variable input) { |
582 | Variable a = input + 1; |
583 | Variable b = input + 2; |
584 | ctx->mark_non_differentiable({a}); |
585 | return {a, b}; |
586 | } |
587 | |
588 | static variable_list backward( |
589 | AutogradContext* ctx, |
590 | variable_list grad_output) { |
591 | const Variable &grad_a = grad_output[0], &grad_b = grad_output[1]; |
592 | EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5})); |
593 | EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5})); |
594 | return {grad_b}; |
595 | } |
596 | }; |
597 | |
598 | auto x = torch::randn({5, 5}, torch::requires_grad()); |
599 | auto out = MyFunction::apply(x); |
600 | |
601 | ASSERT_FALSE(out[0].requires_grad()); |
602 | ASSERT_TRUE(out[1].requires_grad()); |
603 | out[1].sum().backward(); |
604 | ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5})); |
605 | } |
606 | |
607 | TEST(CustomAutogradTest, MarkNonDifferentiableNone) { |
608 | struct MyFunction : public Function<MyFunction> { |
609 | static Variable forward(AutogradContext* ctx, Variable input) { |
610 | auto output = input.clone(); |
611 | ctx->mark_non_differentiable({output}); |
612 | return output; |
613 | } |
614 | |
615 | static variable_list backward( |
616 | AutogradContext* ctx, |
617 | variable_list grad_outputs) { |
618 | return {}; |
619 | } |
620 | }; |
621 | |
622 | auto x = torch::randn({5, 5}, torch::requires_grad()); |
623 | auto r = MyFunction::apply(x * x); |
624 | (r * x).sum().backward(); |
625 | } |
626 | |
627 | TEST(CustomAutogradTest, ReturnLeafInplace) { |
628 | struct Inplace : public Function<Inplace> { |
629 | static variable_list forward(AutogradContext* ctx, Variable a, Variable b) { |
630 | ctx->mark_dirty({a}); |
631 | return {a.add_(b), b + 2}; |
632 | } |
633 | |
634 | static variable_list backward( |
635 | AutogradContext* ctx, |
636 | variable_list grad_output) { |
637 | return {grad_output[0], grad_output[0] + grad_output[1]}; |
638 | } |
639 | }; |
640 | |
641 | Variable x = torch::randn({5, 5}); |
642 | Variable y = torch::randn({5, 5}, torch::requires_grad()); |
643 | |
644 | auto out = Inplace::apply(x, y); |
645 | auto& q = out[0]; |
646 | ASSERT_TRUE(torch::equal(q, x)); |
647 | ASSERT_TRUE(q.requires_grad()); |
648 | q.sum().backward(); |
649 | ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5})); |
650 | } |
651 | |
652 | TEST(CustomAutogradTest, ReturnDuplicateInplace) { |
653 | struct DoubleInplace : public Function<DoubleInplace> { |
654 | static variable_list forward(AutogradContext* ctx, Variable x) { |
655 | x.mul_(2); |
656 | ctx->mark_dirty({x}); |
657 | return {x, x}; |
658 | } |
659 | |
660 | static variable_list backward( |
661 | AutogradContext* ctsx, |
662 | variable_list grad_outputs) { |
663 | return {grad_outputs[0] * 2 + grad_outputs[1] * 2}; |
664 | } |
665 | }; |
666 | |
667 | auto x = torch::randn({5, 5}, torch::requires_grad()); |
668 | |
669 | ASSERT_THROWS_WITH( |
670 | DoubleInplace::apply(x), "leaf Variable that requires grad" ); |
671 | // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one |
672 | // output"); |
673 | |
674 | auto out = DoubleInplace::apply(x.clone()); |
675 | ASSERT_TRUE(torch::equal(out[0], out[1])); |
676 | } |
677 | |
678 | TEST(CustomAutogradTest, ReturnDuplicate) { |
679 | struct DoubleDuplicate : public Function<DoubleDuplicate> { |
680 | static variable_list forward(AutogradContext* ctx, Variable x) { |
681 | auto output = x * 2; |
682 | return {output, output}; |
683 | } |
684 | |
685 | static variable_list backward( |
686 | AutogradContext* ctx, |
687 | variable_list grad_outputs) { |
688 | return {grad_outputs[0] * 2 + grad_outputs[1] * 2}; |
689 | } |
690 | }; |
691 | |
692 | auto x = torch::randn({5, 5}, torch::requires_grad()); |
693 | auto out = DoubleDuplicate::apply(x); |
694 | ASSERT_TRUE(torch::equal(out[0], out[1])); |
695 | } |
696 | |
697 | TEST(CustomAutogradTest, SaveEmptyForBackward) { |
698 | struct MyFunction : public Function<MyFunction> { |
699 | static Variable forward(AutogradContext* ctx, Variable input) { |
700 | ctx->save_for_backward({Variable(), input, Variable()}); |
701 | return input * input; |
702 | } |
703 | |
704 | static variable_list backward( |
705 | AutogradContext* ctx, |
706 | variable_list grad_output) { |
707 | auto saved = ctx->get_saved_variables(); |
708 | EXPECT_FALSE(saved[0].defined()); |
709 | EXPECT_FALSE(saved[2].defined()); |
710 | return {saved[1] * 2 * grad_output[0]}; |
711 | } |
712 | }; |
713 | |
714 | Variable x = torch::randn({5, 5}, torch::requires_grad()); |
715 | auto y = MyFunction::apply(x); |
716 | y.sum().backward(); |
717 | ASSERT_VARIABLE_EQ(x.grad(), 2 * x); |
718 | } |
719 | |
720 | TEST(CustomAutogradTest, InvalidGradients) { |
721 | struct MyFunction : public Function<MyFunction> { |
722 | static Variable forward(AutogradContext* ctx, Variable x) { |
723 | return x * 2; |
724 | } |
725 | |
726 | static variable_list backward( |
727 | AutogradContext* ctsx, |
728 | variable_list grad_outputs) { |
729 | return { |
730 | torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))}; |
731 | } |
732 | }; |
733 | |
734 | auto input1 = |
735 | torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true)); |
736 | ASSERT_THROWS_WITH( |
737 | MyFunction::apply(input1).sum().backward(), "expected shape" ); |
738 | auto input2 = |
739 | torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true)); |
740 | } |
741 | |
742 | TEST(CustomAutogradTest, NoGradInput) { |
743 | struct MyFunction : public Function<MyFunction> { |
744 | static Variable forward(AutogradContext*, Variable x) { |
745 | return x; |
746 | } |
747 | |
748 | static variable_list backward( |
749 | AutogradContext*, |
750 | variable_list grad_outputs) { |
751 | return grad_outputs; |
752 | } |
753 | }; |
754 | |
755 | Variable x = torch::randn({5, 5}, torch::requires_grad()); |
756 | Variable y; |
757 | { |
758 | at::NoGradGuard no_grad; |
759 | y = MyFunction::apply(x); |
760 | } |
761 | |
762 | ASSERT_TRUE(x.requires_grad()); |
763 | ASSERT_FALSE(y.grad_fn()); |
764 | } |
765 | |
766 | TEST(CustomAutogradTest, TooManyGrads) { |
767 | struct MyFunction : public Function<MyFunction> { |
768 | static Variable forward(AutogradContext*, Variable input) { |
769 | return input; |
770 | } |
771 | |
772 | static variable_list backward(AutogradContext*, variable_list grad_output) { |
773 | grad_output.insert(grad_output.end(), {Variable(), Variable()}); |
774 | return grad_output; |
775 | } |
776 | }; |
777 | } |
778 | |
779 | TEST(CustomAutogradTest, DepNoGrad) { |
780 | struct F1 : public Function<F1> { |
781 | static variable_list forward(AutogradContext* ctx, Variable input) { |
782 | auto out = torch::randn(input.sizes()); |
783 | ctx->mark_non_differentiable({out}); |
784 | return {input, out}; |
785 | } |
786 | |
787 | static variable_list backward( |
788 | AutogradContext* ctx, |
789 | variable_list grad_output) { |
790 | return {grad_output[0]}; |
791 | } |
792 | }; |
793 | |
794 | struct F2 : public Function<F2> { |
795 | static Variable forward(AutogradContext*, Variable input, Variable ignore) { |
796 | return input; |
797 | } |
798 | |
799 | static variable_list backward(AutogradContext*, variable_list grad_output) { |
800 | return {grad_output[0], Variable()}; |
801 | } |
802 | }; |
803 | |
804 | auto x = torch::randn(5, torch::requires_grad()); |
805 | auto out = F1::apply(x); |
806 | Variable &a = out[0], &b = out[1]; |
807 | b = b + 1; // Separate F1 and F2 by another operation |
808 | ASSERT_TRUE(a.requires_grad()); |
809 | ASSERT_FALSE(b.requires_grad()); |
810 | |
811 | auto c = F2::apply(a, b); |
812 | c.backward(torch::ones(c.sizes()), false, false); |
813 | ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes())); |
814 | } |
815 | |
816 | TEST(CustomAutogradTest, Reentrant) { |
817 | static Variable y_data = torch::randn({2, 2}); |
818 | struct Reenter : public Function<Reenter> { |
819 | static Variable forward(AutogradContext* ctx, Variable input) { |
820 | Variable output; |
821 | { |
822 | at::AutoGradMode enable_grad(true); |
823 | auto x = make_variable(input.tensor_data(), true); |
824 | auto y = make_variable(y_data.tensor_data(), true); |
825 | output = x * y; |
826 | |
827 | ctx->saved_data["x" ] = x; |
828 | ctx->saved_data["y" ] = y; |
829 | ctx->saved_data["output_var" ] = output; |
830 | } |
831 | return output.detach(); |
832 | } |
833 | |
834 | static variable_list backward( |
835 | AutogradContext* ctx, |
836 | variable_list grad_output) { |
837 | { |
838 | at::AutoGradMode enable_grad(true); |
839 | auto out = ctx->saved_data["output_var" ].toTensor(); |
840 | out.sum().backward(); |
841 | } |
842 | return {ctx->saved_data["x" ].toTensor().grad() * grad_output[0]}; |
843 | } |
844 | }; |
845 | |
846 | auto x = torch::randn({2, 2}, torch::requires_grad()); |
847 | auto out = Reenter::apply(x); |
848 | out.sum().backward(); |
849 | ASSERT_VARIABLE_EQ(x.grad(), y_data); |
850 | } |
851 | |
852 | // NOTE: If this fails for apparently unrelated reasons in TSAN be aware of |
853 | // the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950 |
854 | TEST(CustomAutogradTest, DeepReentrant) { |
855 | struct DeepReenter : public Function<DeepReenter> { |
856 | static Variable forward(AutogradContext* ctx, Variable x) { |
857 | { |
858 | at::AutoGradMode enable_grad(true); |
859 | ctx->saved_data["x" ] = make_variable(x.tensor_data(), true) - 1; |
860 | } |
861 | return ctx->saved_data["x" ].toTensor().detach(); |
862 | } |
863 | |
864 | static variable_list backward( |
865 | AutogradContext* ctx, |
866 | variable_list grad_output) { |
867 | if (!at::native::is_nonzero(ctx->saved_data["x" ].toTensor())) { |
868 | return grad_output; |
869 | } |
870 | { |
871 | at::AutoGradMode enable_grad(true); |
872 | apply(ctx->saved_data["x" ].toTensor())[0].sum().backward(); |
873 | return grad_output; |
874 | } |
875 | } |
876 | }; |
877 | |
878 | // This should not stack overflow |
879 | auto v = |
880 | torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true)); |
881 | DeepReenter::apply(v).sum().backward(); |
882 | } |
883 | |
884 | TEST(CustomAutogradTest, ReentrantPriority) { |
885 | static std::vector<int> order; |
886 | |
887 | struct MyFunction : public Function<MyFunction> { |
888 | static Variable forward(AutogradContext*, Variable x) { |
889 | return x; |
890 | } |
891 | |
892 | static variable_list backward(AutogradContext*, variable_list grad) { |
893 | order.push_back(0); |
894 | return grad; |
895 | } |
896 | }; |
897 | |
898 | struct Reenter : public Function<Reenter> { |
899 | static Variable forward(AutogradContext* ctx, Variable x) { |
900 | { |
901 | at::AutoGradMode enable_grad(true); |
902 | ctx->saved_data["x" ] = make_variable(x.tensor_data(), true) - 1; |
903 | } |
904 | return ctx->saved_data["x" ].toTensor().detach(); |
905 | } |
906 | |
907 | static variable_list backward( |
908 | AutogradContext* ctx, |
909 | variable_list grad_output) { |
910 | order.push_back(1); |
911 | if (!at::native::is_nonzero(ctx->saved_data["x" ].toTensor())) { |
912 | return grad_output; |
913 | } |
914 | { |
915 | at::AutoGradMode enable_grad(true); |
916 | apply(ctx->saved_data["x" ].toTensor())[0].sum().backward(); |
917 | return grad_output; |
918 | } |
919 | } |
920 | }; |
921 | |
922 | auto a = MyFunction::apply( |
923 | torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true))); |
924 | auto b = Reenter::apply( |
925 | torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true))); |
926 | auto v = a * b; |
927 | v.backward(); |
928 | |
929 | // All the reentrant tasks should be prioritized over the MyFunction backward |
930 | // task. |
931 | ASSERT_EQ(order.size(), 10); |
932 | ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9); |
933 | ASSERT_EQ(order.back(), 0); |
934 | // Clear static variable in case test get executed in a loop |
935 | order.clear(); |
936 | } |
937 | |
938 | TEST(CustomAutogradTest, Hooks) { |
939 | Variable x = torch::ones({5, 5}, torch::requires_grad()); |
940 | Variable y = torch::ones({5, 5}) * 4; |
941 | y.set_requires_grad(true); |
942 | |
943 | int counter = 0; |
944 | |
945 | std::function<void(int, Variable)> bw_hook( |
946 | [&counter](int inc, Variable grad) { counter += inc; }); |
947 | |
948 | Variable z = x * x + x * 2 + x * y + y; |
949 | x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); }); |
950 | auto hook_1 = |
951 | z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); }); |
952 | z.backward(torch::ones({5, 5}), true, true); |
953 | ASSERT_EQ(counter, 1); |
954 | |
955 | auto hook_2 = |
956 | z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); }); |
957 | z.backward(torch::ones({5, 5}), true, true); |
958 | ASSERT_EQ(counter, 4); |
959 | |
960 | z.remove_hook(hook_2); |
961 | z.backward(torch::ones({5, 5}), true, true); |
962 | ASSERT_EQ(counter, 5); |
963 | |
964 | std::function<Variable(Variable)> bw_hook_modify( |
965 | [](Variable grad) { return grad.mul(2); }); |
966 | |
967 | z.remove_hook(hook_1); |
968 | z.register_hook(bw_hook_modify); |
969 | y.grad().zero_(); |
970 | z.backward(torch::ones({5, 5}), true, false); |
971 | ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2); |
972 | |
973 | y.register_hook(bw_hook_modify); |
974 | y.grad().zero_(); |
975 | z.backward(torch::ones({5, 5}), false, false); |
976 | ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4); |
977 | |
978 | ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index" ); |
979 | } |
980 | |
981 | TEST(CustomAutogradTest, HooksInplace) { |
982 | auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); |
983 | |
984 | int hook1_count = 0; |
985 | auto hook1 = ([&hook1_count](Variable grad) { |
986 | hook1_count++; |
987 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); |
988 | }); |
989 | |
990 | int hook2_count = 0; |
991 | auto hook2 = ([&hook2_count](Variable grad) { |
992 | hook2_count++; |
993 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); |
994 | }); |
995 | |
996 | a.register_hook(hook1); |
997 | a.mul_(2); |
998 | a.register_hook(hook2); |
999 | |
1000 | auto out = (a + 1).sum(); |
1001 | out.backward(); |
1002 | |
1003 | ASSERT_EQ(hook1_count, 1); |
1004 | ASSERT_EQ(hook2_count, 1); |
1005 | } |
1006 | |
1007 | TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) { |
1008 | auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); |
1009 | |
1010 | int hook1_count = 0; |
1011 | auto hook1 = ([&hook1_count](Variable grad) { |
1012 | hook1_count++; |
1013 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); |
1014 | }); |
1015 | |
1016 | int hook2_count = 0; |
1017 | auto hook2 = ([&hook2_count](Variable grad) { |
1018 | hook2_count++; |
1019 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); |
1020 | }); |
1021 | |
1022 | int hook3_count = 0; |
1023 | auto hook3 = ([&hook3_count](Variable grad) { |
1024 | hook3_count++; |
1025 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); |
1026 | }); |
1027 | |
1028 | a.register_hook(hook1); |
1029 | a.retain_grad(); |
1030 | a.register_hook(hook2); |
1031 | |
1032 | a.mul_(2); |
1033 | a.register_hook(hook3); |
1034 | |
1035 | auto out = (a + 1).sum(); |
1036 | out.backward(); |
1037 | |
1038 | ASSERT_EQ(hook1_count, 1); |
1039 | ASSERT_EQ(hook2_count, 1); |
1040 | ASSERT_EQ(hook3_count, 1); |
1041 | |
1042 | ASSERT_TRUE(a.retains_grad()); |
1043 | ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5})); |
1044 | } |
1045 | |
1046 | TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) { |
1047 | auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); |
1048 | |
1049 | int hook1_count = 0; |
1050 | auto hook1 = ([&hook1_count](Variable grad) { |
1051 | hook1_count++; |
1052 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4); |
1053 | }); |
1054 | |
1055 | int hook2_count = 0; |
1056 | auto hook2 = ([&hook2_count](Variable grad) { |
1057 | hook2_count++; |
1058 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4); |
1059 | }); |
1060 | |
1061 | int hook3_count = 0; |
1062 | auto hook3 = ([&hook3_count](Variable grad) { |
1063 | hook3_count++; |
1064 | ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); |
1065 | }); |
1066 | |
1067 | a.register_hook(hook1); |
1068 | a.retain_grad(); |
1069 | a.register_hook(hook2); |
1070 | |
1071 | a.mul_(2); |
1072 | a.mul_(2); |
1073 | a.register_hook(hook3); |
1074 | |
1075 | auto out = (a + 1).sum(); |
1076 | out.backward(); |
1077 | |
1078 | ASSERT_EQ(hook1_count, 1); |
1079 | ASSERT_EQ(hook2_count, 1); |
1080 | ASSERT_EQ(hook3_count, 1); |
1081 | |
1082 | ASSERT_TRUE(a.retains_grad()); |
1083 | ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5})); |
1084 | } |
1085 | |
1086 | TEST(CustomAutogradTest, HookNone) { |
1087 | struct NoneGradientFunction : public Function<NoneGradientFunction> { |
1088 | static variable_list forward(AutogradContext* ctx, Variable x, Variable y) { |
1089 | return {x, y}; |
1090 | } |
1091 | |
1092 | static variable_list backward(AutogradContext* ctx, variable_list grad) { |
1093 | return {grad[0], Variable()}; |
1094 | } |
1095 | }; |
1096 | |
1097 | bool was_called = false; |
1098 | |
1099 | auto hook = ([&was_called](Variable grad) { |
1100 | ASSERT_TRUE(grad.defined()); |
1101 | was_called = true; |
1102 | }); |
1103 | |
1104 | auto x = torch::randn({5, 5}, torch::requires_grad()); |
1105 | auto y = torch::randn({5, 5}); |
1106 | |
1107 | auto out = NoneGradientFunction::apply(x, y); |
1108 | Variable rx = x[0], ry = x[1]; |
1109 | |
1110 | rx.register_hook(hook); |
1111 | ry.register_hook(hook); |
1112 | (rx + ry).sum().backward(); |
1113 | ASSERT_TRUE(was_called); |
1114 | } |
1115 | |
1116 | TEST(CustomAutogradTest, BackwardWithInputs) { |
1117 | Variable x = torch::randn({5, 5}, torch::requires_grad()); |
1118 | Variable y = torch::randn({5, 5}, torch::requires_grad()); |
1119 | Variable z = x * x + x * y + y * y; |
1120 | Variable x_grad_expected = 2 * x + y; |
1121 | Variable y_grad_expected = x + 2 * y; |
1122 | |
1123 | z.backward(torch::ones({5, 5}), false, false, {x}); |
1124 | |
1125 | ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); |
1126 | ASSERT_FALSE(y.grad().defined()); |
1127 | } |
1128 | |
1129 | TEST(CustomAutogradTest, BackwardWithEmptyInputs) { |
1130 | Variable x = torch::randn({5, 5}, torch::requires_grad()); |
1131 | Variable y = torch::randn({5, 5}, torch::requires_grad()); |
1132 | Variable z = x * x + x * y + y * y; |
1133 | Variable x_grad_expected = 2 * x + y; |
1134 | Variable y_grad_expected = x + 2 * y; |
1135 | ASSERT_THROWS_WITH( |
1136 | z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}), |
1137 | "cannot be empty" ); |
1138 | } |
1139 | |
1140 | TEST(CustomAutogradTest, BackwardWithNonLeafInputs) { |
1141 | Variable x = torch::randn({5, 5}, torch::requires_grad()); |
1142 | Variable y = torch::randn({5, 5}, torch::requires_grad()); |
1143 | Variable z = x * x; |
1144 | Variable w = y * z + x * y + y * y; |
1145 | |
1146 | Variable x_grad_expected = 2 * x * y + y; |
1147 | Variable z_grad_expected = y; |
1148 | |
1149 | w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z}); |
1150 | |
1151 | ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); |
1152 | ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected); |
1153 | ASSERT_FALSE(y.grad().defined()); |
1154 | } |
1155 | |
1156 | TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) { |
1157 | c10::WarningUtils::WarnAlways guard(true); |
1158 | |
1159 | torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true); |
1160 | auto z = x * x; |
1161 | { |
1162 | WarningCapture warnings; |
1163 | z.backward(torch::ones({5, 5}), c10::nullopt, true); |
1164 | ASSERT_TRUE( |
1165 | warnings.str().find("Using backward() with create_graph=True" ) != |
1166 | std::string::npos); |
1167 | } |
1168 | |
1169 | { |
1170 | WarningCapture warnings; |
1171 | torch::autograd::backward({z}, {torch::ones({5, 5})}, c10::nullopt, true); |
1172 | ASSERT_TRUE( |
1173 | warnings.str().find("Using backward() with create_graph=True" ) != |
1174 | std::string::npos); |
1175 | } |
1176 | } |
1177 | |
1178 | /** |
1179 | * Tests for AutogradNotImplementedFallback |
1180 | * - Check that we created the NotImplemented kernel when inputs require grad |
1181 | * but when no inputs require grad, we should not create this node |
1182 | * - check_inplace logic |
1183 | * - view ops |
1184 | * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't |
1185 | * test non-NDEBUG builds. |
1186 | * - tensorlist input and output |
1187 | * - multiple outputs / non-tensor output |
1188 | * - rebase_history vs set_history |
1189 | */ |
1190 | namespace { |
1191 | |
1192 | torch::Tensor inplace_op( |
1193 | const torch::Tensor& self, |
1194 | const torch::Tensor& other) { |
1195 | return self.add_(other); |
1196 | } |
1197 | |
1198 | std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op( |
1199 | const torch::Tensor& self, |
1200 | const torch::Tensor& other) { |
1201 | other.add_(self); |
1202 | self.add_(other); |
1203 | return std::tuple<torch::Tensor, torch::Tensor>(self, other); |
1204 | } |
1205 | |
1206 | std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op( |
1207 | const torch::Tensor& self, |
1208 | const torch::Tensor& other) { |
1209 | // This is not allowed. We test below that this calling into the boxed kernel |
1210 | // will raise an error |
1211 | return std::tuple<torch::Tensor, torch::Tensor>(self, other); |
1212 | } |
1213 | |
1214 | std::tuple<torch::Tensor, torch::Tensor> non_first_view_op( |
1215 | const torch::Tensor& self, |
1216 | const torch::Tensor& other) { |
1217 | // This is not allowed. We test below that this calling into the boxed kernel |
1218 | // will raise an error |
1219 | return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other); |
1220 | } |
1221 | |
1222 | int64_t ret_single_non_tensor( |
1223 | const torch::Tensor& self, |
1224 | const torch::Tensor& other) { |
1225 | return 12; |
1226 | } |
1227 | |
1228 | torch::Tensor opt_op( |
1229 | const torch::Tensor& self, |
1230 | const c10::optional<at::Tensor>& other) { |
1231 | if (other.has_value()) { |
1232 | return self + other.value(); |
1233 | } else { |
1234 | return self.clone(); |
1235 | } |
1236 | } |
1237 | |
1238 | torch::Tensor my_custom_op( |
1239 | const torch::Tensor& self, |
1240 | const torch::Tensor& other) { |
1241 | return self + other; |
1242 | } |
1243 | |
1244 | std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor( |
1245 | const torch::Tensor& self, |
1246 | const torch::Tensor& other) { |
1247 | auto a = self - other; |
1248 | auto b = self + other; |
1249 | return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12); |
1250 | } |
1251 | |
1252 | torch::Tensor view_op(const torch::Tensor& self) { |
1253 | return self.alias(); |
1254 | } |
1255 | |
1256 | torch::Tensor ( |
1257 | const torch::Tensor& self, |
1258 | const torch::Tensor& other) { |
1259 | return self.alias(); |
1260 | } |
1261 | |
1262 | std::vector<torch::Tensor> ret_tensor_vector_view( |
1263 | const torch::Tensor& self, |
1264 | const torch::Tensor& other) { |
1265 | return {self.alias(), self.alias()}; |
1266 | } |
1267 | |
1268 | std::vector<at::Tensor> ret_tensor_vector( |
1269 | const torch::Tensor& self, |
1270 | const torch::Tensor& other) { |
1271 | std::vector<at::Tensor> out; |
1272 | out.push_back(self + other); |
1273 | out.push_back(self - other); |
1274 | return out; |
1275 | } |
1276 | |
1277 | torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) { |
1278 | const auto& res = self.clone(); |
1279 | for (const auto& t : other) { |
1280 | res.add_(t); |
1281 | } |
1282 | return res; |
1283 | } |
1284 | |
1285 | #define REGISTER_TEST_OP(name, schema, fn) \ |
1286 | auto m = MAKE_TORCH_LIBRARY(_test); \ |
1287 | m.def(schema); \ |
1288 | auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \ |
1289 | auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \ |
1290 | auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \ |
1291 | m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \ |
1292 | m_autograd.impl( \ |
1293 | name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \ |
1294 | m_inplaceorview.impl( \ |
1295 | name, \ |
1296 | c10::DispatchKey::ADInplaceOrView, \ |
1297 | autogradNotImplementedInplaceOrViewFallback()); |
1298 | |
1299 | template <typename F> |
1300 | void assertBasicChecks(F op) { |
1301 | auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1302 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1303 | auto c = torch::tensor({1.}, {torch::kFloat32}); |
1304 | |
1305 | // If any inputs require grad, |
1306 | auto out1 = op(a, b); |
1307 | ASSERT_THROWS_WITH(out1.backward(), "is not implemented" ); |
1308 | |
1309 | // # Should not have grad_fn if none require grad |
1310 | auto out2 = op(b, c); |
1311 | ASSERT_THROWS_WITH( |
1312 | out2.backward(), |
1313 | "element 0 of tensors does not require grad and does not have a grad_fn" ); |
1314 | |
1315 | // TODO: Forward AD Tests? |
1316 | } |
1317 | |
1318 | } // namespace |
1319 | |
1320 | // These tests trigger an MSVC bug in the internal arvr build |
1321 | // Reproduce with: buck build @arvr/mode/win/opt |
1322 | // //xplat/caffe2:autograd_libtorch_test_ovrsource It is probably caused by the |
1323 | // lambda, see https://github.com/pytorch/pytorch/issues/48763 |
1324 | #if !defined(_MSC_VER) |
1325 | |
1326 | TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) { |
1327 | REGISTER_TEST_OP( |
1328 | "ret_single_non_tensor" , |
1329 | "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int" , |
1330 | ret_single_non_tensor); |
1331 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1332 | "_test::ret_single_non_tensor" , "" ); |
1333 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1334 | return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>( |
1335 | opHandle, _1, _2); |
1336 | }; |
1337 | |
1338 | auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1339 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1340 | |
1341 | ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b)); |
1342 | } |
1343 | |
1344 | TEST(TestAutogradNotImplementedFallback, InplaceOp) { |
1345 | REGISTER_TEST_OP( |
1346 | "inplace_op" , |
1347 | "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)" , |
1348 | inplace_op); |
1349 | auto opHandle = |
1350 | c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op" , "" ); |
1351 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1352 | return callOpUnboxed< |
1353 | torch::Tensor, |
1354 | const torch::Tensor&, |
1355 | const torch::Tensor&>(opHandle, _1, _2); |
1356 | }; |
1357 | |
1358 | auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1359 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1360 | |
1361 | // Check in-place |
1362 | ASSERT_THROWS_WITH( |
1363 | op(a, b), |
1364 | "a leaf Variable that requires grad is being used in an in-place operation" ); |
1365 | op(b, a); |
1366 | a = a.clone(); |
1367 | b = b.clone(); |
1368 | auto c = op(a, b); |
1369 | ASSERT_TRUE(torch::allclose(c, inplace_op(a, b))); |
1370 | |
1371 | // Test in-place on view |
1372 | auto base = |
1373 | torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
1374 | auto view = base.view(-1); |
1375 | auto t = torch::tensor({1.}, {torch::kFloat32}); |
1376 | |
1377 | torch::Tensor v_nograd; |
1378 | { |
1379 | c10::NoGradGuard guard; |
1380 | v_nograd = base.view(-1); |
1381 | op(v_nograd, t); |
1382 | } |
1383 | |
1384 | ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode" ); |
1385 | ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl()); |
1386 | ASSERT_THAT( |
1387 | op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward" )); |
1388 | } |
1389 | |
1390 | TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) { |
1391 | REGISTER_TEST_OP( |
1392 | "two_arg_inplace_op" , |
1393 | "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))" , |
1394 | two_arg_inplace_op); |
1395 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1396 | "_test::two_arg_inplace_op" , "" ); |
1397 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1398 | return callOpUnboxed< |
1399 | std::tuple<torch::Tensor, torch::Tensor>, |
1400 | const torch::Tensor&, |
1401 | const torch::Tensor&>(opHandle, _1, _2); |
1402 | }; |
1403 | auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1404 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1405 | |
1406 | // Both are modified in-place! |
1407 | ASSERT_THROWS_WITH( |
1408 | op(a, b), |
1409 | "a leaf Variable that requires grad is being used in an in-place operation" ); |
1410 | ASSERT_THROWS_WITH( |
1411 | op(b, a), |
1412 | "a leaf Variable that requires grad is being used in an in-place operation" ); |
1413 | |
1414 | auto c = |
1415 | torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
1416 | auto d = |
1417 | torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
1418 | |
1419 | auto saved_version_c = c._version(); |
1420 | auto saved_version_d = d._version(); |
1421 | op(c, d); |
1422 | ASSERT_NE(c._version(), saved_version_c); |
1423 | ASSERT_NE(d._version(), saved_version_d); |
1424 | } |
1425 | |
1426 | TEST(TestAutogradNotImplementedFallback, OptOp) { |
1427 | REGISTER_TEST_OP( |
1428 | "opt_op" , "_test::opt_op(Tensor self, Tensor? other) -> Tensor" , opt_op); |
1429 | auto opHandle = |
1430 | c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op" , "" ); |
1431 | auto op = [&](const torch::Tensor& _1, |
1432 | const c10::optional<torch::Tensor>& _2) { |
1433 | return callOpUnboxed< |
1434 | torch::Tensor, |
1435 | const torch::Tensor&, |
1436 | const c10::optional<torch::Tensor>&>(opHandle, _1, _2); |
1437 | }; |
1438 | |
1439 | auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1440 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1441 | |
1442 | ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b))); |
1443 | ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {}))); |
1444 | } |
1445 | |
1446 | TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) { |
1447 | REGISTER_TEST_OP( |
1448 | "my_custom_op" , |
1449 | "_test::my_custom_op(Tensor self, Tensor other) -> Tensor" , |
1450 | my_custom_op); |
1451 | auto opHandle = |
1452 | c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op" , "" ); |
1453 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1454 | return callOpUnboxed< |
1455 | torch::Tensor, |
1456 | const torch::Tensor&, |
1457 | const torch::Tensor&>(opHandle, _1, _2); |
1458 | }; |
1459 | |
1460 | assertBasicChecks(op); |
1461 | } |
1462 | |
1463 | TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) { |
1464 | REGISTER_TEST_OP( |
1465 | "ret_tuple_non_tensor" , |
1466 | "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)" , |
1467 | ret_tuple_non_tensor); |
1468 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1469 | "_test::ret_tuple_non_tensor" , "" ); |
1470 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1471 | torch::Tensor out0; |
1472 | torch::Tensor out1; |
1473 | int64_t out2; |
1474 | auto out = callOpUnboxed< |
1475 | std::tuple<torch::Tensor, torch::Tensor, int64_t>, |
1476 | const torch::Tensor&, |
1477 | const torch::Tensor&>(opHandle, _1, _2); |
1478 | std::tie(out0, out1, out2) = std::move(out); |
1479 | return out0; |
1480 | }; |
1481 | |
1482 | assertBasicChecks(op); |
1483 | } |
1484 | |
1485 | TEST(TestAutogradNotImplementedFallback, ViewOp) { |
1486 | REGISTER_TEST_OP( |
1487 | "view_op" , "_test::view_op(Tensor(a) self) -> Tensor(a)" , view_op); |
1488 | auto opHandle = |
1489 | c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op" , "" ); |
1490 | auto op = [&](const torch::Tensor& _1) { |
1491 | return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1); |
1492 | }; |
1493 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1494 | auto v = op(b); |
1495 | ASSERT_TRUE(v.is_view()); |
1496 | ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); |
1497 | |
1498 | auto b1 = |
1499 | torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
1500 | auto v1 = op(b1); |
1501 | ASSERT_TRUE(v1.is_view()); |
1502 | ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl()); |
1503 | |
1504 | // Test inplace on view |
1505 | auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1506 | |
1507 | // raise on rebase_history when it refreshes grad_fn |
1508 | ASSERT_THROWS_WITH( |
1509 | v1.add_(t), "which does not have a derivative implemented is forbidden" ); |
1510 | // base should not be aware of the views, so this is still okay |
1511 | b1.add_(t); |
1512 | ASSERT_THROWS_WITH( |
1513 | v1.grad_fn(), |
1514 | "which does not have a derivative implemented is forbidden" ); |
1515 | } |
1516 | |
1517 | TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) { |
1518 | REGISTER_TEST_OP( |
1519 | "view_op_with_extra_arg" , |
1520 | "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)" , |
1521 | view_op_with_extra_arg); |
1522 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1523 | "_test::view_op_with_extra_arg" , "" ); |
1524 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1525 | return callOpUnboxed< |
1526 | torch::Tensor, |
1527 | const torch::Tensor&, |
1528 | const torch::Tensor&>(opHandle, _1, _2); |
1529 | }; |
1530 | assertBasicChecks(op); |
1531 | auto a = torch::tensor({1.}, {torch::kFloat32}); |
1532 | auto b = torch::tensor({2.}, {torch::kFloat32}); |
1533 | auto out1 = op(a, b); |
1534 | ASSERT_TRUE(out1.is_view()); |
1535 | ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl()); |
1536 | } |
1537 | |
1538 | TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) { |
1539 | REGISTER_TEST_OP( |
1540 | "ret_tensor_vector_view" , |
1541 | "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)" , |
1542 | ret_tensor_vector_view); |
1543 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1544 | "_test::ret_tensor_vector_view" , "" ); |
1545 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1546 | return callOpUnboxed< |
1547 | std::vector<at::Tensor>, |
1548 | const torch::Tensor&, |
1549 | const torch::Tensor&>(opHandle, _1, _2); |
1550 | }; |
1551 | auto a = torch::tensor({1.}, {torch::kFloat32}); |
1552 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1553 | auto out = op(a, b); |
1554 | ASSERT_TRUE(out[0].is_view()); |
1555 | ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl()); |
1556 | ASSERT_TRUE(out[1].is_view()); |
1557 | ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl()); |
1558 | } |
1559 | |
1560 | TEST(TestAutogradNotImplementedFallback, DoubleViewOP) { |
1561 | REGISTER_TEST_OP( |
1562 | "two_pairs_of_view_op" , |
1563 | "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))" , |
1564 | two_pairs_of_view_op); |
1565 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1566 | "_test::two_pairs_of_view_op" , "" ); |
1567 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1568 | return callOpUnboxed< |
1569 | std::tuple<torch::Tensor, torch::Tensor>, |
1570 | const torch::Tensor&, |
1571 | const torch::Tensor&>(opHandle, _1, _2); |
1572 | }; |
1573 | auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1574 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1575 | ASSERT_THROWS_WITH( |
1576 | op(a, b), |
1577 | "Expected only a single output in the operator schema to have a non-write alias annotation" ); |
1578 | } |
1579 | |
1580 | TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) { |
1581 | REGISTER_TEST_OP( |
1582 | "non_first_view_op" , |
1583 | "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))" , |
1584 | non_first_view_op); |
1585 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1586 | "_test::non_first_view_op" , "" ); |
1587 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1588 | return callOpUnboxed< |
1589 | std::tuple<torch::Tensor, torch::Tensor>, |
1590 | const torch::Tensor&, |
1591 | const torch::Tensor&>(opHandle, _1, _2); |
1592 | }; |
1593 | auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1594 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1595 | ASSERT_THROWS_WITH( |
1596 | op(a, b), "can only create view relationships between the first" ); |
1597 | } |
1598 | |
1599 | TEST(TestAutogradNotImplementedFallback, RetTensorVector) { |
1600 | REGISTER_TEST_OP( |
1601 | "ret_tensor_vector" , |
1602 | "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]" , |
1603 | ret_tensor_vector); |
1604 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1605 | "_test::ret_tensor_vector" , "" ); |
1606 | auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
1607 | return callOpUnboxed< |
1608 | std::vector<at::Tensor>, |
1609 | const torch::Tensor&, |
1610 | const torch::Tensor&>(opHandle, _1, _2)[0]; |
1611 | }; |
1612 | assertBasicChecks(op); |
1613 | } |
1614 | |
1615 | TEST(TestAutogradNotImplementedFallback, TensorlistOp) { |
1616 | REGISTER_TEST_OP( |
1617 | "tensorlist_op" , |
1618 | "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor" , |
1619 | tensorlist_op); |
1620 | auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
1621 | "_test::tensorlist_op" , "" ); |
1622 | auto op = [&](torch::Tensor _1, at::TensorList _2) { |
1623 | return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>( |
1624 | opHandle, _1, _2); |
1625 | }; |
1626 | |
1627 | auto a = torch::tensor({1.}, {torch::kFloat32}); |
1628 | auto b = torch::tensor({1.}, {torch::kFloat32}); |
1629 | auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
1630 | std::vector<torch::Tensor> vec = {b, c}; |
1631 | auto out = op(a, vec); |
1632 | |
1633 | ASSERT_THROWS_WITH( |
1634 | torch::autograd::grad({out}, {vec[0]}), |
1635 | "One of the differentiated Tensors does not require grad" ); |
1636 | ASSERT_THROWS_WITH( |
1637 | torch::autograd::grad({out}, {vec[1]}), "is not implemented" ); |
1638 | |
1639 | ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec))); |
1640 | } |
1641 | |
1642 | #endif |
1643 | |
1644 | // TODO add these tests if needed |
1645 | // test_once_differentiable |
1646 | // test_sparse_backward |
1647 | // test_save_output_nr |
1648 | // test_free_deep_graph_pyfunction |
1649 | // test_naughty_anomaly_access |
1650 | // test_naughty_autograd-function_stashing_ctx |
1651 | // test_custom_autograd_repeated_grad_grad |
1652 | // test_return_leaf |
1653 | // test_anomaly_detect_nan |
1654 | // test_no_grad_copy |
1655 | |