1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <test/cpp/api/support.h> |
6 | |
7 | #include <algorithm> |
8 | #include <string> |
9 | |
10 | using namespace torch::nn; |
11 | |
12 | struct AnyModuleTest : torch::test::SeedingFixture {}; |
13 | |
14 | TEST_F(AnyModuleTest, SimpleReturnType) { |
15 | struct M : torch::nn::Module { |
16 | int forward() { |
17 | return 123; |
18 | } |
19 | }; |
20 | AnyModule any(M{}); |
21 | ASSERT_EQ(any.forward<int>(), 123); |
22 | } |
23 | |
24 | TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) { |
25 | struct M : torch::nn::Module { |
26 | int forward(int x) { |
27 | return x; |
28 | } |
29 | }; |
30 | AnyModule any(M{}); |
31 | ASSERT_EQ(any.forward<int>(5), 5); |
32 | } |
33 | |
34 | TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) { |
35 | struct M : torch::nn::Module { |
36 | const char* forward(const char* x) { |
37 | return x; |
38 | } |
39 | }; |
40 | AnyModule any(M{}); |
41 | ASSERT_EQ(any.forward<const char*>("hello" ), std::string("hello" )); |
42 | } |
43 | |
44 | TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) { |
45 | struct M : torch::nn::Module { |
46 | std::string forward(int x, const double f) { |
47 | return std::to_string(static_cast<int>(x + f)); |
48 | } |
49 | }; |
50 | AnyModule any(M{}); |
51 | int x = 4; |
52 | ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7" )); |
53 | } |
54 | |
55 | TEST_F( |
56 | AnyModuleTest, |
57 | TensorReturnTypeAndStringArgumentsWithFunkyQualifications) { |
58 | struct M : torch::nn::Module { |
59 | torch::Tensor forward( |
60 | std::string a, |
61 | const std::string& b, |
62 | std::string&& c) { |
63 | const auto s = a + b + c; |
64 | return torch::ones({static_cast<int64_t>(s.size())}); |
65 | } |
66 | }; |
67 | AnyModule any(M{}); |
68 | ASSERT_TRUE( |
69 | any.forward(std::string("a" ), std::string("ab" ), std::string("abc" )) |
70 | .sum() |
71 | .item<int32_t>() == 6); |
72 | } |
73 | |
74 | TEST_F(AnyModuleTest, WrongArgumentType) { |
75 | struct M : torch::nn::Module { |
76 | int forward(float x) { |
77 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
78 | return x; |
79 | } |
80 | }; |
81 | AnyModule any(M{}); |
82 | ASSERT_THROWS_WITH( |
83 | any.forward(5.0), |
84 | "Expected argument #0 to be of type float, " |
85 | "but received value of type double" ); |
86 | } |
87 | |
88 | struct M_test_wrong_number_of_arguments : torch::nn::Module { |
89 | int forward(int a, int b) { |
90 | return a + b; |
91 | } |
92 | }; |
93 | |
94 | TEST_F(AnyModuleTest, WrongNumberOfArguments) { |
95 | AnyModule any(M_test_wrong_number_of_arguments{}); |
96 | #if defined(_MSC_VER) |
97 | std::string module_name = "struct M_test_wrong_number_of_arguments" ; |
98 | #else |
99 | std::string module_name = "M_test_wrong_number_of_arguments" ; |
100 | #endif |
101 | ASSERT_THROWS_WITH( |
102 | any.forward(), |
103 | module_name + |
104 | "'s forward() method expects 2 argument(s), but received 0. " |
105 | "If " + |
106 | module_name + |
107 | "'s forward() method has default arguments, " |
108 | "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." ); |
109 | ASSERT_THROWS_WITH( |
110 | any.forward(5), |
111 | module_name + |
112 | "'s forward() method expects 2 argument(s), but received 1. " |
113 | "If " + |
114 | module_name + |
115 | "'s forward() method has default arguments, " |
116 | "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." ); |
117 | ASSERT_THROWS_WITH( |
118 | any.forward(1, 2, 3), |
119 | module_name + |
120 | "'s forward() method expects 2 argument(s), but received 3." ); |
121 | } |
122 | |
123 | struct M_default_arg_with_macro : torch::nn::Module { |
124 | double forward(int a, int b = 2, double c = 3.0) { |
125 | return a + b + c; |
126 | } |
127 | |
128 | protected: |
129 | FORWARD_HAS_DEFAULT_ARGS( |
130 | {1, torch::nn::AnyValue(2)}, |
131 | {2, torch::nn::AnyValue(3.0)}) |
132 | }; |
133 | |
134 | struct M_default_arg_without_macro : torch::nn::Module { |
135 | double forward(int a, int b = 2, double c = 3.0) { |
136 | return a + b + c; |
137 | } |
138 | }; |
139 | |
140 | TEST_F( |
141 | AnyModuleTest, |
142 | PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) { |
143 | { |
144 | AnyModule any(M_default_arg_with_macro{}); |
145 | |
146 | ASSERT_EQ(any.forward<double>(1), 6.0); |
147 | ASSERT_EQ(any.forward<double>(1, 3), 7.0); |
148 | ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0); |
149 | |
150 | ASSERT_THROWS_WITH( |
151 | any.forward(), |
152 | "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0." ); |
153 | ASSERT_THROWS_WITH( |
154 | any.forward(1, 2, 3.0, 4), |
155 | "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4." ); |
156 | } |
157 | { |
158 | AnyModule any(M_default_arg_without_macro{}); |
159 | |
160 | ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0); |
161 | |
162 | #if defined(_MSC_VER) |
163 | std::string module_name = "struct M_default_arg_without_macro" ; |
164 | #else |
165 | std::string module_name = "M_default_arg_without_macro" ; |
166 | #endif |
167 | |
168 | ASSERT_THROWS_WITH( |
169 | any.forward(), |
170 | module_name + |
171 | "'s forward() method expects 3 argument(s), but received 0. " |
172 | "If " + |
173 | module_name + |
174 | "'s forward() method has default arguments, " |
175 | "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." ); |
176 | ASSERT_THROWS_WITH( |
177 | any.forward<double>(1), |
178 | module_name + |
179 | "'s forward() method expects 3 argument(s), but received 1. " |
180 | "If " + |
181 | module_name + |
182 | "'s forward() method has default arguments, " |
183 | "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." ); |
184 | ASSERT_THROWS_WITH( |
185 | any.forward<double>(1, 3), |
186 | module_name + |
187 | "'s forward() method expects 3 argument(s), but received 2. " |
188 | "If " + |
189 | module_name + |
190 | "'s forward() method has default arguments, " |
191 | "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." ); |
192 | ASSERT_THROWS_WITH( |
193 | any.forward(1, 2, 3.0, 4), |
194 | module_name + |
195 | "'s forward() method expects 3 argument(s), but received 4." ); |
196 | } |
197 | } |
198 | |
199 | struct M : torch::nn::Module { |
200 | explicit M(int value_) : torch::nn::Module("M" ), value(value_) {} |
201 | int value; |
202 | int forward(float x) { |
203 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
204 | return x; |
205 | } |
206 | }; |
207 | |
208 | TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) { |
209 | AnyModule any(M{5}); |
210 | ASSERT_EQ(any.get<M>().value, 5); |
211 | } |
212 | |
213 | TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) { |
214 | struct N : torch::nn::Module { |
215 | torch::Tensor forward(torch::Tensor input) { |
216 | return input; |
217 | } |
218 | }; |
219 | AnyModule any(M{5}); |
220 | ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module" ); |
221 | } |
222 | |
223 | TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) { |
224 | AnyModule any(M{5}); |
225 | auto ptr = any.ptr(); |
226 | ASSERT_NE(ptr, nullptr); |
227 | ASSERT_EQ(ptr->name(), "M" ); |
228 | } |
229 | |
230 | TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) { |
231 | AnyModule any(M{5}); |
232 | auto ptr = any.ptr<M>(); |
233 | ASSERT_NE(ptr, nullptr); |
234 | ASSERT_EQ(ptr->value, 5); |
235 | } |
236 | |
237 | TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) { |
238 | struct N : torch::nn::Module { |
239 | torch::Tensor forward(torch::Tensor input) { |
240 | return input; |
241 | } |
242 | }; |
243 | AnyModule any(M{5}); |
244 | ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module" ); |
245 | } |
246 | |
247 | TEST_F(AnyModuleTest, DefaultStateIsEmpty) { |
248 | struct M : torch::nn::Module { |
249 | explicit M(int value_) : value(value_) {} |
250 | int value; |
251 | int forward(float x) { |
252 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
253 | return x; |
254 | } |
255 | }; |
256 | AnyModule any; |
257 | ASSERT_TRUE(any.is_empty()); |
258 | any = std::make_shared<M>(5); |
259 | ASSERT_FALSE(any.is_empty()); |
260 | ASSERT_EQ(any.get<M>().value, 5); |
261 | } |
262 | |
263 | TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) { |
264 | struct M : torch::nn::Module { |
265 | int forward(int x) { |
266 | return x; |
267 | } |
268 | }; |
269 | AnyModule any; |
270 | ASSERT_TRUE(any.is_empty()); |
271 | ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule" ); |
272 | ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule" ); |
273 | ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule" ); |
274 | ASSERT_THROWS_WITH( |
275 | any.type_info(), "Cannot call type_info() on an empty AnyModule" ); |
276 | ASSERT_THROWS_WITH( |
277 | any.forward<int>(5), "Cannot call forward() on an empty AnyModule" ); |
278 | } |
279 | |
280 | TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) { |
281 | struct M : torch::nn::Module { |
282 | std::string forward(int x) { |
283 | return std::to_string(x); |
284 | } |
285 | }; |
286 | struct N : torch::nn::Module { |
287 | int forward(float x) { |
288 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
289 | return 3 + x; |
290 | } |
291 | }; |
292 | AnyModule any; |
293 | ASSERT_TRUE(any.is_empty()); |
294 | any = std::make_shared<M>(); |
295 | ASSERT_FALSE(any.is_empty()); |
296 | ASSERT_EQ(any.forward<std::string>(5), "5" ); |
297 | any = std::make_shared<N>(); |
298 | ASSERT_FALSE(any.is_empty()); |
299 | ASSERT_EQ(any.forward<int>(5.0f), 8); |
300 | } |
301 | |
302 | TEST_F(AnyModuleTest, ConstructsFromModuleHolder) { |
303 | struct MImpl : torch::nn::Module { |
304 | explicit MImpl(int value_) : torch::nn::Module("M" ), value(value_) {} |
305 | int value; |
306 | int forward(float x) { |
307 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
308 | return x; |
309 | } |
310 | }; |
311 | |
312 | struct M : torch::nn::ModuleHolder<MImpl> { |
313 | using torch::nn::ModuleHolder<MImpl>::ModuleHolder; |
314 | using torch::nn::ModuleHolder<MImpl>::get; |
315 | }; |
316 | |
317 | AnyModule any(M{5}); |
318 | ASSERT_EQ(any.get<MImpl>().value, 5); |
319 | ASSERT_EQ(any.get<M>()->value, 5); |
320 | |
321 | AnyModule module(Linear(3, 4)); |
322 | std::shared_ptr<Module> ptr = module.ptr(); |
323 | Linear linear(module.get<Linear>()); |
324 | } |
325 | |
326 | TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) { |
327 | struct M : torch::nn::Module { |
328 | torch::Tensor forward(torch::Tensor input) { |
329 | return input; |
330 | } |
331 | }; |
332 | |
333 | // When you have an autograd::Variable, it should be converted to a |
334 | // torch::Tensor before being passed to the function (to avoid a type |
335 | // mismatch). |
336 | AnyModule any(M{}); |
337 | ASSERT_TRUE( |
338 | any.forward(torch::autograd::Variable(torch::ones(5))) |
339 | .sum() |
340 | .item<float>() == 5); |
341 | // at::Tensors that are not variables work too. |
342 | ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5); |
343 | } |
344 | |
345 | namespace torch { |
346 | namespace nn { |
347 | struct TestAnyValue { |
348 | template <typename T> |
349 | // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) |
350 | explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {} |
351 | AnyValue operator()() { |
352 | return std::move(value_); |
353 | } |
354 | AnyValue value_; |
355 | }; |
356 | template <typename T> |
357 | AnyValue make_value(T&& value) { |
358 | return TestAnyValue(std::forward<T>(value))(); |
359 | } |
360 | } // namespace nn |
361 | } // namespace torch |
362 | |
363 | struct AnyValueTest : torch::test::SeedingFixture {}; |
364 | |
365 | TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) { |
366 | auto value = make_value<int>(5); |
367 | ASSERT_NE(value.try_get<int>(), nullptr); |
368 | // const and non-const types have the same typeid(), |
369 | // but casting Holder<int> to Holder<const int> is undefined |
370 | // behavior according to UBSAN: |
371 | // https://github.com/pytorch/pytorch/issues/26964 |
372 | // ASSERT_NE(value.try_get<const int>(), nullptr); |
373 | ASSERT_EQ(value.get<int>(), 5); |
374 | } |
375 | // This test does not work at all, because it looks like make_value |
376 | // decays const int into int. |
377 | // TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) { |
378 | // auto value = make_value<const int>(5); |
379 | // ASSERT_NE(value.try_get<const int>(), nullptr); |
380 | // // ASSERT_NE(value.try_get<int>(), nullptr); |
381 | // ASSERT_EQ(value.get<const int>(), 5); |
382 | //} |
383 | TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) { |
384 | auto value = make_value("hello" ); |
385 | ASSERT_NE(value.try_get<const char*>(), nullptr); |
386 | ASSERT_EQ(value.get<const char*>(), std::string("hello" )); |
387 | } |
388 | TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) { |
389 | auto value = make_value(std::string("hello" )); |
390 | ASSERT_NE(value.try_get<std::string>(), nullptr); |
391 | ASSERT_EQ(value.get<std::string>(), "hello" ); |
392 | } |
393 | TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) { |
394 | std::string s("hello" ); |
395 | std::string* p = &s; |
396 | auto value = make_value(p); |
397 | ASSERT_NE(value.try_get<std::string*>(), nullptr); |
398 | ASSERT_EQ(*value.get<std::string*>(), "hello" ); |
399 | } |
400 | TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) { |
401 | std::string s("hello" ); |
402 | const std::string& t = s; |
403 | auto value = make_value(t); |
404 | ASSERT_NE(value.try_get<std::string>(), nullptr); |
405 | ASSERT_EQ(value.get<std::string>(), "hello" ); |
406 | } |
407 | |
408 | TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) { |
409 | auto value = make_value(5); |
410 | ASSERT_NE(value.try_get<int>(), nullptr); |
411 | ASSERT_EQ(value.try_get<float>(), nullptr); |
412 | ASSERT_EQ(value.try_get<long>(), nullptr); |
413 | ASSERT_EQ(value.try_get<std::string>(), nullptr); |
414 | } |
415 | |
416 | TEST_F(AnyValueTest, GetThrowsForTheWrongType) { |
417 | auto value = make_value(5); |
418 | ASSERT_NE(value.try_get<int>(), nullptr); |
419 | ASSERT_THROWS_WITH( |
420 | value.get<float>(), |
421 | "Attempted to cast AnyValue to float, " |
422 | "but its actual type is int" ); |
423 | ASSERT_THROWS_WITH( |
424 | value.get<long>(), |
425 | "Attempted to cast AnyValue to long, " |
426 | "but its actual type is int" ); |
427 | } |
428 | |
429 | TEST_F(AnyValueTest, MoveConstructionIsAllowed) { |
430 | auto value = make_value(5); |
431 | auto copy = make_value(std::move(value)); |
432 | ASSERT_NE(copy.try_get<int>(), nullptr); |
433 | ASSERT_EQ(copy.get<int>(), 5); |
434 | } |
435 | |
436 | TEST_F(AnyValueTest, MoveAssignmentIsAllowed) { |
437 | auto value = make_value(5); |
438 | auto copy = make_value(10); |
439 | copy = std::move(value); |
440 | ASSERT_NE(copy.try_get<int>(), nullptr); |
441 | ASSERT_EQ(copy.get<int>(), 5); |
442 | } |
443 | |
444 | TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) { |
445 | auto value = make_value(5); |
446 | ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code()); |
447 | } |
448 | |
449 | TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) { |
450 | auto value = make_value("hello" ); |
451 | ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code()); |
452 | } |
453 | |
454 | TEST_F(AnyValueTest, TypeInfoIsCorrectForString) { |
455 | auto value = make_value(std::string("hello" )); |
456 | ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code()); |
457 | } |
458 | |