1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <algorithm> |
7 | #include <memory> |
8 | #include <vector> |
9 | |
10 | #include <test/cpp/api/support.h> |
11 | |
12 | using namespace torch::nn; |
13 | using namespace torch::test; |
14 | |
15 | struct ParameterListTest : torch::test::SeedingFixture {}; |
16 | |
17 | TEST_F(ParameterListTest, ConstructsFromSharedPointer) { |
18 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
19 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
20 | torch::Tensor tc = torch::randn({1, 2}); |
21 | ASSERT_TRUE(ta.requires_grad()); |
22 | ASSERT_FALSE(tb.requires_grad()); |
23 | ParameterList list(ta, tb, tc); |
24 | ASSERT_EQ(list->size(), 3); |
25 | } |
26 | |
27 | TEST_F(ParameterListTest, isEmpty) { |
28 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
29 | ParameterList list; |
30 | ASSERT_TRUE(list->is_empty()); |
31 | list->append(ta); |
32 | ASSERT_FALSE(list->is_empty()); |
33 | ASSERT_EQ(list->size(), 1); |
34 | } |
35 | |
36 | TEST_F(ParameterListTest, PushBackAddsAnElement) { |
37 | ParameterList list; |
38 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
39 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
40 | torch::Tensor tc = torch::randn({1, 2}); |
41 | torch::Tensor td = torch::randn({1, 2, 3}); |
42 | ASSERT_EQ(list->size(), 0); |
43 | ASSERT_TRUE(list->is_empty()); |
44 | list->append(ta); |
45 | ASSERT_EQ(list->size(), 1); |
46 | list->append(tb); |
47 | ASSERT_EQ(list->size(), 2); |
48 | list->append(tc); |
49 | ASSERT_EQ(list->size(), 3); |
50 | list->append(td); |
51 | ASSERT_EQ(list->size(), 4); |
52 | } |
53 | TEST_F(ParameterListTest, ForEachLoop) { |
54 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
55 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
56 | torch::Tensor tc = torch::randn({1, 2}); |
57 | torch::Tensor td = torch::randn({1, 2, 3}); |
58 | ParameterList list(ta, tb, tc, td); |
59 | std::vector<torch::Tensor> params = {ta, tb, tc, td}; |
60 | ASSERT_EQ(list->size(), 4); |
61 | int idx = 0; |
62 | for (const auto& pair : *list) { |
63 | ASSERT_TRUE( |
64 | torch::all(torch::eq(pair.value(), params[idx++])).item<bool>()); |
65 | } |
66 | } |
67 | |
68 | TEST_F(ParameterListTest, AccessWithAt) { |
69 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
70 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
71 | torch::Tensor tc = torch::randn({1, 2}); |
72 | torch::Tensor td = torch::randn({1, 2, 3}); |
73 | std::vector<torch::Tensor> params = {ta, tb, tc, td}; |
74 | |
75 | ParameterList list; |
76 | for (auto& param : params) { |
77 | list->append(param); |
78 | } |
79 | ASSERT_EQ(list->size(), 4); |
80 | |
81 | // returns the correct module for a given index |
82 | for (const auto i : c10::irange(params.size())) { |
83 | ASSERT_TRUE(torch::all(torch::eq(list->at(i), params[i])).item<bool>()); |
84 | } |
85 | |
86 | for (const auto i : c10::irange(params.size())) { |
87 | ASSERT_TRUE(torch::all(torch::eq(list[i], params[i])).item<bool>()); |
88 | } |
89 | |
90 | // throws for a bad index |
91 | ASSERT_THROWS_WITH(list->at(params.size() + 100), "Index out of range" ); |
92 | ASSERT_THROWS_WITH(list->at(params.size() + 1), "Index out of range" ); |
93 | ASSERT_THROWS_WITH(list[params.size() + 1], "Index out of range" ); |
94 | } |
95 | |
96 | TEST_F(ParameterListTest, ExtendPushesParametersFromOtherParameterList) { |
97 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
98 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
99 | torch::Tensor tc = torch::randn({1, 2}); |
100 | torch::Tensor td = torch::randn({1, 2, 3}); |
101 | torch::Tensor te = torch::randn({1, 2}); |
102 | torch::Tensor tf = torch::randn({1, 2, 3}); |
103 | ParameterList a(ta, tb); |
104 | ParameterList b(tc, td); |
105 | a->extend(*b); |
106 | |
107 | ASSERT_EQ(a->size(), 4); |
108 | ASSERT_TRUE(torch::all(torch::eq(a[0], ta)).item<bool>()); |
109 | ASSERT_TRUE(torch::all(torch::eq(a[1], tb)).item<bool>()); |
110 | ASSERT_TRUE(torch::all(torch::eq(a[2], tc)).item<bool>()); |
111 | ASSERT_TRUE(torch::all(torch::eq(a[3], td)).item<bool>()); |
112 | |
113 | ASSERT_EQ(b->size(), 2); |
114 | ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>()); |
115 | ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>()); |
116 | |
117 | std::vector<torch::Tensor> c = {te, tf}; |
118 | b->extend(c); |
119 | |
120 | ASSERT_EQ(b->size(), 4); |
121 | ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>()); |
122 | ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>()); |
123 | ASSERT_TRUE(torch::all(torch::eq(b[2], te)).item<bool>()); |
124 | ASSERT_TRUE(torch::all(torch::eq(b[3], tf)).item<bool>()); |
125 | } |
126 | |
127 | TEST_F(ParameterListTest, PrettyPrintParameterList) { |
128 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
129 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
130 | torch::Tensor tc = torch::randn({1, 2}); |
131 | ParameterList list(ta, tb, tc); |
132 | ASSERT_EQ( |
133 | c10::str(list), |
134 | "torch::nn::ParameterList(\n" |
135 | "(0): Parameter containing: [Float of size [1, 2]]\n" |
136 | "(1): Parameter containing: [Float of size [1, 2]]\n" |
137 | "(2): Parameter containing: [Float of size [1, 2]]\n" |
138 | ")" ); |
139 | } |
140 | |
141 | TEST_F(ParameterListTest, IncrementAdd) { |
142 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
143 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
144 | torch::Tensor tc = torch::randn({1, 2}); |
145 | torch::Tensor td = torch::randn({1, 2, 3}); |
146 | torch::Tensor te = torch::randn({1, 2}); |
147 | torch::Tensor tf = torch::randn({1, 2, 3}); |
148 | ParameterList listA(ta, tb, tc); |
149 | ParameterList listB(td, te, tf); |
150 | std::vector<torch::Tensor> tensors{ta, tb, tc, td, te, tf}; |
151 | int idx = 0; |
152 | *listA += *listB; |
153 | ASSERT_TRUE(torch::all(torch::eq(listA[0], ta)).item<bool>()); |
154 | ASSERT_TRUE(torch::all(torch::eq(listA[1], tb)).item<bool>()); |
155 | ASSERT_TRUE(torch::all(torch::eq(listA[2], tc)).item<bool>()); |
156 | ASSERT_TRUE(torch::all(torch::eq(listA[3], td)).item<bool>()); |
157 | ASSERT_TRUE(torch::all(torch::eq(listA[4], te)).item<bool>()); |
158 | ASSERT_TRUE(torch::all(torch::eq(listA[5], tf)).item<bool>()); |
159 | for (const auto& P : listA->named_parameters(false)) |
160 | ASSERT_TRUE(torch::all(torch::eq(P.value(), tensors[idx++])).item<bool>()); |
161 | |
162 | ASSERT_EQ(idx, 6); |
163 | } |
164 | |