1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/support.h>
7
8using namespace torch::nn;
9using namespace torch::test;
10
11struct AGIUnit : torch::nn::Module {};
12
13namespace test {
14struct AGIUnit : torch::nn::Module {};
15struct AGIUnit2 : torch::nn::Module {
16 AGIUnit2() : torch::nn::Module("Foo") {}
17};
18} // namespace test
19
20struct ModuleTest : torch::test::SeedingFixture {};
21
22TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
23 Linear module(3, 4);
24 ASSERT_TRUE(module->is_training());
25
26 module->eval();
27 ASSERT_FALSE(module->is_training());
28
29 module->train();
30 ASSERT_TRUE(module->is_training());
31}
32
33TEST_F(ModuleTest, ZeroGrad) {
34 Linear module(3, 4);
35 auto weight = torch::ones({8, 3}, torch::requires_grad());
36 auto loss = module(weight).sum();
37 loss.backward();
38 for (auto& parameter : module->parameters()) {
39 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
40 auto grad = parameter.grad();
41 ASSERT_TRUE(grad.defined());
42 ASSERT_NE(grad.sum().item<float>(), 0);
43 }
44 module->zero_grad();
45 for (auto& parameter : module->parameters()) {
46 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
47 auto grad = parameter.grad();
48 ASSERT_FALSE(grad.defined());
49 }
50}
51
52TEST_F(ModuleTest, ZeroGradWithUndefined) {
53 struct TestModule : torch::nn::Module {
54 TestModule() {
55 x = register_parameter("x", torch::ones(5, torch::requires_grad()));
56 y = register_parameter("y", torch::ones(5, torch::requires_grad()));
57 }
58 torch::Tensor x, y;
59 };
60
61 TestModule module;
62 auto z = module.x * 2;
63 z.sum().backward();
64
65 ASSERT_TRUE(module.x.grad().defined());
66 ASSERT_FALSE(module.y.grad().defined());
67
68 module.zero_grad(false); // set_to_none = false
69
70 ASSERT_TRUE(module.x.grad().defined());
71 ASSERT_FALSE(module.y.grad().defined());
72
73 ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
74
75 module.zero_grad();
76
77 ASSERT_FALSE(module.x.grad().defined());
78 ASSERT_FALSE(module.y.grad().defined());
79}
80
81TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
82 struct TestModel : public torch::nn::Module {};
83 ASSERT_THROWS_WITH(
84 TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
85 "Submodule name must not contain a dot (got 'name.with.dot')");
86 ASSERT_THROWS_WITH(
87 TestModel{}.register_module("", torch::nn::Linear(3, 4)),
88 "Submodule name must not be empty");
89}
90
91TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
92 struct TestModel : public torch::nn::Module {};
93 TestModel model;
94 model.register_module("linear", torch::nn::Linear(3, 4));
95 ASSERT_THROWS_WITH(
96 model.register_module("linear", torch::nn::Linear(3, 4)),
97 "Submodule 'linear' already defined");
98}
99
100TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
101 torch::nn::Module model;
102 ASSERT_THROWS_WITH(
103 model.replace_module("linear", torch::nn::Linear(3, 4)),
104 "Submodule 'linear' is not defined");
105}
106
107TEST_F(ModuleTest, ReplaceModule) {
108 struct TestModel : public torch::nn::Module {
109 torch::nn::Linear l1{nullptr};
110 TestModel() {
111 l1 = register_module("l1", torch::nn::Linear(3, 4));
112 }
113 };
114 auto model = std::make_shared<TestModel>();
115 model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
116 ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
117 ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
118}
119
120TEST_F(ModuleTest, UnregisterModule) {
121 struct TestModel : public torch::nn::Module {};
122 TestModel model;
123 ASSERT_THROWS_WITH(
124 model.unregister_module("linear"),
125 "No Module with name `linear` is registered");
126 model.register_module("linear", torch::nn::Linear(3, 4));
127 model.unregister_module("linear");
128 ASSERT_TRUE(model.children().empty());
129}
130
131TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
132 struct TestModel : public torch::nn::Module {};
133 ASSERT_THROWS_WITH(
134 TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
135 "Parameter name must not contain a dot (got 'name.with.dot')");
136 ASSERT_THROWS_WITH(
137 TestModel{}.register_parameter("", torch::ones(5)),
138 "Parameter name must not be empty");
139}
140
141TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
142 struct TestModel : public torch::nn::Module {};
143 TestModel model;
144 model.register_parameter("p", torch::ones(5));
145 ASSERT_THROWS_WITH(
146 model.register_parameter("p", torch::ones(5)),
147 "Parameter 'p' already defined");
148}
149
150TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
151 struct TestModel : public torch::nn::Module {};
152 {
153 TestModel model;
154 model.register_parameter(
155 "undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
156 ASSERT_EQ(model.parameters().size(), 0);
157 }
158 {
159 WarningCapture warnings;
160
161 TestModel model;
162 model.register_parameter("undefined_tensor", torch::Tensor());
163 ASSERT_EQ(model.parameters().size(), 0);
164
165 ASSERT_EQ(
166 count_substr_occurrences(
167 warnings.str(),
168 "Ignoring the `requires_grad=true` function parameter"),
169 1);
170 }
171}
172
173TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
174 struct TestModel : public torch::nn::Module {};
175 ASSERT_THROWS_WITH(
176 TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
177 "Buffer name must not contain a dot (got 'name.with.dot')");
178 ASSERT_THROWS_WITH(
179 TestModel{}.register_buffer("", torch::ones(5)),
180 "Buffer name must not be empty");
181}
182
183TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
184 struct TestModel : public torch::nn::Module {};
185 TestModel model;
186 model.register_buffer("p", torch::ones(5));
187 ASSERT_THROWS_WITH(
188 model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
189}
190
191TEST_F(ModuleTest, CanGetName) {
192 // CHECK instead of REQUIRE because demangling may fail.
193 AGIUnit agi;
194 // Call it twice just to make sure there are no bugs in the lazy
195 // initialization semantics.
196 EXPECT_EQ(agi.name(), "AGIUnit");
197 EXPECT_EQ(agi.name(), "AGIUnit");
198 EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
199 EXPECT_EQ(test::AGIUnit2().name(), "Foo");
200}
201
202TEST_F(ModuleTest, AsCastsModulesCorrectly) {
203 Linear module(3, 4);
204 ASSERT_EQ(module->as<Linear>(), module.get());
205 ASSERT_EQ(module->as<LinearImpl>(), module.get());
206 ASSERT_EQ(module->as<Module>(), module.get());
207 ASSERT_EQ(module->as<AGIUnit>(), nullptr);
208
209 std::shared_ptr<Module> raw = module.ptr();
210 ASSERT_EQ(raw->as<Linear>(), module.get());
211 ASSERT_EQ(raw->as<LinearImpl>(), module.get());
212 ASSERT_EQ(raw->as<Module>(), module.get());
213 ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
214
215 Module& raw_ref = *raw.get();
216 ASSERT_EQ(raw_ref.as<Linear>(), module.get());
217 ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
218 ASSERT_EQ(raw_ref.as<Module>(), module.get());
219 ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
220 if (auto* linear = raw_ref.as<Linear>()) {
221 ASSERT_EQ(linear->weight.ndimension(), 2);
222 }
223
224 AGIUnit unit;
225 ASSERT_EQ(unit.as<Linear>(), nullptr);
226 ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
227 ASSERT_EQ(unit.as<AGIUnit>(), &unit);
228}
229
230void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
231 torch::Device to_device,
232 torch::Dtype to_dtype) {
233 {
234 // Case 1: Undefined tensors as parameters
235 Linear module(LinearOptions(10, 20).bias(false));
236 ASSERT_TRUE(module->weight.defined());
237 ASSERT_FALSE(module->bias.defined());
238
239 module->to(to_device);
240 ASSERT_TRUE(module->weight.defined());
241 ASSERT_EQ(module->weight.device().type(), to_device.type());
242 ASSERT_FALSE(module->bias.defined());
243
244 module->to(to_dtype);
245 ASSERT_TRUE(module->weight.defined());
246 ASSERT_EQ(module->weight.dtype(), to_dtype);
247 ASSERT_FALSE(module->bias.defined());
248 }
249 {
250 // Case 2: Undefined tensors as buffers
251 BatchNorm1d module(
252 BatchNorm1dOptions(5).track_running_stats(false).affine(true));
253 ASSERT_TRUE(module->weight.defined());
254 ASSERT_FALSE(module->running_mean.defined());
255
256 module->to(to_device);
257 ASSERT_TRUE(module->weight.defined());
258 ASSERT_EQ(module->weight.device().type(), to_device.type());
259 ASSERT_FALSE(module->running_mean.defined());
260
261 module->to(to_dtype);
262 ASSERT_TRUE(module->weight.defined());
263 ASSERT_EQ(module->weight.dtype(), to_dtype);
264 ASSERT_FALSE(module->running_mean.defined());
265 }
266}
267
268TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
269 test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
270}
271
272TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
273 test_DeviceOrDtypeConversionSkipsUndefinedTensor(
274 torch::kCUDA, torch::kDouble);
275}
276
277TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
278 {
279 Linear module(LinearOptions(10, 20).bias(false));
280
281 auto params = module->parameters();
282 ASSERT_EQ(params.size(), 1);
283 auto named_params = module->named_parameters();
284 ASSERT_EQ(named_params.size(), 1);
285
286 ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
287 ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
288 }
289 {
290 BatchNorm1d module(
291 BatchNorm1dOptions(5).track_running_stats(false).affine(false));
292
293 auto buffers = module->buffers();
294 ASSERT_EQ(buffers.size(), 0);
295 auto named_buffers = module->named_buffers();
296 ASSERT_EQ(named_buffers.size(), 0);
297 }
298 {
299 BatchNorm1d module(
300 BatchNorm1dOptions(5).track_running_stats(true).affine(false));
301
302 auto buffers = module->buffers();
303 ASSERT_EQ(buffers.size(), 3);
304 auto named_buffers = module->named_buffers();
305 ASSERT_EQ(named_buffers.size(), 3);
306
307 ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
308 ASSERT_TRUE(
309 pointer_equal(named_buffers["running_mean"], module->running_mean));
310 ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
311 ASSERT_TRUE(
312 pointer_equal(named_buffers["running_var"], module->running_var));
313 ASSERT_TRUE(
314 pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
315 ASSERT_TRUE(pointer_equal(
316 named_buffers["num_batches_tracked"], module->num_batches_tracked));
317 }
318}
319
320TEST_F(ModuleTest, Conversion_MultiCUDA) {
321 Linear module(128, 64);
322 for (auto& parameter : module->parameters()) {
323 ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
324 ASSERT_EQ(parameter.dtype(), torch::kFloat32);
325 }
326 {
327 module->to({torch::kCUDA, 0});
328 for (auto& parameter : module->parameters()) {
329 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
330 ASSERT_EQ(parameter.device().index(), 0);
331 }
332 module->to({torch::kCUDA, 1});
333 for (auto& parameter : module->parameters()) {
334 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
335 ASSERT_EQ(parameter.device().index(), 1);
336 }
337 }
338 {
339 module->to(torch::Device(torch::kCPU));
340 for (auto& parameter : module->parameters()) {
341 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
342 }
343 }
344 {
345 module->to(torch::kFloat64);
346 for (auto& parameter : module->parameters()) {
347 ASSERT_EQ(parameter.dtype(), torch::kFloat64);
348 }
349 }
350}
351
352TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
353 Linear module(128, 64);
354 for (auto& parameter : module->parameters()) {
355 parameter.requires_grad_(false);
356 }
357 {
358 module->to(torch::kInt32);
359 for (auto& parameter : module->parameters()) {
360 ASSERT_EQ(parameter.dtype(), torch::kInt32);
361 }
362 }
363 {
364 module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
365 for (auto& parameter : module->parameters()) {
366 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
367 ASSERT_EQ(parameter.device().index(), 1);
368 }
369 for (auto& parameter : module->parameters()) {
370 ASSERT_EQ(parameter.dtype(), torch::kUInt8);
371 }
372 }
373}
374
375TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
376 struct UnCloneable : Module {};
377 UnCloneable module;
378 ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
379}
380
381TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
382 struct Cloneable : Module {
383 std::shared_ptr<Module> clone(
384 const torch::optional<torch::Device>& device =
385 torch::nullopt) const override {
386 return nullptr;
387 }
388 };
389 Cloneable module;
390 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
391 ASSERT_NO_THROW({ module.clone(); });
392}
393
394// NOLINTNEXTLINE(bugprone-exception-escape)
395struct TestDistinctParametersModule
396 : public Cloneable<TestDistinctParametersModule> {
397 TestDistinctParametersModule() {
398 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
399 reset();
400 }
401 void reset() override {
402 l1 = register_module("l1", Linear(10, 3));
403 l2 = register_module("l2", Linear(3, 5));
404 l3 = register_module("l3", Linear(5, 100));
405 buffer = register_buffer("buf", torch::ones({2, 2}));
406 }
407
408 Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
409 torch::Tensor buffer;
410};
411
412void testDistinctParameters(
413 std::shared_ptr<Module> m1,
414 std::shared_ptr<Module> m2) {
415 auto params1 = m1->named_parameters();
416 auto params2 = m2->named_parameters();
417 ASSERT_EQ(params1.size(), 6);
418 ASSERT_EQ(params2.size(), 6);
419 for (auto& param : params1) {
420 ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
421 ASSERT_TRUE(param->allclose(params2[param.key()]));
422 param->add_(2);
423 }
424 for (auto& param : params1) {
425 ASSERT_FALSE(param->allclose(params2[param.key()]));
426 }
427
428 auto buffers1 = m1->named_buffers();
429 auto buffers2 = m2->named_buffers();
430 ASSERT_EQ(buffers1.size(), 1);
431 ASSERT_EQ(buffers2.size(), 1);
432 for (auto& buffer : buffers1) {
433 ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
434 ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
435 buffer->add_(2);
436 }
437 for (auto& buffer : buffers1) {
438 ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
439 }
440}
441
442TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
443 auto module = std::make_shared<TestDistinctParametersModule>();
444 torch::NoGradGuard no_grad;
445 auto module2 = module->clone();
446 testDistinctParameters(module, module2);
447}
448
449TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
450 auto module = std::make_shared<TestDistinctParametersModule>();
451 torch::NoGradGuard no_grad;
452 torch::Device device(torch::kCUDA, 0);
453 module->to(device);
454 auto module2 = module->clone(device);
455 testDistinctParameters(module, module2);
456}
457
458TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
459 auto module = std::make_shared<TestDistinctParametersModule>();
460 torch::NoGradGuard no_grad;
461 torch::Device d0(torch::kCUDA, 0);
462 torch::Device d1(torch::kCUDA, 1);
463 module->to(d0);
464 auto module2 = module->clone(d1);
465
466 for (auto& param : module->parameters()) {
467 ASSERT_EQ(param.device(), d0);
468 }
469
470 for (auto& param : module2->parameters()) {
471 ASSERT_EQ(param.device(), d1);
472 }
473
474 // need to move the module back to d0 as allclose expects two tensors on
475 // the same device.
476 module2->to(d0);
477 testDistinctParameters(module, module2);
478}
479
480TEST_F(ModuleTest, ClonePreservesExternalReferences) {
481 // NOLINTNEXTLINE(bugprone-exception-escape)
482 struct TestModule : public Cloneable<TestModule> {
483 TestModule() {
484 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
485 reset();
486 }
487 void reset() override {
488 weight = register_parameter("weight", torch::ones({4, 4}));
489 }
490 torch::Tensor weight;
491 };
492 auto module = std::make_shared<TestModule>();
493 {
494 torch::NoGradGuard no_grad;
495 module->weight += 1;
496 }
497 ASSERT_TRUE(
498 pointer_equal(module->weight, module->named_parameters()["weight"]));
499 ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
500
501 auto module2 = std::dynamic_pointer_cast<TestModule>(
502 std::shared_ptr<Module>(module->clone()));
503 ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
504 ASSERT_TRUE(
505 pointer_equal(module2->weight, module2->named_parameters()["weight"]));
506 ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
507 ASSERT_TRUE(module2->weight.allclose(module->weight));
508 ASSERT_FALSE(
509 pointer_equal(module2->weight, module->named_parameters()["weight"]));
510}
511
512TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
513 // NOLINTNEXTLINE(bugprone-exception-escape)
514 struct TestModule : public Cloneable<TestModule> {
515 TestModule() {
516 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
517 reset();
518 }
519 void reset() override {
520 weight = register_parameter("weight", torch::ones({4, 4}));
521 }
522
523 torch::Tensor weight;
524 int value = 0;
525 };
526 // NOLINTNEXTLINE(bugprone-exception-escape)
527 struct NestedModule : public Cloneable<NestedModule> {
528 NestedModule() {
529 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
530 reset();
531 }
532 void reset() override {
533 module = register_module("module", std::make_shared<TestModule>());
534 }
535 std::shared_ptr<TestModule> module;
536 };
537
538 auto a = std::make_shared<NestedModule>();
539 {
540 torch::NoGradGuard no_grad;
541 a->module->weight += 1;
542 a->module->value = 123;
543 }
544
545 auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
546
547 ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
548 ASSERT_TRUE(pointer_equal(
549 b->module->weight, b->module->named_parameters()["weight"]));
550 ASSERT_TRUE(
551 b->module->named_parameters()["weight"].allclose(a->module->weight));
552 ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
553 ASSERT_EQ(b->module->value, a->module->value);
554}
555
556TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
557 // NOLINTNEXTLINE(bugprone-exception-escape)
558 struct TestModule : public Cloneable<TestModule> {
559 TestModule() {
560 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
561 reset();
562 }
563 void reset() override {
564 l1 = register_module("l1", Linear(10, 3));
565 l2 = register_module("l2", Linear(3, 5));
566 l3 = register_module("l3", Linear(5, 100));
567 buffer = register_buffer("buf", torch::ones({2, 2}));
568 }
569
570 Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
571 torch::Tensor buffer;
572 };
573
574 TestModule m;
575 torch::Device device(torch::kCUDA, 0);
576
577 m.to(device);
578
579 auto clone = m.clone();
580 for (const auto& parameter : clone->parameters()) {
581 ASSERT_EQ(parameter.device().type(), device.type());
582 ASSERT_EQ(parameter.device().index(), device.index());
583 }
584 for (const auto& buffer : clone->buffers()) {
585 ASSERT_EQ(buffer.device().type(), device.type());
586 ASSERT_EQ(buffer.device().index(), device.index());
587 }
588}
589
590TEST_F(
591 ModuleTest,
592 CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
593 // NOLINTNEXTLINE(bugprone-exception-escape)
594 struct TestModule : public Cloneable<TestModule> {
595 TestModule() {
596 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
597 reset();
598 }
599 void reset() override {
600 l1 = register_module("l1", Linear(10, 3));
601 l2 = register_module("l2", Linear(3, 5));
602 l3 = register_module("l3", Linear(5, 100));
603 buffer = register_buffer("buf", torch::ones({2, 2}));
604 }
605
606 Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
607 torch::Tensor buffer;
608 };
609
610 TestModule m;
611 torch::Device device(torch::kCUDA, 1);
612 // everything is on CPU here
613 auto clone = m.clone(device);
614 for (const auto& parameter : clone->parameters()) {
615 ASSERT_EQ(parameter.device().type(), device.type());
616 ASSERT_EQ(parameter.device().index(), device.index());
617 }
618 for (const auto& buffer : clone->buffers()) {
619 ASSERT_EQ(buffer.device().type(), device.type());
620 ASSERT_EQ(buffer.device().index(), device.index());
621 }
622}
623
624struct ParameterTestModule : Module {
625 ParameterTestModule() {
626 a = register_parameter("a", torch::zeros({2, 2}));
627 b = register_parameter("b", torch::ones({2, 2}));
628 c = register_parameter("c", torch::ones({2, 2}) * 2);
629 }
630
631 torch::Tensor a, b, c;
632};
633
634TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
635 ParameterTestModule module;
636 ASSERT_EQ(module.parameters().size(), 3);
637 ASSERT_EQ(module.named_parameters().size(), 3);
638}
639
640TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
641 ParameterTestModule module;
642 auto parameters = module.named_parameters();
643 ASSERT_TRUE(parameters.contains("a"));
644 ASSERT_TRUE(parameters.contains("b"));
645 ASSERT_TRUE(parameters.contains("c"));
646}
647
648struct BufferTestModule : Module {
649 BufferTestModule() {
650 a = register_buffer("a", torch::zeros({2, 2}));
651 b = register_buffer("b", torch::ones({2, 2}));
652 c = register_buffer("c", torch::ones({2, 2}) * 2);
653 }
654
655 torch::Tensor a, b, c;
656};
657
658TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
659 BufferTestModule module;
660 ASSERT_EQ(module.buffers().size(), 3);
661 ASSERT_EQ(module.named_buffers().size(), 3);
662}
663
664TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
665 BufferTestModule module;
666 auto buffers = module.named_buffers();
667 ASSERT_TRUE(buffers.contains("a"));
668 ASSERT_TRUE(buffers.contains("b"));
669 ASSERT_TRUE(buffers.contains("c"));
670}
671
672struct AImpl : torch::nn::Module {
673 AImpl() : x_(123) {}
674 AImpl(int x) : x_(x) {}
675 int x_;
676};
677TORCH_MODULE(A);
678
679TEST_F(
680 ModuleTest,
681 DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
682 A a;
683 ASSERT_TRUE(a);
684 ASSERT_FALSE(a.is_empty());
685 ASSERT_EQ(a->x_, 123);
686}
687
688TEST_F(
689 ModuleTest,
690 ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
691 A a(5);
692 ASSERT_TRUE(a);
693 ASSERT_FALSE(a.is_empty());
694 ASSERT_EQ(a->x_, 5);
695}
696
697TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
698 A a = nullptr;
699 ASSERT_FALSE(a);
700 ASSERT_TRUE(a.is_empty());
701 ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
702}
703
704struct TestModule : public torch::nn::Module {
705 TestModule(int64_t size) {
706 p1 = register_parameter("p1", torch::randn({size}));
707 p2 = register_parameter("p2", torch::randn({size}));
708 b1 = register_buffer("b1", torch::randn({size}));
709 b2 = register_buffer("b2", torch::randn({size}));
710 }
711
712 torch::Tensor forward(torch::Tensor input) {
713 return input;
714 }
715
716 torch::Tensor p1, p2, b1, b2;
717};
718
719TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
720 torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
721 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
722 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
723 model.ptr(), model[0], model[1], model[2]};
724 ASSERT_EQ(modules.size(), expected.size());
725 for (const auto i : c10::irange(expected.size())) {
726 // Assert pointer equality.
727 ASSERT_EQ(modules[i].get(), expected[i].get());
728 }
729}
730
731TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
732 torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
733 std::vector<std::shared_ptr<torch::nn::Module>> modules =
734 model->modules(/*include_self=*/false);
735 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
736 model[0], model[1], model[2]};
737 ASSERT_EQ(modules.size(), expected.size());
738 for (const auto i : c10::irange(expected.size())) {
739 // Assert pointer equality.
740 ASSERT_EQ(modules[i].get(), expected[i].get());
741 }
742}
743
744TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
745 torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
746 torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
747 model->named_modules();
748 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
749 model.ptr(), model[0], model[1], model[2]};
750 ASSERT_EQ(modules.size(), expected.size());
751 for (const auto i : c10::irange(expected.size())) {
752 // Assert pointer equality.
753 ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
754 ASSERT_EQ(modules[i].value().get(), expected[i].get());
755 }
756}
757
758TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
759 torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
760 torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
761 model->named_modules(
762 /*name_prefix=*/std::string(), /*include_self=*/false);
763 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
764 model[0], model[1], model[2]};
765 ASSERT_EQ(modules.size(), expected.size());
766 for (const auto i : c10::irange(expected.size())) {
767 // Assert pointer equality.
768 ASSERT_EQ(modules[i].key(), std::to_string(i));
769 ASSERT_EQ(modules[i].value().get(), expected[i].get());
770 }
771}
772
773TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
774 torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
775 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
776 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
777 model[0], model[1], model[2]};
778 ASSERT_EQ(modules.size(), expected.size());
779 for (const auto i : c10::irange(expected.size())) {
780 // Assert pointer equality.
781 ASSERT_EQ(modules[i].get(), expected[i].get());
782 }
783
784 // For this flat model, this should be true.
785 ASSERT_EQ(modules, model->modules(/*include_self=*/false));
786}
787
788TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
789 torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
790 torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
791 model->named_children();
792 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
793 model[0], model[1], model[2]};
794 ASSERT_EQ(modules.size(), expected.size());
795 for (const auto i : c10::irange(expected.size())) {
796 // Assert pointer equality.
797 ASSERT_EQ(modules[i].key(), std::to_string(i));
798 ASSERT_EQ(modules[i].value().get(), expected[i].get());
799 }
800}
801
802TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
803 TestModule module(1);
804 std::vector<torch::Tensor> parameters = module.parameters();
805 ASSERT_EQ(parameters.size(), 2);
806 ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
807 ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
808}
809
810TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
811 TestModule module(1);
812 torch::OrderedDict<std::string, torch::Tensor> parameters =
813 module.named_parameters();
814 ASSERT_EQ(parameters.size(), 2);
815 ASSERT_EQ(parameters[0].key(), "p1");
816 ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
817 ASSERT_EQ(parameters[1].key(), "p2");
818 ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
819}
820
821TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
822 TestModule module(1);
823 std::vector<torch::Tensor> buffers = module.buffers();
824 ASSERT_EQ(buffers.size(), 2);
825 ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
826 ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
827}
828
829TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
830 TestModule module(1);
831 torch::OrderedDict<std::string, torch::Tensor> buffers =
832 module.named_buffers();
833 ASSERT_EQ(buffers.size(), 2);
834 ASSERT_EQ(buffers[0].key(), "b1");
835 ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
836 ASSERT_EQ(buffers[1].key(), "b2");
837 ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
838}
839
840struct TestContainer : torch::nn::Module {
841 TestContainer(int64_t number, std::vector<TestContainer> modules = {})
842 : tensor(torch::tensor(number)) {
843 for (const auto i : c10::irange(modules.size())) {
844 register_module(
845 std::to_string(i),
846 std::make_shared<TestContainer>(std::move(modules[i])));
847 }
848 }
849 torch::Tensor tensor;
850};
851
852int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
853 return std::dynamic_pointer_cast<TestContainer>(module)
854 ->tensor.item<int64_t>();
855}
856
857std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
858 return std::make_shared<TestContainer>(TestContainer(
859 0,
860 {TestContainer(1, {TestContainer(2), TestContainer(3)}),
861 TestContainer(4),
862 TestContainer(
863 5,
864 {TestContainer(6),
865 TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
866}
867
868std::vector<std::pair<std::string, int64_t>>
869make_key_value_pairs_for_deeply_nested_container() {
870 return {
871 {"test_prefix", 0},
872 {"test_prefix.0", 1},
873 {"test_prefix.0.0", 2},
874 {"test_prefix.0.1", 3},
875 {"test_prefix.1", 4},
876 {"test_prefix.2", 5},
877 {"test_prefix.2.0", 6},
878 {"test_prefix.2.1", 7},
879 {"test_prefix.2.1.0", 8},
880 {"test_prefix.2.1.1", 9}};
881}
882
883TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
884 auto model = make_deeply_nested_test_container();
885 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
886
887 ASSERT_EQ(modules.size(), 10);
888 for (const auto i : c10::irange(modules.size())) {
889 ASSERT_EQ(get_test_container_item(modules[i]), i);
890 }
891}
892
893TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
894 auto model = make_deeply_nested_test_container();
895 torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
896 model->named_modules(/*name_prefix=*/"test_prefix");
897 auto expected = make_key_value_pairs_for_deeply_nested_container();
898
899 ASSERT_EQ(modules.size(), expected.size());
900
901 for (const auto i : c10::irange(expected.size())) {
902 ASSERT_EQ(modules[i].key(), expected[i].first);
903 ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
904 }
905}
906
907TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
908 auto model = make_deeply_nested_test_container();
909 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
910
911 ASSERT_EQ(modules.size(), 3);
912 ASSERT_EQ(get_test_container_item(modules[0]), 1);
913 ASSERT_EQ(get_test_container_item(modules[1]), 4);
914 ASSERT_EQ(get_test_container_item(modules[2]), 5);
915}
916
917TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
918 auto model = make_deeply_nested_test_container();
919 torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
920 model->named_children();
921
922 ASSERT_EQ(modules.size(), 3);
923
924 ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
925 ASSERT_EQ(modules[0].key(), "0");
926
927 ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
928 ASSERT_EQ(modules[1].key(), "1");
929
930 ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
931 ASSERT_EQ(modules[2].key(), "2");
932}
933
934TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
935 auto model = make_deeply_nested_test_container();
936 int64_t index = 0;
937 model->apply([&index](torch::nn::Module& module) {
938 ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
939 });
940 ASSERT_EQ(index, 10);
941}
942
943TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
944 std::shared_ptr<const TestContainer> model =
945 make_deeply_nested_test_container();
946 int64_t index = 0;
947 model->apply([&index](const torch::nn::Module& module) {
948 ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
949 });
950 ASSERT_EQ(index, 10);
951}
952
953TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
954 auto model = make_deeply_nested_test_container();
955 auto expected = make_key_value_pairs_for_deeply_nested_container();
956 int64_t index = 0;
957 model->apply(
958 [&index, expected](const std::string& name, torch::nn::Module& module) {
959 ASSERT_EQ(name, expected[index].first);
960 ASSERT_EQ(
961 module.as<TestContainer>()->tensor.item<int64_t>(),
962 expected[index++].second);
963 },
964 /*name_prefix=*/"test_prefix");
965 ASSERT_EQ(index, 10);
966}
967
968TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
969 std::shared_ptr<const TestContainer> model =
970 make_deeply_nested_test_container();
971 auto expected = make_key_value_pairs_for_deeply_nested_container();
972 int64_t index = 0;
973 model->apply(
974 [&index, &expected](
975 const std::string& name, const torch::nn::Module& module) {
976 ASSERT_EQ(name, expected[index].first);
977 ASSERT_EQ(
978 module.as<const TestContainer>()->tensor.item<int64_t>(),
979 expected[index++].second);
980 },
981 /*name_prefix=*/"test_prefix");
982 ASSERT_EQ(index, 10);
983}
984
985TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
986 auto model = make_deeply_nested_test_container();
987 int64_t index = 0;
988 model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
989 ASSERT_EQ(get_test_container_item(module), index++);
990 });
991 ASSERT_EQ(index, 10);
992}
993
994TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
995 auto model = make_deeply_nested_test_container();
996 auto expected = make_key_value_pairs_for_deeply_nested_container();
997 int64_t index = 0;
998 model->apply(
999 [&index, &expected](
1000 const std::string& name,
1001 const std::shared_ptr<torch::nn::Module>& module) {
1002 ASSERT_EQ(name, expected[index].first);
1003 ASSERT_EQ(get_test_container_item(module), expected[index++].second);
1004 },
1005 /*name_prefix=*/"test_prefix");
1006 ASSERT_EQ(index, 10);
1007}
1008
1009TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
1010 {
1011 TestModule module(1);
1012 ASSERT_THROWS_WITH(
1013 module.modules(),
1014 "It looks like you attempted to retrieve "
1015 "your top-level module as a shared_ptr")
1016 }
1017 {
1018 TestModule module(1);
1019 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1020 ASSERT_NO_THROW(module.modules(/*include_self=*/false));
1021 }
1022 {
1023 auto module = std::make_shared<TestModule>(1);
1024 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1025 ASSERT_NO_THROW(module->modules());
1026 }
1027}
1028
1029struct EmptyModule : torch::nn::Module {};
1030
1031TEST_F(ModuleTest, PrettyPrint) {
1032 struct TestModule : torch::nn::Module {
1033 TestModule(int x, float y) : x_(x), y_(y) {}
1034
1035 void pretty_print(std::ostream& stream) const override {
1036 stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
1037 }
1038
1039 int x_;
1040 float y_;
1041 };
1042
1043 ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
1044 ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
1045}
1046
1047struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
1048 int64_t forward(torch::Tensor x) {
1049 return x.numel();
1050 }
1051};
1052TORCH_MODULE(ModuleWithNonTensorForward);
1053
1054TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
1055 ModuleWithNonTensorForward m;
1056 ASSERT_EQ(m(torch::ones(123)), 123);
1057}
1058