1 | #pragma once |
2 | |
3 | #include <ATen/core/builtin_function.h> |
4 | #include <ATen/core/function_schema.h> |
5 | #include <ATen/core/ivalue.h> |
6 | #include <ATen/core/class_type.h> |
7 | #include <ATen/core/op_registration/infer_schema.h> |
8 | #include <ATen/core/stack.h> |
9 | #include <c10/util/C++17.h> |
10 | #include <c10/util/Metaprogramming.h> |
11 | #include <c10/util/TypeList.h> |
12 | #include <c10/util/TypeTraits.h> |
13 | #include <torch/custom_class_detail.h> |
14 | #include <torch/library.h> |
15 | #include <iostream> |
16 | #include <sstream> |
17 | |
18 | namespace torch { |
19 | |
20 | /// This function is used in conjunction with `class_::def()` to register |
21 | /// a constructor for a given C++ class type. For example, |
22 | /// `torch::init<int, std::string>()` would register a two-argument constructor |
23 | /// taking an `int` and a `std::string` as argument. |
24 | template <class... Types> |
25 | detail::types<void, Types...> init() { |
26 | return detail::types<void, Types...>{}; |
27 | } |
28 | |
29 | template <typename Func, typename... ParameterTypeList> |
30 | struct InitLambda { |
31 | Func f; |
32 | }; |
33 | |
34 | template <typename Func> |
35 | decltype(auto) init(Func&& f) { |
36 | using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>; |
37 | using ParameterTypeList = typename InitTraits::parameter_types; |
38 | |
39 | InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)}; |
40 | return init; |
41 | } |
42 | |
43 | /// Entry point for custom C++ class registration. To register a C++ class |
44 | /// in PyTorch, instantiate `torch::class_` with the desired class as the |
45 | /// template parameter. Typically, this instantiation should be done in |
46 | /// the initialization of a global variable, so that the class will be |
47 | /// made available on dynamic library loading without any additional API |
48 | /// calls needed. For example, to register a class named Foo, you might |
49 | /// create a global variable like so: |
50 | /// |
51 | /// static auto register_foo = torch::class_<Foo>("myclasses", "Foo") |
52 | /// .def("myMethod", &Foo::myMethod) |
53 | /// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) { |
54 | /// // Do something with `self` |
55 | /// }); |
56 | /// |
57 | /// In addition to registering the class, this registration also chains |
58 | /// `def()` calls to register methods. `myMethod()` is registered with |
59 | /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()` |
60 | /// is registered with a C++ lambda expression. |
61 | template <class CurClass> |
62 | class class_ : public ::torch::detail::class_base { |
63 | static_assert( |
64 | std::is_base_of<CustomClassHolder, CurClass>::value, |
65 | "torch::class_<T> requires T to inherit from CustomClassHolder" ); |
66 | |
67 | public: |
68 | /// This constructor actually registers the class type. |
69 | /// String argument `namespaceName` is an identifier for the |
70 | /// namespace you would like this class to appear in. |
71 | /// String argument `className` is the name you would like to |
72 | /// see this class exposed as in Python and TorchScript. For example, if |
73 | /// you pass `foo` as the namespace name and `Bar` as the className, the |
74 | /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript |
75 | explicit class_( |
76 | const std::string& namespaceName, |
77 | const std::string& className, |
78 | std::string doc_string = "" ) |
79 | : class_base( |
80 | namespaceName, |
81 | className, |
82 | std::move(doc_string), |
83 | typeid(c10::intrusive_ptr<CurClass>), |
84 | typeid(c10::tagged_capsule<CurClass>)) {} |
85 | |
86 | /// def() can be used in conjunction with `torch::init()` to register |
87 | /// a constructor for a given C++ class type. For example, passing |
88 | /// `torch::init<int, std::string>()` would register a two-argument |
89 | /// constructor taking an `int` and a `std::string` as argument. |
90 | template <typename... Types> |
91 | class_& def( |
92 | torch::detail::types<void, Types...>, |
93 | std::string doc_string = "" , |
94 | std::initializer_list<arg> default_args = |
95 | {}) { // Used in combination with |
96 | // torch::init<...>() |
97 | auto func = [](c10::tagged_capsule<CurClass> self, Types... args) { |
98 | auto classObj = c10::make_intrusive<CurClass>(args...); |
99 | auto object = self.ivalue.toObject(); |
100 | object->setSlot(0, c10::IValue::make_capsule(std::move(classObj))); |
101 | }; |
102 | |
103 | defineMethod( |
104 | "__init__" , |
105 | std::move(func), |
106 | std::move(doc_string), |
107 | std::move(default_args)); |
108 | return *this; |
109 | } |
110 | |
111 | // Used in combination with torch::init([]lambda(){......}) |
112 | template <typename Func, typename... ParameterTypes> |
113 | class_& def( |
114 | InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init, |
115 | std::string doc_string = "" , |
116 | std::initializer_list<arg> default_args = {}) { |
117 | auto init_lambda_wrapper = [func = std::move(init.f)]( |
118 | c10::tagged_capsule<CurClass> self, |
119 | ParameterTypes... arg) { |
120 | c10::intrusive_ptr<CurClass> classObj = |
121 | at::guts::invoke(func, std::forward<ParameterTypes>(arg)...); |
122 | auto object = self.ivalue.toObject(); |
123 | object->setSlot(0, c10::IValue::make_capsule(classObj)); |
124 | }; |
125 | |
126 | defineMethod( |
127 | "__init__" , |
128 | std::move(init_lambda_wrapper), |
129 | std::move(doc_string), |
130 | std::move(default_args)); |
131 | |
132 | return *this; |
133 | } |
134 | |
135 | /// This is the normal method registration API. `name` is the name that |
136 | /// the method will be made accessible by in Python and TorchScript. |
137 | /// `f` is a callable object that defines the method. Typically `f` |
138 | /// will either be a pointer to a method on `CurClass`, or a lambda |
139 | /// expression that takes a `c10::intrusive_ptr<CurClass>` as the first |
140 | /// argument (emulating a `this` argument in a C++ method.) |
141 | /// |
142 | /// Examples: |
143 | /// |
144 | /// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in |
145 | /// // Python and TorchScript |
146 | /// .def("call_foo", &Foo::foo) |
147 | /// |
148 | /// // Exposes the given lambda expression as method `call_lambda()` |
149 | /// // in Python and TorchScript. |
150 | /// .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) { |
151 | /// // do something |
152 | /// }) |
153 | template <typename Func> |
154 | class_& def( |
155 | std::string name, |
156 | Func f, |
157 | std::string doc_string = "" , |
158 | std::initializer_list<arg> default_args = {}) { |
159 | auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f)); |
160 | defineMethod( |
161 | std::move(name), |
162 | std::move(wrapped_f), |
163 | std::move(doc_string), |
164 | std::move(default_args)); |
165 | return *this; |
166 | } |
167 | |
168 | /// Method registration API for static methods. |
169 | template <typename Func> |
170 | class_& def_static(std::string name, Func func, std::string doc_string = "" ) { |
171 | auto qualMethodName = qualClassName + "." + name; |
172 | auto schema = |
173 | c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "" ); |
174 | |
175 | auto wrapped_func = |
176 | [func = std::move(func)](jit::Stack& stack) mutable -> void { |
177 | using RetType = |
178 | typename c10::guts::infer_function_traits_t<Func>::return_type; |
179 | detail::BoxedProxy<RetType, Func>()(stack, func); |
180 | }; |
181 | auto method = std::make_unique<jit::BuiltinOpFunction>( |
182 | std::move(qualMethodName), |
183 | std::move(schema), |
184 | std::move(wrapped_func), |
185 | std::move(doc_string)); |
186 | |
187 | classTypePtr->addStaticMethod(method.get()); |
188 | registerCustomClassMethod(std::move(method)); |
189 | return *this; |
190 | } |
191 | |
192 | /// Property registration API for properties with both getter and setter |
193 | /// functions. |
194 | template <typename GetterFunc, typename SetterFunc> |
195 | class_& def_property( |
196 | const std::string& name, |
197 | GetterFunc getter_func, |
198 | SetterFunc setter_func, |
199 | std::string doc_string = "" ) { |
200 | torch::jit::Function* getter{}; |
201 | torch::jit::Function* setter{}; |
202 | |
203 | auto wrapped_getter = |
204 | detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func)); |
205 | getter = defineMethod(name + "_getter" , wrapped_getter, doc_string); |
206 | |
207 | auto wrapped_setter = |
208 | detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func)); |
209 | setter = defineMethod(name + "_setter" , wrapped_setter, doc_string); |
210 | |
211 | classTypePtr->addProperty(name, getter, setter); |
212 | return *this; |
213 | } |
214 | |
215 | /// Property registration API for properties with only getter function. |
216 | template <typename GetterFunc> |
217 | class_& def_property( |
218 | const std::string& name, |
219 | GetterFunc getter_func, |
220 | std::string doc_string = "" ) { |
221 | torch::jit::Function* getter{}; |
222 | |
223 | auto wrapped_getter = |
224 | detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func)); |
225 | getter = defineMethod(name + "_getter" , wrapped_getter, doc_string); |
226 | |
227 | classTypePtr->addProperty(name, getter, nullptr); |
228 | return *this; |
229 | } |
230 | |
231 | /// Property registration API for properties with read-write access. |
232 | template <typename T> |
233 | class_& def_readwrite(const std::string& name, T CurClass::*field) { |
234 | auto getter_func = [field = |
235 | field](const c10::intrusive_ptr<CurClass>& self) { |
236 | return self.get()->*field; |
237 | }; |
238 | |
239 | auto setter_func = [field = field]( |
240 | const c10::intrusive_ptr<CurClass>& self, T value) { |
241 | self.get()->*field = value; |
242 | }; |
243 | |
244 | return def_property(name, getter_func, setter_func); |
245 | } |
246 | |
247 | /// Property registration API for properties with read-only access. |
248 | template <typename T> |
249 | class_& def_readonly(const std::string& name, T CurClass::*field) { |
250 | auto getter_func = |
251 | [field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) { |
252 | return self.get()->*field; |
253 | }; |
254 | |
255 | return def_property(name, getter_func); |
256 | } |
257 | |
258 | /// This is an unsafe method registration API added for adding custom JIT |
259 | /// backend support via custom C++ classes. It is not for general purpose use. |
260 | class_& _def_unboxed( |
261 | std::string name, |
262 | std::function<void(jit::Stack&)> func, |
263 | c10::FunctionSchema schema, |
264 | std::string doc_string = "" ) { |
265 | auto method = std::make_unique<jit::BuiltinOpFunction>( |
266 | qualClassName + "." + name, |
267 | std::move(schema), |
268 | std::move(func), |
269 | std::move(doc_string)); |
270 | classTypePtr->addMethod(method.get()); |
271 | registerCustomClassMethod(std::move(method)); |
272 | return *this; |
273 | } |
274 | |
275 | /// def_pickle() is used to define exactly what state gets serialized |
276 | /// or deserialized for a given instance of a custom C++ class in |
277 | /// Python or TorchScript. This protocol is equivalent to the Pickle |
278 | /// concept of `__getstate__` and `__setstate__` from Python |
279 | /// (https://docs.python.org/2/library/pickle.html#object.__getstate__) |
280 | /// |
281 | /// Currently, both the `get_state` and `set_state` callables must be |
282 | /// C++ lambda expressions. They should have the following signatures, |
283 | /// where `CurClass` is the class you're registering and `T1` is some object |
284 | /// that encapsulates the state of the object. |
285 | /// |
286 | /// __getstate__(intrusive_ptr<CurClass>) -> T1 |
287 | /// __setstate__(T2) -> intrusive_ptr<CurClass> |
288 | /// |
289 | /// `T1` must be an object that is convertable to IValue by the same rules |
290 | /// for custom op/method registration. |
291 | /// |
292 | /// For the common case, T1 == T2. T1 can also be a subtype of T2. An |
293 | /// example where it makes sense for T1 and T2 to differ is if __setstate__ |
294 | /// handles legacy formats in a backwards compatible way. |
295 | /// |
296 | /// Example: |
297 | /// |
298 | /// .def_pickle( |
299 | /// // __getstate__ |
300 | /// [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) { |
301 | /// return self->stack_; |
302 | /// }, |
303 | /// [](std::vector<std::string> state) { // __setstate__ |
304 | /// return c10::make_intrusive<MyStackClass<std::string>>( |
305 | /// std::vector<std::string>{"i", "was", "deserialized"}); |
306 | /// }) |
307 | template <typename GetStateFn, typename SetStateFn> |
308 | class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) { |
309 | static_assert( |
310 | c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value && |
311 | c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value, |
312 | "def_pickle() currently only supports lambdas as " |
313 | "__getstate__ and __setstate__ arguments." ); |
314 | def("__getstate__" , std::forward<GetStateFn>(get_state)); |
315 | |
316 | // __setstate__ needs to be registered with some custom handling: |
317 | // We need to wrap the invocation of the user-provided function |
318 | // such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>) |
319 | // and assign it to the `capsule` attribute. |
320 | using SetStateTraits = |
321 | c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>; |
322 | using SetStateArg = typename c10::guts::typelist::head_t< |
323 | typename SetStateTraits::parameter_types>; |
324 | auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)]( |
325 | c10::tagged_capsule<CurClass> self, |
326 | SetStateArg&& arg) { |
327 | c10::intrusive_ptr<CurClass> classObj = |
328 | at::guts::invoke(set_state, std::forward<SetStateArg>(arg)); |
329 | auto object = self.ivalue.toObject(); |
330 | object->setSlot(0, c10::IValue::make_capsule(classObj)); |
331 | }; |
332 | defineMethod( |
333 | "__setstate__" , |
334 | detail::wrap_func<CurClass, decltype(setstate_wrapper)>( |
335 | std::move(setstate_wrapper))); |
336 | |
337 | // type validation |
338 | auto getstate_schema = classTypePtr->getMethod("__getstate__" ).getSchema(); |
339 | auto format_getstate_schema = [&getstate_schema]() { |
340 | std::stringstream ss; |
341 | ss << getstate_schema; |
342 | return ss.str(); |
343 | }; |
344 | TORCH_CHECK( |
345 | getstate_schema.arguments().size() == 1, |
346 | "__getstate__ should take exactly one argument: self. Got: " , |
347 | format_getstate_schema()); |
348 | auto first_arg_type = getstate_schema.arguments().at(0).type(); |
349 | TORCH_CHECK( |
350 | *first_arg_type == *classTypePtr, |
351 | "self argument of __getstate__ must be the custom class type. Got " , |
352 | first_arg_type->repr_str()); |
353 | TORCH_CHECK( |
354 | getstate_schema.returns().size() == 1, |
355 | "__getstate__ should return exactly one value for serialization. Got: " , |
356 | format_getstate_schema()); |
357 | |
358 | auto ser_type = getstate_schema.returns().at(0).type(); |
359 | auto setstate_schema = classTypePtr->getMethod("__setstate__" ).getSchema(); |
360 | auto arg_type = setstate_schema.arguments().at(1).type(); |
361 | TORCH_CHECK( |
362 | ser_type->isSubtypeOf(*arg_type), |
363 | "__getstate__'s return type should be a subtype of " |
364 | "input argument of __setstate__. Got " , |
365 | ser_type->repr_str(), |
366 | " but expected " , |
367 | arg_type->repr_str()); |
368 | |
369 | return *this; |
370 | } |
371 | |
372 | private: |
373 | template <typename Func> |
374 | torch::jit::Function* defineMethod( |
375 | std::string name, |
376 | Func func, |
377 | std::string doc_string = "" , |
378 | std::initializer_list<arg> default_args = {}) { |
379 | auto qualMethodName = qualClassName + "." + name; |
380 | auto schema = |
381 | c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "" ); |
382 | |
383 | // If default values are provided for function arguments, there must be |
384 | // none (no default values) or default values for all function |
385 | // arguments, except for self. This is because argument names are not |
386 | // extracted by inferFunctionSchemaSingleReturn, and so there must be a |
387 | // torch::arg instance in default_args even for arguments that do not |
388 | // have an actual default value provided. |
389 | TORCH_CHECK( |
390 | default_args.size() == 0 || |
391 | default_args.size() == schema.arguments().size() - 1, |
392 | "Default values must be specified for none or all arguments" ); |
393 | |
394 | // If there are default args, copy the argument names and default values to |
395 | // the function schema. |
396 | if (default_args.size() > 0) { |
397 | schema = withNewArguments(schema, default_args); |
398 | } |
399 | |
400 | auto wrapped_func = |
401 | [func = std::move(func)](jit::Stack& stack) mutable -> void { |
402 | // TODO: we need to figure out how to profile calls to custom functions |
403 | // like this! Currently can't do it because the profiler stuff is in |
404 | // libtorch and not ATen |
405 | using RetType = |
406 | typename c10::guts::infer_function_traits_t<Func>::return_type; |
407 | detail::BoxedProxy<RetType, Func>()(stack, func); |
408 | }; |
409 | auto method = std::make_unique<jit::BuiltinOpFunction>( |
410 | qualMethodName, |
411 | std::move(schema), |
412 | std::move(wrapped_func), |
413 | std::move(doc_string)); |
414 | |
415 | // Register the method here to keep the Method alive. |
416 | // ClassTypes do not hold ownership of their methods (normally it |
417 | // those are held by the CompilationUnit), so we need a proxy for |
418 | // that behavior here. |
419 | auto method_val = method.get(); |
420 | classTypePtr->addMethod(method_val); |
421 | registerCustomClassMethod(std::move(method)); |
422 | return method_val; |
423 | } |
424 | }; |
425 | |
426 | /// make_custom_class() is a convenient way to create an instance of a |
427 | /// registered custom class and wrap it in an IValue, for example when you want |
428 | /// to pass the object to TorchScript. Its syntax is equivalent to APIs like |
429 | /// `std::make_shared<>` or `c10::make_intrusive<>`. |
430 | /// |
431 | /// For example, if you have a custom C++ class that can be constructed from an |
432 | /// `int` and `std::string`, you might use this API like so: |
433 | /// |
434 | /// IValue custom_class_iv = torch::make_custom_class<MyClass>(3, |
435 | /// "foobarbaz"); |
436 | template <typename CurClass, typename... CtorArgs> |
437 | c10::IValue make_custom_class(CtorArgs&&... args) { |
438 | auto userClassInstance = |
439 | c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...); |
440 | return c10::IValue(std::move(userClassInstance)); |
441 | } |
442 | |
443 | // Alternative api for creating a torchbind class over torch::class_ this api is |
444 | // preffered to prevent size regressions on Edge usecases. Must be used in |
445 | // conjunction with TORCH_SELECTIVE_CLASS macro aka |
446 | // selective_class<foo>("foo_namespace", TORCH_SELECTIVE_CLASS("foo")) |
447 | template <class CurClass> |
448 | inline class_<CurClass> selective_class_( |
449 | const std::string& namespace_name, |
450 | detail::SelectiveStr<true> className) { |
451 | auto class_name = std::string(className.operator const char*()); |
452 | return torch::class_<CurClass>(namespace_name, class_name); |
453 | } |
454 | |
455 | template <class CurClass> |
456 | inline detail::ClassNotSelected selective_class_( |
457 | const std::string&, |
458 | detail::SelectiveStr<false>) { |
459 | return detail::ClassNotSelected(); |
460 | } |
461 | |
462 | // jit namespace for backward-compatibility |
463 | // We previously defined everything in torch::jit but moved it out to |
464 | // better reflect that these features are not limited only to TorchScript |
465 | namespace jit { |
466 | |
467 | using ::torch::class_; |
468 | using ::torch::getCustomClass; |
469 | using ::torch::init; |
470 | using ::torch::isCustomClass; |
471 | |
472 | } // namespace jit |
473 | |
474 | template <class CurClass> |
475 | inline class_<CurClass> Library::class_(const std::string& className) { |
476 | TORCH_CHECK( |
477 | kind_ == DEF || kind_ == FRAGMENT, |
478 | "class_(\"" , |
479 | className, |
480 | "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " |
481 | "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " |
482 | "(Error occurred at " , |
483 | file_, |
484 | ":" , |
485 | line_, |
486 | ")" ); |
487 | TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":" , line_); |
488 | return torch::class_<CurClass>(*ns_, className); |
489 | } |
490 | |
491 | const std::unordered_set<std::string> getAllCustomClassesNames(); |
492 | |
493 | template <class CurClass> |
494 | inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) { |
495 | auto class_name = std::string(className.operator const char*()); |
496 | TORCH_CHECK( |
497 | kind_ == DEF || kind_ == FRAGMENT, |
498 | "class_(\"" , |
499 | class_name, |
500 | "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " |
501 | "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " |
502 | "(Error occurred at " , |
503 | file_, |
504 | ":" , |
505 | line_, |
506 | ")" ); |
507 | TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":" , line_); |
508 | return torch::class_<CurClass>(*ns_, class_name); |
509 | } |
510 | |
511 | template <class CurClass> |
512 | inline detail::ClassNotSelected Library::class_(detail::SelectiveStr<false>) { |
513 | return detail::ClassNotSelected(); |
514 | } |
515 | |
516 | } // namespace torch |
517 | |