1#include <gtest/gtest.h>
2#include <torch/torch.h>
3#include <algorithm>
4#include <memory>
5#include <vector>
6
7#include <test/cpp/api/support.h>
8
9using namespace torch::nn;
10using namespace torch::test;
11
12struct ModuleDictTest : torch::test::SeedingFixture {};
13
14TEST_F(ModuleDictTest, ConstructsFromList) {
15 struct M : Module {
16 explicit M(int value_) : value(value_) {}
17 int value;
18 };
19
20 std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
21 {"module_1", std::make_shared<M>(1)},
22 {"module_2", std::make_shared<M>(2)},
23 {"module_3", std::make_shared<M>(3)}};
24 ModuleDict dict(list);
25 ASSERT_EQ(dict->size(), 3);
26}
27
28TEST_F(ModuleDictTest, ConstructsFromordereddict) {
29 struct M : Module {
30 explicit M(int value_) : value(value_) {}
31 int value;
32 };
33
34 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
35 {"module_1", std::make_shared<M>(1)},
36 {"module_2", std::make_shared<M>(2)},
37 {"module_3", std::make_shared<M>(3)},
38 };
39 ModuleDict dict(ordereddict);
40 ASSERT_EQ(dict->size(), 3);
41}
42
43TEST_F(ModuleDictTest, UpdatePopClearContains) {
44 struct M : Module {
45 explicit M(int value_) : value(value_) {}
46 int value;
47 };
48
49 ModuleDict dict;
50 ASSERT_TRUE(dict->empty());
51 // Update by List
52 std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
53 {"module_1", std::make_shared<M>(1)}};
54 dict->update(list1);
55 ASSERT_EQ(dict->size(), 1);
56 ASSERT_TRUE(dict->contains("module_1"));
57 // Update by OrderedDict
58 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
59 {"module_2", std::make_shared<M>(2)}};
60 dict->update(ordereddict);
61 ASSERT_EQ(dict->size(), 2);
62 ASSERT_TRUE(dict->contains("module_2"));
63 // Update by another ModuleDict
64 std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
65 {"module_3", std::make_shared<M>(3)}};
66 ModuleDict updatedict(list2);
67 dict->update(*updatedict);
68 ASSERT_EQ(dict->size(), 3);
69 ASSERT_TRUE(dict->contains("module_3"));
70 // Pop
71 dict->pop("module_1");
72 ASSERT_EQ(dict->size(), 2);
73 // Pop unexist
74 ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined");
75 // Clear
76 dict->clear();
77 ASSERT_EQ(dict->size(), 0);
78}
79
80TEST_F(ModuleDictTest, UpdateExist) {
81 struct M : Module {
82 explicit M(int value_) : value(value_) {}
83 int value;
84 };
85 std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
86 {"module_1", std::make_shared<M>(1)},
87 {"module_2", std::make_shared<M>(2)}};
88 ModuleDict dict(list1);
89 ASSERT_EQ(dict->at<M>("module_2").value, 2);
90 // Update by list
91 std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
92 {"module_2", std::make_shared<M>(0)},
93 {"module_3", std::make_shared<M>(3)}};
94 dict->update(list2);
95 ASSERT_EQ(dict->size(), 3);
96 ASSERT_EQ(dict->at<M>("module_2").value, 0);
97 // Update by ordereddict
98 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
99 {"module_3", std::make_shared<M>(0)},
100 {"module_4", std::make_shared<M>(4)}};
101 dict->update(ordereddict);
102 ASSERT_EQ(dict->size(), 4);
103 ASSERT_EQ(dict->at<M>("module_3").value, 0);
104 // Update by ModuleDict
105 std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = {
106 {"module_4", std::make_shared<M>(0)},
107 {"module_1", std::make_shared<M>(0)}};
108 ModuleDict dict2(list3);
109 dict->update(*dict2);
110 ASSERT_EQ(dict->size(), 4);
111 ASSERT_EQ(dict->at<M>("module_1").value, 0);
112 ASSERT_EQ(dict->at<M>("module_4").value, 0);
113}
114
115TEST_F(ModuleDictTest, Keys) {
116 struct M : Module {
117 explicit M(int value_) : value(value_) {}
118 int value;
119 };
120
121 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
122 {"linear", Linear(10, 3).ptr()},
123 {"conv", Conv2d(1, 2, 3).ptr()},
124 {"dropout", Dropout(0.5).ptr()},
125 };
126 ModuleDict dict(ordereddict);
127 const auto& keys = dict->keys();
128 std::vector<std::string> expected{"linear", "conv", "dropout"};
129 ASSERT_EQ(keys, expected);
130 ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined");
131
132 ASSERT_TRUE(dict["linear"]->as<Linear>());
133 ASSERT_TRUE(dict["conv"]->as<Conv2d>());
134 ASSERT_TRUE(dict["dropout"]->as<Dropout>());
135}
136
137TEST_F(ModuleDictTest, Values) {
138 struct M : Module {
139 explicit M(int value_) : value(value_) {}
140 int value;
141 };
142
143 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
144 {"module_1", std::make_shared<M>(1)},
145 {"module_2", std::make_shared<M>(2)},
146 };
147 ModuleDict dict(ordereddict);
148 const auto& values = dict->values();
149 const auto& expected = ordereddict.values();
150 ASSERT_EQ(values, expected);
151 ASSERT_TRUE(std::equal(
152 dict->begin(),
153 dict->end(),
154 ordereddict.begin(),
155 [](const auto& lhs, const auto& rhs) {
156 return lhs.value().get() == rhs.value().get();
157 }));
158}
159
160TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) {
161 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
162 {"linear", Linear(10, 3).ptr()},
163 {"conv", Conv2d(1, 2, 3).ptr()},
164 {"dropout", Dropout(0.5).ptr()},
165 {"batch", BatchNorm2d(5).ptr()},
166 {"embedding", Embedding(4, 10).ptr()},
167 {"lstm", LSTM(4, 5).ptr()}};
168 ModuleDict dict(ordereddict);
169}
170
171TEST_F(ModuleDictTest, HasReferenceSemantics) {
172 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
173 {"linear1", Linear(2, 3).ptr()},
174 {"linear2", Linear(3, 4).ptr()},
175 {"linear3", Linear(4, 5).ptr()},
176 };
177 ModuleDict first(ordereddict);
178 ModuleDict second(ordereddict);
179
180 ASSERT_EQ(first->size(), second->size());
181 ASSERT_TRUE(std::equal(
182 first->begin(),
183 first->end(),
184 second->begin(),
185 [](const auto& lhs, const auto& rhs) {
186 return lhs.value().get() == rhs.value().get();
187 }));
188}
189
190void iscloneable_helper(torch::Device device) {
191 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
192 {"linear", Linear(2, 3).ptr()},
193 {"relu", Functional(torch::relu).ptr()},
194 {"batch", BatchNorm1d(3).ptr()},
195 };
196 ModuleDict dict(ordereddict);
197 dict->to(device);
198 ModuleDict clone =
199 std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
200 ASSERT_EQ(dict->size(), clone->size());
201
202 for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end();
203 ++it, ++it_c) {
204 // The key should be same
205 ASSERT_EQ(it->key(), it_c->key());
206 // The modules should be the same kind (type).
207 ASSERT_EQ(it->value()->name(), it_c->value()->name());
208 // But not pointer-equal (distinct objects).
209 ASSERT_NE(it->value(), it_c->value());
210 }
211
212 // Verify that the clone is deep, i.e. parameters of modules are cloned too.
213 torch::NoGradGuard no_grad;
214
215 auto params1 = dict->named_parameters();
216 auto params2 = clone->named_parameters();
217 ASSERT_EQ(params1.size(), params2.size());
218 for (auto& param : params1) {
219 ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
220 ASSERT_EQ(param->device(), params2[param.key()].device());
221 ASSERT_TRUE(param->allclose(params2[param.key()]));
222 param->add_(2);
223 }
224 for (auto& param : params1) {
225 ASSERT_FALSE(param->allclose(params2[param.key()]));
226 }
227}
228
229TEST_F(ModuleDictTest, IsCloneable) {
230 iscloneable_helper(torch::kCPU);
231}
232
233TEST_F(ModuleDictTest, IsCloneable_CUDA) {
234 iscloneable_helper({torch::kCUDA, 0});
235}
236
237TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) {
238 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = {
239 {"linear", Linear(10, 3).ptr()},
240 {"conv", Conv2d(1, 2, 3).ptr()},
241 {"test", Dropout(0.5).ptr()},
242 };
243 ModuleDict dict(ordereddict1);
244
245 auto modules = dict->children();
246 ASSERT_TRUE(modules[0]->as<Linear>());
247 ASSERT_TRUE(modules[1]->as<Conv2d>());
248 ASSERT_TRUE(modules[2]->as<Dropout>());
249
250 // Update Existing
251 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = {
252 {"lstm", LSTM(4, 5).ptr()}, {"test", BatchNorm2d(5).ptr()}};
253 dict->update(ordereddict2);
254
255 modules = dict->children();
256 ASSERT_TRUE(modules[0]->as<Linear>());
257 ASSERT_TRUE(modules[1]->as<Conv2d>());
258 // Keep Order
259 ASSERT_TRUE(modules[2]->as<BatchNorm2d>());
260 ASSERT_TRUE(modules[3]->as<LSTM>());
261}
262
263TEST_F(ModuleDictTest, CloneToDevice_CUDA) {
264 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
265 {"linear", Linear(2, 3).ptr()},
266 {"relu", Functional(torch::relu).ptr()},
267 {"batch", BatchNorm1d(3).ptr()},
268 };
269 ModuleDict dict(ordereddict);
270 torch::Device device(torch::kCUDA, 0);
271 ModuleDict clone =
272 std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
273 for (const auto& p : clone->parameters()) {
274 ASSERT_EQ(p.device(), device);
275 }
276 for (const auto& b : clone->buffers()) {
277 ASSERT_EQ(b.device(), device);
278 }
279}
280
281TEST_F(ModuleDictTest, PrettyPrintModuleDict) {
282 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
283 {"linear", Linear(10, 3).ptr()},
284 {"conv", Conv2d(1, 2, 3).ptr()},
285 {"dropout", Dropout(0.5).ptr()},
286 {"batch", BatchNorm2d(5).ptr()},
287 {"embedding", Embedding(4, 10).ptr()},
288 {"lstm", LSTM(4, 5).ptr()}};
289 ModuleDict dict(ordereddict);
290
291 ASSERT_EQ(
292 c10::str(dict),
293 "torch::nn::ModuleDict(\n"
294 " (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
295 " (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
296 " (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
297 " (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
298 " (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
299 " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
300 ")");
301}
302
303TEST_F(ModuleDictTest, InvalidAt) {
304 torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
305 {"linear", Linear(10, 3).ptr()}};
306 ModuleDict dict(ordereddict);
307 ASSERT_THROWS_WITH(
308 dict->at<torch::nn::Dropout2dImpl>("linear"), "Unable to cast module");
309}
310