1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <test/cpp/api/support.h> |
7 | |
8 | using namespace torch::nn; |
9 | using namespace torch::test; |
10 | |
11 | struct AGIUnit : torch::nn::Module {}; |
12 | |
13 | namespace test { |
14 | struct AGIUnit : torch::nn::Module {}; |
15 | struct AGIUnit2 : torch::nn::Module { |
16 | AGIUnit2() : torch::nn::Module("Foo" ) {} |
17 | }; |
18 | } // namespace test |
19 | |
20 | struct ModuleTest : torch::test::SeedingFixture {}; |
21 | |
22 | TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) { |
23 | Linear module(3, 4); |
24 | ASSERT_TRUE(module->is_training()); |
25 | |
26 | module->eval(); |
27 | ASSERT_FALSE(module->is_training()); |
28 | |
29 | module->train(); |
30 | ASSERT_TRUE(module->is_training()); |
31 | } |
32 | |
33 | TEST_F(ModuleTest, ZeroGrad) { |
34 | Linear module(3, 4); |
35 | auto weight = torch::ones({8, 3}, torch::requires_grad()); |
36 | auto loss = module(weight).sum(); |
37 | loss.backward(); |
38 | for (auto& parameter : module->parameters()) { |
39 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
40 | auto grad = parameter.grad(); |
41 | ASSERT_TRUE(grad.defined()); |
42 | ASSERT_NE(grad.sum().item<float>(), 0); |
43 | } |
44 | module->zero_grad(); |
45 | for (auto& parameter : module->parameters()) { |
46 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
47 | auto grad = parameter.grad(); |
48 | ASSERT_FALSE(grad.defined()); |
49 | } |
50 | } |
51 | |
52 | TEST_F(ModuleTest, ZeroGradWithUndefined) { |
53 | struct TestModule : torch::nn::Module { |
54 | TestModule() { |
55 | x = register_parameter("x" , torch::ones(5, torch::requires_grad())); |
56 | y = register_parameter("y" , torch::ones(5, torch::requires_grad())); |
57 | } |
58 | torch::Tensor x, y; |
59 | }; |
60 | |
61 | TestModule module; |
62 | auto z = module.x * 2; |
63 | z.sum().backward(); |
64 | |
65 | ASSERT_TRUE(module.x.grad().defined()); |
66 | ASSERT_FALSE(module.y.grad().defined()); |
67 | |
68 | module.zero_grad(false); // set_to_none = false |
69 | |
70 | ASSERT_TRUE(module.x.grad().defined()); |
71 | ASSERT_FALSE(module.y.grad().defined()); |
72 | |
73 | ASSERT_EQ(module.x.grad().sum().item<float>(), 0); |
74 | |
75 | module.zero_grad(); |
76 | |
77 | ASSERT_FALSE(module.x.grad().defined()); |
78 | ASSERT_FALSE(module.y.grad().defined()); |
79 | } |
80 | |
81 | TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) { |
82 | struct TestModel : public torch::nn::Module {}; |
83 | ASSERT_THROWS_WITH( |
84 | TestModel{}.register_module("name.with.dot" , torch::nn::Linear(3, 4)), |
85 | "Submodule name must not contain a dot (got 'name.with.dot')" ); |
86 | ASSERT_THROWS_WITH( |
87 | TestModel{}.register_module("" , torch::nn::Linear(3, 4)), |
88 | "Submodule name must not be empty" ); |
89 | } |
90 | |
91 | TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) { |
92 | struct TestModel : public torch::nn::Module {}; |
93 | TestModel model; |
94 | model.register_module("linear" , torch::nn::Linear(3, 4)); |
95 | ASSERT_THROWS_WITH( |
96 | model.register_module("linear" , torch::nn::Linear(3, 4)), |
97 | "Submodule 'linear' already defined" ); |
98 | } |
99 | |
100 | TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) { |
101 | torch::nn::Module model; |
102 | ASSERT_THROWS_WITH( |
103 | model.replace_module("linear" , torch::nn::Linear(3, 4)), |
104 | "Submodule 'linear' is not defined" ); |
105 | } |
106 | |
107 | TEST_F(ModuleTest, ReplaceModule) { |
108 | struct TestModel : public torch::nn::Module { |
109 | torch::nn::Linear l1{nullptr}; |
110 | TestModel() { |
111 | l1 = register_module("l1" , torch::nn::Linear(3, 4)); |
112 | } |
113 | }; |
114 | auto model = std::make_shared<TestModel>(); |
115 | model->l1 = model->replace_module("l1" , torch::nn::Linear(5, 6)); |
116 | ASSERT_EQ(model->named_parameters()["l1.weight" ].size(0), 6); |
117 | ASSERT_EQ(model->l1.get(), model->named_modules()["l1" ]->as<Linear>()); |
118 | } |
119 | |
120 | TEST_F(ModuleTest, UnregisterModule) { |
121 | struct TestModel : public torch::nn::Module {}; |
122 | TestModel model; |
123 | ASSERT_THROWS_WITH( |
124 | model.unregister_module("linear" ), |
125 | "No Module with name `linear` is registered" ); |
126 | model.register_module("linear" , torch::nn::Linear(3, 4)); |
127 | model.unregister_module("linear" ); |
128 | ASSERT_TRUE(model.children().empty()); |
129 | } |
130 | |
131 | TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) { |
132 | struct TestModel : public torch::nn::Module {}; |
133 | ASSERT_THROWS_WITH( |
134 | TestModel{}.register_parameter("name.with.dot" , torch::ones(5)), |
135 | "Parameter name must not contain a dot (got 'name.with.dot')" ); |
136 | ASSERT_THROWS_WITH( |
137 | TestModel{}.register_parameter("" , torch::ones(5)), |
138 | "Parameter name must not be empty" ); |
139 | } |
140 | |
141 | TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) { |
142 | struct TestModel : public torch::nn::Module {}; |
143 | TestModel model; |
144 | model.register_parameter("p" , torch::ones(5)); |
145 | ASSERT_THROWS_WITH( |
146 | model.register_parameter("p" , torch::ones(5)), |
147 | "Parameter 'p' already defined" ); |
148 | } |
149 | |
150 | TEST_F(ModuleTest, RegisterParameterUndefinedTensor) { |
151 | struct TestModel : public torch::nn::Module {}; |
152 | { |
153 | TestModel model; |
154 | model.register_parameter( |
155 | "undefined_tensor" , torch::Tensor(), /*requires_grad=*/false); |
156 | ASSERT_EQ(model.parameters().size(), 0); |
157 | } |
158 | { |
159 | WarningCapture warnings; |
160 | |
161 | TestModel model; |
162 | model.register_parameter("undefined_tensor" , torch::Tensor()); |
163 | ASSERT_EQ(model.parameters().size(), 0); |
164 | |
165 | ASSERT_EQ( |
166 | count_substr_occurrences( |
167 | warnings.str(), |
168 | "Ignoring the `requires_grad=true` function parameter" ), |
169 | 1); |
170 | } |
171 | } |
172 | |
173 | TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) { |
174 | struct TestModel : public torch::nn::Module {}; |
175 | ASSERT_THROWS_WITH( |
176 | TestModel{}.register_buffer("name.with.dot" , torch::ones(5)), |
177 | "Buffer name must not contain a dot (got 'name.with.dot')" ); |
178 | ASSERT_THROWS_WITH( |
179 | TestModel{}.register_buffer("" , torch::ones(5)), |
180 | "Buffer name must not be empty" ); |
181 | } |
182 | |
183 | TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) { |
184 | struct TestModel : public torch::nn::Module {}; |
185 | TestModel model; |
186 | model.register_buffer("p" , torch::ones(5)); |
187 | ASSERT_THROWS_WITH( |
188 | model.register_buffer("p" , torch::ones(5)), "Buffer 'p' already defined" ); |
189 | } |
190 | |
191 | TEST_F(ModuleTest, CanGetName) { |
192 | // CHECK instead of REQUIRE because demangling may fail. |
193 | AGIUnit agi; |
194 | // Call it twice just to make sure there are no bugs in the lazy |
195 | // initialization semantics. |
196 | EXPECT_EQ(agi.name(), "AGIUnit" ); |
197 | EXPECT_EQ(agi.name(), "AGIUnit" ); |
198 | EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit" ); |
199 | EXPECT_EQ(test::AGIUnit2().name(), "Foo" ); |
200 | } |
201 | |
202 | TEST_F(ModuleTest, AsCastsModulesCorrectly) { |
203 | Linear module(3, 4); |
204 | ASSERT_EQ(module->as<Linear>(), module.get()); |
205 | ASSERT_EQ(module->as<LinearImpl>(), module.get()); |
206 | ASSERT_EQ(module->as<Module>(), module.get()); |
207 | ASSERT_EQ(module->as<AGIUnit>(), nullptr); |
208 | |
209 | std::shared_ptr<Module> raw = module.ptr(); |
210 | ASSERT_EQ(raw->as<Linear>(), module.get()); |
211 | ASSERT_EQ(raw->as<LinearImpl>(), module.get()); |
212 | ASSERT_EQ(raw->as<Module>(), module.get()); |
213 | ASSERT_EQ(raw->as<AGIUnit>(), nullptr); |
214 | |
215 | Module& raw_ref = *raw.get(); |
216 | ASSERT_EQ(raw_ref.as<Linear>(), module.get()); |
217 | ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get()); |
218 | ASSERT_EQ(raw_ref.as<Module>(), module.get()); |
219 | ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr); |
220 | if (auto* linear = raw_ref.as<Linear>()) { |
221 | ASSERT_EQ(linear->weight.ndimension(), 2); |
222 | } |
223 | |
224 | AGIUnit unit; |
225 | ASSERT_EQ(unit.as<Linear>(), nullptr); |
226 | ASSERT_EQ(unit.as<LinearImpl>(), nullptr); |
227 | ASSERT_EQ(unit.as<AGIUnit>(), &unit); |
228 | } |
229 | |
230 | void test_DeviceOrDtypeConversionSkipsUndefinedTensor( |
231 | torch::Device to_device, |
232 | torch::Dtype to_dtype) { |
233 | { |
234 | // Case 1: Undefined tensors as parameters |
235 | Linear module(LinearOptions(10, 20).bias(false)); |
236 | ASSERT_TRUE(module->weight.defined()); |
237 | ASSERT_FALSE(module->bias.defined()); |
238 | |
239 | module->to(to_device); |
240 | ASSERT_TRUE(module->weight.defined()); |
241 | ASSERT_EQ(module->weight.device().type(), to_device.type()); |
242 | ASSERT_FALSE(module->bias.defined()); |
243 | |
244 | module->to(to_dtype); |
245 | ASSERT_TRUE(module->weight.defined()); |
246 | ASSERT_EQ(module->weight.dtype(), to_dtype); |
247 | ASSERT_FALSE(module->bias.defined()); |
248 | } |
249 | { |
250 | // Case 2: Undefined tensors as buffers |
251 | BatchNorm1d module( |
252 | BatchNorm1dOptions(5).track_running_stats(false).affine(true)); |
253 | ASSERT_TRUE(module->weight.defined()); |
254 | ASSERT_FALSE(module->running_mean.defined()); |
255 | |
256 | module->to(to_device); |
257 | ASSERT_TRUE(module->weight.defined()); |
258 | ASSERT_EQ(module->weight.device().type(), to_device.type()); |
259 | ASSERT_FALSE(module->running_mean.defined()); |
260 | |
261 | module->to(to_dtype); |
262 | ASSERT_TRUE(module->weight.defined()); |
263 | ASSERT_EQ(module->weight.dtype(), to_dtype); |
264 | ASSERT_FALSE(module->running_mean.defined()); |
265 | } |
266 | } |
267 | |
268 | TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) { |
269 | test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble); |
270 | } |
271 | |
272 | TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) { |
273 | test_DeviceOrDtypeConversionSkipsUndefinedTensor( |
274 | torch::kCUDA, torch::kDouble); |
275 | } |
276 | |
277 | TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) { |
278 | { |
279 | Linear module(LinearOptions(10, 20).bias(false)); |
280 | |
281 | auto params = module->parameters(); |
282 | ASSERT_EQ(params.size(), 1); |
283 | auto named_params = module->named_parameters(); |
284 | ASSERT_EQ(named_params.size(), 1); |
285 | |
286 | ASSERT_TRUE(pointer_equal(params[0], named_params["weight" ])); |
287 | ASSERT_TRUE(pointer_equal(named_params["weight" ], module->weight)); |
288 | } |
289 | { |
290 | BatchNorm1d module( |
291 | BatchNorm1dOptions(5).track_running_stats(false).affine(false)); |
292 | |
293 | auto buffers = module->buffers(); |
294 | ASSERT_EQ(buffers.size(), 0); |
295 | auto named_buffers = module->named_buffers(); |
296 | ASSERT_EQ(named_buffers.size(), 0); |
297 | } |
298 | { |
299 | BatchNorm1d module( |
300 | BatchNorm1dOptions(5).track_running_stats(true).affine(false)); |
301 | |
302 | auto buffers = module->buffers(); |
303 | ASSERT_EQ(buffers.size(), 3); |
304 | auto named_buffers = module->named_buffers(); |
305 | ASSERT_EQ(named_buffers.size(), 3); |
306 | |
307 | ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean" ])); |
308 | ASSERT_TRUE( |
309 | pointer_equal(named_buffers["running_mean" ], module->running_mean)); |
310 | ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var" ])); |
311 | ASSERT_TRUE( |
312 | pointer_equal(named_buffers["running_var" ], module->running_var)); |
313 | ASSERT_TRUE( |
314 | pointer_equal(buffers[2], named_buffers["num_batches_tracked" ])); |
315 | ASSERT_TRUE(pointer_equal( |
316 | named_buffers["num_batches_tracked" ], module->num_batches_tracked)); |
317 | } |
318 | } |
319 | |
320 | TEST_F(ModuleTest, Conversion_MultiCUDA) { |
321 | Linear module(128, 64); |
322 | for (auto& parameter : module->parameters()) { |
323 | ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU)); |
324 | ASSERT_EQ(parameter.dtype(), torch::kFloat32); |
325 | } |
326 | { |
327 | module->to({torch::kCUDA, 0}); |
328 | for (auto& parameter : module->parameters()) { |
329 | ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA); |
330 | ASSERT_EQ(parameter.device().index(), 0); |
331 | } |
332 | module->to({torch::kCUDA, 1}); |
333 | for (auto& parameter : module->parameters()) { |
334 | ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA); |
335 | ASSERT_EQ(parameter.device().index(), 1); |
336 | } |
337 | } |
338 | { |
339 | module->to(torch::Device(torch::kCPU)); |
340 | for (auto& parameter : module->parameters()) { |
341 | ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU); |
342 | } |
343 | } |
344 | { |
345 | module->to(torch::kFloat64); |
346 | for (auto& parameter : module->parameters()) { |
347 | ASSERT_EQ(parameter.dtype(), torch::kFloat64); |
348 | } |
349 | } |
350 | } |
351 | |
352 | TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) { |
353 | Linear module(128, 64); |
354 | for (auto& parameter : module->parameters()) { |
355 | parameter.requires_grad_(false); |
356 | } |
357 | { |
358 | module->to(torch::kInt32); |
359 | for (auto& parameter : module->parameters()) { |
360 | ASSERT_EQ(parameter.dtype(), torch::kInt32); |
361 | } |
362 | } |
363 | { |
364 | module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8); |
365 | for (auto& parameter : module->parameters()) { |
366 | ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA); |
367 | ASSERT_EQ(parameter.device().index(), 1); |
368 | } |
369 | for (auto& parameter : module->parameters()) { |
370 | ASSERT_EQ(parameter.dtype(), torch::kUInt8); |
371 | } |
372 | } |
373 | } |
374 | |
375 | TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) { |
376 | struct UnCloneable : Module {}; |
377 | UnCloneable module; |
378 | ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented" ); |
379 | } |
380 | |
381 | TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) { |
382 | struct Cloneable : Module { |
383 | std::shared_ptr<Module> clone( |
384 | const torch::optional<torch::Device>& device = |
385 | torch::nullopt) const override { |
386 | return nullptr; |
387 | } |
388 | }; |
389 | Cloneable module; |
390 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
391 | ASSERT_NO_THROW({ module.clone(); }); |
392 | } |
393 | |
394 | // NOLINTNEXTLINE(bugprone-exception-escape) |
395 | struct TestDistinctParametersModule |
396 | : public Cloneable<TestDistinctParametersModule> { |
397 | TestDistinctParametersModule() { |
398 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
399 | reset(); |
400 | } |
401 | void reset() override { |
402 | l1 = register_module("l1" , Linear(10, 3)); |
403 | l2 = register_module("l2" , Linear(3, 5)); |
404 | l3 = register_module("l3" , Linear(5, 100)); |
405 | buffer = register_buffer("buf" , torch::ones({2, 2})); |
406 | } |
407 | |
408 | Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; |
409 | torch::Tensor buffer; |
410 | }; |
411 | |
412 | void ( |
413 | std::shared_ptr<Module> m1, |
414 | std::shared_ptr<Module> m2) { |
415 | auto params1 = m1->named_parameters(); |
416 | auto params2 = m2->named_parameters(); |
417 | ASSERT_EQ(params1.size(), 6); |
418 | ASSERT_EQ(params2.size(), 6); |
419 | for (auto& param : params1) { |
420 | ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()])); |
421 | ASSERT_TRUE(param->allclose(params2[param.key()])); |
422 | param->add_(2); |
423 | } |
424 | for (auto& param : params1) { |
425 | ASSERT_FALSE(param->allclose(params2[param.key()])); |
426 | } |
427 | |
428 | auto buffers1 = m1->named_buffers(); |
429 | auto buffers2 = m2->named_buffers(); |
430 | ASSERT_EQ(buffers1.size(), 1); |
431 | ASSERT_EQ(buffers2.size(), 1); |
432 | for (auto& buffer : buffers1) { |
433 | ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()])); |
434 | ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()])); |
435 | buffer->add_(2); |
436 | } |
437 | for (auto& buffer : buffers1) { |
438 | ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()])); |
439 | } |
440 | } |
441 | |
442 | TEST_F(ModuleTest, CloneCreatesDistinctParameters) { |
443 | auto module = std::make_shared<TestDistinctParametersModule>(); |
444 | torch::NoGradGuard no_grad; |
445 | auto module2 = module->clone(); |
446 | testDistinctParameters(module, module2); |
447 | } |
448 | |
449 | TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) { |
450 | auto module = std::make_shared<TestDistinctParametersModule>(); |
451 | torch::NoGradGuard no_grad; |
452 | torch::Device device(torch::kCUDA, 0); |
453 | module->to(device); |
454 | auto module2 = module->clone(device); |
455 | testDistinctParameters(module, module2); |
456 | } |
457 | |
458 | TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) { |
459 | auto module = std::make_shared<TestDistinctParametersModule>(); |
460 | torch::NoGradGuard no_grad; |
461 | torch::Device d0(torch::kCUDA, 0); |
462 | torch::Device d1(torch::kCUDA, 1); |
463 | module->to(d0); |
464 | auto module2 = module->clone(d1); |
465 | |
466 | for (auto& param : module->parameters()) { |
467 | ASSERT_EQ(param.device(), d0); |
468 | } |
469 | |
470 | for (auto& param : module2->parameters()) { |
471 | ASSERT_EQ(param.device(), d1); |
472 | } |
473 | |
474 | // need to move the module back to d0 as allclose expects two tensors on |
475 | // the same device. |
476 | module2->to(d0); |
477 | testDistinctParameters(module, module2); |
478 | } |
479 | |
480 | TEST_F(ModuleTest, ClonePreservesExternalReferences) { |
481 | // NOLINTNEXTLINE(bugprone-exception-escape) |
482 | struct TestModule : public Cloneable<TestModule> { |
483 | TestModule() { |
484 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
485 | reset(); |
486 | } |
487 | void reset() override { |
488 | weight = register_parameter("weight" , torch::ones({4, 4})); |
489 | } |
490 | torch::Tensor weight; |
491 | }; |
492 | auto module = std::make_shared<TestModule>(); |
493 | { |
494 | torch::NoGradGuard no_grad; |
495 | module->weight += 1; |
496 | } |
497 | ASSERT_TRUE( |
498 | pointer_equal(module->weight, module->named_parameters()["weight" ])); |
499 | ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight" ])); |
500 | |
501 | auto module2 = std::dynamic_pointer_cast<TestModule>( |
502 | std::shared_ptr<Module>(module->clone())); |
503 | ASSERT_FALSE(pointer_equal(module2->weight, module->weight)); |
504 | ASSERT_TRUE( |
505 | pointer_equal(module2->weight, module2->named_parameters()["weight" ])); |
506 | ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight" ])); |
507 | ASSERT_TRUE(module2->weight.allclose(module->weight)); |
508 | ASSERT_FALSE( |
509 | pointer_equal(module2->weight, module->named_parameters()["weight" ])); |
510 | } |
511 | |
512 | TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) { |
513 | // NOLINTNEXTLINE(bugprone-exception-escape) |
514 | struct TestModule : public Cloneable<TestModule> { |
515 | TestModule() { |
516 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
517 | reset(); |
518 | } |
519 | void reset() override { |
520 | weight = register_parameter("weight" , torch::ones({4, 4})); |
521 | } |
522 | |
523 | torch::Tensor weight; |
524 | int value = 0; |
525 | }; |
526 | // NOLINTNEXTLINE(bugprone-exception-escape) |
527 | struct NestedModule : public Cloneable<NestedModule> { |
528 | NestedModule() { |
529 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
530 | reset(); |
531 | } |
532 | void reset() override { |
533 | module = register_module("module" , std::make_shared<TestModule>()); |
534 | } |
535 | std::shared_ptr<TestModule> module; |
536 | }; |
537 | |
538 | auto a = std::make_shared<NestedModule>(); |
539 | { |
540 | torch::NoGradGuard no_grad; |
541 | a->module->weight += 1; |
542 | a->module->value = 123; |
543 | } |
544 | |
545 | auto b = std::dynamic_pointer_cast<NestedModule>(a->clone()); |
546 | |
547 | ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight)); |
548 | ASSERT_TRUE(pointer_equal( |
549 | b->module->weight, b->module->named_parameters()["weight" ])); |
550 | ASSERT_TRUE( |
551 | b->module->named_parameters()["weight" ].allclose(a->module->weight)); |
552 | ASSERT_TRUE(b->module->weight.allclose(a->module->weight)); |
553 | ASSERT_EQ(b->module->value, a->module->value); |
554 | } |
555 | |
556 | TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) { |
557 | // NOLINTNEXTLINE(bugprone-exception-escape) |
558 | struct TestModule : public Cloneable<TestModule> { |
559 | TestModule() { |
560 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
561 | reset(); |
562 | } |
563 | void reset() override { |
564 | l1 = register_module("l1" , Linear(10, 3)); |
565 | l2 = register_module("l2" , Linear(3, 5)); |
566 | l3 = register_module("l3" , Linear(5, 100)); |
567 | buffer = register_buffer("buf" , torch::ones({2, 2})); |
568 | } |
569 | |
570 | Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; |
571 | torch::Tensor buffer; |
572 | }; |
573 | |
574 | TestModule m; |
575 | torch::Device device(torch::kCUDA, 0); |
576 | |
577 | m.to(device); |
578 | |
579 | auto clone = m.clone(); |
580 | for (const auto& parameter : clone->parameters()) { |
581 | ASSERT_EQ(parameter.device().type(), device.type()); |
582 | ASSERT_EQ(parameter.device().index(), device.index()); |
583 | } |
584 | for (const auto& buffer : clone->buffers()) { |
585 | ASSERT_EQ(buffer.device().type(), device.type()); |
586 | ASSERT_EQ(buffer.device().index(), device.index()); |
587 | } |
588 | } |
589 | |
590 | TEST_F( |
591 | ModuleTest, |
592 | CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) { |
593 | // NOLINTNEXTLINE(bugprone-exception-escape) |
594 | struct TestModule : public Cloneable<TestModule> { |
595 | TestModule() { |
596 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
597 | reset(); |
598 | } |
599 | void reset() override { |
600 | l1 = register_module("l1" , Linear(10, 3)); |
601 | l2 = register_module("l2" , Linear(3, 5)); |
602 | l3 = register_module("l3" , Linear(5, 100)); |
603 | buffer = register_buffer("buf" , torch::ones({2, 2})); |
604 | } |
605 | |
606 | Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; |
607 | torch::Tensor buffer; |
608 | }; |
609 | |
610 | TestModule m; |
611 | torch::Device device(torch::kCUDA, 1); |
612 | // everything is on CPU here |
613 | auto clone = m.clone(device); |
614 | for (const auto& parameter : clone->parameters()) { |
615 | ASSERT_EQ(parameter.device().type(), device.type()); |
616 | ASSERT_EQ(parameter.device().index(), device.index()); |
617 | } |
618 | for (const auto& buffer : clone->buffers()) { |
619 | ASSERT_EQ(buffer.device().type(), device.type()); |
620 | ASSERT_EQ(buffer.device().index(), device.index()); |
621 | } |
622 | } |
623 | |
624 | struct ParameterTestModule : Module { |
625 | ParameterTestModule() { |
626 | a = register_parameter("a" , torch::zeros({2, 2})); |
627 | b = register_parameter("b" , torch::ones({2, 2})); |
628 | c = register_parameter("c" , torch::ones({2, 2}) * 2); |
629 | } |
630 | |
631 | torch::Tensor a, b, c; |
632 | }; |
633 | |
634 | TEST_F(ModuleTest, HasCorrectNumberOfParameters) { |
635 | ParameterTestModule module; |
636 | ASSERT_EQ(module.parameters().size(), 3); |
637 | ASSERT_EQ(module.named_parameters().size(), 3); |
638 | } |
639 | |
640 | TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) { |
641 | ParameterTestModule module; |
642 | auto parameters = module.named_parameters(); |
643 | ASSERT_TRUE(parameters.contains("a" )); |
644 | ASSERT_TRUE(parameters.contains("b" )); |
645 | ASSERT_TRUE(parameters.contains("c" )); |
646 | } |
647 | |
648 | struct BufferTestModule : Module { |
649 | BufferTestModule() { |
650 | a = register_buffer("a" , torch::zeros({2, 2})); |
651 | b = register_buffer("b" , torch::ones({2, 2})); |
652 | c = register_buffer("c" , torch::ones({2, 2}) * 2); |
653 | } |
654 | |
655 | torch::Tensor a, b, c; |
656 | }; |
657 | |
658 | TEST_F(ModuleTest, HasCorrectNumberOfBuffers) { |
659 | BufferTestModule module; |
660 | ASSERT_EQ(module.buffers().size(), 3); |
661 | ASSERT_EQ(module.named_buffers().size(), 3); |
662 | } |
663 | |
664 | TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) { |
665 | BufferTestModule module; |
666 | auto buffers = module.named_buffers(); |
667 | ASSERT_TRUE(buffers.contains("a" )); |
668 | ASSERT_TRUE(buffers.contains("b" )); |
669 | ASSERT_TRUE(buffers.contains("c" )); |
670 | } |
671 | |
672 | struct AImpl : torch::nn::Module { |
673 | AImpl() : x_(123) {} |
674 | AImpl(int x) : x_(x) {} |
675 | int x_; |
676 | }; |
677 | TORCH_MODULE(A); |
678 | |
679 | TEST_F( |
680 | ModuleTest, |
681 | DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) { |
682 | A a; |
683 | ASSERT_TRUE(a); |
684 | ASSERT_FALSE(a.is_empty()); |
685 | ASSERT_EQ(a->x_, 123); |
686 | } |
687 | |
688 | TEST_F( |
689 | ModuleTest, |
690 | ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) { |
691 | A a(5); |
692 | ASSERT_TRUE(a); |
693 | ASSERT_FALSE(a.is_empty()); |
694 | ASSERT_EQ(a->x_, 5); |
695 | } |
696 | |
697 | TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) { |
698 | A a = nullptr; |
699 | ASSERT_FALSE(a); |
700 | ASSERT_TRUE(a.is_empty()); |
701 | ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder" ); |
702 | } |
703 | |
704 | struct TestModule : public torch::nn::Module { |
705 | TestModule(int64_t size) { |
706 | p1 = register_parameter("p1" , torch::randn({size})); |
707 | p2 = register_parameter("p2" , torch::randn({size})); |
708 | b1 = register_buffer("b1" , torch::randn({size})); |
709 | b2 = register_buffer("b2" , torch::randn({size})); |
710 | } |
711 | |
712 | torch::Tensor forward(torch::Tensor input) { |
713 | return input; |
714 | } |
715 | |
716 | torch::Tensor p1, p2, b1, b2; |
717 | }; |
718 | |
719 | TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) { |
720 | torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
721 | std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules(); |
722 | std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
723 | model.ptr(), model[0], model[1], model[2]}; |
724 | ASSERT_EQ(modules.size(), expected.size()); |
725 | for (const auto i : c10::irange(expected.size())) { |
726 | // Assert pointer equality. |
727 | ASSERT_EQ(modules[i].get(), expected[i].get()); |
728 | } |
729 | } |
730 | |
731 | TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) { |
732 | torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
733 | std::vector<std::shared_ptr<torch::nn::Module>> modules = |
734 | model->modules(/*include_self=*/false); |
735 | std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
736 | model[0], model[1], model[2]}; |
737 | ASSERT_EQ(modules.size(), expected.size()); |
738 | for (const auto i : c10::irange(expected.size())) { |
739 | // Assert pointer equality. |
740 | ASSERT_EQ(modules[i].get(), expected[i].get()); |
741 | } |
742 | } |
743 | |
744 | TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) { |
745 | torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
746 | torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
747 | model->named_modules(); |
748 | std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
749 | model.ptr(), model[0], model[1], model[2]}; |
750 | ASSERT_EQ(modules.size(), expected.size()); |
751 | for (const auto i : c10::irange(expected.size())) { |
752 | // Assert pointer equality. |
753 | ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string()); |
754 | ASSERT_EQ(modules[i].value().get(), expected[i].get()); |
755 | } |
756 | } |
757 | |
758 | TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) { |
759 | torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
760 | torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
761 | model->named_modules( |
762 | /*name_prefix=*/std::string(), /*include_self=*/false); |
763 | std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
764 | model[0], model[1], model[2]}; |
765 | ASSERT_EQ(modules.size(), expected.size()); |
766 | for (const auto i : c10::irange(expected.size())) { |
767 | // Assert pointer equality. |
768 | ASSERT_EQ(modules[i].key(), std::to_string(i)); |
769 | ASSERT_EQ(modules[i].value().get(), expected[i].get()); |
770 | } |
771 | } |
772 | |
773 | TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) { |
774 | torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
775 | std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children(); |
776 | std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
777 | model[0], model[1], model[2]}; |
778 | ASSERT_EQ(modules.size(), expected.size()); |
779 | for (const auto i : c10::irange(expected.size())) { |
780 | // Assert pointer equality. |
781 | ASSERT_EQ(modules[i].get(), expected[i].get()); |
782 | } |
783 | |
784 | // For this flat model, this should be true. |
785 | ASSERT_EQ(modules, model->modules(/*include_self=*/false)); |
786 | } |
787 | |
788 | TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) { |
789 | torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
790 | torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
791 | model->named_children(); |
792 | std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
793 | model[0], model[1], model[2]}; |
794 | ASSERT_EQ(modules.size(), expected.size()); |
795 | for (const auto i : c10::irange(expected.size())) { |
796 | // Assert pointer equality. |
797 | ASSERT_EQ(modules[i].key(), std::to_string(i)); |
798 | ASSERT_EQ(modules[i].value().get(), expected[i].get()); |
799 | } |
800 | } |
801 | |
802 | TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) { |
803 | TestModule module(1); |
804 | std::vector<torch::Tensor> parameters = module.parameters(); |
805 | ASSERT_EQ(parameters.size(), 2); |
806 | ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>()); |
807 | ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>()); |
808 | } |
809 | |
810 | TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) { |
811 | TestModule module(1); |
812 | torch::OrderedDict<std::string, torch::Tensor> parameters = |
813 | module.named_parameters(); |
814 | ASSERT_EQ(parameters.size(), 2); |
815 | ASSERT_EQ(parameters[0].key(), "p1" ); |
816 | ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>()); |
817 | ASSERT_EQ(parameters[1].key(), "p2" ); |
818 | ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>()); |
819 | } |
820 | |
821 | TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) { |
822 | TestModule module(1); |
823 | std::vector<torch::Tensor> buffers = module.buffers(); |
824 | ASSERT_EQ(buffers.size(), 2); |
825 | ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>()); |
826 | ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>()); |
827 | } |
828 | |
829 | TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) { |
830 | TestModule module(1); |
831 | torch::OrderedDict<std::string, torch::Tensor> buffers = |
832 | module.named_buffers(); |
833 | ASSERT_EQ(buffers.size(), 2); |
834 | ASSERT_EQ(buffers[0].key(), "b1" ); |
835 | ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>()); |
836 | ASSERT_EQ(buffers[1].key(), "b2" ); |
837 | ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>()); |
838 | } |
839 | |
840 | struct TestContainer : torch::nn::Module { |
841 | TestContainer(int64_t number, std::vector<TestContainer> modules = {}) |
842 | : tensor(torch::tensor(number)) { |
843 | for (const auto i : c10::irange(modules.size())) { |
844 | register_module( |
845 | std::to_string(i), |
846 | std::make_shared<TestContainer>(std::move(modules[i]))); |
847 | } |
848 | } |
849 | torch::Tensor tensor; |
850 | }; |
851 | |
852 | int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) { |
853 | return std::dynamic_pointer_cast<TestContainer>(module) |
854 | ->tensor.item<int64_t>(); |
855 | } |
856 | |
857 | std::shared_ptr<TestContainer> make_deeply_nested_test_container() { |
858 | return std::make_shared<TestContainer>(TestContainer( |
859 | 0, |
860 | {TestContainer(1, {TestContainer(2), TestContainer(3)}), |
861 | TestContainer(4), |
862 | TestContainer( |
863 | 5, |
864 | {TestContainer(6), |
865 | TestContainer(7, {TestContainer(8), TestContainer(9)})})})); |
866 | } |
867 | |
868 | std::vector<std::pair<std::string, int64_t>> |
869 | make_key_value_pairs_for_deeply_nested_container() { |
870 | return { |
871 | {"test_prefix" , 0}, |
872 | {"test_prefix.0" , 1}, |
873 | {"test_prefix.0.0" , 2}, |
874 | {"test_prefix.0.1" , 3}, |
875 | {"test_prefix.1" , 4}, |
876 | {"test_prefix.2" , 5}, |
877 | {"test_prefix.2.0" , 6}, |
878 | {"test_prefix.2.1" , 7}, |
879 | {"test_prefix.2.1.0" , 8}, |
880 | {"test_prefix.2.1.1" , 9}}; |
881 | } |
882 | |
883 | TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) { |
884 | auto model = make_deeply_nested_test_container(); |
885 | std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules(); |
886 | |
887 | ASSERT_EQ(modules.size(), 10); |
888 | for (const auto i : c10::irange(modules.size())) { |
889 | ASSERT_EQ(get_test_container_item(modules[i]), i); |
890 | } |
891 | } |
892 | |
893 | TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) { |
894 | auto model = make_deeply_nested_test_container(); |
895 | torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
896 | model->named_modules(/*name_prefix=*/"test_prefix" ); |
897 | auto expected = make_key_value_pairs_for_deeply_nested_container(); |
898 | |
899 | ASSERT_EQ(modules.size(), expected.size()); |
900 | |
901 | for (const auto i : c10::irange(expected.size())) { |
902 | ASSERT_EQ(modules[i].key(), expected[i].first); |
903 | ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second); |
904 | } |
905 | } |
906 | |
907 | TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) { |
908 | auto model = make_deeply_nested_test_container(); |
909 | std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children(); |
910 | |
911 | ASSERT_EQ(modules.size(), 3); |
912 | ASSERT_EQ(get_test_container_item(modules[0]), 1); |
913 | ASSERT_EQ(get_test_container_item(modules[1]), 4); |
914 | ASSERT_EQ(get_test_container_item(modules[2]), 5); |
915 | } |
916 | |
917 | TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) { |
918 | auto model = make_deeply_nested_test_container(); |
919 | torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
920 | model->named_children(); |
921 | |
922 | ASSERT_EQ(modules.size(), 3); |
923 | |
924 | ASSERT_EQ(get_test_container_item(modules[0].value()), 1); |
925 | ASSERT_EQ(modules[0].key(), "0" ); |
926 | |
927 | ASSERT_EQ(get_test_container_item(modules[1].value()), 4); |
928 | ASSERT_EQ(modules[1].key(), "1" ); |
929 | |
930 | ASSERT_EQ(get_test_container_item(modules[2].value()), 5); |
931 | ASSERT_EQ(modules[2].key(), "2" ); |
932 | } |
933 | |
934 | TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) { |
935 | auto model = make_deeply_nested_test_container(); |
936 | int64_t index = 0; |
937 | model->apply([&index](torch::nn::Module& module) { |
938 | ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++); |
939 | }); |
940 | ASSERT_EQ(index, 10); |
941 | } |
942 | |
943 | TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) { |
944 | std::shared_ptr<const TestContainer> model = |
945 | make_deeply_nested_test_container(); |
946 | int64_t index = 0; |
947 | model->apply([&index](const torch::nn::Module& module) { |
948 | ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++); |
949 | }); |
950 | ASSERT_EQ(index, 10); |
951 | } |
952 | |
953 | TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) { |
954 | auto model = make_deeply_nested_test_container(); |
955 | auto expected = make_key_value_pairs_for_deeply_nested_container(); |
956 | int64_t index = 0; |
957 | model->apply( |
958 | [&index, expected](const std::string& name, torch::nn::Module& module) { |
959 | ASSERT_EQ(name, expected[index].first); |
960 | ASSERT_EQ( |
961 | module.as<TestContainer>()->tensor.item<int64_t>(), |
962 | expected[index++].second); |
963 | }, |
964 | /*name_prefix=*/"test_prefix" ); |
965 | ASSERT_EQ(index, 10); |
966 | } |
967 | |
968 | TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) { |
969 | std::shared_ptr<const TestContainer> model = |
970 | make_deeply_nested_test_container(); |
971 | auto expected = make_key_value_pairs_for_deeply_nested_container(); |
972 | int64_t index = 0; |
973 | model->apply( |
974 | [&index, &expected]( |
975 | const std::string& name, const torch::nn::Module& module) { |
976 | ASSERT_EQ(name, expected[index].first); |
977 | ASSERT_EQ( |
978 | module.as<const TestContainer>()->tensor.item<int64_t>(), |
979 | expected[index++].second); |
980 | }, |
981 | /*name_prefix=*/"test_prefix" ); |
982 | ASSERT_EQ(index, 10); |
983 | } |
984 | |
985 | TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) { |
986 | auto model = make_deeply_nested_test_container(); |
987 | int64_t index = 0; |
988 | model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) { |
989 | ASSERT_EQ(get_test_container_item(module), index++); |
990 | }); |
991 | ASSERT_EQ(index, 10); |
992 | } |
993 | |
994 | TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) { |
995 | auto model = make_deeply_nested_test_container(); |
996 | auto expected = make_key_value_pairs_for_deeply_nested_container(); |
997 | int64_t index = 0; |
998 | model->apply( |
999 | [&index, &expected]( |
1000 | const std::string& name, |
1001 | const std::shared_ptr<torch::nn::Module>& module) { |
1002 | ASSERT_EQ(name, expected[index].first); |
1003 | ASSERT_EQ(get_test_container_item(module), expected[index++].second); |
1004 | }, |
1005 | /*name_prefix=*/"test_prefix" ); |
1006 | ASSERT_EQ(index, 10); |
1007 | } |
1008 | |
1009 | TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) { |
1010 | { |
1011 | TestModule module(1); |
1012 | ASSERT_THROWS_WITH( |
1013 | module.modules(), |
1014 | "It looks like you attempted to retrieve " |
1015 | "your top-level module as a shared_ptr" ) |
1016 | } |
1017 | { |
1018 | TestModule module(1); |
1019 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
1020 | ASSERT_NO_THROW(module.modules(/*include_self=*/false)); |
1021 | } |
1022 | { |
1023 | auto module = std::make_shared<TestModule>(1); |
1024 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
1025 | ASSERT_NO_THROW(module->modules()); |
1026 | } |
1027 | } |
1028 | |
1029 | struct EmptyModule : torch::nn::Module {}; |
1030 | |
1031 | TEST_F(ModuleTest, PrettyPrint) { |
1032 | struct TestModule : torch::nn::Module { |
1033 | TestModule(int x, float y) : x_(x), y_(y) {} |
1034 | |
1035 | void pretty_print(std::ostream& stream) const override { |
1036 | stream << "TestModule(x=" << x_ << ", y=" << y_ << ")" ; |
1037 | } |
1038 | |
1039 | int x_; |
1040 | float y_; |
1041 | }; |
1042 | |
1043 | ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule" ); |
1044 | ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)" ); |
1045 | } |
1046 | |
1047 | struct ModuleWithNonTensorForwardImpl : torch::nn::Module { |
1048 | int64_t forward(torch::Tensor x) { |
1049 | return x.numel(); |
1050 | } |
1051 | }; |
1052 | TORCH_MODULE(ModuleWithNonTensorForward); |
1053 | |
1054 | TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) { |
1055 | ModuleWithNonTensorForward m; |
1056 | ASSERT_EQ(m(torch::ones(123)), 123); |
1057 | } |
1058 | |