1#include <ATen/core/boxing/impl/test_helpers.h>
2#include <gtest/gtest.h>
3
4#include <ATen/core/op_registration/op_registration.h>
5#include <torch/torch.h>
6
7#include <torch/csrc/autograd/FunctionsManual.h>
8#include <torch/csrc/autograd/functions/basic_ops.h>
9
10#include <test/cpp/api/support.h>
11
12using namespace torch::autograd;
13using namespace torch::test;
14
15#define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
16#define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))
17
18std::string graph_desc(std::shared_ptr<Node> node) {
19 if (!node) {
20 return "None";
21 }
22 auto result = node->name() + "(";
23 auto next_edges = node->next_edges();
24 for (auto& edge : next_edges) {
25 result += graph_desc(edge.function);
26 }
27 return result + ")";
28}
29
30Variable simple_fn(const Variable& x, const Variable& y) {
31 return x + 2 * y + x * y;
32}
33
34TEST(AutogradAPITests, BackwardSimpleTest) {
35 Variable x = torch::randn({2, 2}, torch::requires_grad());
36 Variable y = torch::randn({2, 2}, torch::requires_grad());
37 auto res = simple_fn(x, y);
38 backward({res.sum()}, {});
39
40 ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
41 ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
42}
43
44TEST(AutogradAPITests, BackwardTest) {
45 Variable x = torch::randn({2, 2}, torch::requires_grad());
46 Variable y = torch::randn({2, 2}, torch::requires_grad());
47 auto res = simple_fn(x, y);
48 backward({res}, {torch::ones({2, 2})}, {}, true);
49
50 backward({res}, {torch::ones({2, 2})});
51
52 ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
53 ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
54}
55
56TEST(AutogradAPITests, GradSimpleTest) {
57 // basic grad
58 Variable x = torch::randn({2, 2}, torch::requires_grad());
59 Variable y = torch::randn({2, 2}, torch::requires_grad());
60 auto res = simple_fn(x, y);
61 auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
62
63 ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
64 ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
65}
66
67TEST(AutogradAPITests, GradTest) {
68 Variable x = torch::randn({2, 2}, torch::requires_grad());
69 Variable y = torch::randn({2, 2}, torch::requires_grad());
70 auto res = simple_fn(x, y);
71 res.backward(torch::ones({2, 2}), false, true);
72
73 Variable x_grad = y + torch::ones({2, 2});
74 Variable y_grad = x + torch::ones({2, 2}) * 2;
75 ASSERT_VARIABLE_EQ(x.grad(), x_grad);
76 ASSERT_VARIABLE_EQ(y.grad(), y_grad);
77
78 Variable grad_sum = 2 * x.grad() + y.grad();
79 auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
80
81 ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
82 ASSERT_VARIABLE_EQ(x.grad(), x_grad);
83 ASSERT_VARIABLE_EQ(y.grad(), y_grad);
84}
85
86TEST(AutogradAPITests, GradNonLeafTest) {
87 Variable x_init = torch::randn({2, 2}, torch::requires_grad());
88 Variable x = x_init;
89 Variable y = torch::randn({2, 2}, torch::requires_grad());
90 Variable grad_output = torch::ones({2, 2});
91
92 for (int i = 0; i < 5; ++i) {
93 auto res = simple_fn(x, y);
94 auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
95
96 Variable grad_x_expected = y + torch::ones({2, 2});
97 ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
98 ASSERT_FALSE(x.grad().defined());
99 ASSERT_FALSE(y.grad().defined());
100 x = x + 0.05 * input_grads[0];
101 }
102
103 float val_init = simple_fn(x_init, y).sum().item().toFloat();
104 float val_final = simple_fn(x, y).sum().item().toFloat();
105 ASSERT_TRUE(val_final > val_init);
106
107 x.backward(grad_output, false, true);
108 ASSERT_TRUE(x_init.grad().defined());
109 ASSERT_TRUE(y.grad().defined());
110}
111
112TEST(AutogradAPITests, GradUnreachableTest) {
113 Variable x = torch::ones({1}, torch::requires_grad());
114 Variable y = torch::ones({1}, torch::requires_grad());
115
116 Variable z = x * 2;
117 Variable w = y * 2;
118
119 auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
120 ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
121 ASSERT_FALSE(grad_res[1].defined());
122
123 // This is slightly different than the case above, because z doesn't even
124 // have a grad accumulator allocated.
125 z = torch::ones({1}, torch::requires_grad());
126 grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
127
128 ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
129 ASSERT_FALSE(grad_res[1].defined());
130
131 // allow_unused=False, but grads contains None inside, should throw
132 ASSERT_THROWS_WITH(
133 grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
134}
135
136TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
137 // Test that certain nodes are not erroneously executed when an input
138 // is unreachable. See #39784
139 struct MyFunction : public Function<MyFunction> {
140 static Variable forward(AutogradContext* ctx, Variable var) {
141 return var;
142 }
143
144 static variable_list backward(
145 AutogradContext* ctx,
146 variable_list grad_output) {
147 ADD_FAILURE() << "This node should not be executed!";
148 return grad_output;
149 }
150 };
151
152 auto x = torch::randn(1, torch::requires_grad());
153 auto x1 = torch::randn(1);
154 auto x2 = MyFunction::apply(x + x1);
155
156 auto y = torch::randn(1, torch::requires_grad());
157 auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
158 ASSERT_FALSE(grad_res[0].defined());
159}
160
161TEST(AutogradAPITests, EmptyInput) {
162 Variable x = torch::ones({1}, torch::requires_grad());
163 ASSERT_THROWS_WITH(
164 grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
165}
166
167TEST(AutogradAPITests, RetainGrad) {
168 auto input = torch::rand({1, 3}, torch::requires_grad());
169 auto h1 = input * 3;
170 auto out = (h1 * h1).sum();
171
172 {
173 // Warning when grad is accessed for non-leaf tensor
174 WarningCapture warnings;
175 ASSERT_FALSE(h1.grad().defined());
176 ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
177 }
178 // It should be possible to call retain_grad() multiple times
179 h1.retain_grad();
180 h1.retain_grad();
181 {
182 // If retain_grad is true for a non-leaf tensor,
183 // there should not be any warning when grad is accessed
184 WarningCapture warnings;
185 ASSERT_FALSE(h1.grad().defined());
186 ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
187 }
188
189 // Gradient should be accumulated
190 // NOLINTNEXTLINE(bugprone-argument-comment)
191 out.backward({}, /*keep_graph=*/true);
192 ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
193 // NOLINTNEXTLINE(bugprone-argument-comment)
194 out.backward({}, /*keep_graph=*/true);
195 ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
196
197 {
198 torch::NoGradGuard no_grad;
199 input.grad().zero_();
200 }
201 // It should be a no-op for leaves
202 input.retain_grad();
203 input.retain_grad();
204 out.backward();
205 ASSERT_VARIABLE_EQ(input * 18, input.grad());
206}
207
208TEST(AutogradAPITests, AnomalyMode) {
209 // Needs to have backtrace as warning and then throw an error
210 torch::autograd::DetectAnomalyGuard detect_anomaly;
211 {
212 WarningCapture warnings;
213 auto x = torch::tensor({5.0}, torch::requires_grad());
214 auto y = x * x;
215 auto z = y * y;
216 y += 1;
217 ASSERT_THROWS_WITH(z.backward(), "inplace");
218 ASSERT_TRUE(
219 warnings.str().find("Traceback of forward") != std::string::npos);
220 }
221 auto double_backward_produce_nan = [](bool should_throw) {
222 auto x = torch::tensor({0.0}, torch::requires_grad());
223 auto y = x.pow(1.5);
224 auto gr =
225 // NOLINTNEXTLINE(bugprone-argument-comment)
226 grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
227 if (should_throw) {
228 WarningCapture warnings;
229 ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
230 , "returned nan");
231 auto msgs = warnings.messages();
232 ASSERT_EQ(msgs.size(), 2);
233 ASSERT_TRUE(
234 msgs[0].find("Traceback of forward call that caused the error") !=
235 std::string::npos);
236 ASSERT_TRUE(
237 msgs[1].find(
238 "Traceback of forward call that induced the previous calculation") !=
239 std::string::npos);
240 } else {
241 grad({gr[0]}, {x}, {torch::tensor({0.0})});
242 }
243 };
244
245 double_backward_produce_nan(true);
246 {
247 torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false);
248 double_backward_produce_nan(false);
249 {
250 torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true);
251 double_backward_produce_nan(true);
252 }
253 }
254 double_backward_produce_nan(true);
255}
256
257TEST(CustomAutogradTest, CustomFunction) {
258 struct MyFunction : public Function<MyFunction> {
259 static Variable forward(
260 AutogradContext* ctx,
261 Variable var1,
262 int mul,
263 Variable var2) {
264 ctx->saved_data["mul"] = mul;
265 ctx->save_for_backward({var1, var2});
266 return var1 + mul * var2 + var1 * var2;
267 }
268
269 static variable_list backward(
270 AutogradContext* ctx,
271 variable_list grad_output) {
272 int mul = ctx->saved_data["mul"].toInt();
273 auto saved = ctx->get_saved_variables();
274 auto var1 = saved[0];
275 auto var2 = saved[1];
276 variable_list output = {
277 grad_output[0] + grad_output[0] * var2,
278 Variable(),
279 grad_output[0] * mul + grad_output[0] * var1};
280 return output;
281 }
282 };
283
284 Variable x = torch::randn({5, 5}, torch::requires_grad());
285 Variable y = torch::randn({5, 5}, torch::requires_grad());
286 auto res = MyFunction::apply(x, 2, y);
287 auto go = torch::ones({}, torch::requires_grad());
288 res.sum().backward(go, false, true);
289
290 ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
291 ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
292}
293
294TEST(CustomAutogradTest, CustomFunctionWithTensorList) {
295 struct MyFunction : public Function<MyFunction> {
296 static Variable forward(AutogradContext* ctx, at::TensorList tensors) {
297 torch::autograd::variable_list vars;
298 for (const at::Tensor& tensor : tensors) {
299 vars.push_back(tensor);
300 }
301 ctx->save_for_backward(vars);
302 return tensors[0] + tensors[1] + tensors[0] * tensors[1];
303 }
304
305 static variable_list backward(
306 AutogradContext* ctx,
307 variable_list grad_output) {
308 auto saved = ctx->get_saved_variables();
309 auto var1 = saved[0];
310 auto var2 = saved[1];
311 variable_list output = {
312 grad_output[0] + grad_output[0] * var2,
313 grad_output[0] + grad_output[0] * var1};
314 return output;
315 }
316 };
317
318 at::Tensor x = torch::randn({5, 5}, torch::requires_grad());
319 at::Tensor y = torch::randn({5, 5}, torch::requires_grad());
320 torch::autograd::variable_list variables = {x, y};
321 at::TensorList tensors = variables;
322 auto res = MyFunction::apply(tensors);
323 auto go = torch::ones({}, torch::requires_grad());
324 res.sum().backward(go, false, true);
325
326 ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
327 ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}));
328}
329
330TEST(CustomAutogradTest, GraphTaskTrimEdges) {
331 struct MyFunction : public Function<MyFunction> {
332 static Variable forward(
333 AutogradContext* ctx,
334 Variable var1,
335 Variable var2,
336 int mul,
337 bool needs_input1_grad,
338 bool needs_input2_grad) {
339 // setup the expected should and should not compute idx
340 ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
341 ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
342
343 ctx->saved_data["mul"] = mul;
344 ctx->save_for_backward({var1, var2});
345 return var1 + mul * var2 + var1 * var2;
346 }
347
348 static variable_list backward(
349 AutogradContext* ctx,
350 variable_list grad_output) {
351 // Test `needs_input_grad` method is working correctly.
352 // We have to test this within the backward function.
353 auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
354 auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
355 IndexRange var1_idx = {0, 1};
356 IndexRange var2_idx = {1, 2};
357 EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
358 EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
359 EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
360 EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
361 EXPECT_EQ(
362 ctx->needs_input_grad({var1_idx, var2_idx}),
363 needs_input1_grad || needs_input2_grad);
364
365 // calculate gradients
366 int mul = ctx->saved_data["mul"].toInt();
367 auto saved = ctx->get_saved_variables();
368 auto var1 = saved[0];
369 auto var2 = saved[1];
370
371 Variable grad_var1, grad_var2;
372 if (ctx->needs_input_grad(0)) {
373 grad_var1 = grad_output[0] + grad_output[0] * var2;
374 }
375 if (ctx->needs_input_grad(1)) {
376 grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
377 }
378 variable_list output = {
379 grad_var1,
380 grad_var2,
381 Variable(),
382 Variable(),
383 Variable(),
384 };
385 return output;
386 }
387 };
388
389 Variable x = torch::randn({5, 5}, torch::requires_grad());
390 Variable y = torch::randn({5, 5}, torch::requires_grad());
391 auto go = torch::ones_like(x);
392 Variable out;
393
394 // grad_x
395 out = MyFunction::apply(
396 x,
397 y,
398 2,
399 /* needs_input1_grad= */ true,
400 /* needs_input2_grad= */ false);
401 auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
402 ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
403
404 // grad_y
405 out = MyFunction::apply(
406 x,
407 y,
408 2,
409 /* needs_input1_grad= */ false,
410 /* needs_input2_grad= */ true);
411 auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
412 ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
413
414 // grad_x and grad_y
415 out = MyFunction::apply(
416 x,
417 y,
418 2,
419 /* needs_input1_grad= */ true,
420 /* needs_input2_grad= */ true);
421 auto grads = torch::autograd::grad({out}, {x, y}, {go});
422 ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
423 ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
424}
425
426TEST(CustomAutogradTest, FunctionReturnsInput) {
427 struct MyFunction : public Function<MyFunction> {
428 static Variable forward(AutogradContext* ctx, Variable var1) {
429 return var1;
430 }
431
432 static variable_list backward(
433 AutogradContext* ctx,
434 variable_list grad_output) {
435 return {grad_output[0] * 2};
436 }
437 };
438
439 Variable x(torch::ones(1, torch::requires_grad()));
440 MyFunction::apply(x).backward(torch::ones(1), true, true);
441 ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
442}
443
444TEST(CustomAutogradTest, FunctionReturnsUndefined) {
445 struct MyFunction : public Function<MyFunction> {
446 static Variable forward(AutogradContext* ctx, Variable var) {
447 return var * 2;
448 }
449
450 static variable_list backward(
451 AutogradContext* ctx,
452 variable_list grad_output) {
453 at::Tensor undefined_tensor;
454 return {undefined_tensor};
455 }
456 };
457
458 auto x = torch::ones(1, torch::requires_grad());
459
460 MyFunction::apply(x).backward();
461 ASSERT_FALSE(x.grad().defined());
462
463 MyFunction::apply(x.pow(2)).backward();
464 ASSERT_FALSE(x.grad().defined());
465
466 MyFunction::apply(x).sum().backward();
467 ASSERT_FALSE(x.grad().defined());
468
469 ASSERT_FALSE(torch::autograd::grad(
470 {MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
471 .defined());
472}
473
474TEST(CustomAutogradTest, MaterializeGrads) {
475 struct MyFunction : public Function<MyFunction> {
476 static Variable forward(AutogradContext* ctx, Variable var) {
477 return var;
478 }
479
480 static variable_list backward(
481 AutogradContext* ctx,
482 variable_list grad_output) {
483 EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
484 return grad_output;
485 }
486 };
487
488 auto x = torch::ones(1, torch::requires_grad());
489 UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
490}
491
492TEST(CustomAutogradTest, DontMaterializeGrads) {
493 struct MyFunction : public Function<MyFunction> {
494 static Variable forward(AutogradContext* ctx, Variable var) {
495 ctx->set_materialize_grads(false);
496 return var;
497 }
498
499 static variable_list backward(
500 AutogradContext* ctx,
501 variable_list grad_output) {
502 EXPECT_FALSE(grad_output[0].defined());
503 return grad_output;
504 }
505 };
506
507 auto x = torch::ones(1, torch::requires_grad());
508 UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
509}
510
511TEST(CustomAutogradTest, NoGradCustomFunction) {
512 // Custom Function should respect grad mode
513 struct MyOp : public Function<MyOp> {
514 static Variable forward(AutogradContext* ctx, Variable x) {
515 return x + 1;
516 }
517
518 static variable_list backward(AutogradContext* ctx, variable_list dy) {
519 return dy;
520 }
521 };
522
523 auto x = torch::ones({5, 5}, torch::requires_grad());
524 {
525 at::NoGradGuard no_grad;
526 auto y = MyOp::apply(x);
527 ASSERT_FALSE(y.requires_grad());
528 }
529}
530
531TEST(CustomAutogradTest, MarkDirty) {
532 struct MyFunction : public Function<MyFunction> {
533 static Variable forward(AutogradContext* ctx, Variable v) {
534 // Change the value inplace
535 auto v_data = v.data_ptr<float>();
536 v_data[0] = 2;
537 ctx->mark_dirty({v});
538 return v;
539 }
540
541 static variable_list backward(
542 AutogradContext* ctx,
543 variable_list grad_output) {
544 return {(grad_output[0] * 2.0)};
545 }
546 };
547
548 // Clone here because modifying leafs inplace is not allowed
549 auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
550 auto version_before = x._version();
551 auto out = MyFunction::apply(x);
552 auto version_after = x._version();
553 ASSERT_TRUE(version_after >= (version_before + 1));
554 out.sum().backward();
555}
556
557TEST(CustomAutogradTest, MarkNonDifferentiable) {
558 struct MyFunction : public Function<MyFunction> {
559 static Variable forward(AutogradContext* ctx, Variable v) {
560 Variable output = v > 0;
561 ctx->mark_non_differentiable({output});
562 return output;
563 }
564
565 static variable_list backward(
566 AutogradContext* ctx,
567 variable_list grad_output) {
568 return {(grad_output[0] * 0.0)};
569 }
570 };
571
572 auto x = torch::randn({5, 5}, torch::requires_grad());
573 auto mask = MyFunction::apply(x);
574 ASSERT_FALSE(mask.requires_grad());
575 auto y = x.masked_fill(mask, 0);
576 y.sum().backward();
577}
578
579TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
580 struct MyFunction : public Function<MyFunction> {
581 static variable_list forward(AutogradContext* ctx, Variable input) {
582 Variable a = input + 1;
583 Variable b = input + 2;
584 ctx->mark_non_differentiable({a});
585 return {a, b};
586 }
587
588 static variable_list backward(
589 AutogradContext* ctx,
590 variable_list grad_output) {
591 const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
592 EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
593 EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
594 return {grad_b};
595 }
596 };
597
598 auto x = torch::randn({5, 5}, torch::requires_grad());
599 auto out = MyFunction::apply(x);
600
601 ASSERT_FALSE(out[0].requires_grad());
602 ASSERT_TRUE(out[1].requires_grad());
603 out[1].sum().backward();
604 ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
605}
606
607TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
608 struct MyFunction : public Function<MyFunction> {
609 static Variable forward(AutogradContext* ctx, Variable input) {
610 auto output = input.clone();
611 ctx->mark_non_differentiable({output});
612 return output;
613 }
614
615 static variable_list backward(
616 AutogradContext* ctx,
617 variable_list grad_outputs) {
618 return {};
619 }
620 };
621
622 auto x = torch::randn({5, 5}, torch::requires_grad());
623 auto r = MyFunction::apply(x * x);
624 (r * x).sum().backward();
625}
626
627TEST(CustomAutogradTest, ReturnLeafInplace) {
628 struct Inplace : public Function<Inplace> {
629 static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
630 ctx->mark_dirty({a});
631 return {a.add_(b), b + 2};
632 }
633
634 static variable_list backward(
635 AutogradContext* ctx,
636 variable_list grad_output) {
637 return {grad_output[0], grad_output[0] + grad_output[1]};
638 }
639 };
640
641 Variable x = torch::randn({5, 5});
642 Variable y = torch::randn({5, 5}, torch::requires_grad());
643
644 auto out = Inplace::apply(x, y);
645 auto& q = out[0];
646 ASSERT_TRUE(torch::equal(q, x));
647 ASSERT_TRUE(q.requires_grad());
648 q.sum().backward();
649 ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
650}
651
652TEST(CustomAutogradTest, ReturnDuplicateInplace) {
653 struct DoubleInplace : public Function<DoubleInplace> {
654 static variable_list forward(AutogradContext* ctx, Variable x) {
655 x.mul_(2);
656 ctx->mark_dirty({x});
657 return {x, x};
658 }
659
660 static variable_list backward(
661 AutogradContext* ctsx,
662 variable_list grad_outputs) {
663 return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
664 }
665 };
666
667 auto x = torch::randn({5, 5}, torch::requires_grad());
668
669 ASSERT_THROWS_WITH(
670 DoubleInplace::apply(x), "leaf Variable that requires grad");
671 // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
672 // output");
673
674 auto out = DoubleInplace::apply(x.clone());
675 ASSERT_TRUE(torch::equal(out[0], out[1]));
676}
677
678TEST(CustomAutogradTest, ReturnDuplicate) {
679 struct DoubleDuplicate : public Function<DoubleDuplicate> {
680 static variable_list forward(AutogradContext* ctx, Variable x) {
681 auto output = x * 2;
682 return {output, output};
683 }
684
685 static variable_list backward(
686 AutogradContext* ctx,
687 variable_list grad_outputs) {
688 return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
689 }
690 };
691
692 auto x = torch::randn({5, 5}, torch::requires_grad());
693 auto out = DoubleDuplicate::apply(x);
694 ASSERT_TRUE(torch::equal(out[0], out[1]));
695}
696
697TEST(CustomAutogradTest, SaveEmptyForBackward) {
698 struct MyFunction : public Function<MyFunction> {
699 static Variable forward(AutogradContext* ctx, Variable input) {
700 ctx->save_for_backward({Variable(), input, Variable()});
701 return input * input;
702 }
703
704 static variable_list backward(
705 AutogradContext* ctx,
706 variable_list grad_output) {
707 auto saved = ctx->get_saved_variables();
708 EXPECT_FALSE(saved[0].defined());
709 EXPECT_FALSE(saved[2].defined());
710 return {saved[1] * 2 * grad_output[0]};
711 }
712 };
713
714 Variable x = torch::randn({5, 5}, torch::requires_grad());
715 auto y = MyFunction::apply(x);
716 y.sum().backward();
717 ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
718}
719
720TEST(CustomAutogradTest, InvalidGradients) {
721 struct MyFunction : public Function<MyFunction> {
722 static Variable forward(AutogradContext* ctx, Variable x) {
723 return x * 2;
724 }
725
726 static variable_list backward(
727 AutogradContext* ctsx,
728 variable_list grad_outputs) {
729 return {
730 torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
731 }
732 };
733
734 auto input1 =
735 torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
736 ASSERT_THROWS_WITH(
737 MyFunction::apply(input1).sum().backward(), "expected shape");
738 auto input2 =
739 torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
740}
741
742TEST(CustomAutogradTest, NoGradInput) {
743 struct MyFunction : public Function<MyFunction> {
744 static Variable forward(AutogradContext*, Variable x) {
745 return x;
746 }
747
748 static variable_list backward(
749 AutogradContext*,
750 variable_list grad_outputs) {
751 return grad_outputs;
752 }
753 };
754
755 Variable x = torch::randn({5, 5}, torch::requires_grad());
756 Variable y;
757 {
758 at::NoGradGuard no_grad;
759 y = MyFunction::apply(x);
760 }
761
762 ASSERT_TRUE(x.requires_grad());
763 ASSERT_FALSE(y.grad_fn());
764}
765
766TEST(CustomAutogradTest, TooManyGrads) {
767 struct MyFunction : public Function<MyFunction> {
768 static Variable forward(AutogradContext*, Variable input) {
769 return input;
770 }
771
772 static variable_list backward(AutogradContext*, variable_list grad_output) {
773 grad_output.insert(grad_output.end(), {Variable(), Variable()});
774 return grad_output;
775 }
776 };
777}
778
779TEST(CustomAutogradTest, DepNoGrad) {
780 struct F1 : public Function<F1> {
781 static variable_list forward(AutogradContext* ctx, Variable input) {
782 auto out = torch::randn(input.sizes());
783 ctx->mark_non_differentiable({out});
784 return {input, out};
785 }
786
787 static variable_list backward(
788 AutogradContext* ctx,
789 variable_list grad_output) {
790 return {grad_output[0]};
791 }
792 };
793
794 struct F2 : public Function<F2> {
795 static Variable forward(AutogradContext*, Variable input, Variable ignore) {
796 return input;
797 }
798
799 static variable_list backward(AutogradContext*, variable_list grad_output) {
800 return {grad_output[0], Variable()};
801 }
802 };
803
804 auto x = torch::randn(5, torch::requires_grad());
805 auto out = F1::apply(x);
806 Variable &a = out[0], &b = out[1];
807 b = b + 1; // Separate F1 and F2 by another operation
808 ASSERT_TRUE(a.requires_grad());
809 ASSERT_FALSE(b.requires_grad());
810
811 auto c = F2::apply(a, b);
812 c.backward(torch::ones(c.sizes()), false, false);
813 ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
814}
815
816TEST(CustomAutogradTest, Reentrant) {
817 static Variable y_data = torch::randn({2, 2});
818 struct Reenter : public Function<Reenter> {
819 static Variable forward(AutogradContext* ctx, Variable input) {
820 Variable output;
821 {
822 at::AutoGradMode enable_grad(true);
823 auto x = make_variable(input.tensor_data(), true);
824 auto y = make_variable(y_data.tensor_data(), true);
825 output = x * y;
826
827 ctx->saved_data["x"] = x;
828 ctx->saved_data["y"] = y;
829 ctx->saved_data["output_var"] = output;
830 }
831 return output.detach();
832 }
833
834 static variable_list backward(
835 AutogradContext* ctx,
836 variable_list grad_output) {
837 {
838 at::AutoGradMode enable_grad(true);
839 auto out = ctx->saved_data["output_var"].toTensor();
840 out.sum().backward();
841 }
842 return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
843 }
844 };
845
846 auto x = torch::randn({2, 2}, torch::requires_grad());
847 auto out = Reenter::apply(x);
848 out.sum().backward();
849 ASSERT_VARIABLE_EQ(x.grad(), y_data);
850}
851
852// NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
853// the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
854TEST(CustomAutogradTest, DeepReentrant) {
855 struct DeepReenter : public Function<DeepReenter> {
856 static Variable forward(AutogradContext* ctx, Variable x) {
857 {
858 at::AutoGradMode enable_grad(true);
859 ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
860 }
861 return ctx->saved_data["x"].toTensor().detach();
862 }
863
864 static variable_list backward(
865 AutogradContext* ctx,
866 variable_list grad_output) {
867 if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
868 return grad_output;
869 }
870 {
871 at::AutoGradMode enable_grad(true);
872 apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
873 return grad_output;
874 }
875 }
876 };
877
878 // This should not stack overflow
879 auto v =
880 torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
881 DeepReenter::apply(v).sum().backward();
882}
883
884TEST(CustomAutogradTest, ReentrantPriority) {
885 static std::vector<int> order;
886
887 struct MyFunction : public Function<MyFunction> {
888 static Variable forward(AutogradContext*, Variable x) {
889 return x;
890 }
891
892 static variable_list backward(AutogradContext*, variable_list grad) {
893 order.push_back(0);
894 return grad;
895 }
896 };
897
898 struct Reenter : public Function<Reenter> {
899 static Variable forward(AutogradContext* ctx, Variable x) {
900 {
901 at::AutoGradMode enable_grad(true);
902 ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
903 }
904 return ctx->saved_data["x"].toTensor().detach();
905 }
906
907 static variable_list backward(
908 AutogradContext* ctx,
909 variable_list grad_output) {
910 order.push_back(1);
911 if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
912 return grad_output;
913 }
914 {
915 at::AutoGradMode enable_grad(true);
916 apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
917 return grad_output;
918 }
919 }
920 };
921
922 auto a = MyFunction::apply(
923 torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
924 auto b = Reenter::apply(
925 torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
926 auto v = a * b;
927 v.backward();
928
929 // All the reentrant tasks should be prioritized over the MyFunction backward
930 // task.
931 ASSERT_EQ(order.size(), 10);
932 ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
933 ASSERT_EQ(order.back(), 0);
934 // Clear static variable in case test get executed in a loop
935 order.clear();
936}
937
938TEST(CustomAutogradTest, Hooks) {
939 Variable x = torch::ones({5, 5}, torch::requires_grad());
940 Variable y = torch::ones({5, 5}) * 4;
941 y.set_requires_grad(true);
942
943 int counter = 0;
944
945 std::function<void(int, Variable)> bw_hook(
946 [&counter](int inc, Variable grad) { counter += inc; });
947
948 Variable z = x * x + x * 2 + x * y + y;
949 x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
950 auto hook_1 =
951 z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
952 z.backward(torch::ones({5, 5}), true, true);
953 ASSERT_EQ(counter, 1);
954
955 auto hook_2 =
956 z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
957 z.backward(torch::ones({5, 5}), true, true);
958 ASSERT_EQ(counter, 4);
959
960 z.remove_hook(hook_2);
961 z.backward(torch::ones({5, 5}), true, true);
962 ASSERT_EQ(counter, 5);
963
964 std::function<Variable(Variable)> bw_hook_modify(
965 [](Variable grad) { return grad.mul(2); });
966
967 z.remove_hook(hook_1);
968 z.register_hook(bw_hook_modify);
969 y.grad().zero_();
970 z.backward(torch::ones({5, 5}), true, false);
971 ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);
972
973 y.register_hook(bw_hook_modify);
974 y.grad().zero_();
975 z.backward(torch::ones({5, 5}), false, false);
976 ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);
977
978 ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
979}
980
981TEST(CustomAutogradTest, HooksInplace) {
982 auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
983
984 int hook1_count = 0;
985 auto hook1 = ([&hook1_count](Variable grad) {
986 hook1_count++;
987 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
988 });
989
990 int hook2_count = 0;
991 auto hook2 = ([&hook2_count](Variable grad) {
992 hook2_count++;
993 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
994 });
995
996 a.register_hook(hook1);
997 a.mul_(2);
998 a.register_hook(hook2);
999
1000 auto out = (a + 1).sum();
1001 out.backward();
1002
1003 ASSERT_EQ(hook1_count, 1);
1004 ASSERT_EQ(hook2_count, 1);
1005}
1006
1007TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
1008 auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1009
1010 int hook1_count = 0;
1011 auto hook1 = ([&hook1_count](Variable grad) {
1012 hook1_count++;
1013 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1014 });
1015
1016 int hook2_count = 0;
1017 auto hook2 = ([&hook2_count](Variable grad) {
1018 hook2_count++;
1019 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1020 });
1021
1022 int hook3_count = 0;
1023 auto hook3 = ([&hook3_count](Variable grad) {
1024 hook3_count++;
1025 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1026 });
1027
1028 a.register_hook(hook1);
1029 a.retain_grad();
1030 a.register_hook(hook2);
1031
1032 a.mul_(2);
1033 a.register_hook(hook3);
1034
1035 auto out = (a + 1).sum();
1036 out.backward();
1037
1038 ASSERT_EQ(hook1_count, 1);
1039 ASSERT_EQ(hook2_count, 1);
1040 ASSERT_EQ(hook3_count, 1);
1041
1042 ASSERT_TRUE(a.retains_grad());
1043 ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1044}
1045
1046TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
1047 auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1048
1049 int hook1_count = 0;
1050 auto hook1 = ([&hook1_count](Variable grad) {
1051 hook1_count++;
1052 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1053 });
1054
1055 int hook2_count = 0;
1056 auto hook2 = ([&hook2_count](Variable grad) {
1057 hook2_count++;
1058 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1059 });
1060
1061 int hook3_count = 0;
1062 auto hook3 = ([&hook3_count](Variable grad) {
1063 hook3_count++;
1064 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1065 });
1066
1067 a.register_hook(hook1);
1068 a.retain_grad();
1069 a.register_hook(hook2);
1070
1071 a.mul_(2);
1072 a.mul_(2);
1073 a.register_hook(hook3);
1074
1075 auto out = (a + 1).sum();
1076 out.backward();
1077
1078 ASSERT_EQ(hook1_count, 1);
1079 ASSERT_EQ(hook2_count, 1);
1080 ASSERT_EQ(hook3_count, 1);
1081
1082 ASSERT_TRUE(a.retains_grad());
1083 ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1084}
1085
1086TEST(CustomAutogradTest, HookNone) {
1087 struct NoneGradientFunction : public Function<NoneGradientFunction> {
1088 static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
1089 return {x, y};
1090 }
1091
1092 static variable_list backward(AutogradContext* ctx, variable_list grad) {
1093 return {grad[0], Variable()};
1094 }
1095 };
1096
1097 bool was_called = false;
1098
1099 auto hook = ([&was_called](Variable grad) {
1100 ASSERT_TRUE(grad.defined());
1101 was_called = true;
1102 });
1103
1104 auto x = torch::randn({5, 5}, torch::requires_grad());
1105 auto y = torch::randn({5, 5});
1106
1107 auto out = NoneGradientFunction::apply(x, y);
1108 Variable rx = x[0], ry = x[1];
1109
1110 rx.register_hook(hook);
1111 ry.register_hook(hook);
1112 (rx + ry).sum().backward();
1113 ASSERT_TRUE(was_called);
1114}
1115
1116TEST(CustomAutogradTest, BackwardWithInputs) {
1117 Variable x = torch::randn({5, 5}, torch::requires_grad());
1118 Variable y = torch::randn({5, 5}, torch::requires_grad());
1119 Variable z = x * x + x * y + y * y;
1120 Variable x_grad_expected = 2 * x + y;
1121 Variable y_grad_expected = x + 2 * y;
1122
1123 z.backward(torch::ones({5, 5}), false, false, {x});
1124
1125 ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1126 ASSERT_FALSE(y.grad().defined());
1127}
1128
1129TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
1130 Variable x = torch::randn({5, 5}, torch::requires_grad());
1131 Variable y = torch::randn({5, 5}, torch::requires_grad());
1132 Variable z = x * x + x * y + y * y;
1133 Variable x_grad_expected = 2 * x + y;
1134 Variable y_grad_expected = x + 2 * y;
1135 ASSERT_THROWS_WITH(
1136 z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
1137 "cannot be empty");
1138}
1139
1140TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
1141 Variable x = torch::randn({5, 5}, torch::requires_grad());
1142 Variable y = torch::randn({5, 5}, torch::requires_grad());
1143 Variable z = x * x;
1144 Variable w = y * z + x * y + y * y;
1145
1146 Variable x_grad_expected = 2 * x * y + y;
1147 Variable z_grad_expected = y;
1148
1149 w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
1150
1151 ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1152 ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
1153 ASSERT_FALSE(y.grad().defined());
1154}
1155
1156TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
1157 c10::WarningUtils::WarnAlways guard(true);
1158
1159 torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
1160 auto z = x * x;
1161 {
1162 WarningCapture warnings;
1163 z.backward(torch::ones({5, 5}), c10::nullopt, true);
1164 ASSERT_TRUE(
1165 warnings.str().find("Using backward() with create_graph=True") !=
1166 std::string::npos);
1167 }
1168
1169 {
1170 WarningCapture warnings;
1171 torch::autograd::backward({z}, {torch::ones({5, 5})}, c10::nullopt, true);
1172 ASSERT_TRUE(
1173 warnings.str().find("Using backward() with create_graph=True") !=
1174 std::string::npos);
1175 }
1176}
1177
1178/**
1179 * Tests for AutogradNotImplementedFallback
1180 * - Check that we created the NotImplemented kernel when inputs require grad
1181 * but when no inputs require grad, we should not create this node
1182 * - check_inplace logic
1183 * - view ops
1184 * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
1185 * test non-NDEBUG builds.
1186 * - tensorlist input and output
1187 * - multiple outputs / non-tensor output
1188 * - rebase_history vs set_history
1189 */
1190namespace {
1191
1192torch::Tensor inplace_op(
1193 const torch::Tensor& self,
1194 const torch::Tensor& other) {
1195 return self.add_(other);
1196}
1197
1198std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
1199 const torch::Tensor& self,
1200 const torch::Tensor& other) {
1201 other.add_(self);
1202 self.add_(other);
1203 return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1204}
1205
1206std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
1207 const torch::Tensor& self,
1208 const torch::Tensor& other) {
1209 // This is not allowed. We test below that this calling into the boxed kernel
1210 // will raise an error
1211 return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1212}
1213
1214std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
1215 const torch::Tensor& self,
1216 const torch::Tensor& other) {
1217 // This is not allowed. We test below that this calling into the boxed kernel
1218 // will raise an error
1219 return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
1220}
1221
1222int64_t ret_single_non_tensor(
1223 const torch::Tensor& self,
1224 const torch::Tensor& other) {
1225 return 12;
1226}
1227
1228torch::Tensor opt_op(
1229 const torch::Tensor& self,
1230 const c10::optional<at::Tensor>& other) {
1231 if (other.has_value()) {
1232 return self + other.value();
1233 } else {
1234 return self.clone();
1235 }
1236}
1237
1238torch::Tensor my_custom_op(
1239 const torch::Tensor& self,
1240 const torch::Tensor& other) {
1241 return self + other;
1242}
1243
1244std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
1245 const torch::Tensor& self,
1246 const torch::Tensor& other) {
1247 auto a = self - other;
1248 auto b = self + other;
1249 return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
1250}
1251
1252torch::Tensor view_op(const torch::Tensor& self) {
1253 return self.alias();
1254}
1255
1256torch::Tensor view_op_with_extra_arg(
1257 const torch::Tensor& self,
1258 const torch::Tensor& other) {
1259 return self.alias();
1260}
1261
1262std::vector<torch::Tensor> ret_tensor_vector_view(
1263 const torch::Tensor& self,
1264 const torch::Tensor& other) {
1265 return {self.alias(), self.alias()};
1266}
1267
1268std::vector<at::Tensor> ret_tensor_vector(
1269 const torch::Tensor& self,
1270 const torch::Tensor& other) {
1271 std::vector<at::Tensor> out;
1272 out.push_back(self + other);
1273 out.push_back(self - other);
1274 return out;
1275}
1276
1277torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
1278 const auto& res = self.clone();
1279 for (const auto& t : other) {
1280 res.add_(t);
1281 }
1282 return res;
1283}
1284
1285#define REGISTER_TEST_OP(name, schema, fn) \
1286 auto m = MAKE_TORCH_LIBRARY(_test); \
1287 m.def(schema); \
1288 auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
1289 auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
1290 auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \
1291 m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
1292 m_autograd.impl( \
1293 name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
1294 m_inplaceorview.impl( \
1295 name, \
1296 c10::DispatchKey::ADInplaceOrView, \
1297 autogradNotImplementedInplaceOrViewFallback());
1298
1299template <typename F>
1300void assertBasicChecks(F op) {
1301 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1302 auto b = torch::tensor({1.}, {torch::kFloat32});
1303 auto c = torch::tensor({1.}, {torch::kFloat32});
1304
1305 // If any inputs require grad,
1306 auto out1 = op(a, b);
1307 ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
1308
1309 // # Should not have grad_fn if none require grad
1310 auto out2 = op(b, c);
1311 ASSERT_THROWS_WITH(
1312 out2.backward(),
1313 "element 0 of tensors does not require grad and does not have a grad_fn");
1314
1315 // TODO: Forward AD Tests?
1316}
1317
1318} // namespace
1319
1320// These tests trigger an MSVC bug in the internal arvr build
1321// Reproduce with: buck build @arvr/mode/win/opt
1322// //xplat/caffe2:autograd_libtorch_test_ovrsource It is probably caused by the
1323// lambda, see https://github.com/pytorch/pytorch/issues/48763
1324#if !defined(_MSC_VER)
1325
1326TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
1327 REGISTER_TEST_OP(
1328 "ret_single_non_tensor",
1329 "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
1330 ret_single_non_tensor);
1331 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1332 "_test::ret_single_non_tensor", "");
1333 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1334 return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
1335 opHandle, _1, _2);
1336 };
1337
1338 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1339 auto b = torch::tensor({1.}, {torch::kFloat32});
1340
1341 ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
1342}
1343
1344TEST(TestAutogradNotImplementedFallback, InplaceOp) {
1345 REGISTER_TEST_OP(
1346 "inplace_op",
1347 "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
1348 inplace_op);
1349 auto opHandle =
1350 c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
1351 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1352 return callOpUnboxed<
1353 torch::Tensor,
1354 const torch::Tensor&,
1355 const torch::Tensor&>(opHandle, _1, _2);
1356 };
1357
1358 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1359 auto b = torch::tensor({1.}, {torch::kFloat32});
1360
1361 // Check in-place
1362 ASSERT_THROWS_WITH(
1363 op(a, b),
1364 "a leaf Variable that requires grad is being used in an in-place operation");
1365 op(b, a);
1366 a = a.clone();
1367 b = b.clone();
1368 auto c = op(a, b);
1369 ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
1370
1371 // Test in-place on view
1372 auto base =
1373 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1374 auto view = base.view(-1);
1375 auto t = torch::tensor({1.}, {torch::kFloat32});
1376
1377 torch::Tensor v_nograd;
1378 {
1379 c10::NoGradGuard guard;
1380 v_nograd = base.view(-1);
1381 op(v_nograd, t);
1382 }
1383
1384 ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
1385 ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
1386 ASSERT_THAT(
1387 op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
1388}
1389
1390TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
1391 REGISTER_TEST_OP(
1392 "two_arg_inplace_op",
1393 "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
1394 two_arg_inplace_op);
1395 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1396 "_test::two_arg_inplace_op", "");
1397 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1398 return callOpUnboxed<
1399 std::tuple<torch::Tensor, torch::Tensor>,
1400 const torch::Tensor&,
1401 const torch::Tensor&>(opHandle, _1, _2);
1402 };
1403 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1404 auto b = torch::tensor({1.}, {torch::kFloat32});
1405
1406 // Both are modified in-place!
1407 ASSERT_THROWS_WITH(
1408 op(a, b),
1409 "a leaf Variable that requires grad is being used in an in-place operation");
1410 ASSERT_THROWS_WITH(
1411 op(b, a),
1412 "a leaf Variable that requires grad is being used in an in-place operation");
1413
1414 auto c =
1415 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1416 auto d =
1417 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1418
1419 auto saved_version_c = c._version();
1420 auto saved_version_d = d._version();
1421 op(c, d);
1422 ASSERT_NE(c._version(), saved_version_c);
1423 ASSERT_NE(d._version(), saved_version_d);
1424}
1425
1426TEST(TestAutogradNotImplementedFallback, OptOp) {
1427 REGISTER_TEST_OP(
1428 "opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
1429 auto opHandle =
1430 c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
1431 auto op = [&](const torch::Tensor& _1,
1432 const c10::optional<torch::Tensor>& _2) {
1433 return callOpUnboxed<
1434 torch::Tensor,
1435 const torch::Tensor&,
1436 const c10::optional<torch::Tensor>&>(opHandle, _1, _2);
1437 };
1438
1439 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1440 auto b = torch::tensor({1.}, {torch::kFloat32});
1441
1442 ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
1443 ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
1444}
1445
1446TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
1447 REGISTER_TEST_OP(
1448 "my_custom_op",
1449 "_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
1450 my_custom_op);
1451 auto opHandle =
1452 c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
1453 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1454 return callOpUnboxed<
1455 torch::Tensor,
1456 const torch::Tensor&,
1457 const torch::Tensor&>(opHandle, _1, _2);
1458 };
1459
1460 assertBasicChecks(op);
1461}
1462
1463TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
1464 REGISTER_TEST_OP(
1465 "ret_tuple_non_tensor",
1466 "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
1467 ret_tuple_non_tensor);
1468 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1469 "_test::ret_tuple_non_tensor", "");
1470 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1471 torch::Tensor out0;
1472 torch::Tensor out1;
1473 int64_t out2;
1474 auto out = callOpUnboxed<
1475 std::tuple<torch::Tensor, torch::Tensor, int64_t>,
1476 const torch::Tensor&,
1477 const torch::Tensor&>(opHandle, _1, _2);
1478 std::tie(out0, out1, out2) = std::move(out);
1479 return out0;
1480 };
1481
1482 assertBasicChecks(op);
1483}
1484
1485TEST(TestAutogradNotImplementedFallback, ViewOp) {
1486 REGISTER_TEST_OP(
1487 "view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
1488 auto opHandle =
1489 c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
1490 auto op = [&](const torch::Tensor& _1) {
1491 return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
1492 };
1493 auto b = torch::tensor({1.}, {torch::kFloat32});
1494 auto v = op(b);
1495 ASSERT_TRUE(v.is_view());
1496 ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
1497
1498 auto b1 =
1499 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1500 auto v1 = op(b1);
1501 ASSERT_TRUE(v1.is_view());
1502 ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());
1503
1504 // Test inplace on view
1505 auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1506
1507 // raise on rebase_history when it refreshes grad_fn
1508 ASSERT_THROWS_WITH(
1509 v1.add_(t), "which does not have a derivative implemented is forbidden");
1510 // base should not be aware of the views, so this is still okay
1511 b1.add_(t);
1512 ASSERT_THROWS_WITH(
1513 v1.grad_fn(),
1514 "which does not have a derivative implemented is forbidden");
1515}
1516
1517TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
1518 REGISTER_TEST_OP(
1519 "view_op_with_extra_arg",
1520 "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
1521 view_op_with_extra_arg);
1522 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1523 "_test::view_op_with_extra_arg", "");
1524 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1525 return callOpUnboxed<
1526 torch::Tensor,
1527 const torch::Tensor&,
1528 const torch::Tensor&>(opHandle, _1, _2);
1529 };
1530 assertBasicChecks(op);
1531 auto a = torch::tensor({1.}, {torch::kFloat32});
1532 auto b = torch::tensor({2.}, {torch::kFloat32});
1533 auto out1 = op(a, b);
1534 ASSERT_TRUE(out1.is_view());
1535 ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1536}
1537
1538TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
1539 REGISTER_TEST_OP(
1540 "ret_tensor_vector_view",
1541 "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
1542 ret_tensor_vector_view);
1543 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1544 "_test::ret_tensor_vector_view", "");
1545 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1546 return callOpUnboxed<
1547 std::vector<at::Tensor>,
1548 const torch::Tensor&,
1549 const torch::Tensor&>(opHandle, _1, _2);
1550 };
1551 auto a = torch::tensor({1.}, {torch::kFloat32});
1552 auto b = torch::tensor({1.}, {torch::kFloat32});
1553 auto out = op(a, b);
1554 ASSERT_TRUE(out[0].is_view());
1555 ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1556 ASSERT_TRUE(out[1].is_view());
1557 ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1558}
1559
1560TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
1561 REGISTER_TEST_OP(
1562 "two_pairs_of_view_op",
1563 "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
1564 two_pairs_of_view_op);
1565 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1566 "_test::two_pairs_of_view_op", "");
1567 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1568 return callOpUnboxed<
1569 std::tuple<torch::Tensor, torch::Tensor>,
1570 const torch::Tensor&,
1571 const torch::Tensor&>(opHandle, _1, _2);
1572 };
1573 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1574 auto b = torch::tensor({1.}, {torch::kFloat32});
1575 ASSERT_THROWS_WITH(
1576 op(a, b),
1577 "Expected only a single output in the operator schema to have a non-write alias annotation");
1578}
1579
1580TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
1581 REGISTER_TEST_OP(
1582 "non_first_view_op",
1583 "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
1584 non_first_view_op);
1585 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1586 "_test::non_first_view_op", "");
1587 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1588 return callOpUnboxed<
1589 std::tuple<torch::Tensor, torch::Tensor>,
1590 const torch::Tensor&,
1591 const torch::Tensor&>(opHandle, _1, _2);
1592 };
1593 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1594 auto b = torch::tensor({1.}, {torch::kFloat32});
1595 ASSERT_THROWS_WITH(
1596 op(a, b), "can only create view relationships between the first");
1597}
1598
1599TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
1600 REGISTER_TEST_OP(
1601 "ret_tensor_vector",
1602 "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
1603 ret_tensor_vector);
1604 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1605 "_test::ret_tensor_vector", "");
1606 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1607 return callOpUnboxed<
1608 std::vector<at::Tensor>,
1609 const torch::Tensor&,
1610 const torch::Tensor&>(opHandle, _1, _2)[0];
1611 };
1612 assertBasicChecks(op);
1613}
1614
1615TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
1616 REGISTER_TEST_OP(
1617 "tensorlist_op",
1618 "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
1619 tensorlist_op);
1620 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1621 "_test::tensorlist_op", "");
1622 auto op = [&](torch::Tensor _1, at::TensorList _2) {
1623 return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
1624 opHandle, _1, _2);
1625 };
1626
1627 auto a = torch::tensor({1.}, {torch::kFloat32});
1628 auto b = torch::tensor({1.}, {torch::kFloat32});
1629 auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1630 std::vector<torch::Tensor> vec = {b, c};
1631 auto out = op(a, vec);
1632
1633 ASSERT_THROWS_WITH(
1634 torch::autograd::grad({out}, {vec[0]}),
1635 "One of the differentiated Tensors does not require grad");
1636 ASSERT_THROWS_WITH(
1637 torch::autograd::grad({out}, {vec[1]}), "is not implemented");
1638
1639 ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
1640}
1641
1642#endif
1643
1644// TODO add these tests if needed
1645// test_once_differentiable
1646// test_sparse_backward
1647// test_save_output_nr
1648// test_free_deep_graph_pyfunction
1649// test_naughty_anomaly_access
1650// test_naughty_autograd-function_stashing_ctx
1651// test_custom_autograd_repeated_grad_grad
1652// test_return_leaf
1653// test_anomaly_detect_nan
1654// test_no_grad_copy
1655