1 | #pragma once |
2 | |
3 | #include <torch/nn/modules/container/any_module_holder.h> |
4 | #include <torch/nn/modules/container/any_value.h> |
5 | #include <torch/nn/pimpl.h> |
6 | #include <torch/ordered_dict.h> |
7 | #include <torch/serialize/archive.h> |
8 | #include <torch/types.h> |
9 | |
10 | #include <ATen/ATen.h> |
11 | |
12 | #include <functional> |
13 | #include <iosfwd> |
14 | #include <map> |
15 | #include <memory> |
16 | #include <string> |
17 | #include <type_traits> |
18 | |
19 | namespace torch { |
20 | namespace nn { |
21 | |
22 | /// The base class for all modules in PyTorch. |
23 | /// |
24 | /// \rst |
25 | /// .. note:: |
26 | /// The design and implementation of this class is largely based on the Python |
27 | /// API. You may want to consult the python documentation for |
28 | /// :py:class:`pytorch:torch.nn.Module` for further clarification on certain |
29 | /// methods or behavior. |
30 | /// \endrst |
31 | /// |
32 | /// A `Module` is an abstraction over the implementation of some function or |
33 | /// algorithm, possibly associated with some persistent data. A `Module` may |
34 | /// contain further `Module`s ("submodules"), each with their own |
35 | /// implementation, persistent data and further submodules. `Module`s can thus |
36 | /// be said to form a recursive tree structure. A `Module` is registered as a |
37 | /// submodule to another `Module` by calling `register_module()`, typically from |
38 | /// within a parent module's constructor. |
39 | /// |
40 | /// A distinction is made between three kinds of persistent data that may be |
41 | /// associated with a `Module`: |
42 | /// |
43 | /// 1. *Parameters*: tensors that record gradients, typically weights updated |
44 | /// during the backward step (e.g. the `weight` of a `Linear` module), |
45 | /// 2. *Buffers*: tensors that do not record gradients, typically updated during |
46 | /// the forward step, such as running statistics (e.g. `mean` and `variance` |
47 | /// in the `BatchNorm` module), |
48 | /// 3. Any additional state, not necessarily tensors, required for the |
49 | /// implementation or configuration of a `Module`. |
50 | /// |
51 | /// The first two kinds of state are special in that they may be registered |
52 | /// with the `Module` system to allow convenient access and batch configuration. |
53 | /// For example, registered parameters in any `Module` may be iterated over via |
54 | /// the `parameters()` accessor. Further, changing the data type of a `Module`'s |
55 | /// registered parameters can be done conveniently via `Module::to()`, e.g. |
56 | /// `module->to(torch::kCUDA)` to move all parameters to GPU memory. Lastly, |
57 | /// registered parameters and buffers are handled specially during a `clone()` |
58 | /// operation, which performs a deepcopy of a cloneable `Module` hierarchy. |
59 | /// |
60 | /// Parameters are registered with a `Module` via `register_parameter`. Buffers |
61 | /// are registered separately via `register_buffer`. These methods are part of |
62 | /// the public API of `Module` and are typically invoked from within a |
63 | /// concrete `Module`s constructor. |
64 | class TORCH_API Module : public std::enable_shared_from_this<Module> { |
65 | public: |
66 | using ModuleApplyFunction = std::function<void(Module&)>; |
67 | using ConstModuleApplyFunction = std::function<void(const Module&)>; |
68 | using NamedModuleApplyFunction = |
69 | std::function<void(const std::string&, Module&)>; |
70 | using ConstNamedModuleApplyFunction = |
71 | std::function<void(const std::string&, const Module&)>; |
72 | using ModulePointerApplyFunction = |
73 | std::function<void(const std::shared_ptr<Module>&)>; |
74 | using NamedModulePointerApplyFunction = |
75 | std::function<void(const std::string&, const std::shared_ptr<Module>&)>; |
76 | |
77 | /// Tells the base `Module` about the name of the submodule. |
78 | explicit Module(std::string name); |
79 | |
80 | /// Constructs the module without immediate knowledge of the submodule's name. |
81 | /// The name of the submodule is inferred via RTTI (if possible) the first |
82 | /// time `.name()` is invoked. |
83 | Module(); |
84 | |
85 | virtual ~Module() = default; |
86 | |
87 | /// Returns the name of the `Module`. |
88 | /// |
89 | /// A `Module` has an associated `name`, which is a string representation of |
90 | /// the kind of concrete `Module` it represents, such as `"Linear"` for the |
91 | /// `Linear` module. Under most circumstances, this name is automatically |
92 | /// inferred via runtime type information (RTTI). In the unusual circumstance |
93 | /// that you have this feature disabled, you may want to manually name your |
94 | /// `Module`s by passing the string name to the `Module` base class' |
95 | /// constructor. |
96 | const std::string& name() const noexcept; |
97 | |
98 | /// Performs a recursive deep copy of the module and all its registered |
99 | /// parameters, buffers and submodules. |
100 | /// |
101 | /// Optionally, this method sets the current device |
102 | /// to the one supplied before cloning. If no device is given, each |
103 | /// parameter and buffer will be moved to the device of its source. |
104 | /// |
105 | /// \rst |
106 | /// .. attention:: |
107 | /// Attempting to call the `clone()` method inherited from the base `Module` |
108 | /// class (the one documented here) will fail. To inherit an actual |
109 | /// implementation of `clone()`, you must subclass `Cloneable`. `Cloneable` |
110 | /// is templatized on the concrete module type, and can thus properly copy a |
111 | /// `Module`. This method is provided on the base class' API solely for an |
112 | /// easier-to-use polymorphic interface. |
113 | /// \endrst |
114 | virtual std::shared_ptr<Module> clone( |
115 | const optional<Device>& device = nullopt) const; |
116 | |
117 | /// Applies the `function` to the `Module` and recursively to every submodule. |
118 | /// The function must accept a `Module&`. |
119 | /// |
120 | /// \rst |
121 | /// .. code-block:: cpp |
122 | /// MyModule module; |
123 | /// module->apply([](nn::Module& module) { |
124 | /// std::cout << module.name() << std::endl; |
125 | /// }); |
126 | /// \endrst |
127 | void apply(const ModuleApplyFunction& function); |
128 | |
129 | /// Applies the `function` to the `Module` and recursively to every submodule. |
130 | /// The function must accept a `const Module&`. |
131 | /// |
132 | /// \rst |
133 | /// .. code-block:: cpp |
134 | /// MyModule module; |
135 | /// module->apply([](const nn::Module& module) { |
136 | /// std::cout << module.name() << std::endl; |
137 | /// }); |
138 | /// \endrst |
139 | void apply(const ConstModuleApplyFunction& function) const; |
140 | |
141 | /// Applies the `function` to the `Module` and recursively to every submodule. |
142 | /// The function must accept a `const std::string&` for the key of the module, |
143 | /// and a `Module&`. The key of the module itself is the empty string. If |
144 | /// `name_prefix` is given, it is prepended to every key as |
145 | /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself). |
146 | /// |
147 | /// \rst |
148 | /// .. code-block:: cpp |
149 | /// MyModule module; |
150 | /// module->apply([](const std::string& key, nn::Module& module) { |
151 | /// std::cout << key << ": " << module.name() << std::endl; |
152 | /// }); |
153 | /// \endrst |
154 | void apply( |
155 | const NamedModuleApplyFunction& function, |
156 | const std::string& name_prefix = std::string()); |
157 | |
158 | /// Applies the `function` to the `Module` and recursively to every submodule. |
159 | /// The function must accept a `const std::string&` for the key of the module, |
160 | /// and a `const Module&`. The key of the module itself is the empty string. |
161 | /// If `name_prefix` is given, it is prepended to every key as |
162 | /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself). |
163 | /// |
164 | /// \rst |
165 | /// .. code-block:: cpp |
166 | /// MyModule module; |
167 | /// module->apply([](const std::string& key, const nn::Module& module) { |
168 | /// std::cout << key << ": " << module.name() << std::endl; |
169 | /// }); |
170 | /// \endrst |
171 | void apply( |
172 | const ConstNamedModuleApplyFunction& function, |
173 | const std::string& name_prefix = std::string()) const; |
174 | |
175 | /// Applies the `function` to the `Module` and recursively to every submodule. |
176 | /// The function must accept a `const std::shared_ptr<Module>&`. |
177 | /// |
178 | /// \rst |
179 | /// .. code-block:: cpp |
180 | /// MyModule module; |
181 | /// module->apply([](const std::shared_ptr<nn::Module>& module) { |
182 | /// std::cout << module->name() << std::endl; |
183 | /// }); |
184 | /// \endrst |
185 | void apply(const ModulePointerApplyFunction& function) const; |
186 | |
187 | /// Applies the `function` to the `Module` and recursively to every submodule. |
188 | /// The function must accept a `const std::string&` for the key of the module, |
189 | /// and a `const std::shared_ptr<Module>&`. The key of the module itself is |
190 | /// the empty string. If `name_prefix` is given, it is prepended to every key |
191 | /// as |
192 | /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself). |
193 | /// |
194 | /// \rst |
195 | /// .. code-block:: cpp |
196 | /// MyModule module; |
197 | /// module->apply([](const std::string& key, |
198 | /// const std::shared_ptr<nn::Module>& module) { |
199 | /// std::cout << key << ": " << module->name() << std::endl; |
200 | /// }); |
201 | /// \endrst |
202 | void apply( |
203 | const NamedModulePointerApplyFunction& function, |
204 | const std::string& name_prefix = std::string()) const; |
205 | |
206 | /// Returns the parameters of this `Module` and if `recurse` is true, also |
207 | /// recursively of every submodule. |
208 | std::vector<Tensor> parameters(bool recurse = true) const; |
209 | |
210 | /// Returns an `OrderedDict` with the parameters of this `Module` along with |
211 | /// their keys, and if `recurse` is true also recursively of every submodule. |
212 | OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const; |
213 | |
214 | /// Returns the buffers of this `Module` and if `recurse` is true, also |
215 | /// recursively of every submodule. |
216 | std::vector<Tensor> buffers(bool recurse = true) const; |
217 | |
218 | /// Returns an `OrderedDict` with the buffers of this `Module` along with |
219 | /// their keys, and if `recurse` is true also recursively of every submodule. |
220 | OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const; |
221 | |
222 | /// Returns the submodules of this `Module` (the entire submodule hierarchy) |
223 | /// and if `include_self` is true, also inserts a `shared_ptr` to this module |
224 | /// in the first position. |
225 | /// |
226 | /// \rst |
227 | /// .. warning:: |
228 | /// Only pass `include_self` as `true` if this `Module` is stored in a |
229 | /// `shared_ptr`! Otherwise an exception will be thrown. You may still call |
230 | /// this method with `include_self` set to false if your `Module` is not |
231 | /// stored in a `shared_ptr`. |
232 | /// \endrst |
233 | std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const; |
234 | |
235 | /// Returns an `OrderedDict` of the submodules of this `Module` (the entire |
236 | /// submodule hierarchy) and their keys, and if `include_self` is true, also |
237 | /// inserts a `shared_ptr` to this module in the first position. If |
238 | /// `name_prefix` is given, it is prepended to every key as |
239 | /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself). |
240 | /// |
241 | /// \rst |
242 | /// .. warning:: |
243 | /// Only pass `include_self` as `true` if this `Module` is stored in a |
244 | /// `shared_ptr`! Otherwise an exception will be thrown. You may still call |
245 | /// this method with `include_self` set to false if your `Module` is not |
246 | /// stored in a `shared_ptr`. |
247 | /// \endrst |
248 | OrderedDict<std::string, std::shared_ptr<Module>> named_modules( |
249 | const std::string& name_prefix = std::string(), |
250 | bool include_self = true) const; |
251 | |
252 | /// Returns the direct submodules of this `Module`. |
253 | std::vector<std::shared_ptr<Module>> children() const; |
254 | |
255 | /// Returns an `OrderedDict` of the direct submodules of this `Module` and |
256 | /// their keys. |
257 | OrderedDict<std::string, std::shared_ptr<Module>> named_children() const; |
258 | |
259 | /// Enables "training" mode. |
260 | virtual void train(bool on = true); |
261 | |
262 | /// Calls train(false) to enable "eval" mode. |
263 | /// Do not override this method, override `train()` instead. |
264 | void eval(); |
265 | |
266 | /// True if the module is in training mode. |
267 | /// |
268 | /// Every `Module` has a boolean associated with it that determines whether |
269 | /// the `Module` is currently in *training* mode (set via `.train()`) or in |
270 | /// *evaluation* (inference) mode (set via `.eval()`). This property is |
271 | /// exposed via `is_training()`, and may be used by the implementation of a |
272 | /// concrete module to modify its runtime behavior. See the `BatchNorm` or |
273 | /// `Dropout` modules for examples of `Module`s that use different code paths |
274 | /// depending on this property. |
275 | virtual bool is_training() const noexcept; |
276 | |
277 | /// Recursively casts all parameters to the given `dtype` and `device`. |
278 | /// |
279 | /// If `non_blocking` is true and the source is in pinned memory and |
280 | /// destination is on the GPU or vice versa, the copy is performed |
281 | /// asynchronously with respect to the host. Otherwise, the argument has no |
282 | /// effect. |
283 | virtual void to( |
284 | torch::Device device, |
285 | torch::Dtype dtype, |
286 | bool non_blocking = false); |
287 | |
288 | /// Recursively casts all parameters to the given dtype. |
289 | /// |
290 | /// If `non_blocking` is true and the source is in pinned memory and |
291 | /// destination is on the GPU or vice versa, the copy is performed |
292 | /// asynchronously with respect to the host. Otherwise, the argument has no |
293 | /// effect. |
294 | virtual void to(torch::Dtype dtype, bool non_blocking = false); |
295 | |
296 | /// Recursively moves all parameters to the given device. |
297 | /// |
298 | /// If `non_blocking` is true and the source is in pinned memory and |
299 | /// destination is on the GPU or vice versa, the copy is performed |
300 | /// asynchronously with respect to the host. Otherwise, the argument has no |
301 | /// effect. |
302 | virtual void to(torch::Device device, bool non_blocking = false); |
303 | |
304 | /// Recursively zeros out the `grad` value of each registered parameter. |
305 | virtual void zero_grad(bool set_to_none = true); |
306 | |
307 | /// Attempts to cast this `Module` to the given `ModuleType`. |
308 | /// |
309 | /// This method is useful when calling `apply()`. |
310 | /// \rst |
311 | /// .. code-block:: cpp |
312 | /// |
313 | /// void initialize_weights(nn::Module& module) { |
314 | /// torch::NoGradGuard no_grad; |
315 | /// if (auto* linear = module.as<nn::Linear>()) { |
316 | /// linear->weight.normal_(0.0, 0.02); |
317 | /// } |
318 | /// } |
319 | /// |
320 | /// MyModule module; |
321 | /// module->apply(initialize_weights); |
322 | /// \endrst |
323 | template <typename ModuleType> |
324 | typename ModuleType::ContainedType* as() noexcept; |
325 | |
326 | /// Attempts to cast this `Module` to the given `ModuleType`. |
327 | /// |
328 | /// This method is useful when calling `apply()`. |
329 | /// \rst |
330 | /// .. code-block:: cpp |
331 | /// void initialize_weights(nn::Module& module) { |
332 | /// torch::NoGradGuard no_grad; |
333 | /// if (auto* linear = module.as<nn::Linear>()) { |
334 | /// linear->weight.normal_(0.0, 0.02); |
335 | /// } |
336 | /// } |
337 | /// |
338 | /// MyModule module; |
339 | /// module->apply(initialize_weights); |
340 | /// \endrst |
341 | template <typename ModuleType> |
342 | const typename ModuleType::ContainedType* as() const noexcept; |
343 | |
344 | /// Attempts to cast this `Module` to the given `ModuleType`. |
345 | /// |
346 | /// This method is useful when calling `apply()`. |
347 | /// \rst |
348 | /// .. code-block:: cpp |
349 | /// |
350 | /// void initialize_weights(nn::Module& module) { |
351 | /// torch::NoGradGuard no_grad; |
352 | /// if (auto* linear = module.as<nn::Linear>()) { |
353 | /// linear->weight.normal_(0.0, 0.02); |
354 | /// } |
355 | /// } |
356 | /// |
357 | /// MyModule module; |
358 | /// module.apply(initialize_weights); |
359 | /// \endrst |
360 | template < |
361 | typename ModuleType, |
362 | typename = torch::detail::disable_if_module_holder_t<ModuleType>> |
363 | ModuleType* as() noexcept; |
364 | |
365 | /// Attempts to cast this `Module` to the given `ModuleType`. |
366 | /// |
367 | /// This method is useful when calling `apply()`. |
368 | /// \rst |
369 | /// .. code-block:: cpp |
370 | /// |
371 | /// void initialize_weights(nn::Module& module) { |
372 | /// torch::NoGradGuard no_grad; |
373 | /// if (auto* linear = module.as<nn::Linear>()) { |
374 | /// linear->weight.normal_(0.0, 0.02); |
375 | /// } |
376 | /// } |
377 | /// |
378 | /// MyModule module; |
379 | /// module.apply(initialize_weights); |
380 | /// \endrst |
381 | template < |
382 | typename ModuleType, |
383 | typename = torch::detail::disable_if_module_holder_t<ModuleType>> |
384 | const ModuleType* as() const noexcept; |
385 | |
386 | /// Serializes the `Module` into the given `OutputArchive`. |
387 | /// |
388 | /// If the `Module` contains unserializable submodules (e.g. |
389 | /// `nn::Functional`), those submodules are skipped when serializing. |
390 | virtual void save(serialize::OutputArchive& archive) const; |
391 | |
392 | /// Deserializes the `Module` from the given `InputArchive`. |
393 | /// |
394 | /// If the `Module` contains unserializable submodules (e.g. |
395 | /// `nn::Functional`), we don't check the existence of those submodules in the |
396 | /// `InputArchive` when deserializing. |
397 | virtual void load(serialize::InputArchive& archive); |
398 | |
399 | /// Streams a pretty representation of the `Module` into the given `stream`. |
400 | /// By default, this representation will be the name of the module (taken from |
401 | /// `name()`), followed by a recursive pretty print of all of the `Module`'s |
402 | /// submodules. |
403 | /// |
404 | /// Override this method to change the pretty print. The input |
405 | /// `stream` should be returned from the method, to allow easy chaining. |
406 | virtual void pretty_print(std::ostream& stream) const; |
407 | |
408 | /// Returns whether the `Module` is serializable. |
409 | virtual bool is_serializable() const; |
410 | |
411 | /// Registers a parameter with this `Module`. |
412 | /// |
413 | /// A parameter should be any gradient-recording tensor used in the |
414 | /// implementation of your `Module`. Registering it makes it available to |
415 | /// methods such as `parameters()`, `clone()` or `to().` |
416 | /// |
417 | /// Note that registering an undefined Tensor (e.g. |
418 | /// `module.register_parameter("param", Tensor())`) is allowed, and is |
419 | /// equivalent to `module.register_parameter("param", None)` in Python API. |
420 | /// |
421 | /// \rst |
422 | /// .. code-block:: cpp |
423 | /// |
424 | /// MyModule::MyModule() { |
425 | /// weight_ = register_parameter("weight", torch::randn({A, B})); |
426 | /// } |
427 | /// \endrst |
428 | Tensor& register_parameter( |
429 | std::string name, |
430 | Tensor tensor, |
431 | bool requires_grad = true); |
432 | |
433 | /// Registers a buffer with this `Module`. |
434 | /// |
435 | /// A buffer is intended to be state in your module that does not record |
436 | /// gradients, such as running statistics. Registering it makes it available |
437 | /// to methods such as `buffers()`, `clone()` or `to(). |
438 | /// |
439 | /// \rst |
440 | /// .. code-block:: cpp |
441 | /// |
442 | /// MyModule::MyModule() { |
443 | /// mean_ = register_buffer("mean", torch::empty({num_features_})); |
444 | /// } |
445 | /// \endrst |
446 | Tensor& register_buffer(std::string name, Tensor tensor); |
447 | |
448 | /// Registers a submodule with this `Module`. |
449 | /// |
450 | /// Registering a module makes it available to methods such as `modules()`, |
451 | /// `clone()` or `to()`. |
452 | /// |
453 | /// \rst |
454 | /// .. code-block:: cpp |
455 | /// |
456 | /// MyModule::MyModule() { |
457 | /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); |
458 | /// } |
459 | /// \endrst |
460 | template <typename ModuleType> |
461 | std::shared_ptr<ModuleType> register_module( |
462 | std::string name, |
463 | std::shared_ptr<ModuleType> module); |
464 | |
465 | /// Registers a submodule with this `Module`. |
466 | /// |
467 | /// This method deals with `ModuleHolder`s. |
468 | /// |
469 | /// Registering a module makes it available to methods such as `modules()`, |
470 | /// `clone()` or `to()`. |
471 | /// |
472 | /// \rst |
473 | /// .. code-block:: cpp |
474 | /// |
475 | /// MyModule::MyModule() { |
476 | /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); |
477 | /// } |
478 | /// \endrst |
479 | template <typename ModuleType> |
480 | std::shared_ptr<ModuleType> register_module( |
481 | std::string name, |
482 | ModuleHolder<ModuleType> module_holder); |
483 | |
484 | /// Replaces a registered submodule with this `Module`. |
485 | /// |
486 | /// This takes care of the registration, if you used submodule members, you |
487 | /// should |
488 | // assign the submodule as well, i.e. use as |
489 | /// module->submodule_ = module->replace_module("linear", |
490 | /// torch::nn::Linear(3, 4)); |
491 | /// It only works when a module of the name is already registered. |
492 | /// |
493 | /// This is useful for replacing a module after initialization, e.g. |
494 | /// for finetuning. |
495 | template <typename ModuleType> |
496 | std::shared_ptr<ModuleType> replace_module( |
497 | const std::string& name, |
498 | std::shared_ptr<ModuleType> module); |
499 | |
500 | /// Replaces a registered submodule with this `Module`. |
501 | /// This method deals with `ModuleHolder`s. |
502 | /// |
503 | /// This takes care of the registration, if you used submodule members, you |
504 | /// should |
505 | // assign the submodule as well, i.e. use as |
506 | /// module->submodule_ = module->replace_module("linear", linear_holder); |
507 | /// It only works when a module of the name is already registered. |
508 | /// |
509 | /// This is useful for replacing a module after initialization, e.g. |
510 | /// for finetuning. |
511 | template <typename ModuleType> |
512 | std::shared_ptr<ModuleType> replace_module( |
513 | const std::string& name, |
514 | ModuleHolder<ModuleType> module_holder); |
515 | |
516 | /// Unregisters a submodule from this `Module`. If there is no such module |
517 | /// with `name` an exception is thrown. |
518 | void unregister_module(const std::string& name); |
519 | |
520 | protected: |
521 | /// The following three functions allow a module with default arguments in its |
522 | /// forward method to be used in a Sequential module. |
523 | /// You should NEVER override these functions manually. Instead, you should |
524 | /// use the `FORWARD_HAS_DEFAULT_ARGS` macro. |
525 | virtual bool _forward_has_default_args() { |
526 | return false; |
527 | } |
528 | |
529 | virtual unsigned int _forward_num_required_args() { |
530 | TORCH_CHECK( |
531 | false, |
532 | "torch::nn::Module subclass that has default arguments in `forward` method " , |
533 | "must override `_forward_num_required_args` method. Please use " , |
534 | "`FORWARD_HAS_DEFAULT_ARGS` macro to do so." ); |
535 | } |
536 | |
537 | virtual std::vector<AnyValue> _forward_populate_default_args( |
538 | std::vector<AnyValue>&& arguments) { |
539 | TORCH_CHECK( |
540 | false, |
541 | "torch::nn::Module subclass that has default arguments in `forward` method " , |
542 | "must override `_forward_populate_default_args` method. Please use " , |
543 | "`FORWARD_HAS_DEFAULT_ARGS` macro to do so." ); |
544 | } |
545 | |
546 | /// The registered parameters of this `Module`. |
547 | /// Inorder to access parameters_ in ParameterDict and ParameterList |
548 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
549 | OrderedDict<std::string, Tensor> parameters_; |
550 | |
551 | private: |
552 | // Friend classes. |
553 | |
554 | template <typename Derived> |
555 | friend class Cloneable; |
556 | |
557 | template <typename ModuleType, typename... ArgumentTypes> |
558 | friend struct AnyModuleHolder; |
559 | |
560 | /// Pretty prints the given `Module` into the `ostream`. |
561 | TORCH_API friend std::ostream& operator<<( |
562 | std::ostream& stream, |
563 | const nn::Module& module); |
564 | |
565 | // data parallel using this method to configure gradient edges during the |
566 | // replicate step. |
567 | template <typename ModuleType> |
568 | friend void replicate_grad_edges( |
569 | const std::shared_ptr<Module>& module, |
570 | const std::vector<std::shared_ptr<ModuleType>>& replicas, |
571 | const std::vector<Device>& devices); |
572 | |
573 | // Private methods. |
574 | |
575 | /// Used in the implementation of `Cloneable`. |
576 | virtual void clone_(Module& other, const optional<Device>& device); |
577 | |
578 | /// The implementation of the various `to()` methods. |
579 | template <typename... Ts> |
580 | void to_impl(Ts&&... ts); |
581 | |
582 | /// Implements pretty printing the module hierarchy. |
583 | void pretty_print_recursive( |
584 | std::ostream& stream, |
585 | const std::string& indentation) const; |
586 | |
587 | /// Applies the `function` to every submodule recursively, starting at this |
588 | /// `Module`'s children (thus not including the module itself). |
589 | void apply_to_submodules( |
590 | const NamedModulePointerApplyFunction& function, |
591 | const std::string& name_prefix = std::string()) const; |
592 | |
593 | /// Returns a shared_ptr to `this` in a safe (checked) way. |
594 | std::shared_ptr<Module> shared_from_this_checked() const; |
595 | |
596 | /// The registered buffers of this `Module`. |
597 | OrderedDict<std::string, Tensor> buffers_; |
598 | |
599 | /// The registered (direct) submodules of this `Module`. |
600 | OrderedDict<std::string, std::shared_ptr<Module>> children_; |
601 | |
602 | /// The module's name (e.g. "LSTM"). |
603 | mutable optional<std::string> name_; |
604 | |
605 | /// Whether the module is in training mode. |
606 | bool is_training_{true}; |
607 | }; |
608 | |
609 | /// Serialize a `Module` pointer into an `OutputArchive`. |
610 | TORCH_API serialize::OutputArchive& operator<<( |
611 | serialize::OutputArchive& archive, |
612 | const std::shared_ptr<nn::Module>& module); |
613 | |
614 | /// Deserializes a `Module` from an `InputArchive`. |
615 | TORCH_API serialize::InputArchive& operator>>( |
616 | serialize::InputArchive& archive, |
617 | const std::shared_ptr<nn::Module>& module); |
618 | |
619 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
620 | |
621 | template <typename ModuleType> |
622 | typename ModuleType::ContainedType* Module::as() noexcept { |
623 | // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for |
624 | // `Linear`, since `LinearImpl` inherits `nn::Module`. |
625 | return as<typename ModuleType::ContainedType>(); |
626 | } |
627 | |
628 | template <typename ModuleType> |
629 | const typename ModuleType::ContainedType* Module::as() const noexcept { |
630 | // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for |
631 | // `Linear`, since `LinearImpl` inherits `nn::Module`. |
632 | return as<typename ModuleType::ContainedType>(); |
633 | } |
634 | |
635 | template <typename ModuleType, typename> |
636 | ModuleType* Module::as() noexcept { |
637 | return dynamic_cast<ModuleType*>(this); |
638 | } |
639 | |
640 | template <typename ModuleType, typename> |
641 | const ModuleType* Module::as() const noexcept { |
642 | return dynamic_cast<const ModuleType*>(this); |
643 | } |
644 | |
645 | template <typename ModuleType> |
646 | std::shared_ptr<ModuleType> Module::register_module( |
647 | std::string name, |
648 | std::shared_ptr<ModuleType> module) { |
649 | TORCH_CHECK(!name.empty(), "Submodule name must not be empty" ); |
650 | TORCH_CHECK( |
651 | name.find('.') == std::string::npos, |
652 | "Submodule name must not contain a dot (got '" , |
653 | name, |
654 | "')" ); |
655 | auto& base_module = children_.insert(std::move(name), std::move(module)); |
656 | return std::dynamic_pointer_cast<ModuleType>(base_module); |
657 | } |
658 | |
659 | template <typename ModuleType> |
660 | std::shared_ptr<ModuleType> Module::register_module( |
661 | std::string name, |
662 | ModuleHolder<ModuleType> module_holder) { |
663 | return register_module(std::move(name), module_holder.ptr()); |
664 | } |
665 | |
666 | template <typename ModuleType> |
667 | std::shared_ptr<ModuleType> Module::replace_module( |
668 | const std::string& name, |
669 | std::shared_ptr<ModuleType> module) { |
670 | auto& base_module = (children_[name] = std::move(module)); |
671 | return std::dynamic_pointer_cast<ModuleType>(base_module); |
672 | } |
673 | |
674 | template <typename ModuleType> |
675 | std::shared_ptr<ModuleType> Module::replace_module( |
676 | const std::string& name, |
677 | ModuleHolder<ModuleType> module_holder) { |
678 | return replace_module(name, module_holder.ptr()); |
679 | } |
680 | |
681 | template <typename... Ts> |
682 | void Module::to_impl(Ts&&... ts) { |
683 | // First call `to()` on every child module. |
684 | for (auto& child : children_) { |
685 | child.value()->to(ts...); |
686 | } |
687 | // Then move every parameter to the new dtype/device. |
688 | for (auto& parameter : named_parameters(/*recurse=*/false)) { |
689 | parameter->set_data(autograd::Variable(*parameter).to(ts...)); |
690 | } |
691 | // Then move every buffer to the new dtype/device. |
692 | for (auto& buffer : named_buffers(/*recurse=*/false)) { |
693 | buffer->set_data(autograd::Variable(*buffer).to(ts...)); |
694 | } |
695 | } |
696 | |
697 | } // namespace nn |
698 | } // namespace torch |
699 | |