1#pragma once
2
3#include <c10/util/irange.h>
4#include <torch/csrc/Export.h>
5#include <torch/serialize/archive.h>
6#include <torch/serialize/tensor.h>
7
8#include <utility>
9
10namespace torch {
11
12/// Serializes the given `value`.
13/// There must be an overload of `operator<<` between `serialize::OutputArchive`
14/// and `Value` for this method to be well-formed. Currently, such an overload
15/// is provided for (subclasses of):
16///
17/// - `torch::nn::Module`,
18/// - `torch::optim::Optimizer`
19/// - `torch::Tensor`
20///
21/// To perform the serialization, a `serialize::OutputArchive` is constructed,
22/// and all arguments after the `value` are forwarded to its `save_to` method.
23/// For example, you can pass a filename, or an `ostream`.
24///
25/// \rst
26/// .. code-block:: cpp
27///
28/// torch::nn::Linear model(3, 4);
29/// torch::save(model, "model.pt");
30///
31/// torch::optim::SGD sgd(/*lr=*/0.9);
32/// std::ostringstream stream;
33/// // Note that the same stream cannot be used in multiple torch::save(...)
34/// // invocations, otherwise the header will be corrupted.
35/// torch::save(sgd, stream);
36///
37/// auto tensor = torch::ones({3, 4});
38/// torch::save(tensor, "my_tensor.pt");
39/// \endrst
40template <typename Value, typename... SaveToArgs>
41void save(const Value& value, SaveToArgs&&... args) {
42 serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>());
43 archive << value;
44 archive.save_to(std::forward<SaveToArgs>(args)...);
45}
46
47/// Serializes the given `tensor_vec` of type `std::vector<torch::Tensor>`.
48///
49/// To perform the serialization, a `serialize::OutputArchive` is constructed,
50/// and all arguments after the `tensor_vec` are forwarded to its `save_to`
51/// method. For example, you can pass a filename, or an `ostream`.
52///
53/// \rst
54/// .. code-block:: cpp
55///
56/// std::vector<torch::Tensor> tensor_vec = { torch::randn({1, 2}),
57/// torch::randn({3, 4}) }; torch::save(tensor_vec, "my_tensor_vec.pt");
58///
59/// std::vector<torch::Tensor> tensor_vec = { torch::randn({5, 6}),
60/// torch::randn({7, 8}) }; std::ostringstream stream;
61/// // Note that the same stream cannot be used in multiple torch::save(...)
62/// // invocations, otherwise the header will be corrupted.
63/// torch::save(tensor_vec, stream);
64/// \endrst
65template <typename... SaveToArgs>
66void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
67 serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>());
68 for (const auto i : c10::irange(tensor_vec.size())) {
69 auto& value = tensor_vec[i];
70 archive.write(c10::to_string(i), value);
71 }
72 archive.save_to(std::forward<SaveToArgs>(args)...);
73}
74
75TORCH_API std::vector<char> pickle_save(const torch::IValue& ivalue);
76TORCH_API torch::IValue pickle_load(const std::vector<char>& data);
77
78/// Deserializes the given `value`.
79/// There must be an overload of `operator>>` between `serialize::InputArchive`
80/// and `Value` for this method to be well-formed. Currently, such an overload
81/// is provided for (subclasses of):
82///
83/// - `torch::nn::Module`,
84/// - `torch::optim::Optimizer`
85/// - `torch::Tensor`
86///
87/// To perform the serialization, a `serialize::InputArchive` is constructed,
88/// and all arguments after the `value` are forwarded to its `load_from` method.
89/// For example, you can pass a filename, or an `istream`.
90///
91/// \rst
92/// .. code-block:: cpp
93///
94/// torch::nn::Linear model(3, 4);
95/// torch::load(model, "model.pt");
96///
97/// torch::optim::SGD sgd(/*lr=*/0.9);
98/// std::istringstream stream("...");
99/// torch::load(sgd, stream);
100///
101/// auto tensor = torch::ones({3, 4});
102/// torch::load(tensor, "my_tensor.pt");
103/// \endrst
104template <typename Value, typename... LoadFromArgs>
105void load(Value& value, LoadFromArgs&&... args) {
106 serialize::InputArchive archive;
107 archive.load_from(std::forward<LoadFromArgs>(args)...);
108 archive >> value;
109}
110
111/// Deserializes the given `tensor_vec` of type `std::vector<torch::Tensor>`.
112///
113/// To perform the serialization, a `serialize::InputArchive` is constructed,
114/// and all arguments after the `value` are forwarded to its `load_from` method.
115/// For example, you can pass a filename, or an `istream`.
116///
117/// \rst
118/// .. code-block:: cpp
119///
120/// std::vector<torch::Tensor> tensor_vec;
121/// torch::load(tensor_vec, "my_tensor_vec.pt");
122///
123/// std::vector<torch::Tensor> tensor_vec;
124/// std::istringstream stream("...");
125/// torch::load(tensor_vec, stream);
126/// \endrst
127template <typename... LoadFromArgs>
128void load(std::vector<torch::Tensor>& tensor_vec, LoadFromArgs&&... args) {
129 serialize::InputArchive archive;
130 archive.load_from(std::forward<LoadFromArgs>(args)...);
131
132 // NOTE: The number of elements in the serialized `std::vector<torch::Tensor>`
133 // is not known ahead of time, so we need a while-loop to increment the index,
134 // and use `archive.try_read(...)` to check whether we have reached the end of
135 // the serialized `std::vector<torch::Tensor>`.
136 size_t index = 0;
137 torch::Tensor value;
138 while (archive.try_read(c10::to_string(index), value)) {
139 tensor_vec.push_back(std::move(value));
140 value = torch::Tensor();
141 index++;
142 }
143}
144} // namespace torch
145