1#pragma once
2
3#include <ATen/core/ATen_fwd.h>
4#include <ATen/core/boxing/BoxedKernel.h>
5#include <ATen/core/stack.h>
6#include <c10/core/DispatchKeySet.h>
7#include <c10/util/intrusive_ptr.h>
8#include <c10/util/TypeList.h>
9
10namespace c10 {
11
12using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
13
14class OperatorHandle;
15struct OperatorKernel;
16class KernelFunction;
17
18template <typename T>
19using has_symint =
20 guts::disjunction<
21 std::is_same<c10::SymInt, std::decay_t<T>>,
22 std::is_same<c10::SymIntArrayRef, std::decay_t<T>>,
23 std::is_same<at::OptionalSymIntArrayRef, std::decay_t<T>>,
24 std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>>
25 >;
26
27template <typename T>
28struct remove_symint {
29 using type = T;
30};
31
32template <>
33struct remove_symint<c10::SymInt> {
34 using type = int64_t;
35};
36
37template <>
38struct remove_symint<at::OptionalSymIntArrayRef> {
39 using type = OptionalIntArrayRef;
40};
41
42template <>
43struct remove_symint<c10::SymIntArrayRef> {
44 using type = c10::IntArrayRef;
45};
46
47template <>
48struct remove_symint<c10::optional<c10::SymInt>> {
49 using type = c10::optional<int64_t>;
50};
51
52
53template <bool symint, typename T>
54struct maybe_keep_symint final {};
55
56template <typename T>
57struct maybe_keep_symint<true, T> { using type = T; };
58
59template <typename T>
60struct maybe_keep_symint<false, T> { using type = typename remove_symint<T>::type; };
61
62template <typename T>
63using fn_has_symint = typename guts::typelist::true_for_any_type<
64 has_symint,
65 typename guts::infer_function_traits<T>::type::parameter_types
66>;
67
68/**
69 * KernelFunction is similar to std::function but stores a kernel function.
70 * You can create a KernelFunction from a boxed or unboxed function/functor/lambda
71 * and call it in a boxed or unboxed way. If the way it was created doesn't
72 * match the way it was called, it will do boxing or unboxing as necessary.
73 */
74class TORCH_API KernelFunction final {
75public:
76 using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction;
77 using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction;
78 using BoxedKernelFunction_withDispatchKeys = BoxedKernel::BoxedKernelFunction_withDispatchKeys;
79
80 KernelFunction();
81
82 // Fast path for dispatch to allow not touching the boxed kernel in
83 // the common case where unboxed is available.
84 bool isValidUnboxed() const;
85 bool isValidSymUnboxed() const;
86 bool isValid() const;
87 bool isFallthrough() const;
88
89 /**
90 * Call the function in a boxed way.
91 * If the kernel function was created with an unboxed function,
92 * this will call an unboxing wrapper which then calls into that
93 * unboxed function.
94 *
95 * Example:
96 *
97 * > void boxed_func(OperatorKernel*, Stack* stack) {...}
98 * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
99 * > Tensor result = func.callBoxed(stack);
100 *
101 * Or, with an unboxed implementation:
102 *
103 * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
104 * > [] (Tensor a, bool b) -> Tensor {...});
105 * > Tensor result = func.callBoxed(stack);
106 */
107 void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
108
109 /**
110 * Call the function in an unboxed way.
111 * If the kernel function was created with a boxed function,
112 * this will box all inputs and then call into that boxed function.
113 *
114 * Note that this doesn't work for all types yet.
115 *
116 * Example:
117 *
118 * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
119 * > [] (Tensor a, bool b) -> Tensor {...});
120 * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
121 *
122 * Or, with a boxed implementation:
123 *
124 * > void boxed_func(OperatorKernel*, Stack* stack) {...}
125 * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
126 * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
127 */
128 template<class Return, class... Args>
129 Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const;
130
131 /**
132 * Create a KernelFunction from a BoxedKernel.
133 */
134 static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn);
135
136 /**
137 * Create a KernelFunction from a boxed function.
138 *
139 * Example:
140 *
141 * > void boxed_func(OperatorKernel*, Stack* stack) {...}
142 * > KernelFunction func = KernelFunction::makeFromBoxedFunction<&boxed_func>();
143 */
144 template<BoxedKernelFunction* func>
145 static KernelFunction makeFromBoxedFunction();
146
147 /**
148 * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
149 * See Note [Plumbing Keys Through The Dispatcher] for details.
150 */
151 template<BoxedKernelFunction_withDispatchKeys* func>
152 static KernelFunction makeFromBoxedFunction();
153
154 /**
155 * Create a KernelFunction from an unboxed functor.
156 *
157 * Example:
158 *
159 * > class MyFunctor final : public c10::OperatorKernel {
160 * > public:
161 * > Tensor operator()(Tensor a, Tensor b) {...}
162 * > };
163 * > KernelFunction func = KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>());
164 */
165 template<bool AllowLegacyTypes = false, class KernelFunctor>
166 static KernelFunction makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor);
167
168 /**
169 * Create a KernelFunction from a boxed functor.
170 *
171 * Example:
172 *
173 * > class MyFunctor final : public c10::OperatorKernel {
174 * > public:
175 * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
176 * > };
177 * > KernelFunction func = KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>());
178 */
179 template<class KernelFunctor>
180 static KernelFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor);
181
182 /**
183 * Create a KernelFunction from an unboxed function.
184 * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
185 * because knowing the function pointer as a template argument (i.e. at
186 * compile time) allows the compiler to inline the function into its
187 * unboxing wrapper and yields better performance when calling the function.
188 *
189 * Example:
190 *
191 * > Tensor unboxed_func(Tensor a, Tensor b) {...}
192 * > KernelFunction func = KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func), &unboxed_func>();
193 */
194 template<class FuncPtr, bool AllowLegacyTypes = false>
195 static KernelFunction makeFromUnboxedFunction(FuncPtr);
196
197 /**
198 * Create a KernelFunction from an unboxed function.
199 * KernelFunction::makeFromUnboxedFunction is usually a better choice than
200 * this if you know the function pointer at compile time, see doc comment
201 * there for an explanation.
202 *
203 * Example:
204 *
205 * > Tensor unboxed_func(Tensor a, Tensor b) {...}
206 * > KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func);
207 */
208 template<bool AllowLegacyTypes = false, class FuncType>
209 static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
210
211 static KernelFunction makeFallthrough();
212 static KernelFunction makeAmbiguousAutogradOther();
213 static KernelFunction makeNamedNotSupported();
214
215 /**
216 * Create a KernelFunction from an unboxed lambda.
217 *
218 * Example:
219 *
220 * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
221 * > [] (Tensor a, bool b) -> Tensor {...});
222 */
223 template<bool AllowLegacyTypes = false, class Lambda>
224 static std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda);
225 template<bool AllowLegacyTypes = false, class Lambda>
226 static std::enable_if_t<!guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda);
227
228 std::string dumpState() const;
229 // For testing internal invariants only
230 bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
231
232private:
233
234 explicit KernelFunction(
235 std::unique_ptr<OperatorKernel> functor,
236 InternalBoxedKernelFunction* boxed_kernel_func,
237 void* unboxed_kernel_func,
238 void* sym_unboxed_kernel_func);
239 explicit KernelFunction(
240 BoxedKernel boxed_fn,
241 void* unboxed_kernel_func,
242 void* sym_unboxed_kernel_func);
243
244 BoxedKernel boxed_kernel_func_;
245 void* unboxed_kernel_func_;
246 void* sym_unboxed_kernel_func_;
247};
248
249}
250
251#include <ATen/core/boxing/KernelFunction_impl.h>
252