1 | #include <c10/core/CompileTimeFunctionPointer.h> |
2 | #include <gtest/gtest.h> |
3 | |
4 | namespace test_is_compile_time_function_pointer { |
5 | static_assert(!c10::is_compile_time_function_pointer<void()>::value, "" ); |
6 | |
7 | void dummy() {} |
8 | static_assert( |
9 | c10::is_compile_time_function_pointer<TORCH_FN_TYPE(dummy)>::value, |
10 | "" ); |
11 | } // namespace test_is_compile_time_function_pointer |
12 | |
13 | namespace test_access_through_type { |
14 | void dummy() {} |
15 | using dummy_ptr = TORCH_FN_TYPE(dummy); |
16 | static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "" ); |
17 | static_assert(dummy_ptr::func_ptr() == &dummy, "" ); |
18 | static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "" ); |
19 | } // namespace test_access_through_type |
20 | |
21 | namespace test_access_through_value { |
22 | void dummy() {} |
23 | constexpr auto dummy_ptr = TORCH_FN(dummy); |
24 | static_assert(dummy_ptr.func_ptr() == &dummy, "" ); |
25 | static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "" ); |
26 | } // namespace test_access_through_value |
27 | |
28 | namespace test_access_through_type_also_works_if_specified_as_pointer { |
29 | void dummy() {} |
30 | using dummy_ptr = TORCH_FN_TYPE(&dummy); |
31 | static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "" ); |
32 | static_assert(dummy_ptr::func_ptr() == &dummy, "" ); |
33 | static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "" ); |
34 | } // namespace test_access_through_type_also_works_if_specified_as_pointer |
35 | |
36 | namespace test_access_through_value_also_works_if_specified_as_pointer { |
37 | void dummy() {} |
38 | constexpr auto dummy_ptr = TORCH_FN(&dummy); |
39 | static_assert(dummy_ptr.func_ptr() == &dummy, "" ); |
40 | static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "" ); |
41 | } // namespace test_access_through_value_also_works_if_specified_as_pointer |
42 | |
43 | namespace test_run_through_type { |
44 | int add(int a, int b) { |
45 | return a + b; |
46 | } |
47 | using Add = TORCH_FN_TYPE(add); |
48 | template <class Func> |
49 | struct Executor { |
50 | int execute(int a, int b) { |
51 | return Func::func_ptr()(a, b); |
52 | } |
53 | }; |
54 | |
55 | TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) { |
56 | Executor<Add> executor; |
57 | EXPECT_EQ(3, executor.execute(1, 2)); |
58 | } |
59 | } // namespace test_run_through_type |
60 | |
61 | namespace test_run_through_value { |
62 | int add(int a, int b) { |
63 | return a + b; |
64 | } |
65 | template <class Func> |
66 | int execute(Func, int a, int b) { |
67 | return Func::func_ptr()(a, b); |
68 | } |
69 | |
70 | TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) { |
71 | EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); |
72 | } |
73 | } // namespace test_run_through_value |
74 | |