1 | #pragma once |
2 | |
3 | #include <torch/detail/static.h> |
4 | #include <torch/nn/module.h> |
5 | #include <torch/ordered_dict.h> |
6 | #include <torch/types.h> |
7 | |
8 | #include <torch/csrc/Device.h> |
9 | #include <torch/csrc/Dtype.h> |
10 | #include <torch/csrc/DynamicTypes.h> |
11 | #include <torch/csrc/Exceptions.h> |
12 | #include <torch/csrc/autograd/python_variable.h> |
13 | #include <torch/csrc/python_headers.h> |
14 | #include <torch/csrc/utils/pybind.h> |
15 | #include <torch/csrc/utils/python_numbers.h> |
16 | #include <torch/csrc/utils/python_tuples.h> |
17 | |
18 | #include <iterator> |
19 | #include <string> |
20 | #include <unordered_map> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | namespace torch { |
25 | namespace python { |
26 | namespace detail { |
27 | inline Device py_object_to_device(py::object object) { |
28 | PyObject* obj = object.ptr(); |
29 | if (THPDevice_Check(obj)) { |
30 | return reinterpret_cast<THPDevice*>(obj)->device; |
31 | } |
32 | throw TypeError("Expected device" ); |
33 | } |
34 | |
35 | inline Dtype py_object_to_dtype(py::object object) { |
36 | PyObject* obj = object.ptr(); |
37 | if (THPDtype_Check(obj)) { |
38 | return reinterpret_cast<THPDtype*>(obj)->scalar_type; |
39 | } |
40 | throw TypeError("Expected dtype" ); |
41 | } |
42 | |
43 | template <typename ModuleType> |
44 | using PyModuleClass = |
45 | py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>; |
46 | |
47 | /// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also |
48 | /// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module |
49 | /// to which it delegates all calls. |
50 | template <typename ModuleType> |
51 | void bind_cpp_module_wrapper( |
52 | py::module module, |
53 | PyModuleClass<ModuleType> cpp_class, |
54 | const char* name) { |
55 | // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass |
56 | // with a dynamically created class below. |
57 | py::object cpp_module = |
58 | py::module::import("torch.nn.cpp" ).attr("ModuleWrapper" ); |
59 | |
60 | // Grab the `type` class which we'll use as a metaclass to create a new class |
61 | // dynamically. |
62 | py::object type_metaclass = |
63 | py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type); |
64 | |
65 | // The `ModuleWrapper` constructor copies all functions to its own `__dict__` |
66 | // in its constructor, but we do need to give our dynamic class a constructor. |
67 | // Inside, we construct an instance of the original C++ module we're binding |
68 | // (the `torch::nn::Module` subclass), and then forward it to the |
69 | // `ModuleWrapper` constructor. |
70 | py::dict attributes; |
71 | |
72 | // `type()` always needs a `str`, but pybind11's `str()` method always creates |
73 | // a `unicode` object. |
74 | py::object name_str = py::str(name); |
75 | |
76 | // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of |
77 | // `torch.nn.Module`, and will delegate all calls to the C++ module we're |
78 | // binding. |
79 | py::object wrapper_class = |
80 | type_metaclass(name_str, py::make_tuple(cpp_module), attributes); |
81 | |
82 | // The constructor of the dynamic class calls `ModuleWrapper.__init__()`, |
83 | // which replaces its methods with those of the C++ module. |
84 | wrapper_class.attr("__init__" ) = py::cpp_function( |
85 | [cpp_module, cpp_class]( |
86 | py::object self, py::args args, py::kwargs kwargs) { |
87 | cpp_module.attr("__init__" )(self, cpp_class(*args, **kwargs)); |
88 | }, |
89 | py::is_method(wrapper_class)); |
90 | |
91 | // Calling `my_module.my_class` now means that `my_class` is a subclass of |
92 | // `ModuleWrapper`, and whose methods call into the C++ module we're binding. |
93 | module.attr(name) = wrapper_class; |
94 | } |
95 | } // namespace detail |
96 | |
97 | /// Adds method bindings for a pybind11 `class_` that binds an `nn::Module` |
98 | /// subclass. |
99 | /// |
100 | /// Say you have a pybind11 class object created with `py::class_<Net>(m, |
101 | /// "Net")`. This function will add all the necessary `.def()` calls to bind the |
102 | /// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into |
103 | /// Python. |
104 | /// |
105 | /// Users should prefer to use `bind_module` if possible. |
106 | template <typename ModuleType, typename... Extra> |
107 | py::class_<ModuleType, Extra...> add_module_bindings( |
108 | py::class_<ModuleType, Extra...> module) { |
109 | // clang-format off |
110 | return module |
111 | .def("train" , |
112 | [](ModuleType& module, bool mode) { module.train(mode); }, |
113 | py::arg("mode" ) = true) |
114 | .def("eval" , [](ModuleType& module) { module.eval(); }) |
115 | .def("clone" , [](ModuleType& module) { return module.clone(); }) |
116 | .def_property_readonly( |
117 | "training" , [](ModuleType& module) { return module.is_training(); }) |
118 | .def("zero_grad" , [](ModuleType& module) { module.zero_grad(); }) |
119 | .def_property_readonly( "_parameters" , [](ModuleType& module) { |
120 | return module.named_parameters(/*recurse=*/false); |
121 | }) |
122 | .def("parameters" , [](ModuleType& module, bool recurse) { |
123 | return module.parameters(recurse); |
124 | }, |
125 | py::arg("recurse" ) = true) |
126 | .def("named_parameters" , [](ModuleType& module, bool recurse) { |
127 | return module.named_parameters(recurse); |
128 | }, |
129 | py::arg("recurse" ) = true) |
130 | .def_property_readonly("_buffers" , [](ModuleType& module) { |
131 | return module.named_buffers(/*recurse=*/false); |
132 | }) |
133 | .def("buffers" , [](ModuleType& module, bool recurse) { |
134 | return module.buffers(recurse); }, |
135 | py::arg("recurse" ) = true) |
136 | .def("named_buffers" , [](ModuleType& module, bool recurse) { |
137 | return module.named_buffers(recurse); |
138 | }, |
139 | py::arg("recurse" ) = true) |
140 | .def_property_readonly( |
141 | "_modules" , [](ModuleType& module) { return module.named_children(); }) |
142 | .def("modules" , [](ModuleType& module) { return module.modules(); }) |
143 | .def("named_modules" , |
144 | [](ModuleType& module, py::object /* unused */, std::string prefix, bool remove_duplicate /* unused */) { |
145 | return module.named_modules(std::move(prefix)); |
146 | }, |
147 | py::arg("memo" ) = py::none(), |
148 | py::arg("prefix" ) = std::string(), |
149 | py::arg("remove_duplicate" ) = true) |
150 | .def("children" , [](ModuleType& module) { return module.children(); }) |
151 | .def("named_children" , |
152 | [](ModuleType& module) { return module.named_children(); }) |
153 | .def("to" , [](ModuleType& module, py::object object, bool non_blocking) { |
154 | if (THPDevice_Check(object.ptr())) { |
155 | module.to( |
156 | reinterpret_cast<THPDevice*>(object.ptr())->device, |
157 | non_blocking); |
158 | } else { |
159 | module.to(detail::py_object_to_dtype(object), non_blocking); |
160 | } |
161 | }, |
162 | py::arg("dtype_or_device" ), |
163 | py::arg("non_blocking" ) = false) |
164 | .def("to" , |
165 | [](ModuleType& module, |
166 | py::object device, |
167 | py::object dtype, |
168 | bool non_blocking) { |
169 | if (device.is_none()) { |
170 | module.to(detail::py_object_to_dtype(dtype), non_blocking); |
171 | } else if (dtype.is_none()) { |
172 | module.to(detail::py_object_to_device(device), non_blocking); |
173 | } else { |
174 | module.to( |
175 | detail::py_object_to_device(device), |
176 | detail::py_object_to_dtype(dtype), |
177 | non_blocking); |
178 | } |
179 | }, |
180 | py::arg("device" ), |
181 | py::arg("dtype" ), |
182 | py::arg("non_blocking" ) = false) |
183 | .def("cuda" , [](ModuleType& module) { module.to(kCUDA); }) |
184 | .def("cpu" , [](ModuleType& module) { module.to(kCPU); }) |
185 | .def("float" , [](ModuleType& module) { module.to(kFloat32); }) |
186 | .def("double" , [](ModuleType& module) { module.to(kFloat64); }) |
187 | .def("half" , [](ModuleType& module) { module.to(kFloat16); }) |
188 | .def("__str__" , [](ModuleType& module) { return module.name(); }) |
189 | .def("__repr__" , [](ModuleType& module) { return module.name(); }); |
190 | // clang-format on |
191 | } |
192 | |
193 | /// Creates a pybind11 class object for an `nn::Module` subclass type and adds |
194 | /// default bindings. |
195 | /// |
196 | /// After adding the default bindings, the class object is returned, such that |
197 | /// you can add more bindings. |
198 | /// |
199 | /// Example usage: |
200 | /// \rst |
201 | /// .. code-block:: cpp |
202 | /// |
203 | /// struct Net : torch::nn::Module { |
204 | /// Net(int in, int out) { } |
205 | /// torch::Tensor forward(torch::Tensor x) { return x; } |
206 | /// }; |
207 | /// |
208 | /// PYBIND11_MODULE(my_module, m) { |
209 | /// torch::python::bind_module<Net>(m, "Net") |
210 | /// .def(py::init<int, int>()) |
211 | /// .def("forward", &Net::forward); |
212 | /// } |
213 | /// \endrst |
214 | template <typename ModuleType, bool force_enable = false> |
215 | torch::disable_if_t< |
216 | torch::detail::has_forward<ModuleType>::value && !force_enable, |
217 | detail::PyModuleClass<ModuleType>> |
218 | bind_module(py::module module, const char* name) { |
219 | py::module cpp = module.def_submodule("cpp" ); |
220 | auto cpp_class = |
221 | add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name)); |
222 | detail::bind_cpp_module_wrapper(module, cpp_class, name); |
223 | return cpp_class; |
224 | } |
225 | |
226 | /// Creates a pybind11 class object for an `nn::Module` subclass type and adds |
227 | /// default bindings. |
228 | /// |
229 | /// After adding the default bindings, the class object is returned, such that |
230 | /// you can add more bindings. |
231 | /// |
232 | /// If the class has a `forward()` method, it is automatically exposed as |
233 | /// `forward()` and `__call__` in Python. |
234 | /// |
235 | /// Example usage: |
236 | /// \rst |
237 | /// .. code-block:: cpp |
238 | /// |
239 | /// struct Net : torch::nn::Module { |
240 | /// Net(int in, int out) { } |
241 | /// torch::Tensor forward(torch::Tensor x) { return x; } |
242 | /// }; |
243 | /// |
244 | /// PYBIND11_MODULE(my_module, m) { |
245 | /// torch::python::bind_module<Net>(m, "Net") |
246 | /// .def(py::init<int, int>()) |
247 | /// .def("forward", &Net::forward); |
248 | /// } |
249 | /// \endrst |
250 | template < |
251 | typename ModuleType, |
252 | typename = |
253 | torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>> |
254 | detail::PyModuleClass<ModuleType> bind_module( |
255 | py::module module, |
256 | const char* name) { |
257 | return bind_module<ModuleType, /*force_enable=*/true>(module, name) |
258 | .def("forward" , &ModuleType::forward) |
259 | .def("__call__" , &ModuleType::forward); |
260 | } |
261 | } // namespace python |
262 | } // namespace torch |
263 | |