1#pragma once
2
3#include <c10/util/TypeTraits.h>
4
5namespace c10 {
6
7/**
8 * Represent a function pointer as a C++ type.
9 * This allows using the function pointer as a type
10 * in a template and calling it from inside the template
11 * allows the compiler to inline the call because it
12 * knows the function pointer at compile time.
13 *
14 * Example 1:
15 * int add(int a, int b) {return a + b;}
16 * using Add = TORCH_FN_TYPE(add);
17 * template<class Func> struct Executor {
18 * int execute(int a, int b) {
19 * return Func::func_ptr()(a, b);
20 * }
21 * };
22 * Executor<Add> executor;
23 * EXPECT_EQ(3, executor.execute(1, 2));
24 *
25 * Example 2:
26 * int add(int a, int b) {return a + b;}
27 * template<class Func> int execute(Func, int a, int b) {
28 * return Func::func_ptr()(a, b);
29 * }
30 * EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
31 */
32template <class FuncType_, FuncType_* func_ptr_>
33struct CompileTimeFunctionPointer final {
34 static_assert(
35 guts::is_function_type<FuncType_>::value,
36 "TORCH_FN can only wrap function types.");
37 using FuncType = FuncType_;
38
39 static constexpr FuncType* func_ptr() {
40 return func_ptr_;
41 }
42};
43
44template <class T>
45struct is_compile_time_function_pointer : std::false_type {};
46template <class FuncType, FuncType* func_ptr>
47struct is_compile_time_function_pointer<
48 CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
49
50} // namespace c10
51
52#define TORCH_FN_TYPE(func) \
53 ::c10::CompileTimeFunctionPointer< \
54 std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
55 func>
56#define TORCH_FN(func) TORCH_FN_TYPE(func)()
57