1 | #pragma once |
2 | |
3 | #include <c10/core/CompileTimeFunctionPointer.h> |
4 | |
5 | namespace c10 { |
6 | namespace impl { |
7 | namespace detail { |
8 | template<class FuncPtr, class ReturnType, class ParameterList> class WrapFunctionIntoFunctor_ {}; |
9 | template<class FuncPtr, class ReturnType, class... Parameters> |
10 | class WrapFunctionIntoFunctor_<FuncPtr, ReturnType, guts::typelist::typelist<Parameters...>> final : public c10::OperatorKernel { |
11 | public: |
12 | C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) { |
13 | return (*FuncPtr::func_ptr())(std::forward<Parameters>(args)...); |
14 | } |
15 | }; |
16 | } |
17 | |
18 | // WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel functor. |
19 | // Since it is a compile time function pointer, many compilers can inline it |
20 | // into the wrapper and you don't get any performance overhead for wrapping. |
21 | template<class FuncPtr> |
22 | struct WrapFunctionIntoFunctor final { |
23 | static_assert(c10::is_compile_time_function_pointer<FuncPtr>::value, "WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN." ); |
24 | using type = detail::WrapFunctionIntoFunctor_< |
25 | FuncPtr, |
26 | typename guts::function_traits<typename FuncPtr::FuncType>::return_type, |
27 | typename guts::function_traits<typename FuncPtr::FuncType>::parameter_types |
28 | >; |
29 | }; |
30 | } |
31 | |
32 | } |
33 | |