1#pragma once
2
3#include <ATen/core/boxing/OperatorKernel.h>
4#include <ATen/core/ivalue.h>
5#include <ATen/core/stack.h>
6#include <c10/util/TypeList.h>
7#include <ATen/core/IListRef.h>
8#include <c10/util/intrusive_ptr.h>
9#include <c10/util/Metaprogramming.h>
10
11#include <utility>
12
13namespace c10 {
14
15using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
16class OperatorHandle;
17
18/*
19 * [Note: Argument forwarding in the dispatcher]
20 *
21 * The dispatcher uses a somewhat unusual way to forward arguments through several layers of
22 * wrapper functions. This can be confusing because an experienced C++ programmer would look at this
23 * and think "oh this is supposed to be forwarding a universal reference but the && is missing. This is a bug.".
24 * It is not a bug. The common way in C++ to forward arguments is to use universal references:
25 *
26 * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
27 *
28 * but that relies on inferring the correct reference type (i.e. value vs & vs &&) from the argument.
29 * In our case, we cannot rely on the argument as supplied by the caller, because that could infer a
30 * different reference type than was used in the kernel function. The correct reference type
31 * is dictated by the kernel signature and must be identical since we cast function pointers
32 * through void* pointers and mismatches would be UB. So we need a forwarding pattern that determines
33 * the reference type to use by looking at the explicitly supplied operator signature, not by looking at
34 * the argument we're calling it with.
35 *
36 * What does std::forward do, exactly?
37 * ------------------------------------
38 * std::forward<T>(t) is a way to cast t to the reference type supplied in T.
39 * Let's assume decay_t<T> == U and T is either U or some reference of U.
40 * - std::forward<T&>(t) will return U&, no matter what kind of reference t is.
41 * - std::forward<T&&>(t) will return U&&, no matter what kind of reference t is.
42 * - std::forward<T>(t) will return U&& (not U!), no matter what kind of reference t is.
43 *
44 * For universal references, that means that in the following function
45 * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
46 *
47 * - when called with arg being a rvalue reference or non-reference value, T gets inferred to be
48 * a non-reference U, and std::forward<T>(t) will return U&&, correctly moving the argument.
49 * - when called with arg behind a lvalue reference, T gets inferred to be U& because that's the only
50 * way to match the signature (in C++, a type that is (T&)&& will collapse to T&).
51 * That means std::forward<T>(t) will return U& and the value will not be moved but passed on as
52 * a lvalue reference.
53 *
54 * How do we use that?
55 * ------------------------------------
56 * But std::forward can also be used outside of the common "universal forwarding" pattern to change
57 * reference types. So instead of following the common C++ pattern, we notice what
58 * std::forward<T>() actually does, and that is it takes a value and changes its reference to the
59 * type of reference passed in as T. If we don't infer T but explicitly specify it, we can use this
60 * to forward based on an explicitly specified reference type instead of the inferred argument type.
61 *
62 * This is why many of the dispatcher functions look like
63 * > template<class T> func(T t) { func2<T>(std::forward<T>(t)); }
64 * instead of the common
65 * > template<class T> func(T&& t) { func2(std::forward<T>(t)); }
66 *
67 * and are expected to be called by explicitly specifying the template parameters in a way that matches
68 * the expected operator signature at each call site.
69 */
70
71namespace impl {
72 // supported_primitive_arg_types defines which primitive types we allow in
73 // kernel functions as arguments or returns.
74 // Additionally, we support lists, dicts and optionals containing these types.
75 using supported_primitive_arg_types = guts::typelist::typelist<
76 int64_t,
77 double,
78 bool,
79 c10::string_view,
80 at::Tensor,
81 at::Scalar,
82 c10::QScheme,
83 c10::ScalarType,
84 c10::Device,
85 c10::Layout,
86 c10::MemoryFormat,
87 at::Dimname
88 >;
89
90 // We have an unboxed functor in hand that takes C++ arguments, and
91 // we're building a boxed functor wrapper for it that takes IValues.
92 // So "outside" is boxed and "inside" is unboxed.
93 //
94 // So a valid input type is one that our boxed functor wrapper can
95 // unbox from an IValue into a C++ value.
96 //
97 // Whereas a valid output type is one that our wrapper can recieve
98 // as a C++ value from the unboxed functor, and box into an IValue.
99
100 //
101 // assert_is_valid_input_type
102 // checks that T can be unboxed from an IValue into a C++ value.
103 //
104
105 template<class T, bool AllowDeprecatedTypes, class Enable = void>
106 struct assert_is_valid_input_type {
107 assert_is_valid_input_type() {
108 guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
109 /* everything is ok, this is a primitive type */
110 }, /* else */ [] {
111 /* otherwise this must be an instance of a valid custom class, since it can only
112 have been created via IValue(x), which ensures this. */
113 });
114 }
115 };
116
117 template<class T, bool AllowDeprecatedTypes>
118 struct assert_is_valid_input_type<c10::optional<T>, AllowDeprecatedTypes>
119 : assert_is_valid_input_type<T, AllowDeprecatedTypes> {};
120
121 template <bool AllowDeprecatedTypes, class... Args>
122 struct TypeCheckHelper;
123
124 template <bool AllowDeprecatedTypes>
125 struct TypeCheckHelper<AllowDeprecatedTypes> {};
126
127 template <bool AllowDeprecatedTypes, class Head, class... Rest>
128 struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...>
129 : TypeCheckHelper<AllowDeprecatedTypes, Rest...> {
130 assert_is_valid_input_type<Head, AllowDeprecatedTypes> check;
131 };
132
133 template<class... Contained, bool AllowDeprecatedTypes>
134 struct assert_is_valid_input_type<std::tuple<Contained...>, AllowDeprecatedTypes>
135 : TypeCheckHelper<AllowDeprecatedTypes, Contained...> {};
136
137 template<class Key, class Value, bool AllowDeprecatedTypes>
138 struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes>
139 : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
140 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
141 "You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
142 };
143
144 template<class Key, class Value, bool AllowDeprecatedTypes>
145 struct assert_is_valid_input_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes>
146 : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
147 static_assert(AllowDeprecatedTypes,
148 "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
149 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
150 "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
151 };
152
153 template<class T, bool AllowDeprecatedTypes>
154 struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes>
155 : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
156 static_assert(!std::is_same<T, at::Scalar>::value,
157 "You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
158 };
159
160 template<class T, bool AllowDeprecatedTypes>
161 struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
162 : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
163 static_assert(!std::is_same<T, at::Scalar>::value,
164 "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
165 };
166
167 template<class T, bool AllowDeprecatedTypes>
168 struct assert_is_valid_input_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes>
169 : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
170 static_assert(!std::is_same<T, at::Scalar>::value,
171 "You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
172 };
173
174 template<class T, size_t N, bool AllowDeprecatedTypes>
175 struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
176 : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
177 static_assert(!std::is_same<T, at::Scalar>::value,
178 "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
179 };
180
181 // The following specialisations of assert_is_valid_input_type are technically not
182 // necessary since we would hit the base case and show an error message
183 // there if they didn't exist, but we can show a better error message
184 // in some common error scenarios.
185 template<class T, bool AllowDeprecatedTypes>
186 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> {
187 // There is no reason to support float when we have double. Keep the API lean.
188 static_assert(guts::false_t<T>::value,
189 "You tried to register a kernel with an unsupported input type: float. Please use double instead.");
190 };
191 template<class T, bool AllowDeprecatedTypes>
192 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> {
193 static_assert(guts::false_t<T>::value,
194 "You tried to register a kernel with an unsupported input type: const char*. Please use c10::string_view instead.");
195 };
196 template<class T, bool AllowDeprecatedTypes>
197 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> {
198 static_assert(guts::false_t<T>::value,
199 "You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead.");
200 };
201 template<class T, bool AllowDeprecatedTypes>
202 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
203 static_assert(guts::false_t<T>::value,
204 "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead.");
205 };
206
207 //
208 // assert_is_valid_output_type
209 //
210
211 template<class T, bool AllowDeprecatedTypes, class Enable = void>
212 struct assert_is_valid_output_type {
213 assert_is_valid_output_type() {
214 guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
215 /* everything is ok, this is a primitive type */
216 }, /* else */ [] {
217 /* otherwise T is verified to be a registered custom class in the IValue
218 constructor, so no benefit in double-checking here */
219 });
220 }
221 };
222
223 template<class T, bool AllowDeprecatedTypes>
224 struct assert_is_valid_output_type<c10::optional<T>, AllowDeprecatedTypes>
225 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
226
227 template<class T, bool AllowDeprecatedTypes>
228 struct assert_is_valid_output_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes>
229 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
230
231 template<class Key, class Value, bool AllowDeprecatedTypes>
232 struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes>
233 : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
234 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
235 "You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
236 static_assert(!std::is_same<Value, at::Scalar>::value,
237 "You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
238 };
239
240 template<class Key, class Value, bool AllowDeprecatedTypes>
241 struct assert_is_valid_output_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes>
242 : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
243 static_assert(AllowDeprecatedTypes,
244 "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
245 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
246 "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
247 static_assert(!std::is_same<Value, at::Scalar>::value,
248 "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
249 };
250
251 template<class T, bool AllowDeprecatedTypes>
252 struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes>
253 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
254 static_assert(!std::is_same<T, at::Scalar>::value,
255 "You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
256 };
257
258 template<class T, bool AllowDeprecatedTypes>
259 struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes>
260 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
261 static_assert(!std::is_same<T, at::Scalar>::value,
262 "You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
263 // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector<T>. Please use List<T> instead.");
264 };
265
266 template<class T, size_t N, bool AllowDeprecatedTypes>
267 struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
268 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
269 static_assert(!std::is_same<T, at::Scalar>::value,
270 "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
271 };
272
273 // The following specialisations of assert_is_valid_output_type are technically not
274 // necessary since we would hit the base case and show an error message
275 // there if they didn't exist, but we can show a better error message
276 // in some common error scenarios.
277 template<class T, bool AllowDeprecatedTypes>
278 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> {
279 // There is no reason to support float when we have double. Keep the API lean.
280 static_assert(guts::false_t<T>::value,
281 "You tried to register a kernel with an unsupported output type: float. Please use double instead.");
282 };
283 template<class T, bool AllowDeprecatedTypes>
284 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> {
285 static_assert(guts::false_t<T>::value,
286 "You tried to register a kernel with an unsupported output type: const char*. Please use c10::string_view instead.");
287 };
288 template<class T, bool AllowDeprecatedTypes>
289 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> {
290 static_assert(guts::false_t<T>::value,
291 "You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead.");
292 };
293 template<class T, bool AllowDeprecatedTypes>
294 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
295 static_assert(guts::false_t<T>::value,
296 "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead.");
297 };
298
299 // ivalue_to_arg
300
301 template<class T>
302 struct decay_if_not_tensor final {
303 using type = std::decay_t<T>;
304 };
305
306 template<>
307 struct decay_if_not_tensor<at::Tensor&> final {
308 using type = at::Tensor&;
309 };
310
311 template<>
312 struct decay_if_not_tensor<const at::Tensor&> final {
313 using type = const at::Tensor&;
314 };
315
316 template<class T, bool AllowDeprecatedTypes>
317 struct ivalue_to_arg final {
318 static decltype(auto) call(IValue& v) {
319 assert_is_valid_input_type<T, AllowDeprecatedTypes>();
320 return std::move(v).to<T>();
321 }
322 };
323
324 // The following two specializations take advantage of specialized
325 // `toTensor()` overloads on IValue to avoid copying.
326 template<bool AllowDeprecatedTypes>
327 struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
328 // We cannot use the default implementation if they asked for a
329 // `at::Tensor&` because it moves from the IValue, so it can't get
330 // an lvalue reference.
331 static at::Tensor& call(IValue& v) {
332 // Tensor& is valid, don't bother asserting
333 return v.toTensor();
334 }
335 };
336
337 template<bool AllowDeprecatedTypes>
338 struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
339 // We should not use the default implementation if they asked for
340 // a `const at::Tensor&` because it moves from the IValue and they
341 // didn't ask for that.
342 static const at::Tensor& call(IValue& v) {
343 // const Tensor& is valid, don't bother asserting
344 return v.toTensor();
345 }
346 };
347
348 template<bool AllowDeprecatedTypes>
349 struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final {
350 static List<at::Tensor> call(IValue& v) {
351 return v.toTensorList();
352 }
353 };
354
355 template<class T, bool AllowDeprecatedTypes>
356 struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
357 // If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that
358 // to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>.
359 static std::vector<T> call(IValue& v) {
360 return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
361 }
362 };
363 template<bool AllowDeprecatedTypes>
364 struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
365 static std::vector<c10::SymInt> call(IValue& v) {
366 if (v.isIntList()) {
367 std::vector<c10::SymInt> r;
368 auto src = v.toIntList();
369 std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
370 return r;
371 } else {
372 return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
373 }
374 }
375 };
376 template<bool AllowDeprecatedTypes>
377 struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes> final {
378 static OptionalArray<c10::SymInt> call(IValue& v) {
379 if (v.isIntList()) {
380 std::vector<c10::SymInt> r;
381 auto src = v.toIntList();
382 std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
383 return OptionalArray<c10::SymInt>(std::move(r));
384 } else {
385 return std::move(v).to<OptionalArray<c10::SymInt>>();
386 }
387 }
388 };
389 template<class T, bool AllowDeprecatedTypes>
390 struct ivalue_to_arg<optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
391 // If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
392 // to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but implicitly convertible
393 // to optional<ArrayRef<T>>.
394 static OptionalArray<T> call(IValue& v) {
395 return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
396 }
397 };
398
399 template<class T, bool AllowDeprecatedTypes>
400 struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final {
401 // If an argument is OptionalArrayRef<T>, convert the IValue to an
402 // optional<std::vector<T>> and pass that to the operator. OptionalArray<T>
403 // is basically a optional<std::vector<T>> but implicitly convertible to
404 // OptionalArrayRef<T>
405 static OptionalArray<T> call(IValue& v) {
406 return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
407 }
408 };
409
410 // return_to_ivalue
411 template<class T, bool AllowDeprecatedTypes, class Enable = void>
412 struct return_to_ivalue final {};
413
414 template<class T, bool AllowDeprecatedTypes>
415 struct return_to_ivalue<T, AllowDeprecatedTypes, std::enable_if_t<!std::is_same<at::Tensor&, T>::value>> final {
416 static IValue call(T&& v) {
417 assert_is_valid_output_type<T, AllowDeprecatedTypes>();
418 return c10::ivalue::from(std::move(v));
419 }
420 static IValue copy(const T& v) {
421 assert_is_valid_output_type<T, AllowDeprecatedTypes>();
422 return IValue(v);
423 }
424 };
425
426 // Special case to allow kernels to return `Tensor&`.
427 // TODO Delete this once kernels don't do that anymore
428 template<bool AllowDeprecatedTypes>
429 struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final {
430 static IValue call(at::Tensor& v) {
431 return c10::ivalue::from(v);
432 }
433 static IValue copy(at::Tensor& v) {
434 return IValue(v);
435 }
436 };
437
438 // wrap_kernel_functor_unboxed_
439
440 template<class KernelFunctor, class OpSignature>
441 struct wrap_kernel_functor_unboxed_ final {};
442
443 // This specialization is for kernels with a first argument that is NOT of type DispatchKeySet
444 // This includes kernels with 0 arguments.
445 template<class KernelFunctor, class ReturnType, class... ParameterTypes>
446 struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(ParameterTypes...)> final {
447 static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value,
448 "Return type mismatch");
449 static_assert(std::is_same<guts::typelist::typelist<ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value,
450 "Parameter types mismatch");
451
452 // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use &&
453 static ReturnType call(OperatorKernel* functor, DispatchKeySet, ParameterTypes... args) {
454 KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
455 // Note [Plumbing Keys Through The Dispatcher 2]
456 // See Note [Plumbing Keys Through The Dispatcher] for the background.
457 // This functor explicitly takes in a dispatchKeySet and drops it on the floor- it does not forward it to the registered kernel.
458 //
459 // This is due to the calling convention within the dispatcher, which expects all registered kernels to have a first argument of type
460 // DispatchKeySet.
461 // This is not the case for pretty much all manually written kernels, however- this functor serves to separate the calling convention
462 // of the dispatcher from the calling convention of manually written kernels.
463 return (*functor_)(std::forward<ParameterTypes>(args)...);
464 }
465 };
466
467 // This specialization is for kernels with a first argument of type DispatchKeySet
468 template<class KernelFunctor, class ReturnType, class... ParameterTypes>
469 struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(DispatchKeySet, ParameterTypes...)> final {
470 static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value,
471 "Return type mismatch");
472 static_assert(std::is_same<guts::typelist::typelist<DispatchKeySet, ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value,
473 "Parameter types mismatch");
474
475 // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use &&
476 static ReturnType call(OperatorKernel* functor, DispatchKeySet dispatchKeySet, ParameterTypes... args) {
477 KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
478 // We're explicitly taking in a dispatchKeySet and forwarding it to the registered kernel.
479 // See Note [Plumbing Keys Through The Dispatcher 2] for details.
480 return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...);
481 }
482 };
483
484 template<class KernelFunctor>
485 using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<KernelFunctor, typename guts::infer_function_traits_t<KernelFunctor>::func_type>;
486
487 // call_functor_with_args_from_stack
488
489 template<class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices, typename... ArgTypes>
490 std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
491 call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) {
492 (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
493
494 // We're explicitly filtering out DispatchKeySet from the argument list.
495 // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
496 // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
497 // See Note [Plumbing Keys Through The Dispatcher] for the background.
498 return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet,
499 ivalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type, AllowDeprecatedTypes>::call(
500 torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices))
501 )...);
502 }
503
504 template<class Functor, bool AllowDeprecatedTypes>
505 std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
506 call_functor_with_args_from_stack(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack) {
507 // We're explicitly filtering out DispatchKeySet from the argument list.
508 // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
509 // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
510 // See Note [Plumbing Keys Through The Dispatcher] for the background.
511 using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<Functor>::parameter_types;
512 constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value;
513 return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack, std::make_index_sequence<num_ivalue_args>(), static_cast<ArgTypes*>(nullptr));
514 }
515
516 // push_outputs
517
518 template<class OutputType, bool AllowDeprecatedTypes>
519 struct push_outputs final {
520 // Contrary to [Note: Argument forwarding in the dispatcher], we use OutputType&& here
521 // to avoid one extra call to the move constructor in this case. This is still not a
522 // universal reference though because OutputType is an explicitly specified class
523 // template parameter.
524 static void call(OutputType&& output, Stack* stack) {
525 torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(std::forward<OutputType>(output)));
526 }
527 static void copy(const OutputType& output, Stack* stack) {
528 torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
529 }
530 };
531 template<class... OutputTypes, bool AllowDeprecatedTypes>
532 struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
533 static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
534 call_(std::move(output), stack, std::make_index_sequence<sizeof...(OutputTypes)>());
535 }
536 static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
537 copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
538 }
539
540 private:
541 template<size_t... indices>
542 static void call_(std::tuple<OutputTypes...>&& output, Stack* stack, std::index_sequence<indices...>) {
543 torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(std::forward<OutputTypes>(std::get<indices>(output)))...);
544 }
545 template<size_t... indices>
546 static void copy_(const std::tuple<OutputTypes...>& output, Stack* stack, std::index_sequence<indices...>) {
547 torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(std::get<indices>(output))...);
548 }
549 };
550 template<bool AllowDeprecatedTypes>
551 struct push_outputs<void, AllowDeprecatedTypes> final {
552 static void call(int /*dummy*/, Stack* /*stack*/) {
553 }
554 static void copy(int /*dummy*/, Stack* /*stack*/) {
555 }
556 };
557
558 // make_boxed_from_unboxed_functor
559
560 template<class KernelFunctor, bool AllowDeprecatedTypes>
561 struct make_boxed_from_unboxed_functor final {
562 static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value,
563 "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
564
565 static void call(OperatorKernel* functor, const OperatorHandle&, DispatchKeySet dispatchKeySet, Stack* stack) {
566 using ReturnType = typename guts::infer_function_traits_t<KernelFunctor>::return_type;
567 // We're explicitly filtering out DispatchKeySet from the argument list.
568 // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
569 // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
570 // See Note [Plumbing Keys Through The Dispatcher] for the background.
571 using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::parameter_types;
572 constexpr bool has_outputs = !std::is_same<void, ReturnType>::value;
573 constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value;
574#ifdef __cpp_if_constexpr
575 if constexpr (has_outputs) {
576#else
577 guts::if_constexpr<has_outputs>([&] (auto delay_check) {
578#endif
579 // Decay ReturnType to ReturnType_ so that if a reference gets returned, we actually store it by value
580 // and don't get a dangling reference. This is only required because some kernels still return `Tensor&`.
581#ifdef __cpp_if_constexpr
582 // [Note: VC++ and 'std': ambiguous symbol]
583 using ReturnType_ = ::std::decay_t<ReturnType>;
584 ReturnType_ output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
585#else
586 using ReturnType_ = std::decay_t<typename decltype(delay_check)::template type_identity<ReturnType>>;
587 ReturnType_ output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, delay_check(stack));
588#endif
589 torch::jit::drop(*stack, num_inputs);
590 // See note [ VC++ and 'std': ambiguous symbol]
591 push_outputs<ReturnType_, AllowDeprecatedTypes>::call(::std::move(output), stack);
592#ifdef __cpp_if_constexpr
593 } else {
594#else
595 }, /* else */ [&] {
596#endif
597 call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
598 torch::jit::drop(*stack, num_inputs);
599#ifdef __cpp_if_constexpr
600 }
601#else
602 });
603#endif
604 }
605 };
606} // namespace impl
607
608} // namespace c10
609
610namespace torch {
611 using OperatorKernel = c10::OperatorKernel;
612}
613