1 | #pragma once |
2 | |
3 | /// \file |
4 | /// |
5 | /// This header provides an API for extending PyTorch's core library |
6 | /// of operators with user defined operators and data types. This |
7 | /// API can be used in a few ways: |
8 | /// |
9 | /// * You can define new custom operators and classes with TORCH_LIBRARY(), |
10 | /// making them available for use in both eager Python as well as in |
11 | /// TorchScript. This API is modeled off of pybind11's `PYBIND11_MODULE` |
12 | /// macro, as the provided functionality is similar (pybind11 lets you bind |
13 | /// C++ to Python only; `torch/library.h` lets you bind C++ simultaneously to |
14 | /// Python and TorchScript). |
15 | /// |
16 | /// * You can override existing operators with TORCH_LIBRARY_IMPL(), |
17 | /// providing a new implementation for these operators for a custom |
18 | /// backend (e.g., XLA). When you pass operators with tensors of your custom |
19 | /// backend, your overridden implementations will be called instead |
20 | /// of the standard implementations. |
21 | /// |
22 | /// * You can use both capabilities at the same time, allowing you |
23 | /// to write custom operators that register CPU/CUDA/Autograd |
24 | /// implementations without having to write the boilerplate |
25 | /// conditionals yourself. |
26 | /// |
27 | /// For a tutorial style introduction to the library API, check |
28 | /// out the [Extending TorchScript with Custom C++ |
29 | /// Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html) |
30 | /// tutorial. |
31 | /// |
32 | /// ``` |
33 | /// // Define a library whose operators live in the namespace 'myops'. |
34 | /// // You must define all of the operators for this library in |
35 | /// // this namespace. |
36 | /// TORCH_LIBRARY(myops, m) { |
37 | /// // Define a operator with exactly one implementation for all backends. |
38 | /// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl); |
39 | /// |
40 | /// // Define a schema for an operator, but provide no implementation |
41 | /// // (use this syntax if you want to use the dispatcher) |
42 | /// m.def("mul(Tensor self, Tensor other) -> Tensor"); |
43 | /// |
44 | /// // Provide an implementation for a defined operator (you can |
45 | /// // provide multiple; one per backend). The dispatcher takes care of |
46 | /// // calling the correct implementation depending on if we get a CPU |
47 | /// // tensor or a CUDA tensor |
48 | /// m.impl("mul", torch::kCPU, &mul_cpu_impl); |
49 | /// m.impl("mul", torch::kCUDA, &mul_cuda_impl); |
50 | /// } |
51 | /// |
52 | /// // Define implementations for operators for a non-standard backend, |
53 | /// // e.g., XLA (valid values are entries of DispatchKey). This can |
54 | /// // be used to define operators in a different file than the initial |
55 | /// // TORCH_LIBRARY definition (e.g., if it is in an external library) |
56 | /// TORCH_LIBRARY_IMPL(myops, XLA, m) { |
57 | /// m.impl("mul", &mul_xla_impl); |
58 | /// } |
59 | /// ``` |
60 | |
61 | #include <ATen/core/op_registration/infer_schema.h> |
62 | #include <ATen/core/op_registration/op_allowlist.h> |
63 | #include <c10/core/DispatchKey.h> |
64 | #include <torch/csrc/jit/frontend/function_schema_parser.h> |
65 | |
66 | // Just for inferFunctionSchemaFromFunctor |
67 | #include <ATen/core/op_registration/op_registration.h> |
68 | #include <ATen/core/enum_tag.h> |
69 | |
70 | namespace torch { |
71 | |
72 | #if defined C10_MOBILE |
73 | /** |
74 | * The NoInferSchemaTag is a type name used to indicate that this call to the |
75 | * CppFunction constructor should not trigger schema inference from functor. |
76 | * Schema inference from functor utilizes template meta-programming, and is |
77 | * costly from a size perspective. Ideally, one would expect that the schema |
78 | * inference would require very little binary size since most of the |
79 | * computation can be done by the compiler at build time, but that isn't |
80 | * necessarily the case. |
81 | * |
82 | * Schema inference is elided only for mobile use-cases where we don't need |
83 | * the additional runtime cost or size overhead on client devices. |
84 | * |
85 | */ |
86 | struct NoInferSchemaTag {}; |
87 | #endif |
88 | |
89 | // For multipy/torchdeploy use case |
90 | enum class _RegisterOrVerify { |
91 | REGISTER, |
92 | VERIFY |
93 | }; |
94 | |
95 | template <class CurClass> |
96 | class class_; |
97 | |
98 | /// Represents a C++ function that implements an operator. Most users won't |
99 | /// interact directly with this class, except via error messages: the |
100 | /// constructors this function define the set of permissible "function"-like |
101 | /// things you can bind via the interface. |
102 | /// |
103 | /// This class erases the type of the passed in function, but durably records |
104 | /// the type via an inferred schema for the function. |
105 | class TORCH_API CppFunction final { |
106 | // TODO: This is morally the same thing as KernelRegistrationConfig, but it's |
107 | // opaque to the user. |
108 | |
109 | public: |
110 | /// This overload accepts function pointers, e.g., `CppFunction(&add_impl)` |
111 | template <typename Func> |
112 | explicit CppFunction( |
113 | Func* f, |
114 | std::enable_if_t< |
115 | c10::guts::is_function_type<Func>::value, |
116 | std::nullptr_t> = nullptr) |
117 | : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)), |
118 | cpp_signature_(c10::impl::CppSignature::make<Func>()), |
119 | schema_( |
120 | c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()), |
121 | debug_() {} |
122 | |
123 | /// This overload accepts compile time function pointers, e.g., |
124 | /// `CppFunction(TORCH_FN(add_impl))` |
125 | template <typename FuncPtr> |
126 | explicit CppFunction( |
127 | FuncPtr f, |
128 | std::enable_if_t< |
129 | c10::is_compile_time_function_pointer<FuncPtr>::value, |
130 | std::nullptr_t> = nullptr) |
131 | : func_(c10::KernelFunction::makeFromUnboxedFunction(f)), |
132 | cpp_signature_( |
133 | c10::impl::CppSignature::make<typename FuncPtr::FuncType>()), |
134 | schema_(c10::detail::inferFunctionSchemaFromFunctor< |
135 | typename FuncPtr::FuncType>()), |
136 | debug_() {} |
137 | |
138 | /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) { |
139 | /// ... })` |
140 | template <typename Lambda> |
141 | explicit CppFunction( |
142 | Lambda&& f, |
143 | std::enable_if_t< |
144 | c10::guts::is_functor<std::decay_t<Lambda>>::value, |
145 | std::nullptr_t> = nullptr) |
146 | : func_(c10::KernelFunction::makeFromUnboxedLambda( |
147 | std::forward<Lambda>(f))), |
148 | cpp_signature_(c10::impl::CppSignature::make<Lambda>()), |
149 | schema_(c10::detail::inferFunctionSchemaFromFunctor< |
150 | std::decay_t<Lambda>>()), |
151 | debug_() {} |
152 | |
153 | #if defined C10_MOBILE |
154 | /// This overload accepts function pointers, e.g., `CppFunction(&add_impl, |
155 | /// NoInferSchemaTag())` |
156 | template <typename Func> |
157 | explicit CppFunction( |
158 | Func* f, |
159 | NoInferSchemaTag, |
160 | std::enable_if_t< |
161 | c10::guts::is_function_type<Func>::value, |
162 | std::nullptr_t> = nullptr) |
163 | : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)), |
164 | cpp_signature_(c10::impl::CppSignature::make<Func>()) |
165 | // TODO: Don't go through WrapRuntimeKernelFunctor |
166 | , |
167 | schema_(nullptr), |
168 | debug_() {} |
169 | |
170 | /// This overload accepts compile time function pointers, e.g., |
171 | /// `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())` |
172 | template <typename FuncPtr> |
173 | explicit CppFunction( |
174 | FuncPtr f, |
175 | NoInferSchemaTag, |
176 | std::enable_if_t< |
177 | c10::is_compile_time_function_pointer<FuncPtr>::value, |
178 | std::nullptr_t> = nullptr) |
179 | : func_(c10::KernelFunction::makeFromUnboxedFunction(f)), |
180 | cpp_signature_( |
181 | c10::impl::CppSignature::make<typename FuncPtr::FuncType>()) |
182 | // TODO: Don't go through WrapRuntimeKernelFunctor |
183 | , |
184 | schema_(nullptr), |
185 | debug_() {} |
186 | |
187 | /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) { |
188 | /// ... }. NoInferSchemaTag())` |
189 | template <typename Lambda> |
190 | explicit CppFunction( |
191 | Lambda&& f, |
192 | NoInferSchemaTag, |
193 | std::enable_if_t< |
194 | c10::guts::is_functor<std::decay_t<Lambda>>::value, |
195 | std::nullptr_t> = nullptr) |
196 | : func_(c10::KernelFunction::makeFromUnboxedLambda( |
197 | std::forward<Lambda>(f))), |
198 | cpp_signature_(c10::impl::CppSignature::make<Lambda>()) |
199 | // TODO: Don't go through WrapRuntimeKernelFunctor |
200 | , |
201 | schema_(nullptr), |
202 | debug_() {} |
203 | #endif |
204 | |
205 | ~CppFunction(); |
206 | |
207 | CppFunction(CppFunction&&) noexcept = default; |
208 | |
209 | CppFunction& operator=(CppFunction&&) = default; |
210 | |
211 | /// \private |
212 | /// Creates a function from a type-erased boxed kernel. |
213 | static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) { |
214 | return CppFunction( |
215 | c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)), |
216 | /* cpp_signature */ c10::nullopt, // not known for boxed functions |
217 | /* schema */ nullptr); |
218 | } |
219 | |
220 | /// This creates a fallthrough function. Fallthrough functions |
221 | /// immediately redispatch to the next available dispatch key, |
222 | /// but are implemented more efficiently than a hand written |
223 | /// function done in the same way. |
224 | static CppFunction makeFallthrough() { |
225 | return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough()); |
226 | } |
227 | |
228 | /// \private |
229 | /// |
230 | /// Creates a function that raises an error saying that named tensors |
231 | /// are not supported when called. |
232 | static CppFunction makeNamedNotSupported() { |
233 | return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported()); |
234 | } |
235 | |
236 | /// Create a function from a boxed kernel function with signature |
237 | /// `void(const OperatorHandle&, Stack*)`; i.e., they receive a |
238 | /// stack of arguments in a boxed calling convention, rather than |
239 | /// in the native C++ calling convention. Boxed functions are |
240 | /// typically only used to register backend fallbacks via |
241 | /// torch::Library::fallback(). |
242 | template <c10::BoxedKernel::BoxedKernelFunction* func> |
243 | static CppFunction makeFromBoxedFunction() { |
244 | return makeFromBoxedKernel( |
245 | c10::BoxedKernel::makeFromFunction<func>()); |
246 | } |
247 | |
248 | // Variant that takes in a boxed kernel function with a plumbed |
249 | // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for |
250 | // details. |
251 | template <c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys* func> |
252 | static CppFunction makeFromBoxedFunction() { |
253 | return makeFromBoxedKernel( |
254 | c10::BoxedKernel::makeFromFunction<func>()); |
255 | } |
256 | |
257 | /// Create a function from a boxed kernel functor which defines |
258 | /// `operator()(const OperatorHandle&, DispatchKeySet, Stack*)` |
259 | /// (receiving arguments from boxed calling convention) and inherits |
260 | /// from `c10::OperatorKernel`. Unlike makeFromBoxedFunction, functions |
261 | /// registered in this way can also carry additional state which |
262 | /// is managed by the functor; this is useful if you're writing an |
263 | /// adapter to some other implementation, e.g., a Python callable, which |
264 | /// is dynamically associated with the registered kernel. |
265 | template <class KernelFunctor> |
266 | static CppFunction makeFromBoxedFunctor( |
267 | std::unique_ptr<KernelFunctor> kernelFunctor) { |
268 | return makeFromBoxedKernel( |
269 | c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); |
270 | } |
271 | |
272 | /// Create a function from an unboxed kernel function. |
273 | /// This is typically used to register common operators. |
274 | template < |
275 | typename FuncPtr, |
276 | std::enable_if_t< |
277 | c10::guts::is_function_type<FuncPtr>::value, |
278 | std::nullptr_t> = nullptr> |
279 | static CppFunction makeFromUnboxedFunction(FuncPtr* f) { |
280 | return CppFunction(f); |
281 | } |
282 | |
283 | /// Create a function from a compile time unboxed kernel function pointer. |
284 | /// This is typically used to register common operators. |
285 | /// Compile time function pointers can be used to allow the compiler |
286 | /// to optimize (e.g. inline) calls to it. |
287 | template < |
288 | typename FuncPtr, |
289 | std::enable_if_t< |
290 | c10::is_compile_time_function_pointer<FuncPtr>::value, |
291 | std::nullptr_t> = nullptr> |
292 | static CppFunction makeFromUnboxedFunction(FuncPtr f) { |
293 | return CppFunction(f); |
294 | } |
295 | |
296 | CppFunction&& debug(std::string d) && { |
297 | debug_ = std::move(d); |
298 | return std::move(*this); |
299 | } |
300 | |
301 | private: |
302 | c10::optional<c10::DispatchKey> dispatch_key_; |
303 | c10::KernelFunction func_; |
304 | c10::optional<c10::impl::CppSignature> cpp_signature_; |
305 | std::unique_ptr<c10::FunctionSchema> schema_; |
306 | std::string debug_; |
307 | |
308 | // The "setter" for dispatch_key_ |
309 | template <typename Func> |
310 | friend CppFunction dispatch(c10::DispatchKey, Func&&); |
311 | |
312 | // The only class which actually pulls out values from CppFunction (does so |
313 | // destructively, felt too lazy to write accessors that I don't even |
314 | // want users to use) |
315 | friend class Library; |
316 | |
317 | CppFunction( |
318 | c10::KernelFunction func, |
319 | c10::optional<c10::impl::CppSignature> cpp_signature, |
320 | std::unique_ptr<c10::FunctionSchema> schema); |
321 | }; |
322 | |
323 | /// \defgroup torch-dispatch-overloads torch::dispatch overloads |
324 | |
325 | /// Create a torch::CppFunction which is associated with a specific |
326 | /// dispatch key. torch::CppFunctions that are tagged with a |
327 | /// c10::DispatchKey don't get invoked unless the dispatcher determines |
328 | /// that this particular c10::DispatchKey is the one that should be |
329 | /// dispatched to. |
330 | /// |
331 | /// This function is generally not used directly, instead, prefer using |
332 | /// TORCH_LIBRARY_IMPL(), which will implicitly set the c10::DispatchKey |
333 | /// for all registration calls inside of its body. |
334 | /// |
335 | /// \ingroup torch-dispatch-overloads |
336 | template <typename Func> |
337 | inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) { |
338 | CppFunction f(std::forward<Func>(raw_f)); |
339 | if (k == c10::DispatchKey::CatchAll) { |
340 | f.dispatch_key_ = c10::nullopt; |
341 | } else { |
342 | f.dispatch_key_ = k; |
343 | } |
344 | return f; |
345 | } |
346 | |
347 | /// Convenience overload of dispatch() which accepts c10::DeviceType |
348 | /// |
349 | /// \ingroup torch-dispatch-overloads |
350 | template <typename Func> |
351 | inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { |
352 | auto deviceTypeToDispatchKey = [](c10::DeviceType t) { |
353 | switch (t) { |
354 | // This list is synchronized with the k-constants in c10/core/DeviceType.h |
355 | case c10::DeviceType::CPU: |
356 | return c10::DispatchKey::CPU; |
357 | case c10::DeviceType::CUDA: |
358 | return c10::DispatchKey::CUDA; |
359 | case c10::DeviceType::IPU: |
360 | return c10::DispatchKey::IPU; |
361 | case c10::DeviceType::XLA: |
362 | return c10::DispatchKey::XLA; |
363 | case c10::DeviceType::Lazy: |
364 | return c10::DispatchKey::Lazy; |
365 | case c10::DeviceType::MPS: |
366 | return c10::DispatchKey::MPS; |
367 | case c10::DeviceType::Meta: |
368 | return c10::DispatchKey::Meta; |
369 | case c10::DeviceType::HIP: |
370 | return c10::DispatchKey::HIP; |
371 | case c10::DeviceType::ORT: |
372 | return c10::DispatchKey::ORT; |
373 | case c10::DeviceType::HPU: |
374 | return c10::DispatchKey::HPU; |
375 | case c10::DeviceType::PrivateUse1: |
376 | return c10::DispatchKey::PrivateUse1; |
377 | default: |
378 | TORCH_CHECK( |
379 | false, |
380 | "Device type " , |
381 | t, |
382 | " cannot be overloaded at dispatch time, " |
383 | "please file a bug report explaining what you were trying to do." ); |
384 | } |
385 | }; |
386 | return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f)); |
387 | } |
388 | |
389 | /// \defgroup torch-schema-overloads torch::schema overloads |
390 | |
391 | /// Construct a c10::FunctionSchema from a string, with an explicitly |
392 | /// specified c10::AliasAnalysisKind. Ordinarily, schemas are simply |
393 | /// passed in as strings, but if you need to specify a custom alias |
394 | /// analysis, you can replace the string with a call to this function. |
395 | /// |
396 | /// ``` |
397 | /// // Default alias analysis (FROM_SCHEMA) |
398 | /// m.def("def3(Tensor self) -> Tensor"); |
399 | /// // Pure function alias analysis |
400 | /// m.def(torch::schema("def3(Tensor self) -> Tensor", |
401 | /// c10::AliasAnalysisKind::PURE_FUNCTION)); |
402 | /// ``` |
403 | /// |
404 | /// \ingroup torch-schema-overloads |
405 | inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) { |
406 | c10::FunctionSchema s = torch::jit::parseSchema(str); |
407 | s.setAliasAnalysis(k); |
408 | return s; |
409 | } |
410 | |
411 | /// Function schemas can be directly constructed from string literals. |
412 | /// |
413 | /// \ingroup torch-schema-overloads |
414 | inline c10::FunctionSchema schema(const char* s) { |
415 | return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA); |
416 | } |
417 | |
418 | /// \private |
419 | /// |
420 | /// Already constructed function schemas are accepted if they are |
421 | /// rvalues. |
422 | /// |
423 | /// \ingroup torch-schema-overloads |
424 | inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) { |
425 | return std::move(s); |
426 | } |
427 | |
428 | namespace detail { |
429 | |
430 | inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName( |
431 | c10::FunctionSchema&& s) { |
432 | return c10::make_right<c10::OperatorName, c10::FunctionSchema>(std::move(s)); |
433 | } |
434 | inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName( |
435 | c10::OperatorName&& n) { |
436 | return c10::make_left<c10::OperatorName, c10::FunctionSchema>(std::move(n)); |
437 | } |
438 | inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName( |
439 | const char* str) { |
440 | auto s = torch::jit::parseSchemaOrName(str); |
441 | if (s.is_right()) { |
442 | s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); |
443 | } |
444 | return s; |
445 | } |
446 | |
447 | class TorchLibraryInit; |
448 | |
449 | } // namespace detail |
450 | |
451 | // Note [Selective build] |
452 | // ~~~~~~~~~~~~~~~~~~~~~~ |
453 | // In some settings, especially mobile, it is important to avoid compiling any |
454 | // references to functions that you aren't actually going to use, so that they |
455 | // can be eliminated by the linker. We call this capability "selective build". |
456 | // |
457 | // A very easy way to implement selective build which results in a lot of |
458 | // boilerplate is to just add ifdef's around every registration call, but this |
459 | // means you have to write a lot of extra lines of code at every registration |
460 | // site, and it also means you have to define some munging scheme to map |
461 | // operators to macros. |
462 | // |
463 | // Instead of doing this, we have a different mechanism centered around the |
464 | // concept of a SelectiveStr. A selective name is like a const char* string, |
465 | // except it also carries at compile time a boolean saying whether or not a |
466 | // registration should actually happen or not. We then have extra overloads |
467 | // which bypass registration entirely if a selective name is disabled. We do a |
468 | // constexpr test to see if a operator should be enabled or not; this is |
469 | // currently implemented in ATen/core/op_registration/op_allowlist.h |
470 | |
471 | namespace detail { |
472 | |
473 | // dummy class for non selected custom torchbind classes |
474 | class ClassNotSelected { |
475 | public: |
476 | ClassNotSelected& def_pickle(...) { |
477 | return *this; |
478 | } |
479 | ClassNotSelected& def(...) { |
480 | return *this; |
481 | } |
482 | }; |
483 | |
484 | // A SelectiveStr is like a const char*, except that it also comes |
485 | // with a type brand that says whether or not the name is enabled or |
486 | // not. If the string is disabled, then (at compile time) we DON'T generate |
487 | // a registration call for it. This class is not intended to be called |
488 | // directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below |
489 | // to create it. |
490 | template <bool enabled> |
491 | class SelectiveStr { |
492 | public: |
493 | constexpr explicit SelectiveStr(const char* name) : name_(name) {} |
494 | constexpr operator const char*() { |
495 | return name_; |
496 | } |
497 | |
498 | private: |
499 | const char* name_; |
500 | }; |
501 | |
502 | #define TORCH_SELECTIVE_CLASS(n) \ |
503 | torch::detail::SelectiveStr<c10::impl::custom_class_allowlist_check(n)>(n) |
504 | #define TORCH_SELECTIVE_NAME(n) \ |
505 | torch::detail::SelectiveStr<c10::impl::op_allowlist_check(n)>(n) |
506 | #define TORCH_SELECTIVE_SCHEMA(n) \ |
507 | torch::detail::SelectiveStr<c10::impl::schema_allowlist_check(n)>(n) |
508 | |
509 | } // namespace detail |
510 | |
511 | /// This object provides the API for defining operators and providing |
512 | /// implementations at dispatch keys. Typically, a torch::Library |
513 | /// is not allocated directly; instead it is created by the |
514 | /// TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() macros. |
515 | /// |
516 | /// Most methods on torch::Library return a reference to itself, |
517 | /// supporting method chaining. |
518 | /// |
519 | /// ``` |
520 | /// // Examples: |
521 | /// |
522 | /// TORCH_LIBRARY(torchvision, m) { |
523 | /// // m is a torch::Library |
524 | /// m.def("roi_align", ...); |
525 | /// ... |
526 | /// } |
527 | /// |
528 | /// TORCH_LIBRARY_IMPL(aten, XLA, m) { |
529 | /// // m is a torch::Library |
530 | /// m.impl("add", ...); |
531 | /// ... |
532 | /// } |
533 | /// ``` |
534 | /// |
535 | class TORCH_API Library final { |
536 | public: |
537 | /// \private |
538 | /// |
539 | /// Which type of macro produced this Library |
540 | enum Kind { |
541 | DEF, // from TORCH_LIBRARY (no qualifier) |
542 | IMPL, |
543 | FRAGMENT, |
544 | }; |
545 | |
546 | /// \private |
547 | /// |
548 | /// Use TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() instead of using these |
549 | /// constructors directly |
550 | Library( |
551 | Kind kind, |
552 | std::string ns, |
553 | c10::optional<c10::DispatchKey> k, |
554 | const char* file, |
555 | uint32_t line); |
556 | |
557 | Library(const Library&) = delete; |
558 | Library& operator=(const Library&) = delete; |
559 | Library(Library&&) = default; |
560 | Library& operator=(Library&&) = default; |
561 | |
562 | // Some notes about the API design here. We had the following constraints: |
563 | // |
564 | // - We need to support multiple "types" of arguments for schema and |
565 | // functions (e.g., unnamed lambda types, regular functions, const char*, |
566 | // fully instantiated schemas) |
567 | // - We don't want to write exponentially many overloads |
568 | // - We don't want to rely on implicit conversion to a common type, |
569 | // because the C++ compiler will only be willing to do a single |
570 | // implicit conversion (reducing the set of valid types which you |
571 | // can invoke with); also error messages are worse when an implicit |
572 | // conversion is not selected (as the compiler will not explain |
573 | // why it didn't select an implicit conversion; this is different |
574 | // from overloads where it will explain each candidate overload and |
575 | // why it didn't apply) |
576 | // |
577 | // To solve all of these constraints at the same time, we use a trick taken |
578 | // from the pybind11 library: template over the argument in the user visible |
579 | // API, and inside of the templated function explicitly call an overloaded |
580 | // function to resolve the argument to a real type. You get the good error |
581 | // messages from overloads, but at the same time you only need to write the |
582 | // overload for any given argument type once. |
583 | |
584 | /// Declare an operator with a schema, but don't provide any implementations |
585 | /// for it. You're expected to then provide implementations using the |
586 | /// impl() method. All template arguments are inferred. |
587 | /// |
588 | /// \param raw_schema The schema of the operator to be defined. |
589 | /// Typically, this is a `const char*` string literal, but any type |
590 | /// accepted by torch::schema() is accepted here. |
591 | /// |
592 | /// ``` |
593 | /// // Example: |
594 | /// TORCH_LIBRARY(myops, m) { |
595 | /// m.def("add(Tensor self, Tensor other) -> Tensor"); |
596 | /// } |
597 | /// ``` |
598 | |
599 | template <typename Schema> |
600 | Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & { |
601 | c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema)); |
602 | return _def(std::move(s), nullptr, tags, rv); |
603 | } |
604 | /// Define an operator for a schema and then register an implementation for |
605 | /// it. This is typically what you would use if you aren't planning |
606 | /// on making use of the dispatcher to structure your operator |
607 | /// implementation. It's roughly equivalent to calling def() and |
608 | /// then impl(), but if you omit the schema of the operator, we will |
609 | /// infer it from the type of your C++ function. All template |
610 | /// arguments are inferred. |
611 | /// |
612 | /// \param raw_name_or_schema The schema of the operator to be |
613 | /// defined, or just the name of the operator if the schema is to be |
614 | /// inferred from `raw_f`. Typically a `const char*` literal. |
615 | /// \param raw_f The C++ function that implements this operator. |
616 | /// Any valid constructor of torch::CppFunction is accepted here; |
617 | /// typically you provide a function pointer or lambda. |
618 | /// |
619 | /// ``` |
620 | /// // Example: |
621 | /// TORCH_LIBRARY(myops, m) { |
622 | /// m.def("add", add_fn); |
623 | /// } |
624 | /// ``` |
625 | template <typename NameOrSchema, typename Func> |
626 | Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & { |
627 | CppFunction f(std::forward<Func>(raw_f)); |
628 | auto name_or_schema = detail::constructSchemaOrName( |
629 | std::forward<NameOrSchema>(raw_name_or_schema)); |
630 | return _def(std::move(name_or_schema), std::move(f)); |
631 | } |
632 | |
633 | /// Register an implementation for an operator. You may register multiple |
634 | /// implementations for a single operator at different dispatch keys |
635 | /// (see torch::dispatch()). Implementations must have a corresponding |
636 | /// declaration (from def()), otherwise they are invalid. If you plan |
637 | /// to register multiple implementations, DO NOT provide a function |
638 | /// implementation when you def() the operator. |
639 | /// |
640 | /// \param name The name of the operator to implement. Do NOT provide |
641 | /// schema here. |
642 | /// \param raw_f The C++ function that implements this operator. Any |
643 | /// valid constructor of torch::CppFunction is accepted here; |
644 | /// typically you provide a function pointer or lambda. |
645 | /// |
646 | /// ``` |
647 | /// // Example: |
648 | /// TORCH_LIBRARY_IMPL(myops, CUDA, m) { |
649 | /// m.impl("add", add_cuda); |
650 | /// } |
651 | /// ``` |
652 | template <typename Name, typename Func> |
653 | Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & { |
654 | // TODO: need to raise an error when you impl a function that has a |
655 | // catch all def |
656 | #if defined C10_MOBILE |
657 | CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag()); |
658 | #else |
659 | CppFunction f(std::forward<Func>(raw_f)); |
660 | #endif |
661 | return _impl(name, std::move(f), rv); |
662 | } |
663 | |
664 | #if defined C10_MOBILE |
665 | // Note: This overload is needed only for C10_MOBILE, since the automatically |
666 | // defined copy constructor for the CppFunction doesn't have the additional |
667 | // NoInferSchemaTag argument. We define the overload for the impl() function |
668 | // to accept a CppFunction&& argument. The already constructed CppFunction |
669 | // object may or may not have the inferred schema, but it doesn't matter |
670 | // for our purposes since if it already has the inferred schema, then we |
671 | // might as well just pass it through directly. |
672 | // |
673 | template <typename Name> |
674 | Library& impl(Name name, CppFunction&& raw_f) & { |
675 | // TODO: need to raise an error when you impl a function that has a |
676 | // catch all def |
677 | CppFunction f(std::forward<CppFunction>(raw_f)); |
678 | return _impl(name, std::move(f)); |
679 | } |
680 | #endif |
681 | |
682 | // Helper for getting an OperatorName for a const char*. You probably |
683 | // don't need this. |
684 | c10::OperatorName _resolve(const char* name) const; |
685 | |
686 | /// \private |
687 | /// |
688 | /// Convenience overload for directly specifying the dispatch key when |
689 | /// impl(). You probably don't need this; instead, prefer specifying |
690 | /// the dispatch key for the entire block in TORCH_LIBRARY_IMPL() |
691 | template <typename Name, typename Dispatch, typename Func> |
692 | Library& impl(Name name, Dispatch&& key, Func&& raw_f) & { |
693 | return impl( |
694 | name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f))); |
695 | } |
696 | |
697 | template <typename Name, typename Func> |
698 | Library& impl_UNBOXED(Name /*name*/, Func* /*raw_f*/) & { |
699 | static_assert( |
700 | c10::guts::false_t<Func>(), |
701 | ".impl_UNBOXED(...) was removed. Please use .impl(...) instead." ); |
702 | return *this; |
703 | } |
704 | |
705 | // These overloads cover cases when a SelectiveStr (see Note [Selective |
706 | // build]) has been disabled at compile time. In that case, don't generate |
707 | // any code referencing the passed in functions at all. |
708 | Library& def(detail::SelectiveStr<false>) & { |
709 | return *this; |
710 | } |
711 | Library& def(detail::SelectiveStr<true> raw_schema) & { |
712 | return def(raw_schema.operator const char*()); |
713 | } |
714 | template <typename Func> |
715 | Library& def(detail::SelectiveStr<false>, Func&& /*raw_f*/) & { |
716 | return *this; |
717 | } |
718 | template <typename Func> |
719 | Library& def(detail::SelectiveStr<true> raw_name_or_schema, Func&& raw_f) & { |
720 | return def( |
721 | raw_name_or_schema.operator const char*(), std::forward<Func>(raw_f)); |
722 | } |
723 | |
724 | template <typename Func> |
725 | Library& impl(detail::SelectiveStr<false>, Func&& /*raw_f*/) & { |
726 | return *this; |
727 | } |
728 | template <typename Dispatch, typename Func> |
729 | Library& impl(detail::SelectiveStr<false>, Dispatch&& /*key*/, Func&& /*raw_f*/) & { |
730 | return *this; |
731 | } |
732 | template <typename Func> |
733 | Library& impl_UNBOXED(detail::SelectiveStr<false> /*name*/, Func* /*raw_f*/) & { |
734 | static_assert( |
735 | c10::guts::false_t<Func>(), |
736 | ".impl_UNBOXED(...) was removed. Please use .impl(...) instead." ); |
737 | return *this; |
738 | } |
739 | |
740 | template <typename Func> |
741 | Library& impl(detail::SelectiveStr<true> name, Func&& raw_f) & { |
742 | return impl(name.operator const char*(), std::forward<Func>(raw_f)); |
743 | } |
744 | template <typename Dispatch, typename Func> |
745 | Library& impl( |
746 | detail::SelectiveStr<true> name, |
747 | Dispatch&& key, |
748 | Func&& raw_f) & { |
749 | return impl( |
750 | name.operator const char*(), |
751 | std::forward<Dispatch>(key), |
752 | std::forward<Func>(raw_f)); |
753 | } |
754 | template <typename Func> |
755 | Library& impl_UNBOXED(detail::SelectiveStr<true> /*name*/, Func* /*raw_f*/) & { |
756 | static_assert( |
757 | c10::guts::false_t<Func>(), |
758 | ".impl_UNBOXED(...) was removed. Please use .impl(...) instead." ); |
759 | return *this; |
760 | } |
761 | |
762 | /// Register a fallback implementation for all operators which will be used |
763 | /// if there is not a specific implementation for an operator available. |
764 | /// There MUST be a DispatchKey associated with a fallback; e.g., |
765 | /// only call this from TORCH_LIBRARY_IMPL() with namespace `_`. |
766 | /// |
767 | /// \param raw_f The function that implements the fallback. Unboxed |
768 | /// functions typically do not work as fallback functions, as |
769 | /// fallback functions must work for every operator (even though |
770 | /// they have varying type signatures). Typical arguments are |
771 | /// CppFunction::makeFallthrough() or |
772 | /// CppFunction::makeFromBoxedFunction() |
773 | /// |
774 | /// ``` |
775 | /// // Example: |
776 | /// |
777 | /// TORCH_LIBRARY_IMPL(_, AutogradXLA, m) { |
778 | /// // If there is not a kernel explicitly registered |
779 | /// // for AutogradXLA, fallthrough to the next |
780 | /// // available kernel |
781 | /// m.fallback(torch::CppFunction::makeFallthrough()); |
782 | /// } |
783 | /// |
784 | /// // See aten/src/ATen/core/dispatch/backend_fallback_test.cpp |
785 | /// // for a full example of boxed fallback |
786 | /// ``` |
787 | template <typename Func> |
788 | Library& fallback(Func&& raw_f) & { |
789 | CppFunction f((std::forward<Func>(raw_f))); |
790 | return _fallback(std::move(f)); |
791 | } |
792 | |
793 | template <class CurClass> |
794 | inline torch::class_<CurClass> class_(const std::string& className); |
795 | |
796 | // These overloads enable the use of selective build on classes registered |
797 | // within a library. The API is the same as before with 1 minor change. |
798 | // Instead of m.class_<foo>("foo") you instead do |
799 | // m.class_<foo>(TORCH_SELECTIVE_CLASS("foo")) |
800 | template <class CurClass> |
801 | inline torch::class_<CurClass> class_(detail::SelectiveStr<true> className); |
802 | |
803 | template <class CurClass> |
804 | inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className); |
805 | |
806 | private: |
807 | Kind kind_; |
808 | c10::optional<std::string> ns_; |
809 | c10::optional<c10::DispatchKey> dispatch_key_; |
810 | const char* file_; |
811 | uint32_t line_; |
812 | |
813 | std::vector<c10::RegistrationHandleRAII> registrars_; |
814 | |
815 | friend class detail::TorchLibraryInit; |
816 | |
817 | // Non-user visible actual implementations of functions. These aren't |
818 | // public because we only implement & qualifier and not && qualifier |
819 | Library& _def( |
820 | c10::FunctionSchema&& schema, |
821 | c10::OperatorName* out_name = nullptr, |
822 | const std::vector<at::Tag>& tags = {}, |
823 | _RegisterOrVerify rv = _RegisterOrVerify::REGISTER |
824 | ) &; |
825 | Library& _def( |
826 | c10::either<c10::OperatorName, c10::FunctionSchema>&&, |
827 | CppFunction&& f) &; |
828 | Library& _impl(const char* name, CppFunction&& f, |
829 | _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &; |
830 | Library& _fallback(CppFunction&& f) &; |
831 | |
832 | at::OperatorName _parseNameForLib(const char* name_str) const; |
833 | }; |
834 | |
835 | namespace detail { |
836 | |
837 | class TorchLibraryInit final { |
838 | private: |
839 | using InitFn = void(Library&); |
840 | Library lib_; |
841 | |
842 | public: |
843 | TorchLibraryInit( |
844 | Library::Kind kind, |
845 | InitFn* fn, |
846 | const char* ns, |
847 | c10::optional<c10::DispatchKey> k, |
848 | const char* file, |
849 | uint32_t line) |
850 | : lib_(kind, ns, k, file, line) { |
851 | fn(lib_); |
852 | } |
853 | }; |
854 | |
855 | } // namespace detail |
856 | |
857 | } // namespace torch |
858 | |
859 | // NB: The EXACT NAMING of the initializer functions (e.g., |
860 | // TORCH_LIBRARY_init_aten) matters for the code analyzer; |
861 | // see the regexes at tools/code_analyzer/run_analyzer.sh |
862 | |
863 | /// Macro for defining a function that will be run at static |
864 | /// initialization time to define a library of operators in the |
865 | /// namespace `ns` (must be a valid C++ identifier, no quotes). |
866 | /// Use this macro when you want to define a new set of custom operators |
867 | /// that do not already exist in PyTorch. |
868 | /// |
869 | /// Example usage: |
870 | /// |
871 | /// ``` |
872 | /// TORCH_LIBRARY(myops, m) { |
873 | /// // m is a torch::Library; methods on it will define |
874 | /// // operators in the myops namespace |
875 | /// m.def("add", add_impl); |
876 | /// } |
877 | /// ``` |
878 | /// |
879 | /// The `m` argument is bound to a torch::Library that is used to |
880 | /// register operators. There may only be one TORCH_LIBRARY() |
881 | /// for any given namespace. |
882 | #define TORCH_LIBRARY(ns, m) \ |
883 | static void TORCH_LIBRARY_init_##ns(torch::Library&); \ |
884 | static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \ |
885 | torch::Library::DEF, \ |
886 | &TORCH_LIBRARY_init_##ns, \ |
887 | #ns, \ |
888 | c10::nullopt, \ |
889 | __FILE__, \ |
890 | __LINE__); \ |
891 | void TORCH_LIBRARY_init_##ns(torch::Library& m) |
892 | |
893 | /// \private |
894 | /// |
895 | /// This macro is a version of TORCH_LIBRARY() that doesn't enforce that there |
896 | /// is only one library (it is a "fragment"). This is used inside the |
897 | /// PerOpRegistration.cpp file, as well as in places where all op registrations |
898 | /// within the same namespace cannot be easily put into one macro block |
899 | /// (this is mostly the case for custom ops in fbcode that were ported from |
900 | /// the old API) |
901 | #define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID) |
902 | |
903 | /// \private |
904 | /// |
905 | /// The above macro requires an extra unique identifier (uid) to prevent |
906 | /// variable name collisions This can happen if TORCH_LIBRARY_FRAGMENT is called |
907 | /// multiple times with the same namespace in the same translation unit. Note |
908 | /// that the TORCH_LIBRARY variant doesn't run into this problem, because it |
909 | /// enforces that it can only be called once for a given namespace. |
910 | #define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ |
911 | static void C10_CONCATENATE( \ |
912 | TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library&); \ |
913 | static const torch::detail::TorchLibraryInit C10_CONCATENATE( \ |
914 | TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \ |
915 | torch::Library::FRAGMENT, \ |
916 | &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ |
917 | #ns, \ |
918 | c10::nullopt, \ |
919 | __FILE__, \ |
920 | __LINE__); \ |
921 | void C10_CONCATENATE( \ |
922 | TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library & m) |
923 | |
924 | /// Macro for defining a function that will be run at static |
925 | /// initialization time to define operator overrides for dispatch key |
926 | /// `k` (must be an unqualified enum member of c10::DispatchKey) in |
927 | /// namespace `ns` (must be a valid C++ identifer, no quotes). Use this |
928 | /// macro when you want to implement a preexisting set of custom |
929 | /// operators on a new dispatch key (e.g., you want to provide CUDA |
930 | /// implementations of already existing operators). One common usage |
931 | /// pattern is to use TORCH_LIBRARY() to define schema for all new |
932 | /// operators you want to define, and then use several |
933 | /// TORCH_LIBRARY_IMPL() blocks to provide implementations of the |
934 | /// operator for CPU, CUDA and Autograd. |
935 | /// |
936 | /// In some cases, you need to define something that applies to all namespaces, |
937 | /// not just one namespace (usually a fallback). In that case, use the reserved |
938 | /// namespace _, e.g., |
939 | /// |
940 | /// ``` |
941 | /// TORCH_LIBRARY_IMPL(_, XLA, m) { |
942 | /// m.fallback(xla_fallback); |
943 | /// } |
944 | /// ``` |
945 | /// |
946 | /// Example usage: |
947 | /// |
948 | /// ``` |
949 | /// TORCH_LIBRARY_IMPL(myops, CPU, m) { |
950 | /// // m is a torch::Library; methods on it will define |
951 | /// // CPU implementations of operators in the myops namespace. |
952 | /// // It is NOT valid to call torch::Library::def() |
953 | /// // in this context. |
954 | /// m.impl("add", add_cpu_impl); |
955 | /// } |
956 | /// ``` |
957 | /// |
958 | /// If ``add_cpu_impl`` is an overloaded function, use a |
959 | /// ``static_cast`` to specify which overload you want |
960 | /// (by providing the full type). |
961 | /// |
962 | // NB: if the dispatch key is not whitelisted, we simply omit the Library |
963 | // call entirely |
964 | #define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID) |
965 | |
966 | /// \private |
967 | /// |
968 | /// The above macro requires an extra unique identifier (uid) to prevent |
969 | /// variable name collisions. This can happen if TORCH_LIBRARY_IMPL is called |
970 | /// multiple times with the same namespace and dispatch key in the same |
971 | /// translation unit. |
972 | #define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \ |
973 | static void C10_CONCATENATE( \ |
974 | TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&); \ |
975 | static const torch::detail::TorchLibraryInit C10_CONCATENATE( \ |
976 | TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ |
977 | torch::Library::IMPL, \ |
978 | c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( \ |
979 | c10::DispatchKey::k)>( \ |
980 | []() { \ |
981 | return &C10_CONCATENATE( \ |
982 | TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid); \ |
983 | }, \ |
984 | []() { return [](torch::Library&) -> void {}; }), \ |
985 | #ns, \ |
986 | c10::make_optional(c10::DispatchKey::k), \ |
987 | __FILE__, \ |
988 | __LINE__); \ |
989 | void C10_CONCATENATE( \ |
990 | TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m) |
991 | |
992 | // These are variants of the macros above which are to be used for testing (they |
993 | // don't setup the static initializer, so you can control the visibility of |
994 | // the allocated library yourself). |
995 | // |
996 | // DO NOT use these in production code, they are NOT understood by the |
997 | // code analyzer and will be incorrectly analyzed in those situations. |
998 | |
999 | /// \private |
1000 | #define MAKE_TORCH_LIBRARY(ns) \ |
1001 | torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__) |
1002 | /// \private |
1003 | #define MAKE_TORCH_LIBRARY_IMPL(ns, k) \ |
1004 | torch::Library( \ |
1005 | torch::Library::IMPL, \ |
1006 | #ns, \ |
1007 | c10::make_optional(c10::DispatchKey::k), \ |
1008 | __FILE__, \ |
1009 | __LINE__) |
1010 | |
1011 | // Make the custom class API visible, so it is available from |
1012 | // torch::Library. |
1013 | |
1014 | #include <torch/custom_class.h> |
1015 | |