1 | #pragma once |
2 | |
3 | #include <c10/util/TypeTraits.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | /** |
8 | * Represent a function pointer as a C++ type. |
9 | * This allows using the function pointer as a type |
10 | * in a template and calling it from inside the template |
11 | * allows the compiler to inline the call because it |
12 | * knows the function pointer at compile time. |
13 | * |
14 | * Example 1: |
15 | * int add(int a, int b) {return a + b;} |
16 | * using Add = TORCH_FN_TYPE(add); |
17 | * template<class Func> struct Executor { |
18 | * int execute(int a, int b) { |
19 | * return Func::func_ptr()(a, b); |
20 | * } |
21 | * }; |
22 | * Executor<Add> executor; |
23 | * EXPECT_EQ(3, executor.execute(1, 2)); |
24 | * |
25 | * Example 2: |
26 | * int add(int a, int b) {return a + b;} |
27 | * template<class Func> int execute(Func, int a, int b) { |
28 | * return Func::func_ptr()(a, b); |
29 | * } |
30 | * EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); |
31 | */ |
32 | template <class FuncType_, FuncType_* func_ptr_> |
33 | struct CompileTimeFunctionPointer final { |
34 | static_assert( |
35 | guts::is_function_type<FuncType_>::value, |
36 | "TORCH_FN can only wrap function types." ); |
37 | using FuncType = FuncType_; |
38 | |
39 | static constexpr FuncType* func_ptr() { |
40 | return func_ptr_; |
41 | } |
42 | }; |
43 | |
44 | template <class T> |
45 | struct is_compile_time_function_pointer : std::false_type {}; |
46 | template <class FuncType, FuncType* func_ptr> |
47 | struct is_compile_time_function_pointer< |
48 | CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {}; |
49 | |
50 | } // namespace c10 |
51 | |
52 | #define TORCH_FN_TYPE(func) \ |
53 | ::c10::CompileTimeFunctionPointer< \ |
54 | std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \ |
55 | func> |
56 | #define TORCH_FN(func) TORCH_FN_TYPE(func)() |
57 | |