1 | #pragma once |
2 | |
3 | #include <ATen/core/boxing/OperatorKernel.h> |
4 | #include <c10/core/DispatchKeySet.h> |
5 | #include <c10/util/intrusive_ptr.h> |
6 | |
7 | namespace c10 { |
8 | |
9 | struct IValue; |
10 | using Stack = std::vector<IValue>; |
11 | |
12 | class OperatorHandle; |
13 | class KernelFunction; |
14 | |
15 | // This kernel implements the behavior of falling through to the next available |
16 | // registered dispatch key. The implementation of this function is FAST; it is |
17 | // no overhead to fallthrough to the next key. See cpp file for some more |
18 | // implementation notes; notably, this does NOT actually go through the |
19 | // boxing/unboxing codepath. |
20 | TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); |
21 | |
22 | // Note [Ambiguity in AutogradOther kernel] |
23 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
24 | // This error-reporting kernel is registered to the AutogradOther entry in the |
25 | // dispatch table when there is both a CompositeImplicitAutograd kernel and a |
26 | // backend kernel for ANY backend that maps to AutogradOther. To see why |
27 | // this is necessary in the AutogradOther case, it's helpful to first see |
28 | // why everything works out fine for a backend that has a reserved Autograd |
29 | // entry (see rule 2.2 in [Note] DispatchTable computation): |
30 | // |
31 | // CPU AutogradCPU |
32 | // reg? registers with... |
33 | // ------------------------------------------------- |
34 | // y Autograd registration takes precedence |
35 | // over CompositeImplicitAutograd. |
36 | // This is good, because the CPU specific backend |
37 | // implementation is more specialized and typically better; |
38 | // if we used the composite, we would bypass it. |
39 | // (NB: the Autograd key is guaranteed to exist because |
40 | // the autograd codegen requires it!) |
41 | // |
42 | // n CompositeImplicitAutograd takes precedence. |
43 | // This is also good, because the Autograd |
44 | // registration (if it exists) would try to redispatch |
45 | // to the (non-existent) CPU implementation; by |
46 | // using the composite, we ensure the operator |
47 | // actually works. |
48 | // |
49 | // As you can see, when we have a specific Autograd key (AutogradCPU), we can |
50 | // decide whether or not to use the CompositeImplicitAutograd kernel or the |
51 | // Autograd kernel based on whether or not the backend kernel exists. |
52 | // |
53 | // However, for AutogradOther (which is the catchall autograd kernel for |
54 | // everything that doesn't have a specific Autograd key), we can't do this |
55 | // trick because there isn't any unique backend to peek at to disambiguate; |
56 | // if there are some backends that have implementations they prefer Autograd, |
57 | // but unimplemented backends would prefer CompositeImplicitAutograd. Rather |
58 | // than arbitrarily pick one or the other, we just register a kernel that raises |
59 | // an error and let the user decide how to proceed. |
60 | TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); |
61 | |
62 | // Note [named_not_supported_kernel] |
63 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
64 | // This kernel implements reporting an error message saying that named tensor is |
65 | // not supported. This kernel doesn't rely on the Stack, and so it is special |
66 | // cased in the dispatcher to be triggered before we attempt boxing (so we can |
67 | // give a good error message in cases when boxing is not supported). When |
68 | // boxing is universally supported this can be removed. |
69 | [[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); |
70 | |
71 | /** |
72 | * BoxedKernel is similar to a std::function storing a boxed kernel. |
73 | */ |
74 | class TORCH_API BoxedKernel final { |
75 | public: |
76 | // This is how boxed kernels are actually stored |
77 | // |
78 | // Note [Plumbing Keys Through The Dispatcher] |
79 | // Benchmarks have shown that it is expensive for the dispatcher to read from thread-local storage (TLS) |
80 | // upon every dispatch call into order to compute which kernel to dispatch to. |
81 | // |
82 | // To mitigate this, we've updated the calling convention inside the dispatcher to expect every kernel that it stores |
83 | // to have a first argument of type DispatchKeySet. |
84 | // |
85 | // What are the invariants of the DispatchKeySet when it gets passed to a kernel? |
86 | // - All keys to the left of the current dispatch key have been masked out. |
87 | // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the highest bit to be DispatchKey::Tracer) |
88 | // - All other keys that dispatcher normally would have computed through TLS + global state + op arguments |
89 | // are still in the set. |
90 | // |
91 | // Kernels can then opt into using this keyset to save the dispatcher from doing repeated work during redispatches: |
92 | // recalculating the highest-priority dispatch key, which involves reading from TLS. Instead, the kernels that opt in will |
93 | // calculate an updated DispatchKeySet directly from the old one, and pass the updated set directly into the dispatcher |
94 | // upon redispatching. |
95 | // |
96 | // This is an opt-in mechanism: Kernels can automatically opt in by setting the first argument in their signature |
97 | // to be of type DispatchKeySet. See the kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for examples. |
98 | // |
99 | // The mechanism for optionally passing that DispatchKeySet into the kernel lives in make_boxed_from_unboxed_functor.h. |
100 | // See Note [Plumbing Keys Through The Dispatcher 2] for details. |
101 | using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); |
102 | // This is the public API for how boxed kernels are defined |
103 | using BoxedKernelFunction = void(const OperatorHandle&, Stack*); |
104 | using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*); |
105 | |
106 | BoxedKernel(); |
107 | |
108 | // Fast path for dispatch to allow not touching the boxed kernel in |
109 | // the common case where unboxed is available. |
110 | bool isValid() const; |
111 | bool isFallthrough() const; |
112 | |
113 | /** |
114 | * Call the function with boxed arguments. |
115 | */ |
116 | void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const; |
117 | |
118 | /** |
119 | * Create a KernelFunction from a boxed function. |
120 | * |
121 | * Example: |
122 | * |
123 | * > void boxed_func(OperatorKernel*, Stack* stack) {...} |
124 | * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>(); |
125 | */ |
126 | template<BoxedKernelFunction* func> |
127 | static BoxedKernel makeFromFunction(); |
128 | |
129 | /** |
130 | * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none) |
131 | * See Note [Plumbing Keys Through The Dispatcher] for details. |
132 | */ |
133 | template<BoxedKernelFunction_withDispatchKeys* func> |
134 | static BoxedKernel makeFromFunction(); |
135 | |
136 | /** |
137 | * Create a KernelFunction from a boxed functor. |
138 | * |
139 | * Example: |
140 | * |
141 | * > class MyFunctor final : public c10::OperatorKernel { |
142 | * > public: |
143 | * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} |
144 | * > }; |
145 | * > BoxedKernel func = BoxedKernel::makeFromFunctor(std::make_unique<MyFunctor>()); |
146 | */ |
147 | template<class KernelFunctor> |
148 | static BoxedKernel makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor); |
149 | |
150 | |
151 | static BoxedKernel makeFallthrough(); |
152 | static BoxedKernel makeAmbiguousAutogradOther(); |
153 | static BoxedKernel makeNamedNotSupported(); |
154 | |
155 | private: |
156 | |
157 | friend class KernelFunction; |
158 | |
159 | template<BoxedKernelFunction* func> |
160 | static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); |
161 | |
162 | template<BoxedKernelFunction_withDispatchKeys* func> |
163 | static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); |
164 | |
165 | explicit BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func); |
166 | |
167 | OperatorKernel* getFunctor() const; |
168 | InternalBoxedKernelFunction* getFnPtr() const; |
169 | |
170 | c10::intrusive_ptr<OperatorKernel> functor_; |
171 | InternalBoxedKernelFunction* boxed_kernel_func_; |
172 | }; |
173 | |
174 | } // namespace c10 |
175 | |
176 | #include <ATen/core/boxing/BoxedKernel_impl.h> |
177 | |