1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <algorithm>
7#include <memory>
8#include <vector>
9
10#include <test/cpp/api/support.h>
11
12using namespace torch::nn;
13using namespace torch::test;
14
15struct ParameterListTest : torch::test::SeedingFixture {};
16
17TEST_F(ParameterListTest, ConstructsFromSharedPointer) {
18 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
19 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
20 torch::Tensor tc = torch::randn({1, 2});
21 ASSERT_TRUE(ta.requires_grad());
22 ASSERT_FALSE(tb.requires_grad());
23 ParameterList list(ta, tb, tc);
24 ASSERT_EQ(list->size(), 3);
25}
26
27TEST_F(ParameterListTest, isEmpty) {
28 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
29 ParameterList list;
30 ASSERT_TRUE(list->is_empty());
31 list->append(ta);
32 ASSERT_FALSE(list->is_empty());
33 ASSERT_EQ(list->size(), 1);
34}
35
36TEST_F(ParameterListTest, PushBackAddsAnElement) {
37 ParameterList list;
38 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
39 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
40 torch::Tensor tc = torch::randn({1, 2});
41 torch::Tensor td = torch::randn({1, 2, 3});
42 ASSERT_EQ(list->size(), 0);
43 ASSERT_TRUE(list->is_empty());
44 list->append(ta);
45 ASSERT_EQ(list->size(), 1);
46 list->append(tb);
47 ASSERT_EQ(list->size(), 2);
48 list->append(tc);
49 ASSERT_EQ(list->size(), 3);
50 list->append(td);
51 ASSERT_EQ(list->size(), 4);
52}
53TEST_F(ParameterListTest, ForEachLoop) {
54 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
55 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
56 torch::Tensor tc = torch::randn({1, 2});
57 torch::Tensor td = torch::randn({1, 2, 3});
58 ParameterList list(ta, tb, tc, td);
59 std::vector<torch::Tensor> params = {ta, tb, tc, td};
60 ASSERT_EQ(list->size(), 4);
61 int idx = 0;
62 for (const auto& pair : *list) {
63 ASSERT_TRUE(
64 torch::all(torch::eq(pair.value(), params[idx++])).item<bool>());
65 }
66}
67
68TEST_F(ParameterListTest, AccessWithAt) {
69 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
70 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
71 torch::Tensor tc = torch::randn({1, 2});
72 torch::Tensor td = torch::randn({1, 2, 3});
73 std::vector<torch::Tensor> params = {ta, tb, tc, td};
74
75 ParameterList list;
76 for (auto& param : params) {
77 list->append(param);
78 }
79 ASSERT_EQ(list->size(), 4);
80
81 // returns the correct module for a given index
82 for (const auto i : c10::irange(params.size())) {
83 ASSERT_TRUE(torch::all(torch::eq(list->at(i), params[i])).item<bool>());
84 }
85
86 for (const auto i : c10::irange(params.size())) {
87 ASSERT_TRUE(torch::all(torch::eq(list[i], params[i])).item<bool>());
88 }
89
90 // throws for a bad index
91 ASSERT_THROWS_WITH(list->at(params.size() + 100), "Index out of range");
92 ASSERT_THROWS_WITH(list->at(params.size() + 1), "Index out of range");
93 ASSERT_THROWS_WITH(list[params.size() + 1], "Index out of range");
94}
95
96TEST_F(ParameterListTest, ExtendPushesParametersFromOtherParameterList) {
97 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
98 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
99 torch::Tensor tc = torch::randn({1, 2});
100 torch::Tensor td = torch::randn({1, 2, 3});
101 torch::Tensor te = torch::randn({1, 2});
102 torch::Tensor tf = torch::randn({1, 2, 3});
103 ParameterList a(ta, tb);
104 ParameterList b(tc, td);
105 a->extend(*b);
106
107 ASSERT_EQ(a->size(), 4);
108 ASSERT_TRUE(torch::all(torch::eq(a[0], ta)).item<bool>());
109 ASSERT_TRUE(torch::all(torch::eq(a[1], tb)).item<bool>());
110 ASSERT_TRUE(torch::all(torch::eq(a[2], tc)).item<bool>());
111 ASSERT_TRUE(torch::all(torch::eq(a[3], td)).item<bool>());
112
113 ASSERT_EQ(b->size(), 2);
114 ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
115 ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
116
117 std::vector<torch::Tensor> c = {te, tf};
118 b->extend(c);
119
120 ASSERT_EQ(b->size(), 4);
121 ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
122 ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
123 ASSERT_TRUE(torch::all(torch::eq(b[2], te)).item<bool>());
124 ASSERT_TRUE(torch::all(torch::eq(b[3], tf)).item<bool>());
125}
126
127TEST_F(ParameterListTest, PrettyPrintParameterList) {
128 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
129 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
130 torch::Tensor tc = torch::randn({1, 2});
131 ParameterList list(ta, tb, tc);
132 ASSERT_EQ(
133 c10::str(list),
134 "torch::nn::ParameterList(\n"
135 "(0): Parameter containing: [Float of size [1, 2]]\n"
136 "(1): Parameter containing: [Float of size [1, 2]]\n"
137 "(2): Parameter containing: [Float of size [1, 2]]\n"
138 ")");
139}
140
141TEST_F(ParameterListTest, IncrementAdd) {
142 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
143 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
144 torch::Tensor tc = torch::randn({1, 2});
145 torch::Tensor td = torch::randn({1, 2, 3});
146 torch::Tensor te = torch::randn({1, 2});
147 torch::Tensor tf = torch::randn({1, 2, 3});
148 ParameterList listA(ta, tb, tc);
149 ParameterList listB(td, te, tf);
150 std::vector<torch::Tensor> tensors{ta, tb, tc, td, te, tf};
151 int idx = 0;
152 *listA += *listB;
153 ASSERT_TRUE(torch::all(torch::eq(listA[0], ta)).item<bool>());
154 ASSERT_TRUE(torch::all(torch::eq(listA[1], tb)).item<bool>());
155 ASSERT_TRUE(torch::all(torch::eq(listA[2], tc)).item<bool>());
156 ASSERT_TRUE(torch::all(torch::eq(listA[3], td)).item<bool>());
157 ASSERT_TRUE(torch::all(torch::eq(listA[4], te)).item<bool>());
158 ASSERT_TRUE(torch::all(torch::eq(listA[5], tf)).item<bool>());
159 for (const auto& P : listA->named_parameters(false))
160 ASSERT_TRUE(torch::all(torch::eq(P.value(), tensors[idx++])).item<bool>());
161
162 ASSERT_EQ(idx, 6);
163}
164