1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <test/cpp/api/support.h> |
6 | |
7 | #include <functional> |
8 | |
9 | using namespace torch::test; |
10 | |
11 | void torch_warn_once_A() { |
12 | TORCH_WARN_ONCE("warn once" ); |
13 | } |
14 | |
15 | void torch_warn_once_B() { |
16 | TORCH_WARN_ONCE("warn something else once" ); |
17 | } |
18 | |
19 | void torch_warn() { |
20 | TORCH_WARN("warn multiple times" ); |
21 | } |
22 | |
23 | TEST(UtilsTest, WarnOnce) { |
24 | { |
25 | WarningCapture warnings; |
26 | |
27 | torch_warn_once_A(); |
28 | torch_warn_once_A(); |
29 | torch_warn_once_B(); |
30 | torch_warn_once_B(); |
31 | |
32 | ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once" ), 1); |
33 | ASSERT_EQ( |
34 | count_substr_occurrences(warnings.str(), "warn something else once" ), |
35 | 1); |
36 | } |
37 | { |
38 | WarningCapture warnings; |
39 | |
40 | torch_warn(); |
41 | torch_warn(); |
42 | torch_warn(); |
43 | |
44 | ASSERT_EQ( |
45 | count_substr_occurrences(warnings.str(), "warn multiple times" ), 3); |
46 | } |
47 | } |
48 | |
49 | TEST(NoGradTest, SetsGradModeCorrectly) { |
50 | torch::manual_seed(0); |
51 | torch::NoGradGuard guard; |
52 | torch::nn::Linear model(5, 2); |
53 | auto x = torch::randn({10, 5}, torch::requires_grad()); |
54 | auto y = model->forward(x); |
55 | torch::Tensor s = y.sum(); |
56 | |
57 | // Mimicking python API behavior: |
58 | ASSERT_THROWS_WITH( |
59 | s.backward(), |
60 | "element 0 of tensors does not require grad and does not have a grad_fn" ) |
61 | } |
62 | |
63 | struct AutogradTest : torch::test::SeedingFixture { |
64 | AutogradTest() { |
65 | x = torch::randn({3, 3}, torch::requires_grad()); |
66 | y = torch::randn({3, 3}); |
67 | z = x * y; |
68 | } |
69 | torch::Tensor x, y, z; |
70 | }; |
71 | |
72 | TEST_F(AutogradTest, CanTakeDerivatives) { |
73 | z.backward(torch::ones_like(z)); |
74 | ASSERT_TRUE(x.grad().allclose(y)); |
75 | } |
76 | |
77 | TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) { |
78 | z.sum().backward(); |
79 | ASSERT_TRUE(x.grad().allclose(y)); |
80 | } |
81 | |
82 | TEST_F(AutogradTest, CanPassCustomGradientInputs) { |
83 | z.sum().backward(torch::ones({}) * 2); |
84 | ASSERT_TRUE(x.grad().allclose(y * 2)); |
85 | } |
86 | |
87 | TEST(UtilsTest, AmbiguousOperatorDefaults) { |
88 | auto tmp = at::empty({}, at::kCPU); |
89 | at::_test_ambiguous_defaults(tmp); |
90 | at::_test_ambiguous_defaults(tmp, 1); |
91 | at::_test_ambiguous_defaults(tmp, 1, 1); |
92 | at::_test_ambiguous_defaults(tmp, 2, "2" ); |
93 | } |
94 | |
95 | int64_t get_first_element(c10::OptionalIntArrayRef arr) { |
96 | return arr.value()[0]; |
97 | } |
98 | |
99 | TEST(OptionalArrayRefTest, DanglingPointerFix) { |
100 | // Ensure that the converting constructor of `OptionalArrayRef` does not |
101 | // create a dangling pointer when given a single value |
102 | ASSERT_TRUE(get_first_element(300) == 300); |
103 | ASSERT_TRUE(get_first_element({400}) == 400); |
104 | } |
105 | |