1 | #include <gtest/gtest.h> |
2 | #include <test/cpp/api/support.h> |
3 | #include <torch/script.h> |
4 | |
5 | using namespace torch::autograd; |
6 | using namespace torch::test; |
7 | |
8 | TEST(GradModeTest, TestRequiresGradFunctionalOp) { |
9 | torch::AutoGradMode mode(false); |
10 | for (bool requires_grad : {true, false}) { |
11 | torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
12 | |
13 | torch::Tensor func_out = c * c; |
14 | ASSERT_FALSE(func_out.requires_grad()); |
15 | ASSERT_TRUE(func_out.is_leaf()); |
16 | } |
17 | } |
18 | |
19 | TEST(GradModeTest, TestRequiresGradInplaceOp) { |
20 | torch::AutoGradMode mode(false); |
21 | for (bool requires_grad : {true, false}) { |
22 | torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
23 | |
24 | c.mul_(2); |
25 | ASSERT_EQ(c.requires_grad(), requires_grad); |
26 | } |
27 | } |
28 | |
29 | TEST(GradModeTest, TestRequiresGradViewOp) { |
30 | torch::AutoGradMode mode(false); |
31 | for (bool requires_grad : {true, false}) { |
32 | torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
33 | |
34 | torch::Tensor view_out = c.view({2, 3}); |
35 | ASSERT_EQ(view_out.requires_grad(), requires_grad); |
36 | ASSERT_TRUE(view_out.is_leaf()); |
37 | } |
38 | } |
39 | |
40 | TEST(GradModeTest, TestRequiresGradViewOpExiting) { |
41 | for (bool requires_grad : {true, false}) { |
42 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
43 | torch::Tensor a = s.clone(); |
44 | torch::Tensor view_out, tmp; |
45 | |
46 | { |
47 | torch::AutoGradMode mode(false); |
48 | view_out = a.view( |
49 | {2, 3}); // go through kernels: VariableType, ADInplaceOrView, CPU |
50 | assert_tensor_creation_meta( |
51 | view_out, torch::autograd::CreationMeta::NO_GRAD_MODE); |
52 | ASSERT_EQ(view_out.requires_grad(), requires_grad); |
53 | ASSERT_TRUE(view_out.is_leaf()); |
54 | } |
55 | |
56 | tmp = view_out * view_out; |
57 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
58 | if (requires_grad) { |
59 | tmp.backward(torch::ones_like(tmp)); |
60 | // TODO: this behavior is a side effect of issue #11390. |
61 | ASSERT_FALSE(view_out.grad().defined()); |
62 | } |
63 | |
64 | if (requires_grad) { |
65 | ASSERT_THROWS_WITH( |
66 | view_out.mul_( |
67 | 2), // go through kernels: VariableType, ADInplaceOrView, CPU |
68 | "A view was created in no_grad mode and is being modified inplace" ); |
69 | } else { |
70 | view_out.mul_(2); |
71 | } |
72 | |
73 | tmp = view_out.view({2, 3}); |
74 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
75 | assert_tensor_creation_meta( |
76 | tmp, torch::autograd::CreationMeta::NO_GRAD_MODE); |
77 | } |
78 | } |
79 | |