1#pragma once
2
3namespace c10 {
4
5inline BoxedKernel::BoxedKernel()
6 : functor_()
7, boxed_kernel_func_(nullptr)
8{}
9
10inline BoxedKernel::BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func)
11: functor_(std::move(functor))
12, boxed_kernel_func_(boxed_kernel_func)
13{}
14
15template<BoxedKernel::BoxedKernelFunction* func>
16inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack) {
17 // Note that we're dropping the DispatchKeySet argument.
18 // See Note [Plumbing Keys Through The Dispatcher 2] for details.
19 func(opHandle, stack);
20}
21
22template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
23inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet ks, Stack* stack) {
24 // See Note [Plumbing Keys Through The Dispatcher 2] for details.
25 func(opHandle, ks, stack);
26}
27
28inline bool BoxedKernel::isValid() const {
29 return boxed_kernel_func_ != nullptr;
30}
31
32inline bool BoxedKernel::isFallthrough() const {
33 return boxed_kernel_func_ == &fallthrough_kernel;
34}
35
36inline void BoxedKernel::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const {
37 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
38 boxed_kernel_func_ != nullptr,
39 "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel."
40 );
41 (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
42}
43
44template<BoxedKernel::BoxedKernelFunction* func>
45inline BoxedKernel BoxedKernel::makeFromFunction() {
46 return BoxedKernel(
47 nullptr, // no functor_ object
48 &make_boxed_function<func>
49 );
50}
51
52template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
53inline BoxedKernel BoxedKernel::makeFromFunction() {
54 return BoxedKernel(
55 nullptr, // no functor_ object
56 &make_boxed_function<func>
57 );
58}
59
60inline BoxedKernel BoxedKernel::makeFallthrough() {
61 return BoxedKernel(
62 nullptr, // no functor_ object
63 &fallthrough_kernel
64 );
65}
66
67inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
68 return BoxedKernel(
69 nullptr, // no functor_ object
70 &ambiguous_autogradother_kernel
71 );
72}
73
74inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
75 return BoxedKernel(
76 nullptr, // no functor_ object
77 &named_not_supported_kernel
78 );
79}
80
81template<class KernelFunctor>
82inline BoxedKernel BoxedKernel::makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) {
83 static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
84 return BoxedKernel(
85 std::move(kernelFunctor),
86 [](OperatorKernel* kernel, const OperatorHandle& op, DispatchKeySet ks, Stack* stack) {
87 (*static_cast<KernelFunctor*>(kernel))(op, ks, stack);
88 }
89 );
90}
91
92inline OperatorKernel* BoxedKernel::getFunctor() const {
93 return functor_.get();
94}
95inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
96 return boxed_kernel_func_;
97}
98
99} // namespace c10
100