1#pragma once
2
3/// \file
4///
5/// This header provides an API for extending PyTorch's core library
6/// of operators with user defined operators and data types. This
7/// API can be used in a few ways:
8///
9/// * You can define new custom operators and classes with TORCH_LIBRARY(),
10/// making them available for use in both eager Python as well as in
11/// TorchScript. This API is modeled off of pybind11's `PYBIND11_MODULE`
12/// macro, as the provided functionality is similar (pybind11 lets you bind
13/// C++ to Python only; `torch/library.h` lets you bind C++ simultaneously to
14/// Python and TorchScript).
15///
16/// * You can override existing operators with TORCH_LIBRARY_IMPL(),
17/// providing a new implementation for these operators for a custom
18/// backend (e.g., XLA). When you pass operators with tensors of your custom
19/// backend, your overridden implementations will be called instead
20/// of the standard implementations.
21///
22/// * You can use both capabilities at the same time, allowing you
23/// to write custom operators that register CPU/CUDA/Autograd
24/// implementations without having to write the boilerplate
25/// conditionals yourself.
26///
27/// For a tutorial style introduction to the library API, check
28/// out the [Extending TorchScript with Custom C++
29/// Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html)
30/// tutorial.
31///
32/// ```
33/// // Define a library whose operators live in the namespace 'myops'.
34/// // You must define all of the operators for this library in
35/// // this namespace.
36/// TORCH_LIBRARY(myops, m) {
37/// // Define a operator with exactly one implementation for all backends.
38/// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
39///
40/// // Define a schema for an operator, but provide no implementation
41/// // (use this syntax if you want to use the dispatcher)
42/// m.def("mul(Tensor self, Tensor other) -> Tensor");
43///
44/// // Provide an implementation for a defined operator (you can
45/// // provide multiple; one per backend). The dispatcher takes care of
46/// // calling the correct implementation depending on if we get a CPU
47/// // tensor or a CUDA tensor
48/// m.impl("mul", torch::kCPU, &mul_cpu_impl);
49/// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
50/// }
51///
52/// // Define implementations for operators for a non-standard backend,
53/// // e.g., XLA (valid values are entries of DispatchKey). This can
54/// // be used to define operators in a different file than the initial
55/// // TORCH_LIBRARY definition (e.g., if it is in an external library)
56/// TORCH_LIBRARY_IMPL(myops, XLA, m) {
57/// m.impl("mul", &mul_xla_impl);
58/// }
59/// ```
60
61#include <ATen/core/op_registration/infer_schema.h>
62#include <ATen/core/op_registration/op_allowlist.h>
63#include <c10/core/DispatchKey.h>
64#include <torch/csrc/jit/frontend/function_schema_parser.h>
65
66// Just for inferFunctionSchemaFromFunctor
67#include <ATen/core/op_registration/op_registration.h>
68#include <ATen/core/enum_tag.h>
69
70namespace torch {
71
72#if defined C10_MOBILE
73/**
74 * The NoInferSchemaTag is a type name used to indicate that this call to the
75 * CppFunction constructor should not trigger schema inference from functor.
76 * Schema inference from functor utilizes template meta-programming, and is
77 * costly from a size perspective. Ideally, one would expect that the schema
78 * inference would require very little binary size since most of the
79 * computation can be done by the compiler at build time, but that isn't
80 * necessarily the case.
81 *
82 * Schema inference is elided only for mobile use-cases where we don't need
83 * the additional runtime cost or size overhead on client devices.
84 *
85 */
86struct NoInferSchemaTag {};
87#endif
88
89// For multipy/torchdeploy use case
90enum class _RegisterOrVerify {
91 REGISTER,
92 VERIFY
93};
94
95template <class CurClass>
96class class_;
97
98/// Represents a C++ function that implements an operator. Most users won't
99/// interact directly with this class, except via error messages: the
100/// constructors this function define the set of permissible "function"-like
101/// things you can bind via the interface.
102///
103/// This class erases the type of the passed in function, but durably records
104/// the type via an inferred schema for the function.
105class TORCH_API CppFunction final {
106 // TODO: This is morally the same thing as KernelRegistrationConfig, but it's
107 // opaque to the user.
108
109 public:
110 /// This overload accepts function pointers, e.g., `CppFunction(&add_impl)`
111 template <typename Func>
112 explicit CppFunction(
113 Func* f,
114 std::enable_if_t<
115 c10::guts::is_function_type<Func>::value,
116 std::nullptr_t> = nullptr)
117 : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
118 cpp_signature_(c10::impl::CppSignature::make<Func>()),
119 schema_(
120 c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()),
121 debug_() {}
122
123 /// This overload accepts compile time function pointers, e.g.,
124 /// `CppFunction(TORCH_FN(add_impl))`
125 template <typename FuncPtr>
126 explicit CppFunction(
127 FuncPtr f,
128 std::enable_if_t<
129 c10::is_compile_time_function_pointer<FuncPtr>::value,
130 std::nullptr_t> = nullptr)
131 : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
132 cpp_signature_(
133 c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
134 schema_(c10::detail::inferFunctionSchemaFromFunctor<
135 typename FuncPtr::FuncType>()),
136 debug_() {}
137
138 /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
139 /// ... })`
140 template <typename Lambda>
141 explicit CppFunction(
142 Lambda&& f,
143 std::enable_if_t<
144 c10::guts::is_functor<std::decay_t<Lambda>>::value,
145 std::nullptr_t> = nullptr)
146 : func_(c10::KernelFunction::makeFromUnboxedLambda(
147 std::forward<Lambda>(f))),
148 cpp_signature_(c10::impl::CppSignature::make<Lambda>()),
149 schema_(c10::detail::inferFunctionSchemaFromFunctor<
150 std::decay_t<Lambda>>()),
151 debug_() {}
152
153#if defined C10_MOBILE
154 /// This overload accepts function pointers, e.g., `CppFunction(&add_impl,
155 /// NoInferSchemaTag())`
156 template <typename Func>
157 explicit CppFunction(
158 Func* f,
159 NoInferSchemaTag,
160 std::enable_if_t<
161 c10::guts::is_function_type<Func>::value,
162 std::nullptr_t> = nullptr)
163 : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
164 cpp_signature_(c10::impl::CppSignature::make<Func>())
165 // TODO: Don't go through WrapRuntimeKernelFunctor
166 ,
167 schema_(nullptr),
168 debug_() {}
169
170 /// This overload accepts compile time function pointers, e.g.,
171 /// `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())`
172 template <typename FuncPtr>
173 explicit CppFunction(
174 FuncPtr f,
175 NoInferSchemaTag,
176 std::enable_if_t<
177 c10::is_compile_time_function_pointer<FuncPtr>::value,
178 std::nullptr_t> = nullptr)
179 : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
180 cpp_signature_(
181 c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
182 // TODO: Don't go through WrapRuntimeKernelFunctor
183 ,
184 schema_(nullptr),
185 debug_() {}
186
187 /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
188 /// ... }. NoInferSchemaTag())`
189 template <typename Lambda>
190 explicit CppFunction(
191 Lambda&& f,
192 NoInferSchemaTag,
193 std::enable_if_t<
194 c10::guts::is_functor<std::decay_t<Lambda>>::value,
195 std::nullptr_t> = nullptr)
196 : func_(c10::KernelFunction::makeFromUnboxedLambda(
197 std::forward<Lambda>(f))),
198 cpp_signature_(c10::impl::CppSignature::make<Lambda>())
199 // TODO: Don't go through WrapRuntimeKernelFunctor
200 ,
201 schema_(nullptr),
202 debug_() {}
203#endif
204
205 ~CppFunction();
206
207 CppFunction(CppFunction&&) noexcept = default;
208
209 CppFunction& operator=(CppFunction&&) = default;
210
211 /// \private
212 /// Creates a function from a type-erased boxed kernel.
213 static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) {
214 return CppFunction(
215 c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)),
216 /* cpp_signature */ c10::nullopt, // not known for boxed functions
217 /* schema */ nullptr);
218 }
219
220 /// This creates a fallthrough function. Fallthrough functions
221 /// immediately redispatch to the next available dispatch key,
222 /// but are implemented more efficiently than a hand written
223 /// function done in the same way.
224 static CppFunction makeFallthrough() {
225 return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough());
226 }
227
228 /// \private
229 ///
230 /// Creates a function that raises an error saying that named tensors
231 /// are not supported when called.
232 static CppFunction makeNamedNotSupported() {
233 return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported());
234 }
235
236 /// Create a function from a boxed kernel function with signature
237 /// `void(const OperatorHandle&, Stack*)`; i.e., they receive a
238 /// stack of arguments in a boxed calling convention, rather than
239 /// in the native C++ calling convention. Boxed functions are
240 /// typically only used to register backend fallbacks via
241 /// torch::Library::fallback().
242 template <c10::BoxedKernel::BoxedKernelFunction* func>
243 static CppFunction makeFromBoxedFunction() {
244 return makeFromBoxedKernel(
245 c10::BoxedKernel::makeFromFunction<func>());
246 }
247
248 // Variant that takes in a boxed kernel function with a plumbed
249 // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for
250 // details.
251 template <c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
252 static CppFunction makeFromBoxedFunction() {
253 return makeFromBoxedKernel(
254 c10::BoxedKernel::makeFromFunction<func>());
255 }
256
257 /// Create a function from a boxed kernel functor which defines
258 /// `operator()(const OperatorHandle&, DispatchKeySet, Stack*)`
259 /// (receiving arguments from boxed calling convention) and inherits
260 /// from `c10::OperatorKernel`. Unlike makeFromBoxedFunction, functions
261 /// registered in this way can also carry additional state which
262 /// is managed by the functor; this is useful if you're writing an
263 /// adapter to some other implementation, e.g., a Python callable, which
264 /// is dynamically associated with the registered kernel.
265 template <class KernelFunctor>
266 static CppFunction makeFromBoxedFunctor(
267 std::unique_ptr<KernelFunctor> kernelFunctor) {
268 return makeFromBoxedKernel(
269 c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
270 }
271
272 /// Create a function from an unboxed kernel function.
273 /// This is typically used to register common operators.
274 template <
275 typename FuncPtr,
276 std::enable_if_t<
277 c10::guts::is_function_type<FuncPtr>::value,
278 std::nullptr_t> = nullptr>
279 static CppFunction makeFromUnboxedFunction(FuncPtr* f) {
280 return CppFunction(f);
281 }
282
283 /// Create a function from a compile time unboxed kernel function pointer.
284 /// This is typically used to register common operators.
285 /// Compile time function pointers can be used to allow the compiler
286 /// to optimize (e.g. inline) calls to it.
287 template <
288 typename FuncPtr,
289 std::enable_if_t<
290 c10::is_compile_time_function_pointer<FuncPtr>::value,
291 std::nullptr_t> = nullptr>
292 static CppFunction makeFromUnboxedFunction(FuncPtr f) {
293 return CppFunction(f);
294 }
295
296 CppFunction&& debug(std::string d) && {
297 debug_ = std::move(d);
298 return std::move(*this);
299 }
300
301 private:
302 c10::optional<c10::DispatchKey> dispatch_key_;
303 c10::KernelFunction func_;
304 c10::optional<c10::impl::CppSignature> cpp_signature_;
305 std::unique_ptr<c10::FunctionSchema> schema_;
306 std::string debug_;
307
308 // The "setter" for dispatch_key_
309 template <typename Func>
310 friend CppFunction dispatch(c10::DispatchKey, Func&&);
311
312 // The only class which actually pulls out values from CppFunction (does so
313 // destructively, felt too lazy to write accessors that I don't even
314 // want users to use)
315 friend class Library;
316
317 CppFunction(
318 c10::KernelFunction func,
319 c10::optional<c10::impl::CppSignature> cpp_signature,
320 std::unique_ptr<c10::FunctionSchema> schema);
321};
322
323/// \defgroup torch-dispatch-overloads torch::dispatch overloads
324
325/// Create a torch::CppFunction which is associated with a specific
326/// dispatch key. torch::CppFunctions that are tagged with a
327/// c10::DispatchKey don't get invoked unless the dispatcher determines
328/// that this particular c10::DispatchKey is the one that should be
329/// dispatched to.
330///
331/// This function is generally not used directly, instead, prefer using
332/// TORCH_LIBRARY_IMPL(), which will implicitly set the c10::DispatchKey
333/// for all registration calls inside of its body.
334///
335/// \ingroup torch-dispatch-overloads
336template <typename Func>
337inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
338 CppFunction f(std::forward<Func>(raw_f));
339 if (k == c10::DispatchKey::CatchAll) {
340 f.dispatch_key_ = c10::nullopt;
341 } else {
342 f.dispatch_key_ = k;
343 }
344 return f;
345}
346
347/// Convenience overload of dispatch() which accepts c10::DeviceType
348///
349/// \ingroup torch-dispatch-overloads
350template <typename Func>
351inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
352 auto deviceTypeToDispatchKey = [](c10::DeviceType t) {
353 switch (t) {
354 // This list is synchronized with the k-constants in c10/core/DeviceType.h
355 case c10::DeviceType::CPU:
356 return c10::DispatchKey::CPU;
357 case c10::DeviceType::CUDA:
358 return c10::DispatchKey::CUDA;
359 case c10::DeviceType::IPU:
360 return c10::DispatchKey::IPU;
361 case c10::DeviceType::XLA:
362 return c10::DispatchKey::XLA;
363 case c10::DeviceType::Lazy:
364 return c10::DispatchKey::Lazy;
365 case c10::DeviceType::MPS:
366 return c10::DispatchKey::MPS;
367 case c10::DeviceType::Meta:
368 return c10::DispatchKey::Meta;
369 case c10::DeviceType::HIP:
370 return c10::DispatchKey::HIP;
371 case c10::DeviceType::ORT:
372 return c10::DispatchKey::ORT;
373 case c10::DeviceType::HPU:
374 return c10::DispatchKey::HPU;
375 case c10::DeviceType::PrivateUse1:
376 return c10::DispatchKey::PrivateUse1;
377 default:
378 TORCH_CHECK(
379 false,
380 "Device type ",
381 t,
382 " cannot be overloaded at dispatch time, "
383 "please file a bug report explaining what you were trying to do.");
384 }
385 };
386 return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
387}
388
389/// \defgroup torch-schema-overloads torch::schema overloads
390
391/// Construct a c10::FunctionSchema from a string, with an explicitly
392/// specified c10::AliasAnalysisKind. Ordinarily, schemas are simply
393/// passed in as strings, but if you need to specify a custom alias
394/// analysis, you can replace the string with a call to this function.
395///
396/// ```
397/// // Default alias analysis (FROM_SCHEMA)
398/// m.def("def3(Tensor self) -> Tensor");
399/// // Pure function alias analysis
400/// m.def(torch::schema("def3(Tensor self) -> Tensor",
401/// c10::AliasAnalysisKind::PURE_FUNCTION));
402/// ```
403///
404/// \ingroup torch-schema-overloads
405inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
406 c10::FunctionSchema s = torch::jit::parseSchema(str);
407 s.setAliasAnalysis(k);
408 return s;
409}
410
411/// Function schemas can be directly constructed from string literals.
412///
413/// \ingroup torch-schema-overloads
414inline c10::FunctionSchema schema(const char* s) {
415 return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
416}
417
418/// \private
419///
420/// Already constructed function schemas are accepted if they are
421/// rvalues.
422///
423/// \ingroup torch-schema-overloads
424inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) {
425 return std::move(s);
426}
427
428namespace detail {
429
430inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
431 c10::FunctionSchema&& s) {
432 return c10::make_right<c10::OperatorName, c10::FunctionSchema>(std::move(s));
433}
434inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
435 c10::OperatorName&& n) {
436 return c10::make_left<c10::OperatorName, c10::FunctionSchema>(std::move(n));
437}
438inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
439 const char* str) {
440 auto s = torch::jit::parseSchemaOrName(str);
441 if (s.is_right()) {
442 s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
443 }
444 return s;
445}
446
447class TorchLibraryInit;
448
449} // namespace detail
450
451// Note [Selective build]
452// ~~~~~~~~~~~~~~~~~~~~~~
453// In some settings, especially mobile, it is important to avoid compiling any
454// references to functions that you aren't actually going to use, so that they
455// can be eliminated by the linker. We call this capability "selective build".
456//
457// A very easy way to implement selective build which results in a lot of
458// boilerplate is to just add ifdef's around every registration call, but this
459// means you have to write a lot of extra lines of code at every registration
460// site, and it also means you have to define some munging scheme to map
461// operators to macros.
462//
463// Instead of doing this, we have a different mechanism centered around the
464// concept of a SelectiveStr. A selective name is like a const char* string,
465// except it also carries at compile time a boolean saying whether or not a
466// registration should actually happen or not. We then have extra overloads
467// which bypass registration entirely if a selective name is disabled. We do a
468// constexpr test to see if a operator should be enabled or not; this is
469// currently implemented in ATen/core/op_registration/op_allowlist.h
470
471namespace detail {
472
473// dummy class for non selected custom torchbind classes
474class ClassNotSelected {
475 public:
476 ClassNotSelected& def_pickle(...) {
477 return *this;
478 }
479 ClassNotSelected& def(...) {
480 return *this;
481 }
482};
483
484// A SelectiveStr is like a const char*, except that it also comes
485// with a type brand that says whether or not the name is enabled or
486// not. If the string is disabled, then (at compile time) we DON'T generate
487// a registration call for it. This class is not intended to be called
488// directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below
489// to create it.
490template <bool enabled>
491class SelectiveStr {
492 public:
493 constexpr explicit SelectiveStr(const char* name) : name_(name) {}
494 constexpr operator const char*() {
495 return name_;
496 }
497
498 private:
499 const char* name_;
500};
501
502#define TORCH_SELECTIVE_CLASS(n) \
503 torch::detail::SelectiveStr<c10::impl::custom_class_allowlist_check(n)>(n)
504#define TORCH_SELECTIVE_NAME(n) \
505 torch::detail::SelectiveStr<c10::impl::op_allowlist_check(n)>(n)
506#define TORCH_SELECTIVE_SCHEMA(n) \
507 torch::detail::SelectiveStr<c10::impl::schema_allowlist_check(n)>(n)
508
509} // namespace detail
510
511/// This object provides the API for defining operators and providing
512/// implementations at dispatch keys. Typically, a torch::Library
513/// is not allocated directly; instead it is created by the
514/// TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() macros.
515///
516/// Most methods on torch::Library return a reference to itself,
517/// supporting method chaining.
518///
519/// ```
520/// // Examples:
521///
522/// TORCH_LIBRARY(torchvision, m) {
523/// // m is a torch::Library
524/// m.def("roi_align", ...);
525/// ...
526/// }
527///
528/// TORCH_LIBRARY_IMPL(aten, XLA, m) {
529/// // m is a torch::Library
530/// m.impl("add", ...);
531/// ...
532/// }
533/// ```
534///
535class TORCH_API Library final {
536 public:
537 /// \private
538 ///
539 /// Which type of macro produced this Library
540 enum Kind {
541 DEF, // from TORCH_LIBRARY (no qualifier)
542 IMPL,
543 FRAGMENT,
544 };
545
546 /// \private
547 ///
548 /// Use TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() instead of using these
549 /// constructors directly
550 Library(
551 Kind kind,
552 std::string ns,
553 c10::optional<c10::DispatchKey> k,
554 const char* file,
555 uint32_t line);
556
557 Library(const Library&) = delete;
558 Library& operator=(const Library&) = delete;
559 Library(Library&&) = default;
560 Library& operator=(Library&&) = default;
561
562 // Some notes about the API design here. We had the following constraints:
563 //
564 // - We need to support multiple "types" of arguments for schema and
565 // functions (e.g., unnamed lambda types, regular functions, const char*,
566 // fully instantiated schemas)
567 // - We don't want to write exponentially many overloads
568 // - We don't want to rely on implicit conversion to a common type,
569 // because the C++ compiler will only be willing to do a single
570 // implicit conversion (reducing the set of valid types which you
571 // can invoke with); also error messages are worse when an implicit
572 // conversion is not selected (as the compiler will not explain
573 // why it didn't select an implicit conversion; this is different
574 // from overloads where it will explain each candidate overload and
575 // why it didn't apply)
576 //
577 // To solve all of these constraints at the same time, we use a trick taken
578 // from the pybind11 library: template over the argument in the user visible
579 // API, and inside of the templated function explicitly call an overloaded
580 // function to resolve the argument to a real type. You get the good error
581 // messages from overloads, but at the same time you only need to write the
582 // overload for any given argument type once.
583
584 /// Declare an operator with a schema, but don't provide any implementations
585 /// for it. You're expected to then provide implementations using the
586 /// impl() method. All template arguments are inferred.
587 ///
588 /// \param raw_schema The schema of the operator to be defined.
589 /// Typically, this is a `const char*` string literal, but any type
590 /// accepted by torch::schema() is accepted here.
591 ///
592 /// ```
593 /// // Example:
594 /// TORCH_LIBRARY(myops, m) {
595 /// m.def("add(Tensor self, Tensor other) -> Tensor");
596 /// }
597 /// ```
598
599 template <typename Schema>
600 Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
601 c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
602 return _def(std::move(s), nullptr, tags, rv);
603 }
604 /// Define an operator for a schema and then register an implementation for
605 /// it. This is typically what you would use if you aren't planning
606 /// on making use of the dispatcher to structure your operator
607 /// implementation. It's roughly equivalent to calling def() and
608 /// then impl(), but if you omit the schema of the operator, we will
609 /// infer it from the type of your C++ function. All template
610 /// arguments are inferred.
611 ///
612 /// \param raw_name_or_schema The schema of the operator to be
613 /// defined, or just the name of the operator if the schema is to be
614 /// inferred from `raw_f`. Typically a `const char*` literal.
615 /// \param raw_f The C++ function that implements this operator.
616 /// Any valid constructor of torch::CppFunction is accepted here;
617 /// typically you provide a function pointer or lambda.
618 ///
619 /// ```
620 /// // Example:
621 /// TORCH_LIBRARY(myops, m) {
622 /// m.def("add", add_fn);
623 /// }
624 /// ```
625 template <typename NameOrSchema, typename Func>
626 Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
627 CppFunction f(std::forward<Func>(raw_f));
628 auto name_or_schema = detail::constructSchemaOrName(
629 std::forward<NameOrSchema>(raw_name_or_schema));
630 return _def(std::move(name_or_schema), std::move(f));
631 }
632
633 /// Register an implementation for an operator. You may register multiple
634 /// implementations for a single operator at different dispatch keys
635 /// (see torch::dispatch()). Implementations must have a corresponding
636 /// declaration (from def()), otherwise they are invalid. If you plan
637 /// to register multiple implementations, DO NOT provide a function
638 /// implementation when you def() the operator.
639 ///
640 /// \param name The name of the operator to implement. Do NOT provide
641 /// schema here.
642 /// \param raw_f The C++ function that implements this operator. Any
643 /// valid constructor of torch::CppFunction is accepted here;
644 /// typically you provide a function pointer or lambda.
645 ///
646 /// ```
647 /// // Example:
648 /// TORCH_LIBRARY_IMPL(myops, CUDA, m) {
649 /// m.impl("add", add_cuda);
650 /// }
651 /// ```
652 template <typename Name, typename Func>
653 Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
654 // TODO: need to raise an error when you impl a function that has a
655 // catch all def
656#if defined C10_MOBILE
657 CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
658#else
659 CppFunction f(std::forward<Func>(raw_f));
660#endif
661 return _impl(name, std::move(f), rv);
662 }
663
664#if defined C10_MOBILE
665 // Note: This overload is needed only for C10_MOBILE, since the automatically
666 // defined copy constructor for the CppFunction doesn't have the additional
667 // NoInferSchemaTag argument. We define the overload for the impl() function
668 // to accept a CppFunction&& argument. The already constructed CppFunction
669 // object may or may not have the inferred schema, but it doesn't matter
670 // for our purposes since if it already has the inferred schema, then we
671 // might as well just pass it through directly.
672 //
673 template <typename Name>
674 Library& impl(Name name, CppFunction&& raw_f) & {
675 // TODO: need to raise an error when you impl a function that has a
676 // catch all def
677 CppFunction f(std::forward<CppFunction>(raw_f));
678 return _impl(name, std::move(f));
679 }
680#endif
681
682 // Helper for getting an OperatorName for a const char*. You probably
683 // don't need this.
684 c10::OperatorName _resolve(const char* name) const;
685
686 /// \private
687 ///
688 /// Convenience overload for directly specifying the dispatch key when
689 /// impl(). You probably don't need this; instead, prefer specifying
690 /// the dispatch key for the entire block in TORCH_LIBRARY_IMPL()
691 template <typename Name, typename Dispatch, typename Func>
692 Library& impl(Name name, Dispatch&& key, Func&& raw_f) & {
693 return impl(
694 name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
695 }
696
697 template <typename Name, typename Func>
698 Library& impl_UNBOXED(Name /*name*/, Func* /*raw_f*/) & {
699 static_assert(
700 c10::guts::false_t<Func>(),
701 ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
702 return *this;
703 }
704
705 // These overloads cover cases when a SelectiveStr (see Note [Selective
706 // build]) has been disabled at compile time. In that case, don't generate
707 // any code referencing the passed in functions at all.
708 Library& def(detail::SelectiveStr<false>) & {
709 return *this;
710 }
711 Library& def(detail::SelectiveStr<true> raw_schema) & {
712 return def(raw_schema.operator const char*());
713 }
714 template <typename Func>
715 Library& def(detail::SelectiveStr<false>, Func&& /*raw_f*/) & {
716 return *this;
717 }
718 template <typename Func>
719 Library& def(detail::SelectiveStr<true> raw_name_or_schema, Func&& raw_f) & {
720 return def(
721 raw_name_or_schema.operator const char*(), std::forward<Func>(raw_f));
722 }
723
724 template <typename Func>
725 Library& impl(detail::SelectiveStr<false>, Func&& /*raw_f*/) & {
726 return *this;
727 }
728 template <typename Dispatch, typename Func>
729 Library& impl(detail::SelectiveStr<false>, Dispatch&& /*key*/, Func&& /*raw_f*/) & {
730 return *this;
731 }
732 template <typename Func>
733 Library& impl_UNBOXED(detail::SelectiveStr<false> /*name*/, Func* /*raw_f*/) & {
734 static_assert(
735 c10::guts::false_t<Func>(),
736 ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
737 return *this;
738 }
739
740 template <typename Func>
741 Library& impl(detail::SelectiveStr<true> name, Func&& raw_f) & {
742 return impl(name.operator const char*(), std::forward<Func>(raw_f));
743 }
744 template <typename Dispatch, typename Func>
745 Library& impl(
746 detail::SelectiveStr<true> name,
747 Dispatch&& key,
748 Func&& raw_f) & {
749 return impl(
750 name.operator const char*(),
751 std::forward<Dispatch>(key),
752 std::forward<Func>(raw_f));
753 }
754 template <typename Func>
755 Library& impl_UNBOXED(detail::SelectiveStr<true> /*name*/, Func* /*raw_f*/) & {
756 static_assert(
757 c10::guts::false_t<Func>(),
758 ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
759 return *this;
760 }
761
762 /// Register a fallback implementation for all operators which will be used
763 /// if there is not a specific implementation for an operator available.
764 /// There MUST be a DispatchKey associated with a fallback; e.g.,
765 /// only call this from TORCH_LIBRARY_IMPL() with namespace `_`.
766 ///
767 /// \param raw_f The function that implements the fallback. Unboxed
768 /// functions typically do not work as fallback functions, as
769 /// fallback functions must work for every operator (even though
770 /// they have varying type signatures). Typical arguments are
771 /// CppFunction::makeFallthrough() or
772 /// CppFunction::makeFromBoxedFunction()
773 ///
774 /// ```
775 /// // Example:
776 ///
777 /// TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
778 /// // If there is not a kernel explicitly registered
779 /// // for AutogradXLA, fallthrough to the next
780 /// // available kernel
781 /// m.fallback(torch::CppFunction::makeFallthrough());
782 /// }
783 ///
784 /// // See aten/src/ATen/core/dispatch/backend_fallback_test.cpp
785 /// // for a full example of boxed fallback
786 /// ```
787 template <typename Func>
788 Library& fallback(Func&& raw_f) & {
789 CppFunction f((std::forward<Func>(raw_f)));
790 return _fallback(std::move(f));
791 }
792
793 template <class CurClass>
794 inline torch::class_<CurClass> class_(const std::string& className);
795
796 // These overloads enable the use of selective build on classes registered
797 // within a library. The API is the same as before with 1 minor change.
798 // Instead of m.class_<foo>("foo") you instead do
799 // m.class_<foo>(TORCH_SELECTIVE_CLASS("foo"))
800 template <class CurClass>
801 inline torch::class_<CurClass> class_(detail::SelectiveStr<true> className);
802
803 template <class CurClass>
804 inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className);
805
806 private:
807 Kind kind_;
808 c10::optional<std::string> ns_;
809 c10::optional<c10::DispatchKey> dispatch_key_;
810 const char* file_;
811 uint32_t line_;
812
813 std::vector<c10::RegistrationHandleRAII> registrars_;
814
815 friend class detail::TorchLibraryInit;
816
817 // Non-user visible actual implementations of functions. These aren't
818 // public because we only implement & qualifier and not && qualifier
819 Library& _def(
820 c10::FunctionSchema&& schema,
821 c10::OperatorName* out_name = nullptr,
822 const std::vector<at::Tag>& tags = {},
823 _RegisterOrVerify rv = _RegisterOrVerify::REGISTER
824 ) &;
825 Library& _def(
826 c10::either<c10::OperatorName, c10::FunctionSchema>&&,
827 CppFunction&& f) &;
828 Library& _impl(const char* name, CppFunction&& f,
829 _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
830 Library& _fallback(CppFunction&& f) &;
831
832 at::OperatorName _parseNameForLib(const char* name_str) const;
833};
834
835namespace detail {
836
837class TorchLibraryInit final {
838 private:
839 using InitFn = void(Library&);
840 Library lib_;
841
842 public:
843 TorchLibraryInit(
844 Library::Kind kind,
845 InitFn* fn,
846 const char* ns,
847 c10::optional<c10::DispatchKey> k,
848 const char* file,
849 uint32_t line)
850 : lib_(kind, ns, k, file, line) {
851 fn(lib_);
852 }
853};
854
855} // namespace detail
856
857} // namespace torch
858
859// NB: The EXACT NAMING of the initializer functions (e.g.,
860// TORCH_LIBRARY_init_aten) matters for the code analyzer;
861// see the regexes at tools/code_analyzer/run_analyzer.sh
862
863/// Macro for defining a function that will be run at static
864/// initialization time to define a library of operators in the
865/// namespace `ns` (must be a valid C++ identifier, no quotes).
866/// Use this macro when you want to define a new set of custom operators
867/// that do not already exist in PyTorch.
868///
869/// Example usage:
870///
871/// ```
872/// TORCH_LIBRARY(myops, m) {
873/// // m is a torch::Library; methods on it will define
874/// // operators in the myops namespace
875/// m.def("add", add_impl);
876/// }
877/// ```
878///
879/// The `m` argument is bound to a torch::Library that is used to
880/// register operators. There may only be one TORCH_LIBRARY()
881/// for any given namespace.
882#define TORCH_LIBRARY(ns, m) \
883 static void TORCH_LIBRARY_init_##ns(torch::Library&); \
884 static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \
885 torch::Library::DEF, \
886 &TORCH_LIBRARY_init_##ns, \
887 #ns, \
888 c10::nullopt, \
889 __FILE__, \
890 __LINE__); \
891 void TORCH_LIBRARY_init_##ns(torch::Library& m)
892
893/// \private
894///
895/// This macro is a version of TORCH_LIBRARY() that doesn't enforce that there
896/// is only one library (it is a "fragment"). This is used inside the
897/// PerOpRegistration.cpp file, as well as in places where all op registrations
898/// within the same namespace cannot be easily put into one macro block
899/// (this is mostly the case for custom ops in fbcode that were ported from
900/// the old API)
901#define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID)
902
903/// \private
904///
905/// The above macro requires an extra unique identifier (uid) to prevent
906/// variable name collisions This can happen if TORCH_LIBRARY_FRAGMENT is called
907/// multiple times with the same namespace in the same translation unit. Note
908/// that the TORCH_LIBRARY variant doesn't run into this problem, because it
909/// enforces that it can only be called once for a given namespace.
910#define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
911 static void C10_CONCATENATE( \
912 TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library&); \
913 static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
914 TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
915 torch::Library::FRAGMENT, \
916 &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
917 #ns, \
918 c10::nullopt, \
919 __FILE__, \
920 __LINE__); \
921 void C10_CONCATENATE( \
922 TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library & m)
923
924/// Macro for defining a function that will be run at static
925/// initialization time to define operator overrides for dispatch key
926/// `k` (must be an unqualified enum member of c10::DispatchKey) in
927/// namespace `ns` (must be a valid C++ identifer, no quotes). Use this
928/// macro when you want to implement a preexisting set of custom
929/// operators on a new dispatch key (e.g., you want to provide CUDA
930/// implementations of already existing operators). One common usage
931/// pattern is to use TORCH_LIBRARY() to define schema for all new
932/// operators you want to define, and then use several
933/// TORCH_LIBRARY_IMPL() blocks to provide implementations of the
934/// operator for CPU, CUDA and Autograd.
935///
936/// In some cases, you need to define something that applies to all namespaces,
937/// not just one namespace (usually a fallback). In that case, use the reserved
938/// namespace _, e.g.,
939///
940/// ```
941/// TORCH_LIBRARY_IMPL(_, XLA, m) {
942/// m.fallback(xla_fallback);
943/// }
944/// ```
945///
946/// Example usage:
947///
948/// ```
949/// TORCH_LIBRARY_IMPL(myops, CPU, m) {
950/// // m is a torch::Library; methods on it will define
951/// // CPU implementations of operators in the myops namespace.
952/// // It is NOT valid to call torch::Library::def()
953/// // in this context.
954/// m.impl("add", add_cpu_impl);
955/// }
956/// ```
957///
958/// If ``add_cpu_impl`` is an overloaded function, use a
959/// ``static_cast`` to specify which overload you want
960/// (by providing the full type).
961///
962// NB: if the dispatch key is not whitelisted, we simply omit the Library
963// call entirely
964#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)
965
966/// \private
967///
968/// The above macro requires an extra unique identifier (uid) to prevent
969/// variable name collisions. This can happen if TORCH_LIBRARY_IMPL is called
970/// multiple times with the same namespace and dispatch key in the same
971/// translation unit.
972#define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \
973 static void C10_CONCATENATE( \
974 TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&); \
975 static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
976 TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
977 torch::Library::IMPL, \
978 c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( \
979 c10::DispatchKey::k)>( \
980 []() { \
981 return &C10_CONCATENATE( \
982 TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid); \
983 }, \
984 []() { return [](torch::Library&) -> void {}; }), \
985 #ns, \
986 c10::make_optional(c10::DispatchKey::k), \
987 __FILE__, \
988 __LINE__); \
989 void C10_CONCATENATE( \
990 TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)
991
992// These are variants of the macros above which are to be used for testing (they
993// don't setup the static initializer, so you can control the visibility of
994// the allocated library yourself).
995//
996// DO NOT use these in production code, they are NOT understood by the
997// code analyzer and will be incorrectly analyzed in those situations.
998
999/// \private
1000#define MAKE_TORCH_LIBRARY(ns) \
1001 torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
1002/// \private
1003#define MAKE_TORCH_LIBRARY_IMPL(ns, k) \
1004 torch::Library( \
1005 torch::Library::IMPL, \
1006 #ns, \
1007 c10::make_optional(c10::DispatchKey::k), \
1008 __FILE__, \
1009 __LINE__)
1010
1011// Make the custom class API visible, so it is available from
1012// torch::Library.
1013
1014#include <torch/custom_class.h>
1015