1#pragma once
2
3#include <ATen/core/boxing/OperatorKernel.h>
4#include <c10/core/DispatchKeySet.h>
5#include <c10/util/intrusive_ptr.h>
6
7namespace c10 {
8
9struct IValue;
10using Stack = std::vector<IValue>;
11
12class OperatorHandle;
13class KernelFunction;
14
15// This kernel implements the behavior of falling through to the next available
16// registered dispatch key. The implementation of this function is FAST; it is
17// no overhead to fallthrough to the next key. See cpp file for some more
18// implementation notes; notably, this does NOT actually go through the
19// boxing/unboxing codepath.
20TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
21
22// Note [Ambiguity in AutogradOther kernel]
23// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24// This error-reporting kernel is registered to the AutogradOther entry in the
25// dispatch table when there is both a CompositeImplicitAutograd kernel and a
26// backend kernel for ANY backend that maps to AutogradOther. To see why
27// this is necessary in the AutogradOther case, it's helpful to first see
28// why everything works out fine for a backend that has a reserved Autograd
29// entry (see rule 2.2 in [Note] DispatchTable computation):
30//
31// CPU AutogradCPU
32// reg? registers with...
33// -------------------------------------------------
34// y Autograd registration takes precedence
35// over CompositeImplicitAutograd.
36// This is good, because the CPU specific backend
37// implementation is more specialized and typically better;
38// if we used the composite, we would bypass it.
39// (NB: the Autograd key is guaranteed to exist because
40// the autograd codegen requires it!)
41//
42// n CompositeImplicitAutograd takes precedence.
43// This is also good, because the Autograd
44// registration (if it exists) would try to redispatch
45// to the (non-existent) CPU implementation; by
46// using the composite, we ensure the operator
47// actually works.
48//
49// As you can see, when we have a specific Autograd key (AutogradCPU), we can
50// decide whether or not to use the CompositeImplicitAutograd kernel or the
51// Autograd kernel based on whether or not the backend kernel exists.
52//
53// However, for AutogradOther (which is the catchall autograd kernel for
54// everything that doesn't have a specific Autograd key), we can't do this
55// trick because there isn't any unique backend to peek at to disambiguate;
56// if there are some backends that have implementations they prefer Autograd,
57// but unimplemented backends would prefer CompositeImplicitAutograd. Rather
58// than arbitrarily pick one or the other, we just register a kernel that raises
59// an error and let the user decide how to proceed.
60TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
61
62// Note [named_not_supported_kernel]
63// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
64// This kernel implements reporting an error message saying that named tensor is
65// not supported. This kernel doesn't rely on the Stack, and so it is special
66// cased in the dispatcher to be triggered before we attempt boxing (so we can
67// give a good error message in cases when boxing is not supported). When
68// boxing is universally supported this can be removed.
69[[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
70
71/**
72 * BoxedKernel is similar to a std::function storing a boxed kernel.
73 */
74class TORCH_API BoxedKernel final {
75public:
76 // This is how boxed kernels are actually stored
77 //
78 // Note [Plumbing Keys Through The Dispatcher]
79 // Benchmarks have shown that it is expensive for the dispatcher to read from thread-local storage (TLS)
80 // upon every dispatch call into order to compute which kernel to dispatch to.
81 //
82 // To mitigate this, we've updated the calling convention inside the dispatcher to expect every kernel that it stores
83 // to have a first argument of type DispatchKeySet.
84 //
85 // What are the invariants of the DispatchKeySet when it gets passed to a kernel?
86 // - All keys to the left of the current dispatch key have been masked out.
87 // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the highest bit to be DispatchKey::Tracer)
88 // - All other keys that dispatcher normally would have computed through TLS + global state + op arguments
89 // are still in the set.
90 //
91 // Kernels can then opt into using this keyset to save the dispatcher from doing repeated work during redispatches:
92 // recalculating the highest-priority dispatch key, which involves reading from TLS. Instead, the kernels that opt in will
93 // calculate an updated DispatchKeySet directly from the old one, and pass the updated set directly into the dispatcher
94 // upon redispatching.
95 //
96 // This is an opt-in mechanism: Kernels can automatically opt in by setting the first argument in their signature
97 // to be of type DispatchKeySet. See the kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for examples.
98 //
99 // The mechanism for optionally passing that DispatchKeySet into the kernel lives in make_boxed_from_unboxed_functor.h.
100 // See Note [Plumbing Keys Through The Dispatcher 2] for details.
101 using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
102 // This is the public API for how boxed kernels are defined
103 using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
104 using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*);
105
106 BoxedKernel();
107
108 // Fast path for dispatch to allow not touching the boxed kernel in
109 // the common case where unboxed is available.
110 bool isValid() const;
111 bool isFallthrough() const;
112
113 /**
114 * Call the function with boxed arguments.
115 */
116 void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
117
118 /**
119 * Create a KernelFunction from a boxed function.
120 *
121 * Example:
122 *
123 * > void boxed_func(OperatorKernel*, Stack* stack) {...}
124 * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
125 */
126 template<BoxedKernelFunction* func>
127 static BoxedKernel makeFromFunction();
128
129 /**
130 * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
131 * See Note [Plumbing Keys Through The Dispatcher] for details.
132 */
133 template<BoxedKernelFunction_withDispatchKeys* func>
134 static BoxedKernel makeFromFunction();
135
136 /**
137 * Create a KernelFunction from a boxed functor.
138 *
139 * Example:
140 *
141 * > class MyFunctor final : public c10::OperatorKernel {
142 * > public:
143 * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
144 * > };
145 * > BoxedKernel func = BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>());
146 */
147 template<class KernelFunctor>
148 static BoxedKernel makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor);
149
150
151 static BoxedKernel makeFallthrough();
152 static BoxedKernel makeAmbiguousAutogradOther();
153 static BoxedKernel makeNamedNotSupported();
154
155private:
156
157 friend class KernelFunction;
158
159 template<BoxedKernelFunction* func>
160 static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);
161
162 template<BoxedKernelFunction_withDispatchKeys* func>
163 static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);
164
165 explicit BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func);
166
167 OperatorKernel* getFunctor() const;
168 InternalBoxedKernelFunction* getFnPtr() const;
169
170 c10::intrusive_ptr<OperatorKernel> functor_;
171 InternalBoxedKernelFunction* boxed_kernel_func_;
172};
173
174} // namespace c10
175
176#include <ATen/core/boxing/BoxedKernel_impl.h>
177