1#pragma once
2
3#include <torch/detail/static.h>
4#include <torch/nn/module.h>
5#include <torch/ordered_dict.h>
6#include <torch/types.h>
7
8#include <torch/csrc/Device.h>
9#include <torch/csrc/Dtype.h>
10#include <torch/csrc/DynamicTypes.h>
11#include <torch/csrc/Exceptions.h>
12#include <torch/csrc/autograd/python_variable.h>
13#include <torch/csrc/python_headers.h>
14#include <torch/csrc/utils/pybind.h>
15#include <torch/csrc/utils/python_numbers.h>
16#include <torch/csrc/utils/python_tuples.h>
17
18#include <iterator>
19#include <string>
20#include <unordered_map>
21#include <utility>
22#include <vector>
23
24namespace torch {
25namespace python {
26namespace detail {
27inline Device py_object_to_device(py::object object) {
28 PyObject* obj = object.ptr();
29 if (THPDevice_Check(obj)) {
30 return reinterpret_cast<THPDevice*>(obj)->device;
31 }
32 throw TypeError("Expected device");
33}
34
35inline Dtype py_object_to_dtype(py::object object) {
36 PyObject* obj = object.ptr();
37 if (THPDtype_Check(obj)) {
38 return reinterpret_cast<THPDtype*>(obj)->scalar_type;
39 }
40 throw TypeError("Expected dtype");
41}
42
43template <typename ModuleType>
44using PyModuleClass =
45 py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
46
47/// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also
48/// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module
49/// to which it delegates all calls.
50template <typename ModuleType>
51void bind_cpp_module_wrapper(
52 py::module module,
53 PyModuleClass<ModuleType> cpp_class,
54 const char* name) {
55 // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
56 // with a dynamically created class below.
57 py::object cpp_module =
58 py::module::import("torch.nn.cpp").attr("ModuleWrapper");
59
60 // Grab the `type` class which we'll use as a metaclass to create a new class
61 // dynamically.
62 py::object type_metaclass =
63 py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
64
65 // The `ModuleWrapper` constructor copies all functions to its own `__dict__`
66 // in its constructor, but we do need to give our dynamic class a constructor.
67 // Inside, we construct an instance of the original C++ module we're binding
68 // (the `torch::nn::Module` subclass), and then forward it to the
69 // `ModuleWrapper` constructor.
70 py::dict attributes;
71
72 // `type()` always needs a `str`, but pybind11's `str()` method always creates
73 // a `unicode` object.
74 py::object name_str = py::str(name);
75
76 // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
77 // `torch.nn.Module`, and will delegate all calls to the C++ module we're
78 // binding.
79 py::object wrapper_class =
80 type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
81
82 // The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
83 // which replaces its methods with those of the C++ module.
84 wrapper_class.attr("__init__") = py::cpp_function(
85 [cpp_module, cpp_class](
86 py::object self, py::args args, py::kwargs kwargs) {
87 cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
88 },
89 py::is_method(wrapper_class));
90
91 // Calling `my_module.my_class` now means that `my_class` is a subclass of
92 // `ModuleWrapper`, and whose methods call into the C++ module we're binding.
93 module.attr(name) = wrapper_class;
94}
95} // namespace detail
96
97/// Adds method bindings for a pybind11 `class_` that binds an `nn::Module`
98/// subclass.
99///
100/// Say you have a pybind11 class object created with `py::class_<Net>(m,
101/// "Net")`. This function will add all the necessary `.def()` calls to bind the
102/// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into
103/// Python.
104///
105/// Users should prefer to use `bind_module` if possible.
106template <typename ModuleType, typename... Extra>
107py::class_<ModuleType, Extra...> add_module_bindings(
108 py::class_<ModuleType, Extra...> module) {
109 // clang-format off
110 return module
111 .def("train",
112 [](ModuleType& module, bool mode) { module.train(mode); },
113 py::arg("mode") = true)
114 .def("eval", [](ModuleType& module) { module.eval(); })
115 .def("clone", [](ModuleType& module) { return module.clone(); })
116 .def_property_readonly(
117 "training", [](ModuleType& module) { return module.is_training(); })
118 .def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
119 .def_property_readonly( "_parameters", [](ModuleType& module) {
120 return module.named_parameters(/*recurse=*/false);
121 })
122 .def("parameters", [](ModuleType& module, bool recurse) {
123 return module.parameters(recurse);
124 },
125 py::arg("recurse") = true)
126 .def("named_parameters", [](ModuleType& module, bool recurse) {
127 return module.named_parameters(recurse);
128 },
129 py::arg("recurse") = true)
130 .def_property_readonly("_buffers", [](ModuleType& module) {
131 return module.named_buffers(/*recurse=*/false);
132 })
133 .def("buffers", [](ModuleType& module, bool recurse) {
134 return module.buffers(recurse); },
135 py::arg("recurse") = true)
136 .def("named_buffers", [](ModuleType& module, bool recurse) {
137 return module.named_buffers(recurse);
138 },
139 py::arg("recurse") = true)
140 .def_property_readonly(
141 "_modules", [](ModuleType& module) { return module.named_children(); })
142 .def("modules", [](ModuleType& module) { return module.modules(); })
143 .def("named_modules",
144 [](ModuleType& module, py::object /* unused */, std::string prefix, bool remove_duplicate /* unused */) {
145 return module.named_modules(std::move(prefix));
146 },
147 py::arg("memo") = py::none(),
148 py::arg("prefix") = std::string(),
149 py::arg("remove_duplicate") = true)
150 .def("children", [](ModuleType& module) { return module.children(); })
151 .def("named_children",
152 [](ModuleType& module) { return module.named_children(); })
153 .def("to", [](ModuleType& module, py::object object, bool non_blocking) {
154 if (THPDevice_Check(object.ptr())) {
155 module.to(
156 reinterpret_cast<THPDevice*>(object.ptr())->device,
157 non_blocking);
158 } else {
159 module.to(detail::py_object_to_dtype(object), non_blocking);
160 }
161 },
162 py::arg("dtype_or_device"),
163 py::arg("non_blocking") = false)
164 .def("to",
165 [](ModuleType& module,
166 py::object device,
167 py::object dtype,
168 bool non_blocking) {
169 if (device.is_none()) {
170 module.to(detail::py_object_to_dtype(dtype), non_blocking);
171 } else if (dtype.is_none()) {
172 module.to(detail::py_object_to_device(device), non_blocking);
173 } else {
174 module.to(
175 detail::py_object_to_device(device),
176 detail::py_object_to_dtype(dtype),
177 non_blocking);
178 }
179 },
180 py::arg("device"),
181 py::arg("dtype"),
182 py::arg("non_blocking") = false)
183 .def("cuda", [](ModuleType& module) { module.to(kCUDA); })
184 .def("cpu", [](ModuleType& module) { module.to(kCPU); })
185 .def("float", [](ModuleType& module) { module.to(kFloat32); })
186 .def("double", [](ModuleType& module) { module.to(kFloat64); })
187 .def("half", [](ModuleType& module) { module.to(kFloat16); })
188 .def("__str__", [](ModuleType& module) { return module.name(); })
189 .def("__repr__", [](ModuleType& module) { return module.name(); });
190 // clang-format on
191}
192
193/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
194/// default bindings.
195///
196/// After adding the default bindings, the class object is returned, such that
197/// you can add more bindings.
198///
199/// Example usage:
200/// \rst
201/// .. code-block:: cpp
202///
203/// struct Net : torch::nn::Module {
204/// Net(int in, int out) { }
205/// torch::Tensor forward(torch::Tensor x) { return x; }
206/// };
207///
208/// PYBIND11_MODULE(my_module, m) {
209/// torch::python::bind_module<Net>(m, "Net")
210/// .def(py::init<int, int>())
211/// .def("forward", &Net::forward);
212/// }
213/// \endrst
214template <typename ModuleType, bool force_enable = false>
215torch::disable_if_t<
216 torch::detail::has_forward<ModuleType>::value && !force_enable,
217 detail::PyModuleClass<ModuleType>>
218bind_module(py::module module, const char* name) {
219 py::module cpp = module.def_submodule("cpp");
220 auto cpp_class =
221 add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
222 detail::bind_cpp_module_wrapper(module, cpp_class, name);
223 return cpp_class;
224}
225
226/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
227/// default bindings.
228///
229/// After adding the default bindings, the class object is returned, such that
230/// you can add more bindings.
231///
232/// If the class has a `forward()` method, it is automatically exposed as
233/// `forward()` and `__call__` in Python.
234///
235/// Example usage:
236/// \rst
237/// .. code-block:: cpp
238///
239/// struct Net : torch::nn::Module {
240/// Net(int in, int out) { }
241/// torch::Tensor forward(torch::Tensor x) { return x; }
242/// };
243///
244/// PYBIND11_MODULE(my_module, m) {
245/// torch::python::bind_module<Net>(m, "Net")
246/// .def(py::init<int, int>())
247/// .def("forward", &Net::forward);
248/// }
249/// \endrst
250template <
251 typename ModuleType,
252 typename =
253 torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
254detail::PyModuleClass<ModuleType> bind_module(
255 py::module module,
256 const char* name) {
257 return bind_module<ModuleType, /*force_enable=*/true>(module, name)
258 .def("forward", &ModuleType::forward)
259 .def("__call__", &ModuleType::forward);
260}
261} // namespace python
262} // namespace torch
263