1 | #include <gtest/gtest.h> |
2 | #include <torch/torch.h> |
3 | #include <algorithm> |
4 | #include <memory> |
5 | #include <vector> |
6 | |
7 | #include <test/cpp/api/support.h> |
8 | |
9 | using namespace torch::nn; |
10 | using namespace torch::test; |
11 | |
12 | struct ModuleDictTest : torch::test::SeedingFixture {}; |
13 | |
14 | TEST_F(ModuleDictTest, ConstructsFromList) { |
15 | struct M : Module { |
16 | explicit M(int value_) : value(value_) {} |
17 | int value; |
18 | }; |
19 | |
20 | std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = { |
21 | {"module_1" , std::make_shared<M>(1)}, |
22 | {"module_2" , std::make_shared<M>(2)}, |
23 | {"module_3" , std::make_shared<M>(3)}}; |
24 | ModuleDict dict(list); |
25 | ASSERT_EQ(dict->size(), 3); |
26 | } |
27 | |
28 | TEST_F(ModuleDictTest, ConstructsFromordereddict) { |
29 | struct M : Module { |
30 | explicit M(int value_) : value(value_) {} |
31 | int value; |
32 | }; |
33 | |
34 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
35 | {"module_1" , std::make_shared<M>(1)}, |
36 | {"module_2" , std::make_shared<M>(2)}, |
37 | {"module_3" , std::make_shared<M>(3)}, |
38 | }; |
39 | ModuleDict dict(ordereddict); |
40 | ASSERT_EQ(dict->size(), 3); |
41 | } |
42 | |
43 | TEST_F(ModuleDictTest, UpdatePopClearContains) { |
44 | struct M : Module { |
45 | explicit M(int value_) : value(value_) {} |
46 | int value; |
47 | }; |
48 | |
49 | ModuleDict dict; |
50 | ASSERT_TRUE(dict->empty()); |
51 | // Update by List |
52 | std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = { |
53 | {"module_1" , std::make_shared<M>(1)}}; |
54 | dict->update(list1); |
55 | ASSERT_EQ(dict->size(), 1); |
56 | ASSERT_TRUE(dict->contains("module_1" )); |
57 | // Update by OrderedDict |
58 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
59 | {"module_2" , std::make_shared<M>(2)}}; |
60 | dict->update(ordereddict); |
61 | ASSERT_EQ(dict->size(), 2); |
62 | ASSERT_TRUE(dict->contains("module_2" )); |
63 | // Update by another ModuleDict |
64 | std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = { |
65 | {"module_3" , std::make_shared<M>(3)}}; |
66 | ModuleDict updatedict(list2); |
67 | dict->update(*updatedict); |
68 | ASSERT_EQ(dict->size(), 3); |
69 | ASSERT_TRUE(dict->contains("module_3" )); |
70 | // Pop |
71 | dict->pop("module_1" ); |
72 | ASSERT_EQ(dict->size(), 2); |
73 | // Pop unexist |
74 | ASSERT_THROWS_WITH(dict->pop("module_4" ), " 'module_4' is not defined" ); |
75 | // Clear |
76 | dict->clear(); |
77 | ASSERT_EQ(dict->size(), 0); |
78 | } |
79 | |
80 | TEST_F(ModuleDictTest, UpdateExist) { |
81 | struct M : Module { |
82 | explicit M(int value_) : value(value_) {} |
83 | int value; |
84 | }; |
85 | std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = { |
86 | {"module_1" , std::make_shared<M>(1)}, |
87 | {"module_2" , std::make_shared<M>(2)}}; |
88 | ModuleDict dict(list1); |
89 | ASSERT_EQ(dict->at<M>("module_2" ).value, 2); |
90 | // Update by list |
91 | std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = { |
92 | {"module_2" , std::make_shared<M>(0)}, |
93 | {"module_3" , std::make_shared<M>(3)}}; |
94 | dict->update(list2); |
95 | ASSERT_EQ(dict->size(), 3); |
96 | ASSERT_EQ(dict->at<M>("module_2" ).value, 0); |
97 | // Update by ordereddict |
98 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
99 | {"module_3" , std::make_shared<M>(0)}, |
100 | {"module_4" , std::make_shared<M>(4)}}; |
101 | dict->update(ordereddict); |
102 | ASSERT_EQ(dict->size(), 4); |
103 | ASSERT_EQ(dict->at<M>("module_3" ).value, 0); |
104 | // Update by ModuleDict |
105 | std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = { |
106 | {"module_4" , std::make_shared<M>(0)}, |
107 | {"module_1" , std::make_shared<M>(0)}}; |
108 | ModuleDict dict2(list3); |
109 | dict->update(*dict2); |
110 | ASSERT_EQ(dict->size(), 4); |
111 | ASSERT_EQ(dict->at<M>("module_1" ).value, 0); |
112 | ASSERT_EQ(dict->at<M>("module_4" ).value, 0); |
113 | } |
114 | |
115 | TEST_F(ModuleDictTest, Keys) { |
116 | struct M : Module { |
117 | explicit M(int value_) : value(value_) {} |
118 | int value; |
119 | }; |
120 | |
121 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
122 | {"linear" , Linear(10, 3).ptr()}, |
123 | {"conv" , Conv2d(1, 2, 3).ptr()}, |
124 | {"dropout" , Dropout(0.5).ptr()}, |
125 | }; |
126 | ModuleDict dict(ordereddict); |
127 | const auto& keys = dict->keys(); |
128 | std::vector<std::string> expected{"linear" , "conv" , "dropout" }; |
129 | ASSERT_EQ(keys, expected); |
130 | ASSERT_THROWS_WITH(dict["batch" ], " 'batch' is not defined" ); |
131 | |
132 | ASSERT_TRUE(dict["linear" ]->as<Linear>()); |
133 | ASSERT_TRUE(dict["conv" ]->as<Conv2d>()); |
134 | ASSERT_TRUE(dict["dropout" ]->as<Dropout>()); |
135 | } |
136 | |
137 | TEST_F(ModuleDictTest, Values) { |
138 | struct M : Module { |
139 | explicit M(int value_) : value(value_) {} |
140 | int value; |
141 | }; |
142 | |
143 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
144 | {"module_1" , std::make_shared<M>(1)}, |
145 | {"module_2" , std::make_shared<M>(2)}, |
146 | }; |
147 | ModuleDict dict(ordereddict); |
148 | const auto& values = dict->values(); |
149 | const auto& expected = ordereddict.values(); |
150 | ASSERT_EQ(values, expected); |
151 | ASSERT_TRUE(std::equal( |
152 | dict->begin(), |
153 | dict->end(), |
154 | ordereddict.begin(), |
155 | [](const auto& lhs, const auto& rhs) { |
156 | return lhs.value().get() == rhs.value().get(); |
157 | })); |
158 | } |
159 | |
160 | TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) { |
161 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
162 | {"linear" , Linear(10, 3).ptr()}, |
163 | {"conv" , Conv2d(1, 2, 3).ptr()}, |
164 | {"dropout" , Dropout(0.5).ptr()}, |
165 | {"batch" , BatchNorm2d(5).ptr()}, |
166 | {"embedding" , Embedding(4, 10).ptr()}, |
167 | {"lstm" , LSTM(4, 5).ptr()}}; |
168 | ModuleDict dict(ordereddict); |
169 | } |
170 | |
171 | TEST_F(ModuleDictTest, HasReferenceSemantics) { |
172 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
173 | {"linear1" , Linear(2, 3).ptr()}, |
174 | {"linear2" , Linear(3, 4).ptr()}, |
175 | {"linear3" , Linear(4, 5).ptr()}, |
176 | }; |
177 | ModuleDict first(ordereddict); |
178 | ModuleDict second(ordereddict); |
179 | |
180 | ASSERT_EQ(first->size(), second->size()); |
181 | ASSERT_TRUE(std::equal( |
182 | first->begin(), |
183 | first->end(), |
184 | second->begin(), |
185 | [](const auto& lhs, const auto& rhs) { |
186 | return lhs.value().get() == rhs.value().get(); |
187 | })); |
188 | } |
189 | |
190 | void iscloneable_helper(torch::Device device) { |
191 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
192 | {"linear" , Linear(2, 3).ptr()}, |
193 | {"relu" , Functional(torch::relu).ptr()}, |
194 | {"batch" , BatchNorm1d(3).ptr()}, |
195 | }; |
196 | ModuleDict dict(ordereddict); |
197 | dict->to(device); |
198 | ModuleDict clone = |
199 | std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device)); |
200 | ASSERT_EQ(dict->size(), clone->size()); |
201 | |
202 | for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end(); |
203 | ++it, ++it_c) { |
204 | // The key should be same |
205 | ASSERT_EQ(it->key(), it_c->key()); |
206 | // The modules should be the same kind (type). |
207 | ASSERT_EQ(it->value()->name(), it_c->value()->name()); |
208 | // But not pointer-equal (distinct objects). |
209 | ASSERT_NE(it->value(), it_c->value()); |
210 | } |
211 | |
212 | // Verify that the clone is deep, i.e. parameters of modules are cloned too. |
213 | torch::NoGradGuard no_grad; |
214 | |
215 | auto params1 = dict->named_parameters(); |
216 | auto params2 = clone->named_parameters(); |
217 | ASSERT_EQ(params1.size(), params2.size()); |
218 | for (auto& param : params1) { |
219 | ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()])); |
220 | ASSERT_EQ(param->device(), params2[param.key()].device()); |
221 | ASSERT_TRUE(param->allclose(params2[param.key()])); |
222 | param->add_(2); |
223 | } |
224 | for (auto& param : params1) { |
225 | ASSERT_FALSE(param->allclose(params2[param.key()])); |
226 | } |
227 | } |
228 | |
229 | TEST_F(ModuleDictTest, IsCloneable) { |
230 | iscloneable_helper(torch::kCPU); |
231 | } |
232 | |
233 | TEST_F(ModuleDictTest, IsCloneable_CUDA) { |
234 | iscloneable_helper({torch::kCUDA, 0}); |
235 | } |
236 | |
237 | TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) { |
238 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = { |
239 | {"linear" , Linear(10, 3).ptr()}, |
240 | {"conv" , Conv2d(1, 2, 3).ptr()}, |
241 | {"test" , Dropout(0.5).ptr()}, |
242 | }; |
243 | ModuleDict dict(ordereddict1); |
244 | |
245 | auto modules = dict->children(); |
246 | ASSERT_TRUE(modules[0]->as<Linear>()); |
247 | ASSERT_TRUE(modules[1]->as<Conv2d>()); |
248 | ASSERT_TRUE(modules[2]->as<Dropout>()); |
249 | |
250 | // Update Existing |
251 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = { |
252 | {"lstm" , LSTM(4, 5).ptr()}, {"test" , BatchNorm2d(5).ptr()}}; |
253 | dict->update(ordereddict2); |
254 | |
255 | modules = dict->children(); |
256 | ASSERT_TRUE(modules[0]->as<Linear>()); |
257 | ASSERT_TRUE(modules[1]->as<Conv2d>()); |
258 | // Keep Order |
259 | ASSERT_TRUE(modules[2]->as<BatchNorm2d>()); |
260 | ASSERT_TRUE(modules[3]->as<LSTM>()); |
261 | } |
262 | |
263 | TEST_F(ModuleDictTest, CloneToDevice_CUDA) { |
264 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
265 | {"linear" , Linear(2, 3).ptr()}, |
266 | {"relu" , Functional(torch::relu).ptr()}, |
267 | {"batch" , BatchNorm1d(3).ptr()}, |
268 | }; |
269 | ModuleDict dict(ordereddict); |
270 | torch::Device device(torch::kCUDA, 0); |
271 | ModuleDict clone = |
272 | std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device)); |
273 | for (const auto& p : clone->parameters()) { |
274 | ASSERT_EQ(p.device(), device); |
275 | } |
276 | for (const auto& b : clone->buffers()) { |
277 | ASSERT_EQ(b.device(), device); |
278 | } |
279 | } |
280 | |
281 | TEST_F(ModuleDictTest, PrettyPrintModuleDict) { |
282 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
283 | {"linear" , Linear(10, 3).ptr()}, |
284 | {"conv" , Conv2d(1, 2, 3).ptr()}, |
285 | {"dropout" , Dropout(0.5).ptr()}, |
286 | {"batch" , BatchNorm2d(5).ptr()}, |
287 | {"embedding" , Embedding(4, 10).ptr()}, |
288 | {"lstm" , LSTM(4, 5).ptr()}}; |
289 | ModuleDict dict(ordereddict); |
290 | |
291 | ASSERT_EQ( |
292 | c10::str(dict), |
293 | "torch::nn::ModuleDict(\n" |
294 | " (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n" |
295 | " (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n" |
296 | " (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n" |
297 | " (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n" |
298 | " (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" |
299 | " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n" |
300 | ")" ); |
301 | } |
302 | |
303 | TEST_F(ModuleDictTest, InvalidAt) { |
304 | torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
305 | {"linear" , Linear(10, 3).ptr()}}; |
306 | ModuleDict dict(ordereddict); |
307 | ASSERT_THROWS_WITH( |
308 | dict->at<torch::nn::Dropout2dImpl>("linear" ), "Unable to cast module" ); |
309 | } |
310 | |