1#include <gtest/gtest.h>
2
3#include <torch/torch.h>
4
5#include <test/cpp/api/support.h>
6
7#include <algorithm>
8#include <string>
9
10using namespace torch::nn;
11
12struct AnyModuleTest : torch::test::SeedingFixture {};
13
14TEST_F(AnyModuleTest, SimpleReturnType) {
15 struct M : torch::nn::Module {
16 int forward() {
17 return 123;
18 }
19 };
20 AnyModule any(M{});
21 ASSERT_EQ(any.forward<int>(), 123);
22}
23
24TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
25 struct M : torch::nn::Module {
26 int forward(int x) {
27 return x;
28 }
29 };
30 AnyModule any(M{});
31 ASSERT_EQ(any.forward<int>(5), 5);
32}
33
34TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
35 struct M : torch::nn::Module {
36 const char* forward(const char* x) {
37 return x;
38 }
39 };
40 AnyModule any(M{});
41 ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
42}
43
44TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
45 struct M : torch::nn::Module {
46 std::string forward(int x, const double f) {
47 return std::to_string(static_cast<int>(x + f));
48 }
49 };
50 AnyModule any(M{});
51 int x = 4;
52 ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
53}
54
55TEST_F(
56 AnyModuleTest,
57 TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
58 struct M : torch::nn::Module {
59 torch::Tensor forward(
60 std::string a,
61 const std::string& b,
62 std::string&& c) {
63 const auto s = a + b + c;
64 return torch::ones({static_cast<int64_t>(s.size())});
65 }
66 };
67 AnyModule any(M{});
68 ASSERT_TRUE(
69 any.forward(std::string("a"), std::string("ab"), std::string("abc"))
70 .sum()
71 .item<int32_t>() == 6);
72}
73
74TEST_F(AnyModuleTest, WrongArgumentType) {
75 struct M : torch::nn::Module {
76 int forward(float x) {
77 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
78 return x;
79 }
80 };
81 AnyModule any(M{});
82 ASSERT_THROWS_WITH(
83 any.forward(5.0),
84 "Expected argument #0 to be of type float, "
85 "but received value of type double");
86}
87
88struct M_test_wrong_number_of_arguments : torch::nn::Module {
89 int forward(int a, int b) {
90 return a + b;
91 }
92};
93
94TEST_F(AnyModuleTest, WrongNumberOfArguments) {
95 AnyModule any(M_test_wrong_number_of_arguments{});
96#if defined(_MSC_VER)
97 std::string module_name = "struct M_test_wrong_number_of_arguments";
98#else
99 std::string module_name = "M_test_wrong_number_of_arguments";
100#endif
101 ASSERT_THROWS_WITH(
102 any.forward(),
103 module_name +
104 "'s forward() method expects 2 argument(s), but received 0. "
105 "If " +
106 module_name +
107 "'s forward() method has default arguments, "
108 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
109 ASSERT_THROWS_WITH(
110 any.forward(5),
111 module_name +
112 "'s forward() method expects 2 argument(s), but received 1. "
113 "If " +
114 module_name +
115 "'s forward() method has default arguments, "
116 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
117 ASSERT_THROWS_WITH(
118 any.forward(1, 2, 3),
119 module_name +
120 "'s forward() method expects 2 argument(s), but received 3.");
121}
122
123struct M_default_arg_with_macro : torch::nn::Module {
124 double forward(int a, int b = 2, double c = 3.0) {
125 return a + b + c;
126 }
127
128 protected:
129 FORWARD_HAS_DEFAULT_ARGS(
130 {1, torch::nn::AnyValue(2)},
131 {2, torch::nn::AnyValue(3.0)})
132};
133
134struct M_default_arg_without_macro : torch::nn::Module {
135 double forward(int a, int b = 2, double c = 3.0) {
136 return a + b + c;
137 }
138};
139
140TEST_F(
141 AnyModuleTest,
142 PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
143 {
144 AnyModule any(M_default_arg_with_macro{});
145
146 ASSERT_EQ(any.forward<double>(1), 6.0);
147 ASSERT_EQ(any.forward<double>(1, 3), 7.0);
148 ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
149
150 ASSERT_THROWS_WITH(
151 any.forward(),
152 "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0.");
153 ASSERT_THROWS_WITH(
154 any.forward(1, 2, 3.0, 4),
155 "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4.");
156 }
157 {
158 AnyModule any(M_default_arg_without_macro{});
159
160 ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
161
162#if defined(_MSC_VER)
163 std::string module_name = "struct M_default_arg_without_macro";
164#else
165 std::string module_name = "M_default_arg_without_macro";
166#endif
167
168 ASSERT_THROWS_WITH(
169 any.forward(),
170 module_name +
171 "'s forward() method expects 3 argument(s), but received 0. "
172 "If " +
173 module_name +
174 "'s forward() method has default arguments, "
175 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
176 ASSERT_THROWS_WITH(
177 any.forward<double>(1),
178 module_name +
179 "'s forward() method expects 3 argument(s), but received 1. "
180 "If " +
181 module_name +
182 "'s forward() method has default arguments, "
183 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
184 ASSERT_THROWS_WITH(
185 any.forward<double>(1, 3),
186 module_name +
187 "'s forward() method expects 3 argument(s), but received 2. "
188 "If " +
189 module_name +
190 "'s forward() method has default arguments, "
191 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
192 ASSERT_THROWS_WITH(
193 any.forward(1, 2, 3.0, 4),
194 module_name +
195 "'s forward() method expects 3 argument(s), but received 4.");
196 }
197}
198
199struct M : torch::nn::Module {
200 explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
201 int value;
202 int forward(float x) {
203 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
204 return x;
205 }
206};
207
208TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
209 AnyModule any(M{5});
210 ASSERT_EQ(any.get<M>().value, 5);
211}
212
213TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
214 struct N : torch::nn::Module {
215 torch::Tensor forward(torch::Tensor input) {
216 return input;
217 }
218 };
219 AnyModule any(M{5});
220 ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
221}
222
223TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
224 AnyModule any(M{5});
225 auto ptr = any.ptr();
226 ASSERT_NE(ptr, nullptr);
227 ASSERT_EQ(ptr->name(), "M");
228}
229
230TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
231 AnyModule any(M{5});
232 auto ptr = any.ptr<M>();
233 ASSERT_NE(ptr, nullptr);
234 ASSERT_EQ(ptr->value, 5);
235}
236
237TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
238 struct N : torch::nn::Module {
239 torch::Tensor forward(torch::Tensor input) {
240 return input;
241 }
242 };
243 AnyModule any(M{5});
244 ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
245}
246
247TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
248 struct M : torch::nn::Module {
249 explicit M(int value_) : value(value_) {}
250 int value;
251 int forward(float x) {
252 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
253 return x;
254 }
255 };
256 AnyModule any;
257 ASSERT_TRUE(any.is_empty());
258 any = std::make_shared<M>(5);
259 ASSERT_FALSE(any.is_empty());
260 ASSERT_EQ(any.get<M>().value, 5);
261}
262
263TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
264 struct M : torch::nn::Module {
265 int forward(int x) {
266 return x;
267 }
268 };
269 AnyModule any;
270 ASSERT_TRUE(any.is_empty());
271 ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
272 ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
273 ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
274 ASSERT_THROWS_WITH(
275 any.type_info(), "Cannot call type_info() on an empty AnyModule");
276 ASSERT_THROWS_WITH(
277 any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
278}
279
280TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
281 struct M : torch::nn::Module {
282 std::string forward(int x) {
283 return std::to_string(x);
284 }
285 };
286 struct N : torch::nn::Module {
287 int forward(float x) {
288 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
289 return 3 + x;
290 }
291 };
292 AnyModule any;
293 ASSERT_TRUE(any.is_empty());
294 any = std::make_shared<M>();
295 ASSERT_FALSE(any.is_empty());
296 ASSERT_EQ(any.forward<std::string>(5), "5");
297 any = std::make_shared<N>();
298 ASSERT_FALSE(any.is_empty());
299 ASSERT_EQ(any.forward<int>(5.0f), 8);
300}
301
302TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
303 struct MImpl : torch::nn::Module {
304 explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
305 int value;
306 int forward(float x) {
307 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
308 return x;
309 }
310 };
311
312 struct M : torch::nn::ModuleHolder<MImpl> {
313 using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
314 using torch::nn::ModuleHolder<MImpl>::get;
315 };
316
317 AnyModule any(M{5});
318 ASSERT_EQ(any.get<MImpl>().value, 5);
319 ASSERT_EQ(any.get<M>()->value, 5);
320
321 AnyModule module(Linear(3, 4));
322 std::shared_ptr<Module> ptr = module.ptr();
323 Linear linear(module.get<Linear>());
324}
325
326TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
327 struct M : torch::nn::Module {
328 torch::Tensor forward(torch::Tensor input) {
329 return input;
330 }
331 };
332
333 // When you have an autograd::Variable, it should be converted to a
334 // torch::Tensor before being passed to the function (to avoid a type
335 // mismatch).
336 AnyModule any(M{});
337 ASSERT_TRUE(
338 any.forward(torch::autograd::Variable(torch::ones(5)))
339 .sum()
340 .item<float>() == 5);
341 // at::Tensors that are not variables work too.
342 ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
343}
344
345namespace torch {
346namespace nn {
347struct TestAnyValue {
348 template <typename T>
349 // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
350 explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {}
351 AnyValue operator()() {
352 return std::move(value_);
353 }
354 AnyValue value_;
355};
356template <typename T>
357AnyValue make_value(T&& value) {
358 return TestAnyValue(std::forward<T>(value))();
359}
360} // namespace nn
361} // namespace torch
362
363struct AnyValueTest : torch::test::SeedingFixture {};
364
365TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
366 auto value = make_value<int>(5);
367 ASSERT_NE(value.try_get<int>(), nullptr);
368 // const and non-const types have the same typeid(),
369 // but casting Holder<int> to Holder<const int> is undefined
370 // behavior according to UBSAN:
371 // https://github.com/pytorch/pytorch/issues/26964
372 // ASSERT_NE(value.try_get<const int>(), nullptr);
373 ASSERT_EQ(value.get<int>(), 5);
374}
375// This test does not work at all, because it looks like make_value
376// decays const int into int.
377// TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
378// auto value = make_value<const int>(5);
379// ASSERT_NE(value.try_get<const int>(), nullptr);
380// // ASSERT_NE(value.try_get<int>(), nullptr);
381// ASSERT_EQ(value.get<const int>(), 5);
382//}
383TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
384 auto value = make_value("hello");
385 ASSERT_NE(value.try_get<const char*>(), nullptr);
386 ASSERT_EQ(value.get<const char*>(), std::string("hello"));
387}
388TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
389 auto value = make_value(std::string("hello"));
390 ASSERT_NE(value.try_get<std::string>(), nullptr);
391 ASSERT_EQ(value.get<std::string>(), "hello");
392}
393TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
394 std::string s("hello");
395 std::string* p = &s;
396 auto value = make_value(p);
397 ASSERT_NE(value.try_get<std::string*>(), nullptr);
398 ASSERT_EQ(*value.get<std::string*>(), "hello");
399}
400TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
401 std::string s("hello");
402 const std::string& t = s;
403 auto value = make_value(t);
404 ASSERT_NE(value.try_get<std::string>(), nullptr);
405 ASSERT_EQ(value.get<std::string>(), "hello");
406}
407
408TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
409 auto value = make_value(5);
410 ASSERT_NE(value.try_get<int>(), nullptr);
411 ASSERT_EQ(value.try_get<float>(), nullptr);
412 ASSERT_EQ(value.try_get<long>(), nullptr);
413 ASSERT_EQ(value.try_get<std::string>(), nullptr);
414}
415
416TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
417 auto value = make_value(5);
418 ASSERT_NE(value.try_get<int>(), nullptr);
419 ASSERT_THROWS_WITH(
420 value.get<float>(),
421 "Attempted to cast AnyValue to float, "
422 "but its actual type is int");
423 ASSERT_THROWS_WITH(
424 value.get<long>(),
425 "Attempted to cast AnyValue to long, "
426 "but its actual type is int");
427}
428
429TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
430 auto value = make_value(5);
431 auto copy = make_value(std::move(value));
432 ASSERT_NE(copy.try_get<int>(), nullptr);
433 ASSERT_EQ(copy.get<int>(), 5);
434}
435
436TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
437 auto value = make_value(5);
438 auto copy = make_value(10);
439 copy = std::move(value);
440 ASSERT_NE(copy.try_get<int>(), nullptr);
441 ASSERT_EQ(copy.get<int>(), 5);
442}
443
444TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
445 auto value = make_value(5);
446 ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
447}
448
449TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
450 auto value = make_value("hello");
451 ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
452}
453
454TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
455 auto value = make_value(std::string("hello"));
456 ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
457}
458