1#include <gtest/gtest.h>
2
3#include <torch/torch.h>
4
5#include <test/cpp/api/support.h>
6
7#include <functional>
8
9using namespace torch::test;
10
11void torch_warn_once_A() {
12 TORCH_WARN_ONCE("warn once");
13}
14
15void torch_warn_once_B() {
16 TORCH_WARN_ONCE("warn something else once");
17}
18
19void torch_warn() {
20 TORCH_WARN("warn multiple times");
21}
22
23TEST(UtilsTest, WarnOnce) {
24 {
25 WarningCapture warnings;
26
27 torch_warn_once_A();
28 torch_warn_once_A();
29 torch_warn_once_B();
30 torch_warn_once_B();
31
32 ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once"), 1);
33 ASSERT_EQ(
34 count_substr_occurrences(warnings.str(), "warn something else once"),
35 1);
36 }
37 {
38 WarningCapture warnings;
39
40 torch_warn();
41 torch_warn();
42 torch_warn();
43
44 ASSERT_EQ(
45 count_substr_occurrences(warnings.str(), "warn multiple times"), 3);
46 }
47}
48
49TEST(NoGradTest, SetsGradModeCorrectly) {
50 torch::manual_seed(0);
51 torch::NoGradGuard guard;
52 torch::nn::Linear model(5, 2);
53 auto x = torch::randn({10, 5}, torch::requires_grad());
54 auto y = model->forward(x);
55 torch::Tensor s = y.sum();
56
57 // Mimicking python API behavior:
58 ASSERT_THROWS_WITH(
59 s.backward(),
60 "element 0 of tensors does not require grad and does not have a grad_fn")
61}
62
63struct AutogradTest : torch::test::SeedingFixture {
64 AutogradTest() {
65 x = torch::randn({3, 3}, torch::requires_grad());
66 y = torch::randn({3, 3});
67 z = x * y;
68 }
69 torch::Tensor x, y, z;
70};
71
72TEST_F(AutogradTest, CanTakeDerivatives) {
73 z.backward(torch::ones_like(z));
74 ASSERT_TRUE(x.grad().allclose(y));
75}
76
77TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
78 z.sum().backward();
79 ASSERT_TRUE(x.grad().allclose(y));
80}
81
82TEST_F(AutogradTest, CanPassCustomGradientInputs) {
83 z.sum().backward(torch::ones({}) * 2);
84 ASSERT_TRUE(x.grad().allclose(y * 2));
85}
86
87TEST(UtilsTest, AmbiguousOperatorDefaults) {
88 auto tmp = at::empty({}, at::kCPU);
89 at::_test_ambiguous_defaults(tmp);
90 at::_test_ambiguous_defaults(tmp, 1);
91 at::_test_ambiguous_defaults(tmp, 1, 1);
92 at::_test_ambiguous_defaults(tmp, 2, "2");
93}
94
95int64_t get_first_element(c10::OptionalIntArrayRef arr) {
96 return arr.value()[0];
97}
98
99TEST(OptionalArrayRefTest, DanglingPointerFix) {
100 // Ensure that the converting constructor of `OptionalArrayRef` does not
101 // create a dangling pointer when given a single value
102 ASSERT_TRUE(get_first_element(300) == 300);
103 ASSERT_TRUE(get_first_element({400}) == 400);
104}
105