1#pragma once
2
3#include <torch/nn/modules/container/any_module_holder.h>
4#include <torch/nn/modules/container/any_value.h>
5#include <torch/nn/pimpl.h>
6#include <torch/ordered_dict.h>
7#include <torch/serialize/archive.h>
8#include <torch/types.h>
9
10#include <ATen/ATen.h>
11
12#include <functional>
13#include <iosfwd>
14#include <map>
15#include <memory>
16#include <string>
17#include <type_traits>
18
19namespace torch {
20namespace nn {
21
22/// The base class for all modules in PyTorch.
23///
24/// \rst
25/// .. note::
26/// The design and implementation of this class is largely based on the Python
27/// API. You may want to consult the python documentation for
28/// :py:class:`pytorch:torch.nn.Module` for further clarification on certain
29/// methods or behavior.
30/// \endrst
31///
32/// A `Module` is an abstraction over the implementation of some function or
33/// algorithm, possibly associated with some persistent data. A `Module` may
34/// contain further `Module`s ("submodules"), each with their own
35/// implementation, persistent data and further submodules. `Module`s can thus
36/// be said to form a recursive tree structure. A `Module` is registered as a
37/// submodule to another `Module` by calling `register_module()`, typically from
38/// within a parent module's constructor.
39///
40/// A distinction is made between three kinds of persistent data that may be
41/// associated with a `Module`:
42///
43/// 1. *Parameters*: tensors that record gradients, typically weights updated
44/// during the backward step (e.g. the `weight` of a `Linear` module),
45/// 2. *Buffers*: tensors that do not record gradients, typically updated during
46/// the forward step, such as running statistics (e.g. `mean` and `variance`
47/// in the `BatchNorm` module),
48/// 3. Any additional state, not necessarily tensors, required for the
49/// implementation or configuration of a `Module`.
50///
51/// The first two kinds of state are special in that they may be registered
52/// with the `Module` system to allow convenient access and batch configuration.
53/// For example, registered parameters in any `Module` may be iterated over via
54/// the `parameters()` accessor. Further, changing the data type of a `Module`'s
55/// registered parameters can be done conveniently via `Module::to()`, e.g.
56/// `module->to(torch::kCUDA)` to move all parameters to GPU memory. Lastly,
57/// registered parameters and buffers are handled specially during a `clone()`
58/// operation, which performs a deepcopy of a cloneable `Module` hierarchy.
59///
60/// Parameters are registered with a `Module` via `register_parameter`. Buffers
61/// are registered separately via `register_buffer`. These methods are part of
62/// the public API of `Module` and are typically invoked from within a
63/// concrete `Module`s constructor.
64class TORCH_API Module : public std::enable_shared_from_this<Module> {
65 public:
66 using ModuleApplyFunction = std::function<void(Module&)>;
67 using ConstModuleApplyFunction = std::function<void(const Module&)>;
68 using NamedModuleApplyFunction =
69 std::function<void(const std::string&, Module&)>;
70 using ConstNamedModuleApplyFunction =
71 std::function<void(const std::string&, const Module&)>;
72 using ModulePointerApplyFunction =
73 std::function<void(const std::shared_ptr<Module>&)>;
74 using NamedModulePointerApplyFunction =
75 std::function<void(const std::string&, const std::shared_ptr<Module>&)>;
76
77 /// Tells the base `Module` about the name of the submodule.
78 explicit Module(std::string name);
79
80 /// Constructs the module without immediate knowledge of the submodule's name.
81 /// The name of the submodule is inferred via RTTI (if possible) the first
82 /// time `.name()` is invoked.
83 Module();
84
85 virtual ~Module() = default;
86
87 /// Returns the name of the `Module`.
88 ///
89 /// A `Module` has an associated `name`, which is a string representation of
90 /// the kind of concrete `Module` it represents, such as `"Linear"` for the
91 /// `Linear` module. Under most circumstances, this name is automatically
92 /// inferred via runtime type information (RTTI). In the unusual circumstance
93 /// that you have this feature disabled, you may want to manually name your
94 /// `Module`s by passing the string name to the `Module` base class'
95 /// constructor.
96 const std::string& name() const noexcept;
97
98 /// Performs a recursive deep copy of the module and all its registered
99 /// parameters, buffers and submodules.
100 ///
101 /// Optionally, this method sets the current device
102 /// to the one supplied before cloning. If no device is given, each
103 /// parameter and buffer will be moved to the device of its source.
104 ///
105 /// \rst
106 /// .. attention::
107 /// Attempting to call the `clone()` method inherited from the base `Module`
108 /// class (the one documented here) will fail. To inherit an actual
109 /// implementation of `clone()`, you must subclass `Cloneable`. `Cloneable`
110 /// is templatized on the concrete module type, and can thus properly copy a
111 /// `Module`. This method is provided on the base class' API solely for an
112 /// easier-to-use polymorphic interface.
113 /// \endrst
114 virtual std::shared_ptr<Module> clone(
115 const optional<Device>& device = nullopt) const;
116
117 /// Applies the `function` to the `Module` and recursively to every submodule.
118 /// The function must accept a `Module&`.
119 ///
120 /// \rst
121 /// .. code-block:: cpp
122 /// MyModule module;
123 /// module->apply([](nn::Module& module) {
124 /// std::cout << module.name() << std::endl;
125 /// });
126 /// \endrst
127 void apply(const ModuleApplyFunction& function);
128
129 /// Applies the `function` to the `Module` and recursively to every submodule.
130 /// The function must accept a `const Module&`.
131 ///
132 /// \rst
133 /// .. code-block:: cpp
134 /// MyModule module;
135 /// module->apply([](const nn::Module& module) {
136 /// std::cout << module.name() << std::endl;
137 /// });
138 /// \endrst
139 void apply(const ConstModuleApplyFunction& function) const;
140
141 /// Applies the `function` to the `Module` and recursively to every submodule.
142 /// The function must accept a `const std::string&` for the key of the module,
143 /// and a `Module&`. The key of the module itself is the empty string. If
144 /// `name_prefix` is given, it is prepended to every key as
145 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
146 ///
147 /// \rst
148 /// .. code-block:: cpp
149 /// MyModule module;
150 /// module->apply([](const std::string& key, nn::Module& module) {
151 /// std::cout << key << ": " << module.name() << std::endl;
152 /// });
153 /// \endrst
154 void apply(
155 const NamedModuleApplyFunction& function,
156 const std::string& name_prefix = std::string());
157
158 /// Applies the `function` to the `Module` and recursively to every submodule.
159 /// The function must accept a `const std::string&` for the key of the module,
160 /// and a `const Module&`. The key of the module itself is the empty string.
161 /// If `name_prefix` is given, it is prepended to every key as
162 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
163 ///
164 /// \rst
165 /// .. code-block:: cpp
166 /// MyModule module;
167 /// module->apply([](const std::string& key, const nn::Module& module) {
168 /// std::cout << key << ": " << module.name() << std::endl;
169 /// });
170 /// \endrst
171 void apply(
172 const ConstNamedModuleApplyFunction& function,
173 const std::string& name_prefix = std::string()) const;
174
175 /// Applies the `function` to the `Module` and recursively to every submodule.
176 /// The function must accept a `const std::shared_ptr<Module>&`.
177 ///
178 /// \rst
179 /// .. code-block:: cpp
180 /// MyModule module;
181 /// module->apply([](const std::shared_ptr<nn::Module>& module) {
182 /// std::cout << module->name() << std::endl;
183 /// });
184 /// \endrst
185 void apply(const ModulePointerApplyFunction& function) const;
186
187 /// Applies the `function` to the `Module` and recursively to every submodule.
188 /// The function must accept a `const std::string&` for the key of the module,
189 /// and a `const std::shared_ptr<Module>&`. The key of the module itself is
190 /// the empty string. If `name_prefix` is given, it is prepended to every key
191 /// as
192 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
193 ///
194 /// \rst
195 /// .. code-block:: cpp
196 /// MyModule module;
197 /// module->apply([](const std::string& key,
198 /// const std::shared_ptr<nn::Module>& module) {
199 /// std::cout << key << ": " << module->name() << std::endl;
200 /// });
201 /// \endrst
202 void apply(
203 const NamedModulePointerApplyFunction& function,
204 const std::string& name_prefix = std::string()) const;
205
206 /// Returns the parameters of this `Module` and if `recurse` is true, also
207 /// recursively of every submodule.
208 std::vector<Tensor> parameters(bool recurse = true) const;
209
210 /// Returns an `OrderedDict` with the parameters of this `Module` along with
211 /// their keys, and if `recurse` is true also recursively of every submodule.
212 OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const;
213
214 /// Returns the buffers of this `Module` and if `recurse` is true, also
215 /// recursively of every submodule.
216 std::vector<Tensor> buffers(bool recurse = true) const;
217
218 /// Returns an `OrderedDict` with the buffers of this `Module` along with
219 /// their keys, and if `recurse` is true also recursively of every submodule.
220 OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const;
221
222 /// Returns the submodules of this `Module` (the entire submodule hierarchy)
223 /// and if `include_self` is true, also inserts a `shared_ptr` to this module
224 /// in the first position.
225 ///
226 /// \rst
227 /// .. warning::
228 /// Only pass `include_self` as `true` if this `Module` is stored in a
229 /// `shared_ptr`! Otherwise an exception will be thrown. You may still call
230 /// this method with `include_self` set to false if your `Module` is not
231 /// stored in a `shared_ptr`.
232 /// \endrst
233 std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const;
234
235 /// Returns an `OrderedDict` of the submodules of this `Module` (the entire
236 /// submodule hierarchy) and their keys, and if `include_self` is true, also
237 /// inserts a `shared_ptr` to this module in the first position. If
238 /// `name_prefix` is given, it is prepended to every key as
239 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
240 ///
241 /// \rst
242 /// .. warning::
243 /// Only pass `include_self` as `true` if this `Module` is stored in a
244 /// `shared_ptr`! Otherwise an exception will be thrown. You may still call
245 /// this method with `include_self` set to false if your `Module` is not
246 /// stored in a `shared_ptr`.
247 /// \endrst
248 OrderedDict<std::string, std::shared_ptr<Module>> named_modules(
249 const std::string& name_prefix = std::string(),
250 bool include_self = true) const;
251
252 /// Returns the direct submodules of this `Module`.
253 std::vector<std::shared_ptr<Module>> children() const;
254
255 /// Returns an `OrderedDict` of the direct submodules of this `Module` and
256 /// their keys.
257 OrderedDict<std::string, std::shared_ptr<Module>> named_children() const;
258
259 /// Enables "training" mode.
260 virtual void train(bool on = true);
261
262 /// Calls train(false) to enable "eval" mode.
263 /// Do not override this method, override `train()` instead.
264 void eval();
265
266 /// True if the module is in training mode.
267 ///
268 /// Every `Module` has a boolean associated with it that determines whether
269 /// the `Module` is currently in *training* mode (set via `.train()`) or in
270 /// *evaluation* (inference) mode (set via `.eval()`). This property is
271 /// exposed via `is_training()`, and may be used by the implementation of a
272 /// concrete module to modify its runtime behavior. See the `BatchNorm` or
273 /// `Dropout` modules for examples of `Module`s that use different code paths
274 /// depending on this property.
275 virtual bool is_training() const noexcept;
276
277 /// Recursively casts all parameters to the given `dtype` and `device`.
278 ///
279 /// If `non_blocking` is true and the source is in pinned memory and
280 /// destination is on the GPU or vice versa, the copy is performed
281 /// asynchronously with respect to the host. Otherwise, the argument has no
282 /// effect.
283 virtual void to(
284 torch::Device device,
285 torch::Dtype dtype,
286 bool non_blocking = false);
287
288 /// Recursively casts all parameters to the given dtype.
289 ///
290 /// If `non_blocking` is true and the source is in pinned memory and
291 /// destination is on the GPU or vice versa, the copy is performed
292 /// asynchronously with respect to the host. Otherwise, the argument has no
293 /// effect.
294 virtual void to(torch::Dtype dtype, bool non_blocking = false);
295
296 /// Recursively moves all parameters to the given device.
297 ///
298 /// If `non_blocking` is true and the source is in pinned memory and
299 /// destination is on the GPU or vice versa, the copy is performed
300 /// asynchronously with respect to the host. Otherwise, the argument has no
301 /// effect.
302 virtual void to(torch::Device device, bool non_blocking = false);
303
304 /// Recursively zeros out the `grad` value of each registered parameter.
305 virtual void zero_grad(bool set_to_none = true);
306
307 /// Attempts to cast this `Module` to the given `ModuleType`.
308 ///
309 /// This method is useful when calling `apply()`.
310 /// \rst
311 /// .. code-block:: cpp
312 ///
313 /// void initialize_weights(nn::Module& module) {
314 /// torch::NoGradGuard no_grad;
315 /// if (auto* linear = module.as<nn::Linear>()) {
316 /// linear->weight.normal_(0.0, 0.02);
317 /// }
318 /// }
319 ///
320 /// MyModule module;
321 /// module->apply(initialize_weights);
322 /// \endrst
323 template <typename ModuleType>
324 typename ModuleType::ContainedType* as() noexcept;
325
326 /// Attempts to cast this `Module` to the given `ModuleType`.
327 ///
328 /// This method is useful when calling `apply()`.
329 /// \rst
330 /// .. code-block:: cpp
331 /// void initialize_weights(nn::Module& module) {
332 /// torch::NoGradGuard no_grad;
333 /// if (auto* linear = module.as<nn::Linear>()) {
334 /// linear->weight.normal_(0.0, 0.02);
335 /// }
336 /// }
337 ///
338 /// MyModule module;
339 /// module->apply(initialize_weights);
340 /// \endrst
341 template <typename ModuleType>
342 const typename ModuleType::ContainedType* as() const noexcept;
343
344 /// Attempts to cast this `Module` to the given `ModuleType`.
345 ///
346 /// This method is useful when calling `apply()`.
347 /// \rst
348 /// .. code-block:: cpp
349 ///
350 /// void initialize_weights(nn::Module& module) {
351 /// torch::NoGradGuard no_grad;
352 /// if (auto* linear = module.as<nn::Linear>()) {
353 /// linear->weight.normal_(0.0, 0.02);
354 /// }
355 /// }
356 ///
357 /// MyModule module;
358 /// module.apply(initialize_weights);
359 /// \endrst
360 template <
361 typename ModuleType,
362 typename = torch::detail::disable_if_module_holder_t<ModuleType>>
363 ModuleType* as() noexcept;
364
365 /// Attempts to cast this `Module` to the given `ModuleType`.
366 ///
367 /// This method is useful when calling `apply()`.
368 /// \rst
369 /// .. code-block:: cpp
370 ///
371 /// void initialize_weights(nn::Module& module) {
372 /// torch::NoGradGuard no_grad;
373 /// if (auto* linear = module.as<nn::Linear>()) {
374 /// linear->weight.normal_(0.0, 0.02);
375 /// }
376 /// }
377 ///
378 /// MyModule module;
379 /// module.apply(initialize_weights);
380 /// \endrst
381 template <
382 typename ModuleType,
383 typename = torch::detail::disable_if_module_holder_t<ModuleType>>
384 const ModuleType* as() const noexcept;
385
386 /// Serializes the `Module` into the given `OutputArchive`.
387 ///
388 /// If the `Module` contains unserializable submodules (e.g.
389 /// `nn::Functional`), those submodules are skipped when serializing.
390 virtual void save(serialize::OutputArchive& archive) const;
391
392 /// Deserializes the `Module` from the given `InputArchive`.
393 ///
394 /// If the `Module` contains unserializable submodules (e.g.
395 /// `nn::Functional`), we don't check the existence of those submodules in the
396 /// `InputArchive` when deserializing.
397 virtual void load(serialize::InputArchive& archive);
398
399 /// Streams a pretty representation of the `Module` into the given `stream`.
400 /// By default, this representation will be the name of the module (taken from
401 /// `name()`), followed by a recursive pretty print of all of the `Module`'s
402 /// submodules.
403 ///
404 /// Override this method to change the pretty print. The input
405 /// `stream` should be returned from the method, to allow easy chaining.
406 virtual void pretty_print(std::ostream& stream) const;
407
408 /// Returns whether the `Module` is serializable.
409 virtual bool is_serializable() const;
410
411 /// Registers a parameter with this `Module`.
412 ///
413 /// A parameter should be any gradient-recording tensor used in the
414 /// implementation of your `Module`. Registering it makes it available to
415 /// methods such as `parameters()`, `clone()` or `to().`
416 ///
417 /// Note that registering an undefined Tensor (e.g.
418 /// `module.register_parameter("param", Tensor())`) is allowed, and is
419 /// equivalent to `module.register_parameter("param", None)` in Python API.
420 ///
421 /// \rst
422 /// .. code-block:: cpp
423 ///
424 /// MyModule::MyModule() {
425 /// weight_ = register_parameter("weight", torch::randn({A, B}));
426 /// }
427 /// \endrst
428 Tensor& register_parameter(
429 std::string name,
430 Tensor tensor,
431 bool requires_grad = true);
432
433 /// Registers a buffer with this `Module`.
434 ///
435 /// A buffer is intended to be state in your module that does not record
436 /// gradients, such as running statistics. Registering it makes it available
437 /// to methods such as `buffers()`, `clone()` or `to().
438 ///
439 /// \rst
440 /// .. code-block:: cpp
441 ///
442 /// MyModule::MyModule() {
443 /// mean_ = register_buffer("mean", torch::empty({num_features_}));
444 /// }
445 /// \endrst
446 Tensor& register_buffer(std::string name, Tensor tensor);
447
448 /// Registers a submodule with this `Module`.
449 ///
450 /// Registering a module makes it available to methods such as `modules()`,
451 /// `clone()` or `to()`.
452 ///
453 /// \rst
454 /// .. code-block:: cpp
455 ///
456 /// MyModule::MyModule() {
457 /// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
458 /// }
459 /// \endrst
460 template <typename ModuleType>
461 std::shared_ptr<ModuleType> register_module(
462 std::string name,
463 std::shared_ptr<ModuleType> module);
464
465 /// Registers a submodule with this `Module`.
466 ///
467 /// This method deals with `ModuleHolder`s.
468 ///
469 /// Registering a module makes it available to methods such as `modules()`,
470 /// `clone()` or `to()`.
471 ///
472 /// \rst
473 /// .. code-block:: cpp
474 ///
475 /// MyModule::MyModule() {
476 /// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
477 /// }
478 /// \endrst
479 template <typename ModuleType>
480 std::shared_ptr<ModuleType> register_module(
481 std::string name,
482 ModuleHolder<ModuleType> module_holder);
483
484 /// Replaces a registered submodule with this `Module`.
485 ///
486 /// This takes care of the registration, if you used submodule members, you
487 /// should
488 // assign the submodule as well, i.e. use as
489 /// module->submodule_ = module->replace_module("linear",
490 /// torch::nn::Linear(3, 4));
491 /// It only works when a module of the name is already registered.
492 ///
493 /// This is useful for replacing a module after initialization, e.g.
494 /// for finetuning.
495 template <typename ModuleType>
496 std::shared_ptr<ModuleType> replace_module(
497 const std::string& name,
498 std::shared_ptr<ModuleType> module);
499
500 /// Replaces a registered submodule with this `Module`.
501 /// This method deals with `ModuleHolder`s.
502 ///
503 /// This takes care of the registration, if you used submodule members, you
504 /// should
505 // assign the submodule as well, i.e. use as
506 /// module->submodule_ = module->replace_module("linear", linear_holder);
507 /// It only works when a module of the name is already registered.
508 ///
509 /// This is useful for replacing a module after initialization, e.g.
510 /// for finetuning.
511 template <typename ModuleType>
512 std::shared_ptr<ModuleType> replace_module(
513 const std::string& name,
514 ModuleHolder<ModuleType> module_holder);
515
516 /// Unregisters a submodule from this `Module`. If there is no such module
517 /// with `name` an exception is thrown.
518 void unregister_module(const std::string& name);
519
520 protected:
521 /// The following three functions allow a module with default arguments in its
522 /// forward method to be used in a Sequential module.
523 /// You should NEVER override these functions manually. Instead, you should
524 /// use the `FORWARD_HAS_DEFAULT_ARGS` macro.
525 virtual bool _forward_has_default_args() {
526 return false;
527 }
528
529 virtual unsigned int _forward_num_required_args() {
530 TORCH_CHECK(
531 false,
532 "torch::nn::Module subclass that has default arguments in `forward` method ",
533 "must override `_forward_num_required_args` method. Please use ",
534 "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
535 }
536
537 virtual std::vector<AnyValue> _forward_populate_default_args(
538 std::vector<AnyValue>&& arguments) {
539 TORCH_CHECK(
540 false,
541 "torch::nn::Module subclass that has default arguments in `forward` method ",
542 "must override `_forward_populate_default_args` method. Please use ",
543 "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
544 }
545
546 /// The registered parameters of this `Module`.
547 /// Inorder to access parameters_ in ParameterDict and ParameterList
548 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
549 OrderedDict<std::string, Tensor> parameters_;
550
551 private:
552 // Friend classes.
553
554 template <typename Derived>
555 friend class Cloneable;
556
557 template <typename ModuleType, typename... ArgumentTypes>
558 friend struct AnyModuleHolder;
559
560 /// Pretty prints the given `Module` into the `ostream`.
561 TORCH_API friend std::ostream& operator<<(
562 std::ostream& stream,
563 const nn::Module& module);
564
565 // data parallel using this method to configure gradient edges during the
566 // replicate step.
567 template <typename ModuleType>
568 friend void replicate_grad_edges(
569 const std::shared_ptr<Module>& module,
570 const std::vector<std::shared_ptr<ModuleType>>& replicas,
571 const std::vector<Device>& devices);
572
573 // Private methods.
574
575 /// Used in the implementation of `Cloneable`.
576 virtual void clone_(Module& other, const optional<Device>& device);
577
578 /// The implementation of the various `to()` methods.
579 template <typename... Ts>
580 void to_impl(Ts&&... ts);
581
582 /// Implements pretty printing the module hierarchy.
583 void pretty_print_recursive(
584 std::ostream& stream,
585 const std::string& indentation) const;
586
587 /// Applies the `function` to every submodule recursively, starting at this
588 /// `Module`'s children (thus not including the module itself).
589 void apply_to_submodules(
590 const NamedModulePointerApplyFunction& function,
591 const std::string& name_prefix = std::string()) const;
592
593 /// Returns a shared_ptr to `this` in a safe (checked) way.
594 std::shared_ptr<Module> shared_from_this_checked() const;
595
596 /// The registered buffers of this `Module`.
597 OrderedDict<std::string, Tensor> buffers_;
598
599 /// The registered (direct) submodules of this `Module`.
600 OrderedDict<std::string, std::shared_ptr<Module>> children_;
601
602 /// The module's name (e.g. "LSTM").
603 mutable optional<std::string> name_;
604
605 /// Whether the module is in training mode.
606 bool is_training_{true};
607};
608
609/// Serialize a `Module` pointer into an `OutputArchive`.
610TORCH_API serialize::OutputArchive& operator<<(
611 serialize::OutputArchive& archive,
612 const std::shared_ptr<nn::Module>& module);
613
614/// Deserializes a `Module` from an `InputArchive`.
615TORCH_API serialize::InputArchive& operator>>(
616 serialize::InputArchive& archive,
617 const std::shared_ptr<nn::Module>& module);
618
619// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
620
621template <typename ModuleType>
622typename ModuleType::ContainedType* Module::as() noexcept {
623 // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
624 // `Linear`, since `LinearImpl` inherits `nn::Module`.
625 return as<typename ModuleType::ContainedType>();
626}
627
628template <typename ModuleType>
629const typename ModuleType::ContainedType* Module::as() const noexcept {
630 // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
631 // `Linear`, since `LinearImpl` inherits `nn::Module`.
632 return as<typename ModuleType::ContainedType>();
633}
634
635template <typename ModuleType, typename>
636ModuleType* Module::as() noexcept {
637 return dynamic_cast<ModuleType*>(this);
638}
639
640template <typename ModuleType, typename>
641const ModuleType* Module::as() const noexcept {
642 return dynamic_cast<const ModuleType*>(this);
643}
644
645template <typename ModuleType>
646std::shared_ptr<ModuleType> Module::register_module(
647 std::string name,
648 std::shared_ptr<ModuleType> module) {
649 TORCH_CHECK(!name.empty(), "Submodule name must not be empty");
650 TORCH_CHECK(
651 name.find('.') == std::string::npos,
652 "Submodule name must not contain a dot (got '",
653 name,
654 "')");
655 auto& base_module = children_.insert(std::move(name), std::move(module));
656 return std::dynamic_pointer_cast<ModuleType>(base_module);
657}
658
659template <typename ModuleType>
660std::shared_ptr<ModuleType> Module::register_module(
661 std::string name,
662 ModuleHolder<ModuleType> module_holder) {
663 return register_module(std::move(name), module_holder.ptr());
664}
665
666template <typename ModuleType>
667std::shared_ptr<ModuleType> Module::replace_module(
668 const std::string& name,
669 std::shared_ptr<ModuleType> module) {
670 auto& base_module = (children_[name] = std::move(module));
671 return std::dynamic_pointer_cast<ModuleType>(base_module);
672}
673
674template <typename ModuleType>
675std::shared_ptr<ModuleType> Module::replace_module(
676 const std::string& name,
677 ModuleHolder<ModuleType> module_holder) {
678 return replace_module(name, module_holder.ptr());
679}
680
681template <typename... Ts>
682void Module::to_impl(Ts&&... ts) {
683 // First call `to()` on every child module.
684 for (auto& child : children_) {
685 child.value()->to(ts...);
686 }
687 // Then move every parameter to the new dtype/device.
688 for (auto& parameter : named_parameters(/*recurse=*/false)) {
689 parameter->set_data(autograd::Variable(*parameter).to(ts...));
690 }
691 // Then move every buffer to the new dtype/device.
692 for (auto& buffer : named_buffers(/*recurse=*/false)) {
693 buffer->set_data(autograd::Variable(*buffer).to(ts...));
694 }
695}
696
697} // namespace nn
698} // namespace torch
699