1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/utils/variadic.h> |
5 | #include <torch/detail/static.h> |
6 | #include <torch/torch.h> |
7 | |
8 | #include <string> |
9 | #include <vector> |
10 | |
11 | template < |
12 | typename T, |
13 | typename = torch::enable_if_t<!torch::detail::is_module<T>::value>> |
14 | bool f(T&& m) { |
15 | return false; |
16 | } |
17 | |
18 | template <typename T> |
19 | torch::detail::enable_if_module_t<T, bool> f(T&& m) { |
20 | return true; |
21 | } |
22 | |
23 | TEST(TestStatic, AllOf) { |
24 | ASSERT_TRUE(torch::all_of<>::value); |
25 | ASSERT_TRUE(torch::all_of<true>::value); |
26 | ASSERT_TRUE((torch::all_of<true, true, true>::value)); |
27 | ASSERT_FALSE(torch::all_of<false>::value); |
28 | ASSERT_FALSE((torch::all_of<false, false, false>::value)); |
29 | ASSERT_FALSE((torch::all_of<true, true, false>::value)); |
30 | } |
31 | |
32 | TEST(TestStatic, AnyOf) { |
33 | ASSERT_FALSE(torch::any_of<>::value); |
34 | ASSERT_TRUE(bool((torch::any_of<true>::value))); |
35 | ASSERT_TRUE(bool((torch::any_of<true, true, true>::value))); |
36 | ASSERT_FALSE(bool((torch::any_of<false>::value))); |
37 | } |
38 | |
39 | TEST(TestStatic, EnableIfModule) { |
40 | ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2))); |
41 | ASSERT_FALSE(f(5)); |
42 | ASSERT_TRUE(torch::detail::check_not_lvalue_references<int>()); |
43 | ASSERT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>())); |
44 | ASSERT_FALSE( |
45 | (torch::detail::check_not_lvalue_references<float, int&, char>())); |
46 | ASSERT_TRUE(torch::detail::check_not_lvalue_references<std::string>()); |
47 | ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>()); |
48 | } |
49 | |
50 | namespace { |
51 | |
52 | struct A : torch::nn::Module { |
53 | int forward() { |
54 | return 5; |
55 | } |
56 | }; |
57 | |
58 | struct B : torch::nn::Module { |
59 | std::string forward(torch::Tensor tensor) { |
60 | return "" ; |
61 | } |
62 | }; |
63 | |
64 | struct C : torch::nn::Module { |
65 | float forward(torch::Tensor& tensor) { |
66 | return 5.0; |
67 | } |
68 | }; |
69 | |
70 | struct D : torch::nn::Module { |
71 | char forward(torch::Tensor&& tensor) { |
72 | return 'x'; |
73 | } |
74 | }; |
75 | |
76 | struct E : torch::nn::Module {}; |
77 | |
78 | } // anonymous namespace |
79 | |
80 | // Put in a function because macros don't handle the comma between arguments to |
81 | // is_same well ... |
82 | template <typename Module, typename ExpectedType, typename... Args> |
83 | void assert_has_expected_type() { |
84 | using ReturnType = |
85 | typename torch::detail::return_type_of_forward<Module, Args...>::type; |
86 | constexpr bool is_expected_type = |
87 | std::is_same<ReturnType, ExpectedType>::value; |
88 | ASSERT_TRUE(is_expected_type) << Module().name(); |
89 | } |
90 | |
91 | TEST(TestStatic, ReturnTypeOfForward) { |
92 | assert_has_expected_type<A, int>(); |
93 | assert_has_expected_type<B, std::string, torch::Tensor>(); |
94 | assert_has_expected_type<C, float, torch::Tensor&>(); |
95 | assert_has_expected_type<D, char, torch::Tensor&&>(); |
96 | assert_has_expected_type<E, void>(); |
97 | } |
98 | |
99 | TEST(TestStatic, Apply) { |
100 | std::vector<int> v; |
101 | torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5); |
102 | ASSERT_EQ(v.size(), 5); |
103 | for (const auto i : c10::irange(v.size())) { |
104 | ASSERT_EQ(v.at(i), i + 1); |
105 | } |
106 | } |
107 | |