1#include <c10/core/CompileTimeFunctionPointer.h>
2#include <gtest/gtest.h>
3
4namespace test_is_compile_time_function_pointer {
5static_assert(!c10::is_compile_time_function_pointer<void()>::value, "");
6
7void dummy() {}
8static_assert(
9 c10::is_compile_time_function_pointer<TORCH_FN_TYPE(dummy)>::value,
10 "");
11} // namespace test_is_compile_time_function_pointer
12
13namespace test_access_through_type {
14void dummy() {}
15using dummy_ptr = TORCH_FN_TYPE(dummy);
16static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
17static_assert(dummy_ptr::func_ptr() == &dummy, "");
18static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
19} // namespace test_access_through_type
20
21namespace test_access_through_value {
22void dummy() {}
23constexpr auto dummy_ptr = TORCH_FN(dummy);
24static_assert(dummy_ptr.func_ptr() == &dummy, "");
25static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
26} // namespace test_access_through_value
27
28namespace test_access_through_type_also_works_if_specified_as_pointer {
29void dummy() {}
30using dummy_ptr = TORCH_FN_TYPE(&dummy);
31static_assert(c10::is_compile_time_function_pointer<dummy_ptr>::value, "");
32static_assert(dummy_ptr::func_ptr() == &dummy, "");
33static_assert(std::is_same<void(), dummy_ptr::FuncType>::value, "");
34} // namespace test_access_through_type_also_works_if_specified_as_pointer
35
36namespace test_access_through_value_also_works_if_specified_as_pointer {
37void dummy() {}
38constexpr auto dummy_ptr = TORCH_FN(&dummy);
39static_assert(dummy_ptr.func_ptr() == &dummy, "");
40static_assert(std::is_same<void(), decltype(dummy_ptr)::FuncType>::value, "");
41} // namespace test_access_through_value_also_works_if_specified_as_pointer
42
43namespace test_run_through_type {
44int add(int a, int b) {
45 return a + b;
46}
47using Add = TORCH_FN_TYPE(add);
48template <class Func>
49struct Executor {
50 int execute(int a, int b) {
51 return Func::func_ptr()(a, b);
52 }
53};
54
55TEST(CompileTimeFunctionPointerTest, runFunctionThroughType) {
56 Executor<Add> executor;
57 EXPECT_EQ(3, executor.execute(1, 2));
58}
59} // namespace test_run_through_type
60
61namespace test_run_through_value {
62int add(int a, int b) {
63 return a + b;
64}
65template <class Func>
66int execute(Func, int a, int b) {
67 return Func::func_ptr()(a, b);
68}
69
70TEST(CompileTimeFunctionPointerTest, runFunctionThroughValue) {
71 EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
72}
73} // namespace test_run_through_value
74