1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <algorithm>
7#include <memory>
8#include <vector>
9
10#include <test/cpp/api/support.h>
11
12using namespace torch::nn;
13using namespace torch::test;
14
15struct ModuleListTest : torch::test::SeedingFixture {};
16
17TEST_F(ModuleListTest, ConstructsFromSharedPointer) {
18 struct M : torch::nn::Module {
19 explicit M(int value_) : value(value_) {}
20 int value;
21 };
22 ModuleList list(
23 std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
24 ASSERT_EQ(list->size(), 3);
25}
26
27TEST_F(ModuleListTest, ConstructsFromConcreteType) {
28 static int copy_count;
29
30 struct M : torch::nn::Module {
31 explicit M(int value_) : value(value_) {}
32 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
33 M(const M& other) : torch::nn::Module(other) {
34 copy_count++;
35 }
36 int value;
37 };
38
39 copy_count = 0;
40 ModuleList list(M(1), M(2), M(3));
41 ASSERT_EQ(list->size(), 3);
42 // NOTE: The current implementation expects each module to be copied exactly
43 // once, which happens when the module is passed into `std::make_shared<T>()`.
44 // TODO: Find a way to avoid copying, and then delete the copy constructor of
45 // `M`.
46 ASSERT_EQ(copy_count, 3);
47}
48
49TEST_F(ModuleListTest, ConstructsFromModuleHolder) {
50 struct MImpl : torch::nn::Module {
51 explicit MImpl(int value_) : value(value_) {}
52 int value;
53 };
54
55 struct M : torch::nn::ModuleHolder<MImpl> {
56 using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
57 using torch::nn::ModuleHolder<MImpl>::get;
58 };
59
60 ModuleList list(M(1), M(2), M(3));
61 ASSERT_EQ(list->size(), 3);
62}
63
64TEST_F(ModuleListTest, PushBackAddsAnElement) {
65 struct M : torch::nn::Module {
66 explicit M(int value_) : value(value_) {}
67 int value;
68 };
69
70 ModuleList list;
71 ASSERT_EQ(list->size(), 0);
72 ASSERT_TRUE(list->is_empty());
73 list->push_back(Linear(3, 4));
74 ASSERT_EQ(list->size(), 1);
75 list->push_back(std::make_shared<M>(1));
76 ASSERT_EQ(list->size(), 2);
77 list->push_back(M(2));
78 ASSERT_EQ(list->size(), 3);
79}
80
81TEST_F(ModuleListTest, Insertion) {
82 struct MImpl : torch::nn::Module {
83 explicit MImpl(int value_) : value(value_) {}
84 int value;
85 };
86 TORCH_MODULE(M);
87
88 ModuleList list;
89 list->push_back(MImpl(1));
90 ASSERT_EQ(list->size(), 1);
91 list->insert(0, std::make_shared<MImpl>(2));
92 ASSERT_EQ(list->size(), 2);
93 list->insert(1, M(3));
94 ASSERT_EQ(list->size(), 3);
95 list->insert(3, M(4));
96 ASSERT_EQ(list->size(), 4);
97 ASSERT_EQ(list->at<MImpl>(0).value, 2);
98 ASSERT_EQ(list->at<MImpl>(1).value, 3);
99 ASSERT_EQ(list->at<MImpl>(2).value, 1);
100 ASSERT_EQ(list->at<MImpl>(3).value, 4);
101
102 std::unordered_map<size_t, size_t> U = {{0, 2}, {1, 3}, {2, 1}, {3, 4}};
103 for (const auto& P : list->named_modules("", false))
104 ASSERT_EQ(U[std::stoul(P.key())], P.value()->as<M>()->value);
105}
106
107TEST_F(ModuleListTest, AccessWithAt) {
108 struct M : torch::nn::Module {
109 explicit M(int value_) : value(value_) {}
110 int value;
111 };
112 std::vector<std::shared_ptr<M>> modules = {
113 std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
114
115 ModuleList list;
116 for (auto& module : modules) {
117 list->push_back(module);
118 }
119 ASSERT_EQ(list->size(), 3);
120
121 // returns the correct module for a given index
122 for (const auto i : c10::irange(modules.size())) {
123 ASSERT_EQ(&list->at<M>(i), modules[i].get());
124 }
125
126 // throws for a bad index
127 ASSERT_THROWS_WITH(list->at<M>(modules.size() + 1), "Index out of range");
128 ASSERT_THROWS_WITH(
129 list->at<M>(modules.size() + 1000000), "Index out of range");
130}
131
132TEST_F(ModuleListTest, AccessWithPtr) {
133 struct M : torch::nn::Module {
134 explicit M(int value_) : value(value_) {}
135 int value;
136 };
137 std::vector<std::shared_ptr<M>> modules = {
138 std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
139
140 ModuleList list;
141 for (auto& module : modules) {
142 list->push_back(module);
143 }
144 ASSERT_EQ(list->size(), 3);
145
146 // returns the correct module for a given index
147 for (const auto i : c10::irange(modules.size())) {
148 ASSERT_EQ(list->ptr(i).get(), modules[i].get());
149 ASSERT_EQ(list[i].get(), modules[i].get());
150 ASSERT_EQ(list->ptr<M>(i).get(), modules[i].get());
151 }
152
153 // throws for a bad index
154 ASSERT_THROWS_WITH(list->ptr(modules.size() + 1), "Index out of range");
155 ASSERT_THROWS_WITH(list->ptr(modules.size() + 1000000), "Index out of range");
156}
157
158TEST_F(ModuleListTest, SanityCheckForHoldingStandardModules) {
159 ModuleList list(
160 Linear(10, 3),
161 Conv2d(1, 2, 3),
162 Dropout(0.5),
163 BatchNorm2d(5),
164 Embedding(4, 10),
165 LSTM(4, 5));
166}
167
168TEST_F(ModuleListTest, ExtendPushesModulesFromOtherModuleList) {
169 struct A : torch::nn::Module {};
170 struct B : torch::nn::Module {};
171 struct C : torch::nn::Module {};
172 struct D : torch::nn::Module {};
173 ModuleList a(A{}, B{});
174 ModuleList b(C{}, D{});
175 a->extend(*b);
176
177 ASSERT_EQ(a->size(), 4);
178 ASSERT_TRUE(a[0]->as<A>());
179 ASSERT_TRUE(a[1]->as<B>());
180 ASSERT_TRUE(a[2]->as<C>());
181 ASSERT_TRUE(a[3]->as<D>());
182
183 ASSERT_EQ(b->size(), 2);
184 ASSERT_TRUE(b[0]->as<C>());
185 ASSERT_TRUE(b[1]->as<D>());
186
187 std::vector<std::shared_ptr<A>> c = {
188 std::make_shared<A>(), std::make_shared<A>()};
189 b->extend(c);
190
191 ASSERT_EQ(b->size(), 4);
192 ASSERT_TRUE(b[0]->as<C>());
193 ASSERT_TRUE(b[1]->as<D>());
194 ASSERT_TRUE(b[2]->as<A>());
195 ASSERT_TRUE(b[3]->as<A>());
196}
197
198TEST_F(ModuleListTest, HasReferenceSemantics) {
199 ModuleList first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
200 ModuleList second(first);
201
202 ASSERT_EQ(first.get(), second.get());
203 ASSERT_EQ(first->size(), second->size());
204 ASSERT_TRUE(std::equal(
205 first->begin(),
206 first->end(),
207 second->begin(),
208 [](const std::shared_ptr<Module>& first,
209 const std::shared_ptr<Module>& second) {
210 return first.get() == second.get();
211 }));
212}
213
214TEST_F(ModuleListTest, IsCloneable) {
215 ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
216 ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone());
217 ASSERT_EQ(list->size(), clone->size());
218
219 for (size_t i = 0; i < list->size(); ++i) {
220 // The modules should be the same kind (type).
221 ASSERT_EQ(list[i]->name(), clone[i]->name());
222 // But not pointer-equal (distinct objects).
223 ASSERT_NE(list[i], clone[i]);
224 }
225
226 // Verify that the clone is deep, i.e. parameters of modules are cloned too.
227
228 torch::NoGradGuard no_grad;
229
230 auto params1 = list->named_parameters();
231 auto params2 = clone->named_parameters();
232 ASSERT_EQ(params1.size(), params2.size());
233 for (auto& param : params1) {
234 ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
235 ASSERT_EQ(param->device(), params2[param.key()].device());
236 ASSERT_TRUE(param->allclose(params2[param.key()]));
237 param->add_(2);
238 }
239 for (auto& param : params1) {
240 ASSERT_FALSE(param->allclose(params2[param.key()]));
241 }
242}
243
244TEST_F(ModuleListTest, RegistersElementsAsSubmodules) {
245 ModuleList list(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
246
247 auto modules = list->children();
248 ASSERT_TRUE(modules[0]->as<Linear>());
249 ASSERT_TRUE(modules[1]->as<Conv2d>());
250 ASSERT_TRUE(modules[2]->as<Dropout2d>());
251}
252
253TEST_F(ModuleListTest, NestingIsPossible) {
254 ModuleList list(
255 (ModuleList(Dropout(), Dropout())),
256 (ModuleList(Dropout(), Dropout()), Dropout()));
257}
258
259TEST_F(ModuleListTest, CloneToDevice_CUDA) {
260 ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
261 torch::Device device(torch::kCUDA, 0);
262 ModuleList clone =
263 std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device));
264 for (const auto& p : clone->parameters()) {
265 ASSERT_EQ(p.device(), device);
266 }
267 for (const auto& b : clone->buffers()) {
268 ASSERT_EQ(b.device(), device);
269 }
270}
271
272TEST_F(ModuleListTest, PrettyPrintModuleList) {
273 ModuleList list(
274 Linear(10, 3),
275 Conv2d(1, 2, 3),
276 Dropout(0.5),
277 BatchNorm2d(5),
278 Embedding(4, 10),
279 LSTM(4, 5));
280 ASSERT_EQ(
281 c10::str(list),
282 "torch::nn::ModuleList(\n"
283 " (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
284 " (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
285 " (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
286 " (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
287 " (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
288 " (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
289 ")");
290}
291
292TEST_F(ModuleListTest, RangeBasedForLoop) {
293 torch::nn::ModuleList mlist(
294 torch::nn::Linear(3, 4),
295 torch::nn::BatchNorm1d(4),
296 torch::nn::Dropout(0.5));
297
298 std::stringstream buffer;
299 for (const auto& module : *mlist) {
300 module->pretty_print(buffer);
301 }
302}
303
304TEST_F(ModuleListTest, InvalidAt) {
305 torch::nn::ModuleList m(torch::nn::Linear(1, 2));
306 ASSERT_THROWS_WITH(
307 m->at<torch::nn::Dropout2dImpl>(0), "Unable to cast module");
308}
309