1#pragma once
2
3#include <ATen/core/builtin_function.h>
4#include <ATen/core/function_schema.h>
5#include <ATen/core/ivalue.h>
6#include <ATen/core/class_type.h>
7#include <ATen/core/op_registration/infer_schema.h>
8#include <ATen/core/stack.h>
9#include <c10/util/C++17.h>
10#include <c10/util/Metaprogramming.h>
11#include <c10/util/TypeList.h>
12#include <c10/util/TypeTraits.h>
13#include <torch/custom_class_detail.h>
14#include <torch/library.h>
15#include <iostream>
16#include <sstream>
17
18namespace torch {
19
20/// This function is used in conjunction with `class_::def()` to register
21/// a constructor for a given C++ class type. For example,
22/// `torch::init<int, std::string>()` would register a two-argument constructor
23/// taking an `int` and a `std::string` as argument.
24template <class... Types>
25detail::types<void, Types...> init() {
26 return detail::types<void, Types...>{};
27}
28
29template <typename Func, typename... ParameterTypeList>
30struct InitLambda {
31 Func f;
32};
33
34template <typename Func>
35decltype(auto) init(Func&& f) {
36 using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>;
37 using ParameterTypeList = typename InitTraits::parameter_types;
38
39 InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
40 return init;
41}
42
43/// Entry point for custom C++ class registration. To register a C++ class
44/// in PyTorch, instantiate `torch::class_` with the desired class as the
45/// template parameter. Typically, this instantiation should be done in
46/// the initialization of a global variable, so that the class will be
47/// made available on dynamic library loading without any additional API
48/// calls needed. For example, to register a class named Foo, you might
49/// create a global variable like so:
50///
51/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
52/// .def("myMethod", &Foo::myMethod)
53/// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
54/// // Do something with `self`
55/// });
56///
57/// In addition to registering the class, this registration also chains
58/// `def()` calls to register methods. `myMethod()` is registered with
59/// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
60/// is registered with a C++ lambda expression.
61template <class CurClass>
62class class_ : public ::torch::detail::class_base {
63 static_assert(
64 std::is_base_of<CustomClassHolder, CurClass>::value,
65 "torch::class_<T> requires T to inherit from CustomClassHolder");
66
67 public:
68 /// This constructor actually registers the class type.
69 /// String argument `namespaceName` is an identifier for the
70 /// namespace you would like this class to appear in.
71 /// String argument `className` is the name you would like to
72 /// see this class exposed as in Python and TorchScript. For example, if
73 /// you pass `foo` as the namespace name and `Bar` as the className, the
74 /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
75 explicit class_(
76 const std::string& namespaceName,
77 const std::string& className,
78 std::string doc_string = "")
79 : class_base(
80 namespaceName,
81 className,
82 std::move(doc_string),
83 typeid(c10::intrusive_ptr<CurClass>),
84 typeid(c10::tagged_capsule<CurClass>)) {}
85
86 /// def() can be used in conjunction with `torch::init()` to register
87 /// a constructor for a given C++ class type. For example, passing
88 /// `torch::init<int, std::string>()` would register a two-argument
89 /// constructor taking an `int` and a `std::string` as argument.
90 template <typename... Types>
91 class_& def(
92 torch::detail::types<void, Types...>,
93 std::string doc_string = "",
94 std::initializer_list<arg> default_args =
95 {}) { // Used in combination with
96 // torch::init<...>()
97 auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
98 auto classObj = c10::make_intrusive<CurClass>(args...);
99 auto object = self.ivalue.toObject();
100 object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
101 };
102
103 defineMethod(
104 "__init__",
105 std::move(func),
106 std::move(doc_string),
107 std::move(default_args));
108 return *this;
109 }
110
111 // Used in combination with torch::init([]lambda(){......})
112 template <typename Func, typename... ParameterTypes>
113 class_& def(
114 InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
115 std::string doc_string = "",
116 std::initializer_list<arg> default_args = {}) {
117 auto init_lambda_wrapper = [func = std::move(init.f)](
118 c10::tagged_capsule<CurClass> self,
119 ParameterTypes... arg) {
120 c10::intrusive_ptr<CurClass> classObj =
121 at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
122 auto object = self.ivalue.toObject();
123 object->setSlot(0, c10::IValue::make_capsule(classObj));
124 };
125
126 defineMethod(
127 "__init__",
128 std::move(init_lambda_wrapper),
129 std::move(doc_string),
130 std::move(default_args));
131
132 return *this;
133 }
134
135 /// This is the normal method registration API. `name` is the name that
136 /// the method will be made accessible by in Python and TorchScript.
137 /// `f` is a callable object that defines the method. Typically `f`
138 /// will either be a pointer to a method on `CurClass`, or a lambda
139 /// expression that takes a `c10::intrusive_ptr<CurClass>` as the first
140 /// argument (emulating a `this` argument in a C++ method.)
141 ///
142 /// Examples:
143 ///
144 /// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
145 /// // Python and TorchScript
146 /// .def("call_foo", &Foo::foo)
147 ///
148 /// // Exposes the given lambda expression as method `call_lambda()`
149 /// // in Python and TorchScript.
150 /// .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
151 /// // do something
152 /// })
153 template <typename Func>
154 class_& def(
155 std::string name,
156 Func f,
157 std::string doc_string = "",
158 std::initializer_list<arg> default_args = {}) {
159 auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
160 defineMethod(
161 std::move(name),
162 std::move(wrapped_f),
163 std::move(doc_string),
164 std::move(default_args));
165 return *this;
166 }
167
168 /// Method registration API for static methods.
169 template <typename Func>
170 class_& def_static(std::string name, Func func, std::string doc_string = "") {
171 auto qualMethodName = qualClassName + "." + name;
172 auto schema =
173 c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
174
175 auto wrapped_func =
176 [func = std::move(func)](jit::Stack& stack) mutable -> void {
177 using RetType =
178 typename c10::guts::infer_function_traits_t<Func>::return_type;
179 detail::BoxedProxy<RetType, Func>()(stack, func);
180 };
181 auto method = std::make_unique<jit::BuiltinOpFunction>(
182 std::move(qualMethodName),
183 std::move(schema),
184 std::move(wrapped_func),
185 std::move(doc_string));
186
187 classTypePtr->addStaticMethod(method.get());
188 registerCustomClassMethod(std::move(method));
189 return *this;
190 }
191
192 /// Property registration API for properties with both getter and setter
193 /// functions.
194 template <typename GetterFunc, typename SetterFunc>
195 class_& def_property(
196 const std::string& name,
197 GetterFunc getter_func,
198 SetterFunc setter_func,
199 std::string doc_string = "") {
200 torch::jit::Function* getter{};
201 torch::jit::Function* setter{};
202
203 auto wrapped_getter =
204 detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
205 getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
206
207 auto wrapped_setter =
208 detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func));
209 setter = defineMethod(name + "_setter", wrapped_setter, doc_string);
210
211 classTypePtr->addProperty(name, getter, setter);
212 return *this;
213 }
214
215 /// Property registration API for properties with only getter function.
216 template <typename GetterFunc>
217 class_& def_property(
218 const std::string& name,
219 GetterFunc getter_func,
220 std::string doc_string = "") {
221 torch::jit::Function* getter{};
222
223 auto wrapped_getter =
224 detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
225 getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
226
227 classTypePtr->addProperty(name, getter, nullptr);
228 return *this;
229 }
230
231 /// Property registration API for properties with read-write access.
232 template <typename T>
233 class_& def_readwrite(const std::string& name, T CurClass::*field) {
234 auto getter_func = [field =
235 field](const c10::intrusive_ptr<CurClass>& self) {
236 return self.get()->*field;
237 };
238
239 auto setter_func = [field = field](
240 const c10::intrusive_ptr<CurClass>& self, T value) {
241 self.get()->*field = value;
242 };
243
244 return def_property(name, getter_func, setter_func);
245 }
246
247 /// Property registration API for properties with read-only access.
248 template <typename T>
249 class_& def_readonly(const std::string& name, T CurClass::*field) {
250 auto getter_func =
251 [field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) {
252 return self.get()->*field;
253 };
254
255 return def_property(name, getter_func);
256 }
257
258 /// This is an unsafe method registration API added for adding custom JIT
259 /// backend support via custom C++ classes. It is not for general purpose use.
260 class_& _def_unboxed(
261 std::string name,
262 std::function<void(jit::Stack&)> func,
263 c10::FunctionSchema schema,
264 std::string doc_string = "") {
265 auto method = std::make_unique<jit::BuiltinOpFunction>(
266 qualClassName + "." + name,
267 std::move(schema),
268 std::move(func),
269 std::move(doc_string));
270 classTypePtr->addMethod(method.get());
271 registerCustomClassMethod(std::move(method));
272 return *this;
273 }
274
275 /// def_pickle() is used to define exactly what state gets serialized
276 /// or deserialized for a given instance of a custom C++ class in
277 /// Python or TorchScript. This protocol is equivalent to the Pickle
278 /// concept of `__getstate__` and `__setstate__` from Python
279 /// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
280 ///
281 /// Currently, both the `get_state` and `set_state` callables must be
282 /// C++ lambda expressions. They should have the following signatures,
283 /// where `CurClass` is the class you're registering and `T1` is some object
284 /// that encapsulates the state of the object.
285 ///
286 /// __getstate__(intrusive_ptr<CurClass>) -> T1
287 /// __setstate__(T2) -> intrusive_ptr<CurClass>
288 ///
289 /// `T1` must be an object that is convertable to IValue by the same rules
290 /// for custom op/method registration.
291 ///
292 /// For the common case, T1 == T2. T1 can also be a subtype of T2. An
293 /// example where it makes sense for T1 and T2 to differ is if __setstate__
294 /// handles legacy formats in a backwards compatible way.
295 ///
296 /// Example:
297 ///
298 /// .def_pickle(
299 /// // __getstate__
300 /// [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
301 /// return self->stack_;
302 /// },
303 /// [](std::vector<std::string> state) { // __setstate__
304 /// return c10::make_intrusive<MyStackClass<std::string>>(
305 /// std::vector<std::string>{"i", "was", "deserialized"});
306 /// })
307 template <typename GetStateFn, typename SetStateFn>
308 class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
309 static_assert(
310 c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
311 c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
312 "def_pickle() currently only supports lambdas as "
313 "__getstate__ and __setstate__ arguments.");
314 def("__getstate__", std::forward<GetStateFn>(get_state));
315
316 // __setstate__ needs to be registered with some custom handling:
317 // We need to wrap the invocation of the user-provided function
318 // such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
319 // and assign it to the `capsule` attribute.
320 using SetStateTraits =
321 c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
322 using SetStateArg = typename c10::guts::typelist::head_t<
323 typename SetStateTraits::parameter_types>;
324 auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)](
325 c10::tagged_capsule<CurClass> self,
326 SetStateArg&& arg) {
327 c10::intrusive_ptr<CurClass> classObj =
328 at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
329 auto object = self.ivalue.toObject();
330 object->setSlot(0, c10::IValue::make_capsule(classObj));
331 };
332 defineMethod(
333 "__setstate__",
334 detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
335 std::move(setstate_wrapper)));
336
337 // type validation
338 auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema();
339 auto format_getstate_schema = [&getstate_schema]() {
340 std::stringstream ss;
341 ss << getstate_schema;
342 return ss.str();
343 };
344 TORCH_CHECK(
345 getstate_schema.arguments().size() == 1,
346 "__getstate__ should take exactly one argument: self. Got: ",
347 format_getstate_schema());
348 auto first_arg_type = getstate_schema.arguments().at(0).type();
349 TORCH_CHECK(
350 *first_arg_type == *classTypePtr,
351 "self argument of __getstate__ must be the custom class type. Got ",
352 first_arg_type->repr_str());
353 TORCH_CHECK(
354 getstate_schema.returns().size() == 1,
355 "__getstate__ should return exactly one value for serialization. Got: ",
356 format_getstate_schema());
357
358 auto ser_type = getstate_schema.returns().at(0).type();
359 auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema();
360 auto arg_type = setstate_schema.arguments().at(1).type();
361 TORCH_CHECK(
362 ser_type->isSubtypeOf(*arg_type),
363 "__getstate__'s return type should be a subtype of "
364 "input argument of __setstate__. Got ",
365 ser_type->repr_str(),
366 " but expected ",
367 arg_type->repr_str());
368
369 return *this;
370 }
371
372 private:
373 template <typename Func>
374 torch::jit::Function* defineMethod(
375 std::string name,
376 Func func,
377 std::string doc_string = "",
378 std::initializer_list<arg> default_args = {}) {
379 auto qualMethodName = qualClassName + "." + name;
380 auto schema =
381 c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
382
383 // If default values are provided for function arguments, there must be
384 // none (no default values) or default values for all function
385 // arguments, except for self. This is because argument names are not
386 // extracted by inferFunctionSchemaSingleReturn, and so there must be a
387 // torch::arg instance in default_args even for arguments that do not
388 // have an actual default value provided.
389 TORCH_CHECK(
390 default_args.size() == 0 ||
391 default_args.size() == schema.arguments().size() - 1,
392 "Default values must be specified for none or all arguments");
393
394 // If there are default args, copy the argument names and default values to
395 // the function schema.
396 if (default_args.size() > 0) {
397 schema = withNewArguments(schema, default_args);
398 }
399
400 auto wrapped_func =
401 [func = std::move(func)](jit::Stack& stack) mutable -> void {
402 // TODO: we need to figure out how to profile calls to custom functions
403 // like this! Currently can't do it because the profiler stuff is in
404 // libtorch and not ATen
405 using RetType =
406 typename c10::guts::infer_function_traits_t<Func>::return_type;
407 detail::BoxedProxy<RetType, Func>()(stack, func);
408 };
409 auto method = std::make_unique<jit::BuiltinOpFunction>(
410 qualMethodName,
411 std::move(schema),
412 std::move(wrapped_func),
413 std::move(doc_string));
414
415 // Register the method here to keep the Method alive.
416 // ClassTypes do not hold ownership of their methods (normally it
417 // those are held by the CompilationUnit), so we need a proxy for
418 // that behavior here.
419 auto method_val = method.get();
420 classTypePtr->addMethod(method_val);
421 registerCustomClassMethod(std::move(method));
422 return method_val;
423 }
424};
425
426/// make_custom_class() is a convenient way to create an instance of a
427/// registered custom class and wrap it in an IValue, for example when you want
428/// to pass the object to TorchScript. Its syntax is equivalent to APIs like
429/// `std::make_shared<>` or `c10::make_intrusive<>`.
430///
431/// For example, if you have a custom C++ class that can be constructed from an
432/// `int` and `std::string`, you might use this API like so:
433///
434/// IValue custom_class_iv = torch::make_custom_class<MyClass>(3,
435/// "foobarbaz");
436template <typename CurClass, typename... CtorArgs>
437c10::IValue make_custom_class(CtorArgs&&... args) {
438 auto userClassInstance =
439 c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
440 return c10::IValue(std::move(userClassInstance));
441}
442
443// Alternative api for creating a torchbind class over torch::class_ this api is
444// preffered to prevent size regressions on Edge usecases. Must be used in
445// conjunction with TORCH_SELECTIVE_CLASS macro aka
446// selective_class<foo>("foo_namespace", TORCH_SELECTIVE_CLASS("foo"))
447template <class CurClass>
448inline class_<CurClass> selective_class_(
449 const std::string& namespace_name,
450 detail::SelectiveStr<true> className) {
451 auto class_name = std::string(className.operator const char*());
452 return torch::class_<CurClass>(namespace_name, class_name);
453}
454
455template <class CurClass>
456inline detail::ClassNotSelected selective_class_(
457 const std::string&,
458 detail::SelectiveStr<false>) {
459 return detail::ClassNotSelected();
460}
461
462// jit namespace for backward-compatibility
463// We previously defined everything in torch::jit but moved it out to
464// better reflect that these features are not limited only to TorchScript
465namespace jit {
466
467using ::torch::class_;
468using ::torch::getCustomClass;
469using ::torch::init;
470using ::torch::isCustomClass;
471
472} // namespace jit
473
474template <class CurClass>
475inline class_<CurClass> Library::class_(const std::string& className) {
476 TORCH_CHECK(
477 kind_ == DEF || kind_ == FRAGMENT,
478 "class_(\"",
479 className,
480 "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
481 "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
482 "(Error occurred at ",
483 file_,
484 ":",
485 line_,
486 ")");
487 TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
488 return torch::class_<CurClass>(*ns_, className);
489}
490
491const std::unordered_set<std::string> getAllCustomClassesNames();
492
493template <class CurClass>
494inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) {
495 auto class_name = std::string(className.operator const char*());
496 TORCH_CHECK(
497 kind_ == DEF || kind_ == FRAGMENT,
498 "class_(\"",
499 class_name,
500 "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
501 "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
502 "(Error occurred at ",
503 file_,
504 ":",
505 line_,
506 ")");
507 TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
508 return torch::class_<CurClass>(*ns_, class_name);
509}
510
511template <class CurClass>
512inline detail::ClassNotSelected Library::class_(detail::SelectiveStr<false>) {
513 return detail::ClassNotSelected();
514}
515
516} // namespace torch
517