1#include <ATen/core/boxing/impl/boxing.h>
2#include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
3#include <ATen/core/boxing/impl/WrapFunctionIntoFunctor.h>
4#include <ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h>
5
6namespace c10 {
7
8inline KernelFunction::KernelFunction()
9 : boxed_kernel_func_()
10 , unboxed_kernel_func_(nullptr)
11 , sym_unboxed_kernel_func_(nullptr)
12{}
13
14inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
15 : boxed_kernel_func_(std::move(functor), boxed_kernel_func)
16 , unboxed_kernel_func_(unboxed_kernel_func)
17 , sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
18{}
19
20inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
21 : boxed_kernel_func_(std::move(boxed_fn))
22 , unboxed_kernel_func_(unboxed_kernel_func)
23 , sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
24{}
25
26inline bool KernelFunction::isValidUnboxed() const {
27 return unboxed_kernel_func_ != nullptr;
28}
29
30inline bool KernelFunction::isValidSymUnboxed() const {
31 return sym_unboxed_kernel_func_ != nullptr;
32}
33
34inline bool KernelFunction::isValid() const {
35 return boxed_kernel_func_.isValid();
36}
37
38inline bool KernelFunction::isFallthrough() const {
39 return boxed_kernel_func_.isFallthrough();
40}
41
42inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const {
43 boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack);
44}
45
46template<class Return, class... Args>
47inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) {
48 using ActualSignature = Return (OperatorKernel*, DispatchKeySet, Args...);
49 ActualSignature* func = reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
50 return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
51}
52
53// This template requires you to explicitly specify the argument you want to
54// forward; it doesn't work if you try to deduce it
55// NB: keep this in sync with cloneWithRealTypes in function_schema.cpp
56
57template <typename T>
58inline typename remove_symint<T>::type unpackSymInt(T x) { return x; }
59
60template <>
61inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
62 return x.expect_int();
63}
64
65template <>
66inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIntArrayRef x) {
67 return C10_AS_INTARRAYREF_SLOW(x);
68}
69
70template <>
71inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) {
72 return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt;
73}
74
75template <>
76inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(at::OptionalSymIntArrayRef x) {
77 return x.has_value() ? c10::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : c10::nullopt;
78}
79
80template<class Return, class... Args>
81C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
82 // note: Args above is intentionally not Args&&. We don't want perfect
83 // forwarding, which would require Args to be deduced, but instead we
84 // want callers to explicitly specify the Args.
85
86 // This should get inlined by compiler
87 if (guts::disjunction<has_symint<Args>...>::value) {
88 if (sym_unboxed_kernel_func_ != nullptr) {
89 auto *functor = boxed_kernel_func_.getFunctor();
90 return callUnboxedKernelFunction<Return, Args...>(
91 sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
92 }
93
94 if (unboxed_kernel_func_ != nullptr) {
95 auto *functor = boxed_kernel_func_.getFunctor();
96 return callUnboxedKernelFunction<Return, typename remove_symint<Args>::type...>(
97 unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt<Args>(args)...);
98 }
99 } else {
100 if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
101 auto *functor = boxed_kernel_func_.getFunctor();
102 return callUnboxedKernelFunction<Return, Args...>(
103 unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
104 }
105 }
106
107 return impl::BoxedKernelWrapper<Return(Args...)>::call(
108 boxed_kernel_func_,
109 opHandle,
110 dispatchKeySet,
111 std::forward<Args>(args)...
112 );
113}
114
115inline KernelFunction KernelFunction::makeFromBoxedKernel(BoxedKernel boxed_fn) {
116 return KernelFunction(std::move(boxed_fn), nullptr); // no unboxed function pointer
117}
118
119template<KernelFunction::BoxedKernelFunction* func>
120inline KernelFunction KernelFunction::makeFromBoxedFunction() {
121 return KernelFunction::makeFromBoxedKernel(
122 BoxedKernel::makeFromFunction<func>());
123}
124
125template<KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
126inline KernelFunction KernelFunction::makeFromBoxedFunction() {
127 return KernelFunction::makeFromBoxedKernel(
128 BoxedKernel::makeFromFunction<func>());
129}
130
131inline KernelFunction KernelFunction::makeFallthrough() {
132 return KernelFunction::makeFromBoxedKernel(
133 BoxedKernel::makeFallthrough());
134}
135
136inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() {
137 return KernelFunction::makeFromBoxedKernel(
138 BoxedKernel::makeAmbiguousAutogradOther());
139}
140
141inline KernelFunction KernelFunction::makeNamedNotSupported() {
142 return KernelFunction::makeFromBoxedKernel(
143 BoxedKernel::makeNamedNotSupported());
144}
145
146template<bool AllowLegacyTypes, class KernelFunctor>
147inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor) {
148#ifndef NDEBUG
149 // This assertion is costly for build time so it's debug-gated.
150 static_assert(guts::is_functor<KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor.");
151#endif
152 static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
153
154 auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
155 void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
156 bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
157 return KernelFunction(
158 std::move(kernelFunctor),
159 &impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::call,
160 is_symint ? nullptr : void_unboxed_fn,
161 is_symint ? void_unboxed_fn : nullptr
162 );
163}
164
165template<class KernelFunctor>
166inline KernelFunction KernelFunction::makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) {
167 return KernelFunction::makeFromBoxedKernel(
168 BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
169}
170
171template<class FuncPtr, bool AllowLegacyTypes>
172inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) {
173 static_assert(is_compile_time_function_pointer<FuncPtr>::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
174 static_assert(!std::is_same<typename FuncPtr::FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
175 static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
176
177#if !defined(C10_MOBILE)
178 (void)func_ptr; // Suppress unused variable warning
179 return makeFromUnboxedFunctor<AllowLegacyTypes, typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>(
180 guts::make_unique_base<OperatorKernel, typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>()
181 );
182#else
183 // On mobile, we rather want to optimize for binary size than for performance,
184 // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction
185 // instead.
186 return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr());
187#endif
188}
189
190template<bool AllowLegacyTypes, class FuncType>
191inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) {
192 static_assert(guts::is_function_type<FuncType>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
193 static_assert(!std::is_same<FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
194 TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
195
196 return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(
197 guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func)
198 );
199}
200
201template<bool AllowLegacyTypes, class Lambda>
202inline std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
203 static_assert(guts::is_functor<std::decay_t<Lambda>>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
204
205#if !defined(C10_MOBILE)
206 return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
207 guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(std::forward<Lambda>(lambda))
208 );
209#else
210 // On mobile, we rather want to optimize for binary size than for performance,
211 // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction
212 // instead.
213 using FuncType = typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type;
214 return makeFromUnboxedRuntimeFunction<AllowLegacyTypes, FuncType>(lambda);
215#endif
216}
217
218template<bool AllowLegacyTypes, class Lambda>
219inline std::enable_if_t<!guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
220 static_assert(guts::is_functor<std::decay_t<Lambda>>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
221
222 return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
223 guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(std::forward<Lambda>(lambda))
224 );
225}
226
227}
228