1#pragma once
2
3/**
4 * Include this file if you want to register operators. It includes all
5 * functionality needed to do so for you.
6 */
7
8#include <c10/core/DispatchKey.h>
9#include <c10/core/DispatchKeySet.h>
10#include <c10/core/CompileTimeFunctionPointer.h>
11#include <ATen/core/boxing/KernelFunction.h>
12#include <ATen/core/dispatch/CppSignature.h>
13#include <ATen/core/dispatch/RegistrationHandleRAII.h>
14#include <ATen/core/op_registration/infer_schema.h>
15#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
16#include <torch/csrc/jit/frontend/function_schema_parser.h>
17#endif
18#include <ATen/core/ATenOpList.h>
19
20namespace c10 {
21
22namespace detail {
23// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
24// We do this because every argument in a function schema is expected to be convertable
25// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
26// See Note [Plumbing Keys Through The Dispatcher]
27template<class KernelFunctor>
28std::unique_ptr<FunctionSchema> inferFunctionSchemaFromFunctor() {
29 using func_type = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::func_type;
30 return std::make_unique<FunctionSchema>(inferFunctionSchemaFlattenedReturns<func_type>());
31}
32}
33
34/**
35 * An instance of this class handles the registration for one or more operators.
36 * Make sure you keep the RegisterOperators instance around since it will
37 * deregister the operator it's responsible for in its destructor.
38 *
39 * Example:
40 *
41 * > namespace {
42 * > class my_kernel_cpu final : public c10::OperatorKernel {
43 * > public:
44 * > Tensor operator()(Tensor a, Tensor b) {...}
45 * > };
46 * > }
47 * >
48 * > static auto registry = c10::RegisterOperators()
49 * > .op(c10::RegisterOperators::options()
50 * > .schema("my_op")
51 * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
52 */
53class TORCH_API RegisterOperators final {
54public:
55 RegisterOperators();
56 ~RegisterOperators();
57
58 RegisterOperators(const RegisterOperators&) = delete;
59 RegisterOperators& operator=(const RegisterOperators&) = delete;
60 RegisterOperators(RegisterOperators&&) noexcept;
61 RegisterOperators& operator=(RegisterOperators&&) noexcept;
62
63 class TORCH_API Options final {
64 public:
65 Options(const Options&) = delete;
66 Options(Options&&) noexcept = delete;
67 Options& operator=(const Options&) = delete;
68 Options& operator=(Options&&) noexcept = delete;
69
70 // internal-only for registering stack based kernels
71 template<KernelFunction::BoxedKernelFunction* kernel_func>
72 Options&& kernel(DispatchKey dispatch_key) && {
73 return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr);
74 }
75
76 // internal-only for registering stack based catch-all kernels
77 template<KernelFunction::BoxedKernelFunction* kernel_func>
78 Options&& catchAllKernel() && {
79 return std::move(*this).kernel(c10::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), nullopt, nullptr);
80 }
81
82 // internal only for registering caffe2 ops
83 Options&& schema(FunctionSchema&& schema) {
84 TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration.");
85 schemaOrName_ = c10::make_right<OperatorName, FunctionSchema>(std::move(schema));
86 return std::move(*this);
87 }
88
89 /**
90 * Use this to specify the schema for an operator. You can also specify
91 * the operator name only to have the function signature part of the
92 * schema be inferred from the kernel function.
93 *
94 * Example:
95 *
96 * > // Infer function signature from my_kernel_cpu
97 * > static auto registry = c10::RegisterOperators()
98 * > .op(c10::RegisterOperators::options()
99 * > .schema("my_op")
100 * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
101 * >
102 * >
103 * > // Explicitly specify full schema
104 * > static auto registry = c10::RegisterOperators()
105 * > .op(c10::RegisterOperators::options()
106 * > .schema("my_op(Tensor a) -> Tensor")
107 * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
108 */
109 Options&& schema(const std::string& schemaOrName) {
110 TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration.");
111
112 #if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD)
113 throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
114 #else
115 schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName);
116 #endif
117
118 return std::move(*this);
119 }
120
121 /**
122 * Use this to register an operator whose kernel is implemented as a functor.
123 * The kernel is only called for inputs matching the given dispatch key.
124 * You can register multiple kernels for different dispatch keys.
125 *
126 * Example:
127 *
128 * > namespace {
129 * > class my_kernel_cpu final : public c10::OperatorKernel {
130 * > public:
131 * > Tensor operator()(Tensor a, Tensor b) {...}
132 * > };
133 * > }
134 * >
135 * > static auto registry = c10::RegisterOperators()
136 * > .op(c10::RegisterOperators::options()
137 * > .schema("my_op")
138 * > .kernel<my_kernel_cpu>(DispatchKey::CPU));
139 *
140 * The functor constructor can take arguments to configure the kernel.
141 * The arguments are defined in the kernel registration.
142 * Example:
143 *
144 * > namespace {
145 * > class my_kernel_cpu final : public c10::OperatorKernel {
146 * > public:
147 * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
148 * > : ... {...}
149 * >
150 * > Tensor operator()(Tensor a, Tensor b) {...}
151 * > };
152 * > }
153 * >
154 * > static auto registry = c10::RegisterOperators()
155 * > .op(c10::RegisterOperators::options()
156 * > .schema("my_op")
157 * > .kernel<my_kernel_cpu>(DispatchKey::CPU, "some_configuration", 3, true));
158 */
159 template<class KernelFunctor, class... ConstructorParameters>
160 // enable_if: only enable it if KernelFunctor is actually a functor
161 std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && {
162 static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
163 static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
164
165 return std::move(*this).kernel(
166 std::move(dispatch_key),
167 KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
168 impl::CppSignature::make<KernelFunctor>(),
169 detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
170 );
171 }
172
173 /**
174 * Use this to register an operator whose kernel is implemented as a functor.
175 * The kernel is a catch-all kernel, meaning it's called independent from
176 * the input. Dispatch is disabled for this operator.
177 *
178 * Example:
179 *
180 * > namespace {
181 * > class my_kernel_cpu final : public c10::OperatorKernel {
182 * > public:
183 * > Tensor operator()(Tensor a, Tensor b) {...}
184 * > };
185 * > }
186 * >
187 * > static auto registry = c10::RegisterOperators()
188 * > .op(c10::RegisterOperators::options()
189 * > .schema("my_op")
190 * > .catchAllKernel<my_kernel_cpu>());
191 *
192 * The functor constructor can take arguments to configure the kernel.
193 * The arguments are defined in the kernel registration.
194 * Example:
195 *
196 * > namespace {
197 * > class my_kernel_cpu final : public c10::OperatorKernel {
198 * > public:
199 * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
200 * > : ... {...}
201 * >
202 * > Tensor operator()(Tensor a, Tensor b) {...}
203 * > };
204 * > }
205 * >
206 * > static auto registry = c10::RegisterOperators()
207 * > .op(c10::RegisterOperators::options()
208 * > .schema("my_op")
209 * > .catchAllKernel<my_kernel_cpu>("some_configuration", 3, true));
210 */
211 template<class KernelFunctor, class... ConstructorParameters>
212 // enable_if: only enable it if KernelFunctor is actually a functor
213 std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && {
214 static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
215 static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
216
217 return std::move(*this).kernel(
218 c10::nullopt,
219 KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
220 impl::CppSignature::make<KernelFunctor>(),
221 detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
222 );
223 }
224
225 /**
226 * Use this to register an operator whose kernel is implemented by a function.
227 * The kernel is only called for inputs matching the given dispatch key.
228 * You can register multiple kernels for different dispatch keys.
229 *
230 * Example:
231 *
232 * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
233 * >
234 * > static auto registry = c10::RegisterOperators()
235 * > .op(c10::RegisterOperators::options()
236 * > .schema("my_op")
237 * > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(DispatchKey::CPU));
238 */
239 template<class FuncType, FuncType* kernel_func>
240 // enable_if: only enable it if FuncType is actually a function
241 std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key) && {
242 static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
243 static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
244
245 return std::move(*this).kernel(
246 std::move(dispatch_key),
247 KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
248 impl::CppSignature::make<FuncType>(),
249 // TODO Do schema inference without relying on WrapFunctionIntoFunctor
250 detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
251 );
252 }
253
254 /**
255 * Use this to register an operator whose kernel is implemented by a function.
256 * The kernel is a catch-all kernel, meaning it's called independent from
257 * the input. Dispatch is disabled for this operator.
258 *
259 * Example:
260 *
261 * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
262 * >
263 * > static auto registry = c10::RegisterOperators()
264 * > .op(c10::RegisterOperators::options()
265 * > .schema("my_op")
266 * > .catchAllKernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
267 */
268 template<class FuncType, FuncType* kernel_func>
269 // enable_if: only enable it if FuncType is actually a function
270 std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel() && {
271 static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
272 static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
273
274 return std::move(*this).kernel(
275 c10::nullopt,
276 KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
277 impl::CppSignature::make<FuncType>(),
278 // TODO Do schema inference without relying on WrapFunctionIntoFunctor
279 detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
280 );
281 }
282
283 template<class FuncType>
284 // enable_if: only enable it if FuncType is actually a function
285 std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && {
286 static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
287 TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
288
289 return std::move(*this).kernel(
290 std::move(dispatch_key),
291 KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
292 impl::CppSignature::make<FuncType>(),
293 // TODO Do schema inference without relying on WrapFunctionIntoFunctor
294 detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
295 );
296 }
297
298 template<class FuncType>
299 // enable_if: only enable it if FuncType is actually a function
300 std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel(FuncType* kernel_func) && {
301 static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
302 TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
303
304 return std::move(*this).kernel(
305 c10::nullopt,
306 KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
307 impl::CppSignature::make<FuncType>(),
308 // TODO Do schema inference without relying on WrapFunctionIntoFunctor
309 detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
310 );
311 }
312
313 /**
314 * Use this to register an operator whose kernel is implemented as a lambda.
315 * The kernel is only called for inputs matching the given dispatch key.
316 * You can register multiple kernels for different dispatch keys.
317 *
318 * The lambda must be stateless, i.e. not have a capture. If your kernel
319 * needs to store some configuration parameters, write the kernel as a
320 * functor instead.
321 *
322 * Example:
323 *
324 * > static auto registry = c10::RegisterOperators()
325 * > .op(c10::RegisterOperators::options()
326 * > .schema("my_op")
327 * > .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...}));
328 */
329 template<class Lambda>
330 // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
331 std::enable_if_t<
332 guts::is_functor<std::decay_t<Lambda>>::value
333 && !std::is_same<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
334 Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && {
335 static_assert(!std::is_base_of<OperatorKernel, std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
336
337 // We don't support stateful lambdas (i.e. lambdas with a capture), because their
338 // behavior would be nonobvious. A functor kernel with cache gets a new instance of
339 // its cache each time the kernel is looked up from the dispatch table.
340 // A lambda with a capture would be global and share its capture between all kernel lookups.
341 // So, instead of making users having to think about it (including the thread-safety
342 // issues this causes), let's just forbid stateful lambdas altogether.
343 static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
344
345 return std::move(*this).kernel(
346 std::move(dispatch_key),
347 KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(functor)),
348 impl::CppSignature::make<Lambda>(),
349 // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
350 detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
351 );
352 }
353
354 /**
355 * Use this to register an operator whose kernel is implemented as a lambda.
356 * The kernel is a catch-all kernel, meaning it's called independent from
357 * the input. Dispatch is disabled for this operator.
358 *
359 * The lambda must be stateless, i.e. not have a capture. If your kernel
360 * needs to store some configuration parameters, write the kernel as a
361 * functor instead.
362 *
363 * Example:
364 *
365 * > static auto registry = c10::RegisterOperators()
366 * > .op(c10::RegisterOperators::options()
367 * > .schema("my_op")
368 * > .catchAllKernel([] (Tensor a) -> Tensor {...}));
369 */
370 template<class Lambda>
371 // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
372 std::enable_if_t<
373 guts::is_functor<std::decay_t<Lambda>>::value
374 && !std::is_same<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
375 Options&&> catchAllKernel(Lambda&& lambda) && {
376 static_assert(!std::is_base_of<OperatorKernel, std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
377
378 // We don't support stateful lambdas (i.e. lambdas with a capture), because their
379 // behavior would be nonobvious.
380 // A lambda with a capture would be global and share its capture between all kernel lookups.
381 // This would be a likely source for unexpected race conditions, so we forbid it.
382 // If a kernel really needs global state, they can just have regular global state
383 // in their .cpp file next to the kernel lambda.
384 static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
385
386 return std::move(*this).kernel(
387 c10::nullopt,
388 KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(lambda)),
389 impl::CppSignature::make<Lambda>(),
390 // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
391 detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
392 );
393 }
394
395 Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && {
396 TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration.");
397 aliasAnalysisKind_ = aliasAnalysisKind;
398 return std::move(*this);
399 }
400
401 private:
402 Options&& kernel(c10::optional<DispatchKey> dispatch_key, KernelFunction&& func, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && {
403 KernelRegistrationConfig config;
404 config.dispatch_key = dispatch_key;
405 config.func = std::move(func);
406 config.cpp_signature = std::move(cpp_signature);
407 config.inferred_function_schema = std::move(inferred_function_schema);
408 kernels.push_back(std::move(config));
409 return std::move(*this);
410 }
411
412 Options()
413 : schemaOrName_(c10::nullopt)
414 , kernels()
415 , aliasAnalysisKind_(c10::nullopt)
416 {}
417
418 // KernelRegistrationConfig accumulates all information from the config
419 // parameters passed to a RegisterOperators::op() call into one object.
420 struct KernelRegistrationConfig final {
421 KernelRegistrationConfig()
422 : dispatch_key(c10::nullopt)
423 , func()
424 , cpp_signature(c10::nullopt)
425 , inferred_function_schema(nullptr)
426 {}
427
428 c10::optional<DispatchKey> dispatch_key;
429 KernelFunction func;
430 c10::optional<impl::CppSignature> cpp_signature;
431 std::unique_ptr<FunctionSchema> inferred_function_schema;
432 };
433
434 c10::optional<c10::either<OperatorName, FunctionSchema>> schemaOrName_;
435
436 std::vector<KernelRegistrationConfig> kernels;
437 optional<AliasAnalysisKind> aliasAnalysisKind_;
438 friend class RegisterOperators;
439 friend class Library;
440 };
441
442 /**
443 * Call this to get an instance of registration options, which
444 * can be passed to a call to RegisterOperators::op() to specify
445 * these options for the operator registration.
446 * See class doc comment for examples.
447 */
448 static Options options() {
449 return {};
450 }
451
452 /**
453 * Call this to register an operator. See class doc comment for examples.
454 */
455 RegisterOperators&& op(Options&& options) && {
456 checkSchemaAndRegisterOp_(std::move(options));
457 return std::move(*this);
458 }
459
460 // Regular mutator version of the && version above
461 RegisterOperators& op(Options&& options) & {
462 checkSchemaAndRegisterOp_(std::move(options));
463 return *this;
464 }
465
466 /**
467 * This is a shorthand for RegisterOperators::op(Options) where you can
468 * specify the operator schema outside of the options parameter.
469 * See class doc comment for examples.
470 */
471 RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && {
472 return std::move(*this).op(std::move(options).schema(schemaOrName));
473 }
474
475 // internal only for registering caffe2 ops
476 RegisterOperators&& op(FunctionSchema schema, Options&& options) && {
477 return std::move(*this).op(std::move(options).schema(std::move(schema)));
478 }
479
480 template<class FuncType>
481 explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options())
482 : RegisterOperators() {
483 std::move(*this).op(schemaOrName, std::forward<FuncType>(func), std::move(options));
484 }
485
486 /**
487 * This API registers an operator based on a kernel function pointer.
488 *
489 * Given a kernel
490 *
491 * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
492 *
493 * This API looks like:
494 *
495 * > static auto registry = c10::RegisterOperators()
496 * > .op("my_op", &my_kernel_cpu);
497 *
498 * If your kernel is small and the overhead of calling it matters,
499 * then this API might be the wrong choice since the following API
500 * has a slightly lower overhead for calling into the kernel:
501 *
502 * > static auto registry = c10::RegisterOperators()
503 * > .op("my_op", c10::RegisterOperators::options()
504 * > .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
505 *
506 * Or, alternatively, write your kernel as a functor:
507 *
508 * > namespace {
509 * > class my_kernel_cpu final : public c10::OperatorKernel {
510 * > public:
511 * > Tensor operator()(Tensor a, Tensor b) {...}
512 * > };
513 * > }
514 * >
515 * > static auto registry = c10::RegisterOperators()
516 * > .op("my_op", c10::RegisterOperators::options()
517 * > .kernel<my_kernel_cpu>());
518 */
519 template<class FuncType>
520 // enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction.
521 std::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, RegisterOperators&&>
522 op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && {
523 constexpr bool AllowLegacyTypes = true;
524 return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
525 c10::nullopt,
526 KernelFunction::makeFromUnboxedRuntimeFunction<AllowLegacyTypes>(func),
527 impl::CppSignature::make<FuncType>(),
528 // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
529 detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
530 ));
531 }
532
533 /**
534 * This API registers an operator based on a kernel lambda.
535 *
536 * This API looks like:
537 *
538 * > static auto registry = c10::RegisterOperators()
539 * > .op("my_op", [] (Tensor a, Tensor b) {...});
540 *
541 * This is equivalent to:
542 *
543 * > static auto registry = c10::RegisterOperators()
544 * > .op("my_op", c10::RegisterOperators::options()
545 * > .catchAllKernel([] (Tensor a, Tensor b) {...}));
546 *
547 */
548 template<class Lambda>
549 // enable_if: only enable it if Lambda is actually a stateless lambda
550 std::enable_if_t<guts::is_functor<Lambda>::value && guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
551 op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
552 static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
553
554 constexpr bool AllowLegacyTypes = true;
555 return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
556 c10::nullopt,
557 KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
558 impl::CppSignature::make<Lambda>(),
559 // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
560 detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
561 ));
562 }
563
564 template<class Lambda>
565 C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead.")
566 // enable_if: only enable it if Lambda is actually a functor but not a stateless lambda
567 std::enable_if_t<guts::is_functor<Lambda>::value && !guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
568 op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
569 static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
570
571 constexpr bool AllowLegacyTypes = true;
572 return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
573 c10::nullopt,
574 KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
575 impl::CppSignature::make<Lambda>(),
576 // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
577 detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
578 ));
579 }
580
581private:
582 void checkSchemaAndRegisterOp_(Options&& config);
583
584 static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options);
585 void checkNoDuplicateKernels_(const Options& options);
586 void registerOp_(Options&& options);
587
588 std::vector<RegistrationHandleRAII> registrars_;
589};
590
591} // namespace c10
592
593namespace torch {
594 // Old-style API
595 using RegisterOperators = c10::RegisterOperators;
596}
597