1 | #pragma once |
2 | |
3 | namespace c10 { |
4 | |
5 | inline BoxedKernel::BoxedKernel() |
6 | : functor_() |
7 | , boxed_kernel_func_(nullptr) |
8 | {} |
9 | |
10 | inline BoxedKernel::BoxedKernel(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func) |
11 | : functor_(std::move(functor)) |
12 | , boxed_kernel_func_(boxed_kernel_func) |
13 | {} |
14 | |
15 | template<BoxedKernel::BoxedKernelFunction* func> |
16 | inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack) { |
17 | // Note that we're dropping the DispatchKeySet argument. |
18 | // See Note [Plumbing Keys Through The Dispatcher 2] for details. |
19 | func(opHandle, stack); |
20 | } |
21 | |
22 | template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func> |
23 | inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet ks, Stack* stack) { |
24 | // See Note [Plumbing Keys Through The Dispatcher 2] for details. |
25 | func(opHandle, ks, stack); |
26 | } |
27 | |
28 | inline bool BoxedKernel::isValid() const { |
29 | return boxed_kernel_func_ != nullptr; |
30 | } |
31 | |
32 | inline bool BoxedKernel::isFallthrough() const { |
33 | return boxed_kernel_func_ == &fallthrough_kernel; |
34 | } |
35 | |
36 | inline void BoxedKernel::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const { |
37 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
38 | boxed_kernel_func_ != nullptr, |
39 | "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel." |
40 | ); |
41 | (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack); |
42 | } |
43 | |
44 | template<BoxedKernel::BoxedKernelFunction* func> |
45 | inline BoxedKernel BoxedKernel::makeFromFunction() { |
46 | return BoxedKernel( |
47 | nullptr, // no functor_ object |
48 | &make_boxed_function<func> |
49 | ); |
50 | } |
51 | |
52 | template<BoxedKernel::BoxedKernelFunction_withDispatchKeys* func> |
53 | inline BoxedKernel BoxedKernel::makeFromFunction() { |
54 | return BoxedKernel( |
55 | nullptr, // no functor_ object |
56 | &make_boxed_function<func> |
57 | ); |
58 | } |
59 | |
60 | inline BoxedKernel BoxedKernel::makeFallthrough() { |
61 | return BoxedKernel( |
62 | nullptr, // no functor_ object |
63 | &fallthrough_kernel |
64 | ); |
65 | } |
66 | |
67 | inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() { |
68 | return BoxedKernel( |
69 | nullptr, // no functor_ object |
70 | &ambiguous_autogradother_kernel |
71 | ); |
72 | } |
73 | |
74 | inline BoxedKernel BoxedKernel::makeNamedNotSupported() { |
75 | return BoxedKernel( |
76 | nullptr, // no functor_ object |
77 | &named_not_supported_kernel |
78 | ); |
79 | } |
80 | |
81 | template<class KernelFunctor> |
82 | inline BoxedKernel BoxedKernel::makeFromFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) { |
83 | static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call BoxedKernel::makeFromFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it." ); |
84 | return BoxedKernel( |
85 | std::move(kernelFunctor), |
86 | [](OperatorKernel* kernel, const OperatorHandle& op, DispatchKeySet ks, Stack* stack) { |
87 | (*static_cast<KernelFunctor*>(kernel))(op, ks, stack); |
88 | } |
89 | ); |
90 | } |
91 | |
92 | inline OperatorKernel* BoxedKernel::getFunctor() const { |
93 | return functor_.get(); |
94 | } |
95 | inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const { |
96 | return boxed_kernel_func_; |
97 | } |
98 | |
99 | } // namespace c10 |
100 | |