1 | #pragma once |
2 | |
3 | /** |
4 | * Include this file if you want to register operators. It includes all |
5 | * functionality needed to do so for you. |
6 | */ |
7 | |
8 | #include <c10/core/DispatchKey.h> |
9 | #include <c10/core/DispatchKeySet.h> |
10 | #include <c10/core/CompileTimeFunctionPointer.h> |
11 | #include <ATen/core/boxing/KernelFunction.h> |
12 | #include <ATen/core/dispatch/CppSignature.h> |
13 | #include <ATen/core/dispatch/RegistrationHandleRAII.h> |
14 | #include <ATen/core/op_registration/infer_schema.h> |
15 | #if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD) |
16 | #include <torch/csrc/jit/frontend/function_schema_parser.h> |
17 | #endif |
18 | #include <ATen/core/ATenOpList.h> |
19 | |
20 | namespace c10 { |
21 | |
22 | namespace detail { |
23 | // The first argument of the schema might be of type DispatchKeySet, in which case we remove it. |
24 | // We do this because every argument in a function schema is expected to be convertable |
25 | // to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of. |
26 | // See Note [Plumbing Keys Through The Dispatcher] |
27 | template<class KernelFunctor> |
28 | std::unique_ptr<FunctionSchema> inferFunctionSchemaFromFunctor() { |
29 | using func_type = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::func_type; |
30 | return std::make_unique<FunctionSchema>(inferFunctionSchemaFlattenedReturns<func_type>()); |
31 | } |
32 | } |
33 | |
34 | /** |
35 | * An instance of this class handles the registration for one or more operators. |
36 | * Make sure you keep the RegisterOperators instance around since it will |
37 | * deregister the operator it's responsible for in its destructor. |
38 | * |
39 | * Example: |
40 | * |
41 | * > namespace { |
42 | * > class my_kernel_cpu final : public c10::OperatorKernel { |
43 | * > public: |
44 | * > Tensor operator()(Tensor a, Tensor b) {...} |
45 | * > }; |
46 | * > } |
47 | * > |
48 | * > static auto registry = c10::RegisterOperators() |
49 | * > .op(c10::RegisterOperators::options() |
50 | * > .schema("my_op") |
51 | * > .kernel<my_kernel_cpu>(DispatchKey::CPU)); |
52 | */ |
53 | class TORCH_API RegisterOperators final { |
54 | public: |
55 | RegisterOperators(); |
56 | ~RegisterOperators(); |
57 | |
58 | RegisterOperators(const RegisterOperators&) = delete; |
59 | RegisterOperators& operator=(const RegisterOperators&) = delete; |
60 | RegisterOperators(RegisterOperators&&) noexcept; |
61 | RegisterOperators& operator=(RegisterOperators&&) noexcept; |
62 | |
63 | class TORCH_API Options final { |
64 | public: |
65 | Options(const Options&) = delete; |
66 | Options(Options&&) noexcept = delete; |
67 | Options& operator=(const Options&) = delete; |
68 | Options& operator=(Options&&) noexcept = delete; |
69 | |
70 | // internal-only for registering stack based kernels |
71 | template<KernelFunction::BoxedKernelFunction* kernel_func> |
72 | Options&& kernel(DispatchKey dispatch_key) && { |
73 | return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr); |
74 | } |
75 | |
76 | // internal-only for registering stack based catch-all kernels |
77 | template<KernelFunction::BoxedKernelFunction* kernel_func> |
78 | Options&& catchAllKernel() && { |
79 | return std::move(*this).kernel(c10::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr); |
80 | } |
81 | |
82 | // internal only for registering caffe2 ops |
83 | Options&& schema(FunctionSchema&& schema) { |
84 | TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration." ); |
85 | schemaOrName_ = c10::make_right<OperatorName, FunctionSchema>(std::move(schema)); |
86 | return std::move(*this); |
87 | } |
88 | |
89 | /** |
90 | * Use this to specify the schema for an operator. You can also specify |
91 | * the operator name only to have the function signature part of the |
92 | * schema be inferred from the kernel function. |
93 | * |
94 | * Example: |
95 | * |
96 | * > // Infer function signature from my_kernel_cpu |
97 | * > static auto registry = c10::RegisterOperators() |
98 | * > .op(c10::RegisterOperators::options() |
99 | * > .schema("my_op") |
100 | * > .kernel<my_kernel_cpu>(DispatchKey::CPU)); |
101 | * > |
102 | * > |
103 | * > // Explicitly specify full schema |
104 | * > static auto registry = c10::RegisterOperators() |
105 | * > .op(c10::RegisterOperators::options() |
106 | * > .schema("my_op(Tensor a) -> Tensor") |
107 | * > .kernel<my_kernel_cpu>(DispatchKey::CPU)); |
108 | */ |
109 | Options&& schema(const std::string& schemaOrName) { |
110 | TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator " , schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration." ); |
111 | |
112 | #if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD) |
113 | throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build." ); |
114 | #else |
115 | schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName); |
116 | #endif |
117 | |
118 | return std::move(*this); |
119 | } |
120 | |
121 | /** |
122 | * Use this to register an operator whose kernel is implemented as a functor. |
123 | * The kernel is only called for inputs matching the given dispatch key. |
124 | * You can register multiple kernels for different dispatch keys. |
125 | * |
126 | * Example: |
127 | * |
128 | * > namespace { |
129 | * > class my_kernel_cpu final : public c10::OperatorKernel { |
130 | * > public: |
131 | * > Tensor operator()(Tensor a, Tensor b) {...} |
132 | * > }; |
133 | * > } |
134 | * > |
135 | * > static auto registry = c10::RegisterOperators() |
136 | * > .op(c10::RegisterOperators::options() |
137 | * > .schema("my_op") |
138 | * > .kernel<my_kernel_cpu>(DispatchKey::CPU)); |
139 | * |
140 | * The functor constructor can take arguments to configure the kernel. |
141 | * The arguments are defined in the kernel registration. |
142 | * Example: |
143 | * |
144 | * > namespace { |
145 | * > class my_kernel_cpu final : public c10::OperatorKernel { |
146 | * > public: |
147 | * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b) |
148 | * > : ... {...} |
149 | * > |
150 | * > Tensor operator()(Tensor a, Tensor b) {...} |
151 | * > }; |
152 | * > } |
153 | * > |
154 | * > static auto registry = c10::RegisterOperators() |
155 | * > .op(c10::RegisterOperators::options() |
156 | * > .schema("my_op") |
157 | * > .kernel<my_kernel_cpu>(DispatchKey::CPU, "some_configuration", 3, true)); |
158 | */ |
159 | template<class KernelFunctor, class... ConstructorParameters> |
160 | // enable_if: only enable it if KernelFunctor is actually a functor |
161 | std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && { |
162 | static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it." ); |
163 | static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor." ); |
164 | |
165 | return std::move(*this).kernel( |
166 | std::move(dispatch_key), |
167 | KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)), |
168 | impl::CppSignature::make<KernelFunctor>(), |
169 | detail::inferFunctionSchemaFromFunctor<KernelFunctor>() |
170 | ); |
171 | } |
172 | |
173 | /** |
174 | * Use this to register an operator whose kernel is implemented as a functor. |
175 | * The kernel is a catch-all kernel, meaning it's called independent from |
176 | * the input. Dispatch is disabled for this operator. |
177 | * |
178 | * Example: |
179 | * |
180 | * > namespace { |
181 | * > class my_kernel_cpu final : public c10::OperatorKernel { |
182 | * > public: |
183 | * > Tensor operator()(Tensor a, Tensor b) {...} |
184 | * > }; |
185 | * > } |
186 | * > |
187 | * > static auto registry = c10::RegisterOperators() |
188 | * > .op(c10::RegisterOperators::options() |
189 | * > .schema("my_op") |
190 | * > .catchAllKernel<my_kernel_cpu>()); |
191 | * |
192 | * The functor constructor can take arguments to configure the kernel. |
193 | * The arguments are defined in the kernel registration. |
194 | * Example: |
195 | * |
196 | * > namespace { |
197 | * > class my_kernel_cpu final : public c10::OperatorKernel { |
198 | * > public: |
199 | * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b) |
200 | * > : ... {...} |
201 | * > |
202 | * > Tensor operator()(Tensor a, Tensor b) {...} |
203 | * > }; |
204 | * > } |
205 | * > |
206 | * > static auto registry = c10::RegisterOperators() |
207 | * > .op(c10::RegisterOperators::options() |
208 | * > .schema("my_op") |
209 | * > .catchAllKernel<my_kernel_cpu>("some_configuration", 3, true)); |
210 | */ |
211 | template<class KernelFunctor, class... ConstructorParameters> |
212 | // enable_if: only enable it if KernelFunctor is actually a functor |
213 | std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && { |
214 | static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it." ); |
215 | static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor." ); |
216 | |
217 | return std::move(*this).kernel( |
218 | c10::nullopt, |
219 | KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)), |
220 | impl::CppSignature::make<KernelFunctor>(), |
221 | detail::inferFunctionSchemaFromFunctor<KernelFunctor>() |
222 | ); |
223 | } |
224 | |
225 | /** |
226 | * Use this to register an operator whose kernel is implemented by a function. |
227 | * The kernel is only called for inputs matching the given dispatch key. |
228 | * You can register multiple kernels for different dispatch keys. |
229 | * |
230 | * Example: |
231 | * |
232 | * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } |
233 | * > |
234 | * > static auto registry = c10::RegisterOperators() |
235 | * > .op(c10::RegisterOperators::options() |
236 | * > .schema("my_op") |
237 | * > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(DispatchKey::CPU)); |
238 | */ |
239 | template<class FuncType, FuncType* kernel_func> |
240 | // enable_if: only enable it if FuncType is actually a function |
241 | std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key) && { |
242 | static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API." ); |
243 | static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr" ); |
244 | |
245 | return std::move(*this).kernel( |
246 | std::move(dispatch_key), |
247 | KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)), |
248 | impl::CppSignature::make<FuncType>(), |
249 | // TODO Do schema inference without relying on WrapFunctionIntoFunctor |
250 | detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>() |
251 | ); |
252 | } |
253 | |
254 | /** |
255 | * Use this to register an operator whose kernel is implemented by a function. |
256 | * The kernel is a catch-all kernel, meaning it's called independent from |
257 | * the input. Dispatch is disabled for this operator. |
258 | * |
259 | * Example: |
260 | * |
261 | * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } |
262 | * > |
263 | * > static auto registry = c10::RegisterOperators() |
264 | * > .op(c10::RegisterOperators::options() |
265 | * > .schema("my_op") |
266 | * > .catchAllKernel<decltype(my_kernel_cpu), &my_kernel_cpu>()); |
267 | */ |
268 | template<class FuncType, FuncType* kernel_func> |
269 | // enable_if: only enable it if FuncType is actually a function |
270 | std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel() && { |
271 | static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API." ); |
272 | static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr" ); |
273 | |
274 | return std::move(*this).kernel( |
275 | c10::nullopt, |
276 | KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)), |
277 | impl::CppSignature::make<FuncType>(), |
278 | // TODO Do schema inference without relying on WrapFunctionIntoFunctor |
279 | detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>() |
280 | ); |
281 | } |
282 | |
283 | template<class FuncType> |
284 | // enable_if: only enable it if FuncType is actually a function |
285 | std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && { |
286 | static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API." ); |
287 | TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr" ); |
288 | |
289 | return std::move(*this).kernel( |
290 | std::move(dispatch_key), |
291 | KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func), |
292 | impl::CppSignature::make<FuncType>(), |
293 | // TODO Do schema inference without relying on WrapFunctionIntoFunctor |
294 | detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>() |
295 | ); |
296 | } |
297 | |
298 | template<class FuncType> |
299 | // enable_if: only enable it if FuncType is actually a function |
300 | std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel(FuncType* kernel_func) && { |
301 | static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API." ); |
302 | TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr" ); |
303 | |
304 | return std::move(*this).kernel( |
305 | c10::nullopt, |
306 | KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func), |
307 | impl::CppSignature::make<FuncType>(), |
308 | // TODO Do schema inference without relying on WrapFunctionIntoFunctor |
309 | detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>() |
310 | ); |
311 | } |
312 | |
313 | /** |
314 | * Use this to register an operator whose kernel is implemented as a lambda. |
315 | * The kernel is only called for inputs matching the given dispatch key. |
316 | * You can register multiple kernels for different dispatch keys. |
317 | * |
318 | * The lambda must be stateless, i.e. not have a capture. If your kernel |
319 | * needs to store some configuration parameters, write the kernel as a |
320 | * functor instead. |
321 | * |
322 | * Example: |
323 | * |
324 | * > static auto registry = c10::RegisterOperators() |
325 | * > .op(c10::RegisterOperators::options() |
326 | * > .schema("my_op") |
327 | * > .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...})); |
328 | */ |
329 | template<class Lambda> |
330 | // enable_if: only enable it if Lambda is a functor (note: lambdas are functors) |
331 | std::enable_if_t< |
332 | guts::is_functor<std::decay_t<Lambda>>::value |
333 | && !std::is_same<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value, |
334 | Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && { |
335 | static_assert(!std::is_base_of<OperatorKernel, std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead." ); |
336 | |
337 | // We don't support stateful lambdas (i.e. lambdas with a capture), because their |
338 | // behavior would be nonobvious. A functor kernel with cache gets a new instance of |
339 | // its cache each time the kernel is looked up from the dispatch table. |
340 | // A lambda with a capture would be global and share its capture between all kernel lookups. |
341 | // So, instead of making users having to think about it (including the thread-safety |
342 | // issues this causes), let's just forbid stateful lambdas altogether. |
343 | static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead." ); |
344 | |
345 | return std::move(*this).kernel( |
346 | std::move(dispatch_key), |
347 | KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(functor)), |
348 | impl::CppSignature::make<Lambda>(), |
349 | // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor |
350 | detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>() |
351 | ); |
352 | } |
353 | |
354 | /** |
355 | * Use this to register an operator whose kernel is implemented as a lambda. |
356 | * The kernel is a catch-all kernel, meaning it's called independent from |
357 | * the input. Dispatch is disabled for this operator. |
358 | * |
359 | * The lambda must be stateless, i.e. not have a capture. If your kernel |
360 | * needs to store some configuration parameters, write the kernel as a |
361 | * functor instead. |
362 | * |
363 | * Example: |
364 | * |
365 | * > static auto registry = c10::RegisterOperators() |
366 | * > .op(c10::RegisterOperators::options() |
367 | * > .schema("my_op") |
368 | * > .catchAllKernel([] (Tensor a) -> Tensor {...})); |
369 | */ |
370 | template<class Lambda> |
371 | // enable_if: only enable it if Lambda is a functor (note: lambdas are functors) |
372 | std::enable_if_t< |
373 | guts::is_functor<std::decay_t<Lambda>>::value |
374 | && !std::is_same<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value, |
375 | Options&&> catchAllKernel(Lambda&& lambda) && { |
376 | static_assert(!std::is_base_of<OperatorKernel, std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead." ); |
377 | |
378 | // We don't support stateful lambdas (i.e. lambdas with a capture), because their |
379 | // behavior would be nonobvious. |
380 | // A lambda with a capture would be global and share its capture between all kernel lookups. |
381 | // This would be a likely source for unexpected race conditions, so we forbid it. |
382 | // If a kernel really needs global state, they can just have regular global state |
383 | // in their .cpp file next to the kernel lambda. |
384 | static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead." ); |
385 | |
386 | return std::move(*this).kernel( |
387 | c10::nullopt, |
388 | KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(lambda)), |
389 | impl::CppSignature::make<Lambda>(), |
390 | // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor |
391 | detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>() |
392 | ); |
393 | } |
394 | |
395 | Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && { |
396 | TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration." ); |
397 | aliasAnalysisKind_ = aliasAnalysisKind; |
398 | return std::move(*this); |
399 | } |
400 | |
401 | private: |
402 | Options&& kernel(c10::optional<DispatchKey> dispatch_key, KernelFunction&& func, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && { |
403 | KernelRegistrationConfig config; |
404 | config.dispatch_key = dispatch_key; |
405 | config.func = std::move(func); |
406 | config.cpp_signature = std::move(cpp_signature); |
407 | config.inferred_function_schema = std::move(inferred_function_schema); |
408 | kernels.push_back(std::move(config)); |
409 | return std::move(*this); |
410 | } |
411 | |
412 | Options() |
413 | : schemaOrName_(c10::nullopt) |
414 | , kernels() |
415 | , aliasAnalysisKind_(c10::nullopt) |
416 | {} |
417 | |
418 | // KernelRegistrationConfig accumulates all information from the config |
419 | // parameters passed to a RegisterOperators::op() call into one object. |
420 | struct KernelRegistrationConfig final { |
421 | KernelRegistrationConfig() |
422 | : dispatch_key(c10::nullopt) |
423 | , func() |
424 | , cpp_signature(c10::nullopt) |
425 | , inferred_function_schema(nullptr) |
426 | {} |
427 | |
428 | c10::optional<DispatchKey> dispatch_key; |
429 | KernelFunction func; |
430 | c10::optional<impl::CppSignature> cpp_signature; |
431 | std::unique_ptr<FunctionSchema> inferred_function_schema; |
432 | }; |
433 | |
434 | c10::optional<c10::either<OperatorName, FunctionSchema>> schemaOrName_; |
435 | |
436 | std::vector<KernelRegistrationConfig> kernels; |
437 | optional<AliasAnalysisKind> aliasAnalysisKind_; |
438 | friend class RegisterOperators; |
439 | friend class Library; |
440 | }; |
441 | |
442 | /** |
443 | * Call this to get an instance of registration options, which |
444 | * can be passed to a call to RegisterOperators::op() to specify |
445 | * these options for the operator registration. |
446 | * See class doc comment for examples. |
447 | */ |
448 | static Options options() { |
449 | return {}; |
450 | } |
451 | |
452 | /** |
453 | * Call this to register an operator. See class doc comment for examples. |
454 | */ |
455 | RegisterOperators&& op(Options&& options) && { |
456 | checkSchemaAndRegisterOp_(std::move(options)); |
457 | return std::move(*this); |
458 | } |
459 | |
460 | // Regular mutator version of the && version above |
461 | RegisterOperators& op(Options&& options) & { |
462 | checkSchemaAndRegisterOp_(std::move(options)); |
463 | return *this; |
464 | } |
465 | |
466 | /** |
467 | * This is a shorthand for RegisterOperators::op(Options) where you can |
468 | * specify the operator schema outside of the options parameter. |
469 | * See class doc comment for examples. |
470 | */ |
471 | RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && { |
472 | return std::move(*this).op(std::move(options).schema(schemaOrName)); |
473 | } |
474 | |
475 | // internal only for registering caffe2 ops |
476 | RegisterOperators&& op(FunctionSchema schema, Options&& options) && { |
477 | return std::move(*this).op(std::move(options).schema(std::move(schema))); |
478 | } |
479 | |
480 | template<class FuncType> |
481 | explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options()) |
482 | : RegisterOperators() { |
483 | std::move(*this).op(schemaOrName, std::forward<FuncType>(func), std::move(options)); |
484 | } |
485 | |
486 | /** |
487 | * This API registers an operator based on a kernel function pointer. |
488 | * |
489 | * Given a kernel |
490 | * |
491 | * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } |
492 | * |
493 | * This API looks like: |
494 | * |
495 | * > static auto registry = c10::RegisterOperators() |
496 | * > .op("my_op", &my_kernel_cpu); |
497 | * |
498 | * If your kernel is small and the overhead of calling it matters, |
499 | * then this API might be the wrong choice since the following API |
500 | * has a slightly lower overhead for calling into the kernel: |
501 | * |
502 | * > static auto registry = c10::RegisterOperators() |
503 | * > .op("my_op", c10::RegisterOperators::options() |
504 | * > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>()); |
505 | * |
506 | * Or, alternatively, write your kernel as a functor: |
507 | * |
508 | * > namespace { |
509 | * > class my_kernel_cpu final : public c10::OperatorKernel { |
510 | * > public: |
511 | * > Tensor operator()(Tensor a, Tensor b) {...} |
512 | * > }; |
513 | * > } |
514 | * > |
515 | * > static auto registry = c10::RegisterOperators() |
516 | * > .op("my_op", c10::RegisterOperators::options() |
517 | * > .kernel<my_kernel_cpu>()); |
518 | */ |
519 | template<class FuncType> |
520 | // enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction. |
521 | std::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, RegisterOperators&&> |
522 | op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && { |
523 | constexpr bool AllowLegacyTypes = true; |
524 | return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( |
525 | c10::nullopt, |
526 | KernelFunction::makeFromUnboxedRuntimeFunction<AllowLegacyTypes>(func), |
527 | impl::CppSignature::make<FuncType>(), |
528 | // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor |
529 | detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>() |
530 | )); |
531 | } |
532 | |
533 | /** |
534 | * This API registers an operator based on a kernel lambda. |
535 | * |
536 | * This API looks like: |
537 | * |
538 | * > static auto registry = c10::RegisterOperators() |
539 | * > .op("my_op", [] (Tensor a, Tensor b) {...}); |
540 | * |
541 | * This is equivalent to: |
542 | * |
543 | * > static auto registry = c10::RegisterOperators() |
544 | * > .op("my_op", c10::RegisterOperators::options() |
545 | * > .catchAllKernel([] (Tensor a, Tensor b) {...})); |
546 | * |
547 | */ |
548 | template<class Lambda> |
549 | // enable_if: only enable it if Lambda is actually a stateless lambda |
550 | std::enable_if_t<guts::is_functor<Lambda>::value && guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&> |
551 | op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && { |
552 | static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead." ); |
553 | |
554 | constexpr bool AllowLegacyTypes = true; |
555 | return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( |
556 | c10::nullopt, |
557 | KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)), |
558 | impl::CppSignature::make<Lambda>(), |
559 | // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor |
560 | detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>() |
561 | )); |
562 | } |
563 | |
564 | template<class Lambda> |
565 | C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead." ) |
566 | // enable_if: only enable it if Lambda is actually a functor but not a stateless lambda |
567 | std::enable_if_t<guts::is_functor<Lambda>::value && !guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&> |
568 | op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && { |
569 | static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead." ); |
570 | |
571 | constexpr bool AllowLegacyTypes = true; |
572 | return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( |
573 | c10::nullopt, |
574 | KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)), |
575 | impl::CppSignature::make<Lambda>(), |
576 | // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor |
577 | detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>() |
578 | )); |
579 | } |
580 | |
581 | private: |
582 | void checkSchemaAndRegisterOp_(Options&& config); |
583 | |
584 | static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options); |
585 | void checkNoDuplicateKernels_(const Options& options); |
586 | void registerOp_(Options&& options); |
587 | |
588 | std::vector<RegistrationHandleRAII> registrars_; |
589 | }; |
590 | |
591 | } // namespace c10 |
592 | |
593 | namespace torch { |
594 | // Old-style API |
595 | using RegisterOperators = c10::RegisterOperators; |
596 | } |
597 | |