1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <algorithm> |
7 | #include <memory> |
8 | #include <vector> |
9 | |
10 | #include <test/cpp/api/support.h> |
11 | |
12 | using namespace torch::nn; |
13 | using namespace torch::test; |
14 | |
15 | struct ModuleListTest : torch::test::SeedingFixture {}; |
16 | |
17 | TEST_F(ModuleListTest, ConstructsFromSharedPointer) { |
18 | struct M : torch::nn::Module { |
19 | explicit M(int value_) : value(value_) {} |
20 | int value; |
21 | }; |
22 | ModuleList list( |
23 | std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)); |
24 | ASSERT_EQ(list->size(), 3); |
25 | } |
26 | |
27 | TEST_F(ModuleListTest, ConstructsFromConcreteType) { |
28 | static int copy_count; |
29 | |
30 | struct M : torch::nn::Module { |
31 | explicit M(int value_) : value(value_) {} |
32 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
33 | M(const M& other) : torch::nn::Module(other) { |
34 | copy_count++; |
35 | } |
36 | int value; |
37 | }; |
38 | |
39 | copy_count = 0; |
40 | ModuleList list(M(1), M(2), M(3)); |
41 | ASSERT_EQ(list->size(), 3); |
42 | // NOTE: The current implementation expects each module to be copied exactly |
43 | // once, which happens when the module is passed into `std::make_shared<T>()`. |
44 | // TODO: Find a way to avoid copying, and then delete the copy constructor of |
45 | // `M`. |
46 | ASSERT_EQ(copy_count, 3); |
47 | } |
48 | |
49 | TEST_F(ModuleListTest, ConstructsFromModuleHolder) { |
50 | struct MImpl : torch::nn::Module { |
51 | explicit MImpl(int value_) : value(value_) {} |
52 | int value; |
53 | }; |
54 | |
55 | struct M : torch::nn::ModuleHolder<MImpl> { |
56 | using torch::nn::ModuleHolder<MImpl>::ModuleHolder; |
57 | using torch::nn::ModuleHolder<MImpl>::get; |
58 | }; |
59 | |
60 | ModuleList list(M(1), M(2), M(3)); |
61 | ASSERT_EQ(list->size(), 3); |
62 | } |
63 | |
64 | TEST_F(ModuleListTest, PushBackAddsAnElement) { |
65 | struct M : torch::nn::Module { |
66 | explicit M(int value_) : value(value_) {} |
67 | int value; |
68 | }; |
69 | |
70 | ModuleList list; |
71 | ASSERT_EQ(list->size(), 0); |
72 | ASSERT_TRUE(list->is_empty()); |
73 | list->push_back(Linear(3, 4)); |
74 | ASSERT_EQ(list->size(), 1); |
75 | list->push_back(std::make_shared<M>(1)); |
76 | ASSERT_EQ(list->size(), 2); |
77 | list->push_back(M(2)); |
78 | ASSERT_EQ(list->size(), 3); |
79 | } |
80 | |
81 | TEST_F(ModuleListTest, Insertion) { |
82 | struct MImpl : torch::nn::Module { |
83 | explicit MImpl(int value_) : value(value_) {} |
84 | int value; |
85 | }; |
86 | TORCH_MODULE(M); |
87 | |
88 | ModuleList list; |
89 | list->push_back(MImpl(1)); |
90 | ASSERT_EQ(list->size(), 1); |
91 | list->insert(0, std::make_shared<MImpl>(2)); |
92 | ASSERT_EQ(list->size(), 2); |
93 | list->insert(1, M(3)); |
94 | ASSERT_EQ(list->size(), 3); |
95 | list->insert(3, M(4)); |
96 | ASSERT_EQ(list->size(), 4); |
97 | ASSERT_EQ(list->at<MImpl>(0).value, 2); |
98 | ASSERT_EQ(list->at<MImpl>(1).value, 3); |
99 | ASSERT_EQ(list->at<MImpl>(2).value, 1); |
100 | ASSERT_EQ(list->at<MImpl>(3).value, 4); |
101 | |
102 | std::unordered_map<size_t, size_t> U = {{0, 2}, {1, 3}, {2, 1}, {3, 4}}; |
103 | for (const auto& P : list->named_modules("" , false)) |
104 | ASSERT_EQ(U[std::stoul(P.key())], P.value()->as<M>()->value); |
105 | } |
106 | |
107 | TEST_F(ModuleListTest, AccessWithAt) { |
108 | struct M : torch::nn::Module { |
109 | explicit M(int value_) : value(value_) {} |
110 | int value; |
111 | }; |
112 | std::vector<std::shared_ptr<M>> modules = { |
113 | std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)}; |
114 | |
115 | ModuleList list; |
116 | for (auto& module : modules) { |
117 | list->push_back(module); |
118 | } |
119 | ASSERT_EQ(list->size(), 3); |
120 | |
121 | // returns the correct module for a given index |
122 | for (const auto i : c10::irange(modules.size())) { |
123 | ASSERT_EQ(&list->at<M>(i), modules[i].get()); |
124 | } |
125 | |
126 | // throws for a bad index |
127 | ASSERT_THROWS_WITH(list->at<M>(modules.size() + 1), "Index out of range" ); |
128 | ASSERT_THROWS_WITH( |
129 | list->at<M>(modules.size() + 1000000), "Index out of range" ); |
130 | } |
131 | |
132 | TEST_F(ModuleListTest, AccessWithPtr) { |
133 | struct M : torch::nn::Module { |
134 | explicit M(int value_) : value(value_) {} |
135 | int value; |
136 | }; |
137 | std::vector<std::shared_ptr<M>> modules = { |
138 | std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)}; |
139 | |
140 | ModuleList list; |
141 | for (auto& module : modules) { |
142 | list->push_back(module); |
143 | } |
144 | ASSERT_EQ(list->size(), 3); |
145 | |
146 | // returns the correct module for a given index |
147 | for (const auto i : c10::irange(modules.size())) { |
148 | ASSERT_EQ(list->ptr(i).get(), modules[i].get()); |
149 | ASSERT_EQ(list[i].get(), modules[i].get()); |
150 | ASSERT_EQ(list->ptr<M>(i).get(), modules[i].get()); |
151 | } |
152 | |
153 | // throws for a bad index |
154 | ASSERT_THROWS_WITH(list->ptr(modules.size() + 1), "Index out of range" ); |
155 | ASSERT_THROWS_WITH(list->ptr(modules.size() + 1000000), "Index out of range" ); |
156 | } |
157 | |
158 | TEST_F(ModuleListTest, SanityCheckForHoldingStandardModules) { |
159 | ModuleList list( |
160 | Linear(10, 3), |
161 | Conv2d(1, 2, 3), |
162 | Dropout(0.5), |
163 | BatchNorm2d(5), |
164 | Embedding(4, 10), |
165 | LSTM(4, 5)); |
166 | } |
167 | |
168 | TEST_F(ModuleListTest, ExtendPushesModulesFromOtherModuleList) { |
169 | struct A : torch::nn::Module {}; |
170 | struct B : torch::nn::Module {}; |
171 | struct C : torch::nn::Module {}; |
172 | struct D : torch::nn::Module {}; |
173 | ModuleList a(A{}, B{}); |
174 | ModuleList b(C{}, D{}); |
175 | a->extend(*b); |
176 | |
177 | ASSERT_EQ(a->size(), 4); |
178 | ASSERT_TRUE(a[0]->as<A>()); |
179 | ASSERT_TRUE(a[1]->as<B>()); |
180 | ASSERT_TRUE(a[2]->as<C>()); |
181 | ASSERT_TRUE(a[3]->as<D>()); |
182 | |
183 | ASSERT_EQ(b->size(), 2); |
184 | ASSERT_TRUE(b[0]->as<C>()); |
185 | ASSERT_TRUE(b[1]->as<D>()); |
186 | |
187 | std::vector<std::shared_ptr<A>> c = { |
188 | std::make_shared<A>(), std::make_shared<A>()}; |
189 | b->extend(c); |
190 | |
191 | ASSERT_EQ(b->size(), 4); |
192 | ASSERT_TRUE(b[0]->as<C>()); |
193 | ASSERT_TRUE(b[1]->as<D>()); |
194 | ASSERT_TRUE(b[2]->as<A>()); |
195 | ASSERT_TRUE(b[3]->as<A>()); |
196 | } |
197 | |
198 | TEST_F(ModuleListTest, HasReferenceSemantics) { |
199 | ModuleList first(Linear(2, 3), Linear(4, 4), Linear(4, 5)); |
200 | ModuleList second(first); |
201 | |
202 | ASSERT_EQ(first.get(), second.get()); |
203 | ASSERT_EQ(first->size(), second->size()); |
204 | ASSERT_TRUE(std::equal( |
205 | first->begin(), |
206 | first->end(), |
207 | second->begin(), |
208 | [](const std::shared_ptr<Module>& first, |
209 | const std::shared_ptr<Module>& second) { |
210 | return first.get() == second.get(); |
211 | })); |
212 | } |
213 | |
214 | TEST_F(ModuleListTest, IsCloneable) { |
215 | ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3)); |
216 | ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone()); |
217 | ASSERT_EQ(list->size(), clone->size()); |
218 | |
219 | for (size_t i = 0; i < list->size(); ++i) { |
220 | // The modules should be the same kind (type). |
221 | ASSERT_EQ(list[i]->name(), clone[i]->name()); |
222 | // But not pointer-equal (distinct objects). |
223 | ASSERT_NE(list[i], clone[i]); |
224 | } |
225 | |
226 | // Verify that the clone is deep, i.e. parameters of modules are cloned too. |
227 | |
228 | torch::NoGradGuard no_grad; |
229 | |
230 | auto params1 = list->named_parameters(); |
231 | auto params2 = clone->named_parameters(); |
232 | ASSERT_EQ(params1.size(), params2.size()); |
233 | for (auto& param : params1) { |
234 | ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()])); |
235 | ASSERT_EQ(param->device(), params2[param.key()].device()); |
236 | ASSERT_TRUE(param->allclose(params2[param.key()])); |
237 | param->add_(2); |
238 | } |
239 | for (auto& param : params1) { |
240 | ASSERT_FALSE(param->allclose(params2[param.key()])); |
241 | } |
242 | } |
243 | |
244 | TEST_F(ModuleListTest, RegistersElementsAsSubmodules) { |
245 | ModuleList list(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5)); |
246 | |
247 | auto modules = list->children(); |
248 | ASSERT_TRUE(modules[0]->as<Linear>()); |
249 | ASSERT_TRUE(modules[1]->as<Conv2d>()); |
250 | ASSERT_TRUE(modules[2]->as<Dropout2d>()); |
251 | } |
252 | |
253 | TEST_F(ModuleListTest, NestingIsPossible) { |
254 | ModuleList list( |
255 | (ModuleList(Dropout(), Dropout())), |
256 | (ModuleList(Dropout(), Dropout()), Dropout())); |
257 | } |
258 | |
259 | TEST_F(ModuleListTest, CloneToDevice_CUDA) { |
260 | ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3)); |
261 | torch::Device device(torch::kCUDA, 0); |
262 | ModuleList clone = |
263 | std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device)); |
264 | for (const auto& p : clone->parameters()) { |
265 | ASSERT_EQ(p.device(), device); |
266 | } |
267 | for (const auto& b : clone->buffers()) { |
268 | ASSERT_EQ(b.device(), device); |
269 | } |
270 | } |
271 | |
272 | TEST_F(ModuleListTest, PrettyPrintModuleList) { |
273 | ModuleList list( |
274 | Linear(10, 3), |
275 | Conv2d(1, 2, 3), |
276 | Dropout(0.5), |
277 | BatchNorm2d(5), |
278 | Embedding(4, 10), |
279 | LSTM(4, 5)); |
280 | ASSERT_EQ( |
281 | c10::str(list), |
282 | "torch::nn::ModuleList(\n" |
283 | " (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n" |
284 | " (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n" |
285 | " (2): torch::nn::Dropout(p=0.5, inplace=false)\n" |
286 | " (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n" |
287 | " (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" |
288 | " (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n" |
289 | ")" ); |
290 | } |
291 | |
292 | TEST_F(ModuleListTest, RangeBasedForLoop) { |
293 | torch::nn::ModuleList mlist( |
294 | torch::nn::Linear(3, 4), |
295 | torch::nn::BatchNorm1d(4), |
296 | torch::nn::Dropout(0.5)); |
297 | |
298 | std::stringstream buffer; |
299 | for (const auto& module : *mlist) { |
300 | module->pretty_print(buffer); |
301 | } |
302 | } |
303 | |
304 | TEST_F(ModuleListTest, InvalidAt) { |
305 | torch::nn::ModuleList m(torch::nn::Linear(1, 2)); |
306 | ASSERT_THROWS_WITH( |
307 | m->at<torch::nn::Dropout2dImpl>(0), "Unable to cast module" ); |
308 | } |
309 | |