1 | #include <ATen/core/boxing/impl/boxing.h> |
2 | #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h> |
3 | #include <ATen/core/boxing/impl/WrapFunctionIntoFunctor.h> |
4 | #include <ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h> |
5 | |
6 | namespace c10 { |
7 | |
8 | inline KernelFunction::KernelFunction() |
9 | : boxed_kernel_func_() |
10 | , unboxed_kernel_func_(nullptr) |
11 | , sym_unboxed_kernel_func_(nullptr) |
12 | {} |
13 | |
14 | inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) |
15 | : boxed_kernel_func_(std::move(functor), boxed_kernel_func) |
16 | , unboxed_kernel_func_(unboxed_kernel_func) |
17 | , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) |
18 | {} |
19 | |
20 | inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) |
21 | : boxed_kernel_func_(std::move(boxed_fn)) |
22 | , unboxed_kernel_func_(unboxed_kernel_func) |
23 | , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) |
24 | {} |
25 | |
26 | inline bool KernelFunction::isValidUnboxed() const { |
27 | return unboxed_kernel_func_ != nullptr; |
28 | } |
29 | |
30 | inline bool KernelFunction::isValidSymUnboxed() const { |
31 | return sym_unboxed_kernel_func_ != nullptr; |
32 | } |
33 | |
34 | inline bool KernelFunction::isValid() const { |
35 | return boxed_kernel_func_.isValid(); |
36 | } |
37 | |
38 | inline bool KernelFunction::isFallthrough() const { |
39 | return boxed_kernel_func_.isFallthrough(); |
40 | } |
41 | |
42 | inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const { |
43 | boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack); |
44 | } |
45 | |
46 | template<class Return, class... Args> |
47 | inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) { |
48 | using ActualSignature = Return (OperatorKernel*, DispatchKeySet, Args...); |
49 | ActualSignature* func = reinterpret_cast<ActualSignature*>(unboxed_kernel_func); |
50 | return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...); |
51 | } |
52 | |
53 | // This template requires you to explicitly specify the argument you want to |
54 | // forward; it doesn't work if you try to deduce it |
55 | // NB: keep this in sync with cloneWithRealTypes in function_schema.cpp |
56 | |
57 | template <typename T> |
58 | inline typename remove_symint<T>::type unpackSymInt(T x) { return x; } |
59 | |
60 | template <> |
61 | inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) { |
62 | return x.expect_int(); |
63 | } |
64 | |
65 | template <> |
66 | inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIntArrayRef x) { |
67 | return C10_AS_INTARRAYREF_SLOW(x); |
68 | } |
69 | |
70 | template <> |
71 | inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) { |
72 | return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt; |
73 | } |
74 | |
75 | template <> |
76 | inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(at::OptionalSymIntArrayRef x) { |
77 | return x.has_value() ? c10::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : c10::nullopt; |
78 | } |
79 | |
80 | template<class Return, class... Args> |
81 | C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const { |
82 | // note: Args above is intentionally not Args&&. We don't want perfect |
83 | // forwarding, which would require Args to be deduced, but instead we |
84 | // want callers to explicitly specify the Args. |
85 | |
86 | // This should get inlined by compiler |
87 | if (guts::disjunction<has_symint<Args>...>::value) { |
88 | if (sym_unboxed_kernel_func_ != nullptr) { |
89 | auto *functor = boxed_kernel_func_.getFunctor(); |
90 | return callUnboxedKernelFunction<Return, Args...>( |
91 | sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...); |
92 | } |
93 | |
94 | if (unboxed_kernel_func_ != nullptr) { |
95 | auto *functor = boxed_kernel_func_.getFunctor(); |
96 | return callUnboxedKernelFunction<Return, typename remove_symint<Args>::type...>( |
97 | unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt<Args>(args)...); |
98 | } |
99 | } else { |
100 | if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { |
101 | auto *functor = boxed_kernel_func_.getFunctor(); |
102 | return callUnboxedKernelFunction<Return, Args...>( |
103 | unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...); |
104 | } |
105 | } |
106 | |
107 | return impl::BoxedKernelWrapper<Return(Args...)>::call( |
108 | boxed_kernel_func_, |
109 | opHandle, |
110 | dispatchKeySet, |
111 | std::forward<Args>(args)... |
112 | ); |
113 | } |
114 | |
115 | inline KernelFunction KernelFunction::makeFromBoxedKernel(BoxedKernel boxed_fn) { |
116 | return KernelFunction(std::move(boxed_fn), nullptr); // no unboxed function pointer |
117 | } |
118 | |
119 | template<KernelFunction::BoxedKernelFunction* func> |
120 | inline KernelFunction KernelFunction::makeFromBoxedFunction() { |
121 | return KernelFunction::makeFromBoxedKernel( |
122 | BoxedKernel::makeFromFunction<func>()); |
123 | } |
124 | |
125 | template<KernelFunction::BoxedKernelFunction_withDispatchKeys* func> |
126 | inline KernelFunction KernelFunction::makeFromBoxedFunction() { |
127 | return KernelFunction::makeFromBoxedKernel( |
128 | BoxedKernel::makeFromFunction<func>()); |
129 | } |
130 | |
131 | inline KernelFunction KernelFunction::makeFallthrough() { |
132 | return KernelFunction::makeFromBoxedKernel( |
133 | BoxedKernel::makeFallthrough()); |
134 | } |
135 | |
136 | inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() { |
137 | return KernelFunction::makeFromBoxedKernel( |
138 | BoxedKernel::makeAmbiguousAutogradOther()); |
139 | } |
140 | |
141 | inline KernelFunction KernelFunction::makeNamedNotSupported() { |
142 | return KernelFunction::makeFromBoxedKernel( |
143 | BoxedKernel::makeNamedNotSupported()); |
144 | } |
145 | |
146 | template<bool AllowLegacyTypes, class KernelFunctor> |
147 | inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor) { |
148 | #ifndef NDEBUG |
149 | // This assertion is costly for build time so it's debug-gated. |
150 | static_assert(guts::is_functor<KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor." ); |
151 | #endif |
152 | static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it." ); |
153 | |
154 | auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call; |
155 | void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn); |
156 | bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value; |
157 | return KernelFunction( |
158 | std::move(kernelFunctor), |
159 | &impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::call, |
160 | is_symint ? nullptr : void_unboxed_fn, |
161 | is_symint ? void_unboxed_fn : nullptr |
162 | ); |
163 | } |
164 | |
165 | template<class KernelFunctor> |
166 | inline KernelFunction KernelFunction::makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) { |
167 | return KernelFunction::makeFromBoxedKernel( |
168 | BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); |
169 | } |
170 | |
171 | template<class FuncPtr, bool AllowLegacyTypes> |
172 | inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) { |
173 | static_assert(is_compile_time_function_pointer<FuncPtr>::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN." ); |
174 | static_assert(!std::is_same<typename FuncPtr::FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead." ); |
175 | static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr" ); |
176 | |
177 | #if !defined(C10_MOBILE) |
178 | (void)func_ptr; // Suppress unused variable warning |
179 | return makeFromUnboxedFunctor<AllowLegacyTypes, typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>( |
180 | guts::make_unique_base<OperatorKernel, typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>() |
181 | ); |
182 | #else |
183 | // On mobile, we rather want to optimize for binary size than for performance, |
184 | // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction |
185 | // instead. |
186 | return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr()); |
187 | #endif |
188 | } |
189 | |
190 | template<bool AllowLegacyTypes, class FuncType> |
191 | inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) { |
192 | static_assert(guts::is_function_type<FuncType>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type." ); |
193 | static_assert(!std::is_same<FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead." ); |
194 | TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr" ); |
195 | |
196 | return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>( |
197 | guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func) |
198 | ); |
199 | } |
200 | |
201 | template<bool AllowLegacyTypes, class Lambda> |
202 | inline std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { |
203 | static_assert(guts::is_functor<std::decay_t<Lambda>>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type." ); |
204 | |
205 | #if !defined(C10_MOBILE) |
206 | return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>( |
207 | guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(std::forward<Lambda>(lambda)) |
208 | ); |
209 | #else |
210 | // On mobile, we rather want to optimize for binary size than for performance, |
211 | // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction |
212 | // instead. |
213 | using FuncType = typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type; |
214 | return makeFromUnboxedRuntimeFunction<AllowLegacyTypes, FuncType>(lambda); |
215 | #endif |
216 | } |
217 | |
218 | template<bool AllowLegacyTypes, class Lambda> |
219 | inline std::enable_if_t<!guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { |
220 | static_assert(guts::is_functor<std::decay_t<Lambda>>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type." ); |
221 | |
222 | return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>( |
223 | guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(std::forward<Lambda>(lambda)) |
224 | ); |
225 | } |
226 | |
227 | } |
228 | |