1 | #pragma once |
2 | |
3 | #include <ATen/core/ATen_fwd.h> |
4 | #include <ATen/core/boxing/BoxedKernel.h> |
5 | #include <ATen/core/stack.h> |
6 | #include <c10/core/DispatchKeySet.h> |
7 | #include <c10/util/intrusive_ptr.h> |
8 | #include <c10/util/TypeList.h> |
9 | |
10 | namespace c10 { |
11 | |
12 | using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. |
13 | |
14 | class OperatorHandle; |
15 | struct OperatorKernel; |
16 | class KernelFunction; |
17 | |
18 | template <typename T> |
19 | using has_symint = |
20 | guts::disjunction< |
21 | std::is_same<c10::SymInt, std::decay_t<T>>, |
22 | std::is_same<c10::SymIntArrayRef, std::decay_t<T>>, |
23 | std::is_same<at::OptionalSymIntArrayRef, std::decay_t<T>>, |
24 | std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>> |
25 | >; |
26 | |
27 | template <typename T> |
28 | struct remove_symint { |
29 | using type = T; |
30 | }; |
31 | |
32 | template <> |
33 | struct remove_symint<c10::SymInt> { |
34 | using type = int64_t; |
35 | }; |
36 | |
37 | template <> |
38 | struct remove_symint<at::OptionalSymIntArrayRef> { |
39 | using type = OptionalIntArrayRef; |
40 | }; |
41 | |
42 | template <> |
43 | struct remove_symint<c10::SymIntArrayRef> { |
44 | using type = c10::IntArrayRef; |
45 | }; |
46 | |
47 | template <> |
48 | struct remove_symint<c10::optional<c10::SymInt>> { |
49 | using type = c10::optional<int64_t>; |
50 | }; |
51 | |
52 | |
53 | template <bool symint, typename T> |
54 | struct maybe_keep_symint final {}; |
55 | |
56 | template <typename T> |
57 | struct maybe_keep_symint<true, T> { using type = T; }; |
58 | |
59 | template <typename T> |
60 | struct maybe_keep_symint<false, T> { using type = typename remove_symint<T>::type; }; |
61 | |
62 | template <typename T> |
63 | using fn_has_symint = typename guts::typelist::true_for_any_type< |
64 | has_symint, |
65 | typename guts::infer_function_traits<T>::type::parameter_types |
66 | >; |
67 | |
68 | /** |
69 | * KernelFunction is similar to std::function but stores a kernel function. |
70 | * You can create a KernelFunction from a boxed or unboxed function/functor/lambda |
71 | * and call it in a boxed or unboxed way. If the way it was created doesn't |
72 | * match the way it was called, it will do boxing or unboxing as necessary. |
73 | */ |
74 | class TORCH_API KernelFunction final { |
75 | public: |
76 | using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction; |
77 | using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction; |
78 | using BoxedKernelFunction_withDispatchKeys = BoxedKernel::BoxedKernelFunction_withDispatchKeys; |
79 | |
80 | KernelFunction(); |
81 | |
82 | // Fast path for dispatch to allow not touching the boxed kernel in |
83 | // the common case where unboxed is available. |
84 | bool isValidUnboxed() const; |
85 | bool isValidSymUnboxed() const; |
86 | bool isValid() const; |
87 | bool isFallthrough() const; |
88 | |
89 | /** |
90 | * Call the function in a boxed way. |
91 | * If the kernel function was created with an unboxed function, |
92 | * this will call an unboxing wrapper which then calls into that |
93 | * unboxed function. |
94 | * |
95 | * Example: |
96 | * |
97 | * > void boxed_func(OperatorKernel*, Stack* stack) {...} |
98 | * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); |
99 | * > Tensor result = func.callBoxed(stack); |
100 | * |
101 | * Or, with an unboxed implementation: |
102 | * |
103 | * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( |
104 | * > [] (Tensor a, bool b) -> Tensor {...}); |
105 | * > Tensor result = func.callBoxed(stack); |
106 | */ |
107 | void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const; |
108 | |
109 | /** |
110 | * Call the function in an unboxed way. |
111 | * If the kernel function was created with a boxed function, |
112 | * this will box all inputs and then call into that boxed function. |
113 | * |
114 | * Note that this doesn't work for all types yet. |
115 | * |
116 | * Example: |
117 | * |
118 | * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( |
119 | * > [] (Tensor a, bool b) -> Tensor {...}); |
120 | * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true); |
121 | * |
122 | * Or, with a boxed implementation: |
123 | * |
124 | * > void boxed_func(OperatorKernel*, Stack* stack) {...} |
125 | * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); |
126 | * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true); |
127 | */ |
128 | template<class Return, class... Args> |
129 | Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const; |
130 | |
131 | /** |
132 | * Create a KernelFunction from a BoxedKernel. |
133 | */ |
134 | static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn); |
135 | |
136 | /** |
137 | * Create a KernelFunction from a boxed function. |
138 | * |
139 | * Example: |
140 | * |
141 | * > void boxed_func(OperatorKernel*, Stack* stack) {...} |
142 | * > KernelFunction func = KernelFunction::makeFromBoxedFunction<&boxed_func>(); |
143 | */ |
144 | template<BoxedKernelFunction* func> |
145 | static KernelFunction makeFromBoxedFunction(); |
146 | |
147 | /** |
148 | * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none) |
149 | * See Note [Plumbing Keys Through The Dispatcher] for details. |
150 | */ |
151 | template<BoxedKernelFunction_withDispatchKeys* func> |
152 | static KernelFunction makeFromBoxedFunction(); |
153 | |
154 | /** |
155 | * Create a KernelFunction from an unboxed functor. |
156 | * |
157 | * Example: |
158 | * |
159 | * > class MyFunctor final : public c10::OperatorKernel { |
160 | * > public: |
161 | * > Tensor operator()(Tensor a, Tensor b) {...} |
162 | * > }; |
163 | * > KernelFunction func = KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>()); |
164 | */ |
165 | template<bool AllowLegacyTypes = false, class KernelFunctor> |
166 | static KernelFunction makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor); |
167 | |
168 | /** |
169 | * Create a KernelFunction from a boxed functor. |
170 | * |
171 | * Example: |
172 | * |
173 | * > class MyFunctor final : public c10::OperatorKernel { |
174 | * > public: |
175 | * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} |
176 | * > }; |
177 | * > KernelFunction func = KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>()); |
178 | */ |
179 | template<class KernelFunctor> |
180 | static KernelFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor); |
181 | |
182 | /** |
183 | * Create a KernelFunction from an unboxed function. |
184 | * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction |
185 | * because knowing the function pointer as a template argument (i.e. at |
186 | * compile time) allows the compiler to inline the function into its |
187 | * unboxing wrapper and yields better performance when calling the function. |
188 | * |
189 | * Example: |
190 | * |
191 | * > Tensor unboxed_func(Tensor a, Tensor b) {...} |
192 | * > KernelFunction func = KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func), &unboxed_func>(); |
193 | */ |
194 | template<class FuncPtr, bool AllowLegacyTypes = false> |
195 | static KernelFunction makeFromUnboxedFunction(FuncPtr); |
196 | |
197 | /** |
198 | * Create a KernelFunction from an unboxed function. |
199 | * KernelFunction::makeFromUnboxedFunction is usually a better choice than |
200 | * this if you know the function pointer at compile time, see doc comment |
201 | * there for an explanation. |
202 | * |
203 | * Example: |
204 | * |
205 | * > Tensor unboxed_func(Tensor a, Tensor b) {...} |
206 | * > KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func); |
207 | */ |
208 | template<bool AllowLegacyTypes = false, class FuncType> |
209 | static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func); |
210 | |
211 | static KernelFunction makeFallthrough(); |
212 | static KernelFunction makeAmbiguousAutogradOther(); |
213 | static KernelFunction makeNamedNotSupported(); |
214 | |
215 | /** |
216 | * Create a KernelFunction from an unboxed lambda. |
217 | * |
218 | * Example: |
219 | * |
220 | * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( |
221 | * > [] (Tensor a, bool b) -> Tensor {...}); |
222 | */ |
223 | template<bool AllowLegacyTypes = false, class Lambda> |
224 | static std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); |
225 | template<bool AllowLegacyTypes = false, class Lambda> |
226 | static std::enable_if_t<!guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); |
227 | |
228 | std::string dumpState() const; |
229 | // For testing internal invariants only |
230 | bool _equalsBoxedAndUnboxed(const KernelFunction&) const; |
231 | |
232 | private: |
233 | |
234 | explicit KernelFunction( |
235 | std::unique_ptr<OperatorKernel> functor, |
236 | InternalBoxedKernelFunction* boxed_kernel_func, |
237 | void* unboxed_kernel_func, |
238 | void* sym_unboxed_kernel_func); |
239 | explicit KernelFunction( |
240 | BoxedKernel boxed_fn, |
241 | void* unboxed_kernel_func, |
242 | void* sym_unboxed_kernel_func); |
243 | |
244 | BoxedKernel boxed_kernel_func_; |
245 | void* unboxed_kernel_func_; |
246 | void* sym_unboxed_kernel_func_; |
247 | }; |
248 | |
249 | } |
250 | |
251 | #include <ATen/core/boxing/KernelFunction_impl.h> |
252 | |