1 | #include <gtest/gtest.h> |
2 | #include <torch/torch.h> |
3 | #include <algorithm> |
4 | #include <memory> |
5 | #include <vector> |
6 | |
7 | #include <test/cpp/api/support.h> |
8 | |
9 | using namespace torch::nn; |
10 | using namespace torch::test; |
11 | |
12 | struct ParameterDictTest : torch::test::SeedingFixture {}; |
13 | |
14 | TEST_F(ParameterDictTest, ConstructFromTensor) { |
15 | ParameterDict dict; |
16 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
17 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
18 | torch::Tensor tc = torch::randn({1, 2}); |
19 | ASSERT_TRUE(ta.requires_grad()); |
20 | ASSERT_FALSE(tb.requires_grad()); |
21 | dict->insert("A" , ta); |
22 | dict->insert("B" , tb); |
23 | dict->insert("C" , tc); |
24 | ASSERT_EQ(dict->size(), 3); |
25 | ASSERT_TRUE(torch::all(torch::eq(dict["A" ], ta)).item<bool>()); |
26 | ASSERT_TRUE(dict["A" ].requires_grad()); |
27 | ASSERT_TRUE(torch::all(torch::eq(dict["B" ], tb)).item<bool>()); |
28 | ASSERT_FALSE(dict["B" ].requires_grad()); |
29 | } |
30 | |
31 | TEST_F(ParameterDictTest, ConstructFromOrderedDict) { |
32 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
33 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
34 | torch::Tensor tc = torch::randn({1, 2}); |
35 | torch::OrderedDict<std::string, torch::Tensor> params = { |
36 | {"A" , ta}, {"B" , tb}, {"C" , tc}}; |
37 | auto dict = torch::nn::ParameterDict(params); |
38 | ASSERT_EQ(dict->size(), 3); |
39 | ASSERT_TRUE(torch::all(torch::eq(dict["A" ], ta)).item<bool>()); |
40 | ASSERT_TRUE(dict["A" ].requires_grad()); |
41 | ASSERT_TRUE(torch::all(torch::eq(dict["B" ], tb)).item<bool>()); |
42 | ASSERT_FALSE(dict["B" ].requires_grad()); |
43 | } |
44 | |
45 | TEST_F(ParameterDictTest, InsertAndContains) { |
46 | ParameterDict dict; |
47 | dict->insert("A" , torch::tensor({1.0})); |
48 | ASSERT_EQ(dict->size(), 1); |
49 | ASSERT_TRUE(dict->contains("A" )); |
50 | ASSERT_FALSE(dict->contains("C" )); |
51 | } |
52 | |
53 | TEST_F(ParameterDictTest, InsertAndClear) { |
54 | ParameterDict dict; |
55 | dict->insert("A" , torch::tensor({1.0})); |
56 | ASSERT_EQ(dict->size(), 1); |
57 | dict->clear(); |
58 | ASSERT_EQ(dict->size(), 0); |
59 | } |
60 | |
61 | TEST_F(ParameterDictTest, InsertAndPop) { |
62 | ParameterDict dict; |
63 | dict->insert("A" , torch::tensor({1.0})); |
64 | ASSERT_EQ(dict->size(), 1); |
65 | ASSERT_THROWS_WITH(dict->pop("B" ), "Parameter 'B' is not defined" ); |
66 | torch::Tensor p = dict->pop("A" ); |
67 | ASSERT_EQ(dict->size(), 0); |
68 | ASSERT_TRUE(torch::eq(p, torch::tensor({1.0})).item<bool>()); |
69 | } |
70 | |
71 | TEST_F(ParameterDictTest, SimpleUpdate) { |
72 | ParameterDict dict; |
73 | ParameterDict wrongDict; |
74 | ParameterDict rightDict; |
75 | dict->insert("A" , torch::tensor({1.0})); |
76 | dict->insert("B" , torch::tensor({2.0})); |
77 | dict->insert("C" , torch::tensor({3.0})); |
78 | wrongDict->insert("A" , torch::tensor({5.0})); |
79 | wrongDict->insert("D" , torch::tensor({5.0})); |
80 | ASSERT_THROWS_WITH(dict->update(*wrongDict), "Parameter 'D' is not defined" ); |
81 | rightDict->insert("A" , torch::tensor({5.0})); |
82 | dict->update(*rightDict); |
83 | ASSERT_EQ(dict->size(), 3); |
84 | ASSERT_TRUE(torch::eq(dict["A" ], torch::tensor({5.0})).item<bool>()); |
85 | } |
86 | |
87 | TEST_F(ParameterDictTest, Keys) { |
88 | torch::OrderedDict<std::string, torch::Tensor> params = { |
89 | {"a" , torch::tensor({1.0})}, |
90 | {"b" , torch::tensor({2.0})}, |
91 | {"c" , torch::tensor({1.0, 2.0})}}; |
92 | auto dict = torch::nn::ParameterDict(params); |
93 | std::vector<std::string> keys = dict->keys(); |
94 | std::vector<std::string> true_keys{"a" , "b" , "c" }; |
95 | ASSERT_EQ(keys, true_keys); |
96 | } |
97 | |
98 | TEST_F(ParameterDictTest, Values) { |
99 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
100 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
101 | torch::Tensor tc = torch::randn({1, 2}); |
102 | torch::OrderedDict<std::string, torch::Tensor> params = { |
103 | {"a" , ta}, {"b" , tb}, {"c" , tc}}; |
104 | auto dict = torch::nn::ParameterDict(params); |
105 | std::vector<torch::Tensor> values = dict->values(); |
106 | std::vector<torch::Tensor> true_values{ta, tb, tc}; |
107 | for (auto i = 0U; i < values.size(); i += 1) { |
108 | ASSERT_TRUE(torch::all(torch::eq(values[i], true_values[i])).item<bool>()); |
109 | } |
110 | } |
111 | |
112 | TEST_F(ParameterDictTest, Get) { |
113 | ParameterDict dict; |
114 | torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
115 | torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
116 | torch::Tensor tc = torch::randn({1, 2}); |
117 | ASSERT_TRUE(ta.requires_grad()); |
118 | ASSERT_FALSE(tb.requires_grad()); |
119 | dict->insert("A" , ta); |
120 | dict->insert("B" , tb); |
121 | dict->insert("C" , tc); |
122 | ASSERT_EQ(dict->size(), 3); |
123 | ASSERT_TRUE(torch::all(torch::eq(dict->get("A" ), ta)).item<bool>()); |
124 | ASSERT_TRUE(dict->get("A" ).requires_grad()); |
125 | ASSERT_TRUE(torch::all(torch::eq(dict->get("B" ), tb)).item<bool>()); |
126 | ASSERT_FALSE(dict->get("B" ).requires_grad()); |
127 | } |
128 | |
129 | TEST_F(ParameterDictTest, PrettyPrintParameterDict) { |
130 | torch::OrderedDict<std::string, torch::Tensor> params = { |
131 | {"a" , torch::tensor({1.0})}, |
132 | {"b" , torch::tensor({2.0, 1.0})}, |
133 | {"c" , torch::tensor({{3.0}, {2.1}})}, |
134 | {"d" , torch::tensor({{3.0, 1.3}, {1.2, 2.1}})}}; |
135 | auto dict = torch::nn::ParameterDict(params); |
136 | ASSERT_EQ( |
137 | c10::str(dict), |
138 | "torch::nn::ParameterDict(\n" |
139 | "(a): Parameter containing: [Float of size [1]]\n" |
140 | "(b): Parameter containing: [Float of size [2]]\n" |
141 | "(c): Parameter containing: [Float of size [2, 1]]\n" |
142 | "(d): Parameter containing: [Float of size [2, 2]]\n" |
143 | ")" ); |
144 | } |
145 | |