1 | #pragma once |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/Export.h> |
5 | #include <torch/serialize/archive.h> |
6 | #include <torch/serialize/tensor.h> |
7 | |
8 | #include <utility> |
9 | |
10 | namespace torch { |
11 | |
12 | /// Serializes the given `value`. |
13 | /// There must be an overload of `operator<<` between `serialize::OutputArchive` |
14 | /// and `Value` for this method to be well-formed. Currently, such an overload |
15 | /// is provided for (subclasses of): |
16 | /// |
17 | /// - `torch::nn::Module`, |
18 | /// - `torch::optim::Optimizer` |
19 | /// - `torch::Tensor` |
20 | /// |
21 | /// To perform the serialization, a `serialize::OutputArchive` is constructed, |
22 | /// and all arguments after the `value` are forwarded to its `save_to` method. |
23 | /// For example, you can pass a filename, or an `ostream`. |
24 | /// |
25 | /// \rst |
26 | /// .. code-block:: cpp |
27 | /// |
28 | /// torch::nn::Linear model(3, 4); |
29 | /// torch::save(model, "model.pt"); |
30 | /// |
31 | /// torch::optim::SGD sgd(/*lr=*/0.9); |
32 | /// std::ostringstream stream; |
33 | /// // Note that the same stream cannot be used in multiple torch::save(...) |
34 | /// // invocations, otherwise the header will be corrupted. |
35 | /// torch::save(sgd, stream); |
36 | /// |
37 | /// auto tensor = torch::ones({3, 4}); |
38 | /// torch::save(tensor, "my_tensor.pt"); |
39 | /// \endrst |
40 | template <typename Value, typename... SaveToArgs> |
41 | void save(const Value& value, SaveToArgs&&... args) { |
42 | serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>()); |
43 | archive << value; |
44 | archive.save_to(std::forward<SaveToArgs>(args)...); |
45 | } |
46 | |
47 | /// Serializes the given `tensor_vec` of type `std::vector<torch::Tensor>`. |
48 | /// |
49 | /// To perform the serialization, a `serialize::OutputArchive` is constructed, |
50 | /// and all arguments after the `tensor_vec` are forwarded to its `save_to` |
51 | /// method. For example, you can pass a filename, or an `ostream`. |
52 | /// |
53 | /// \rst |
54 | /// .. code-block:: cpp |
55 | /// |
56 | /// std::vector<torch::Tensor> tensor_vec = { torch::randn({1, 2}), |
57 | /// torch::randn({3, 4}) }; torch::save(tensor_vec, "my_tensor_vec.pt"); |
58 | /// |
59 | /// std::vector<torch::Tensor> tensor_vec = { torch::randn({5, 6}), |
60 | /// torch::randn({7, 8}) }; std::ostringstream stream; |
61 | /// // Note that the same stream cannot be used in multiple torch::save(...) |
62 | /// // invocations, otherwise the header will be corrupted. |
63 | /// torch::save(tensor_vec, stream); |
64 | /// \endrst |
65 | template <typename... SaveToArgs> |
66 | void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) { |
67 | serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>()); |
68 | for (const auto i : c10::irange(tensor_vec.size())) { |
69 | auto& value = tensor_vec[i]; |
70 | archive.write(c10::to_string(i), value); |
71 | } |
72 | archive.save_to(std::forward<SaveToArgs>(args)...); |
73 | } |
74 | |
75 | TORCH_API std::vector<char> pickle_save(const torch::IValue& ivalue); |
76 | TORCH_API torch::IValue pickle_load(const std::vector<char>& data); |
77 | |
78 | /// Deserializes the given `value`. |
79 | /// There must be an overload of `operator>>` between `serialize::InputArchive` |
80 | /// and `Value` for this method to be well-formed. Currently, such an overload |
81 | /// is provided for (subclasses of): |
82 | /// |
83 | /// - `torch::nn::Module`, |
84 | /// - `torch::optim::Optimizer` |
85 | /// - `torch::Tensor` |
86 | /// |
87 | /// To perform the serialization, a `serialize::InputArchive` is constructed, |
88 | /// and all arguments after the `value` are forwarded to its `load_from` method. |
89 | /// For example, you can pass a filename, or an `istream`. |
90 | /// |
91 | /// \rst |
92 | /// .. code-block:: cpp |
93 | /// |
94 | /// torch::nn::Linear model(3, 4); |
95 | /// torch::load(model, "model.pt"); |
96 | /// |
97 | /// torch::optim::SGD sgd(/*lr=*/0.9); |
98 | /// std::istringstream stream("..."); |
99 | /// torch::load(sgd, stream); |
100 | /// |
101 | /// auto tensor = torch::ones({3, 4}); |
102 | /// torch::load(tensor, "my_tensor.pt"); |
103 | /// \endrst |
104 | template <typename Value, typename... LoadFromArgs> |
105 | void load(Value& value, LoadFromArgs&&... args) { |
106 | serialize::InputArchive archive; |
107 | archive.load_from(std::forward<LoadFromArgs>(args)...); |
108 | archive >> value; |
109 | } |
110 | |
111 | /// Deserializes the given `tensor_vec` of type `std::vector<torch::Tensor>`. |
112 | /// |
113 | /// To perform the serialization, a `serialize::InputArchive` is constructed, |
114 | /// and all arguments after the `value` are forwarded to its `load_from` method. |
115 | /// For example, you can pass a filename, or an `istream`. |
116 | /// |
117 | /// \rst |
118 | /// .. code-block:: cpp |
119 | /// |
120 | /// std::vector<torch::Tensor> tensor_vec; |
121 | /// torch::load(tensor_vec, "my_tensor_vec.pt"); |
122 | /// |
123 | /// std::vector<torch::Tensor> tensor_vec; |
124 | /// std::istringstream stream("..."); |
125 | /// torch::load(tensor_vec, stream); |
126 | /// \endrst |
127 | template <typename... LoadFromArgs> |
128 | void load(std::vector<torch::Tensor>& tensor_vec, LoadFromArgs&&... args) { |
129 | serialize::InputArchive archive; |
130 | archive.load_from(std::forward<LoadFromArgs>(args)...); |
131 | |
132 | // NOTE: The number of elements in the serialized `std::vector<torch::Tensor>` |
133 | // is not known ahead of time, so we need a while-loop to increment the index, |
134 | // and use `archive.try_read(...)` to check whether we have reached the end of |
135 | // the serialized `std::vector<torch::Tensor>`. |
136 | size_t index = 0; |
137 | torch::Tensor value; |
138 | while (archive.try_read(c10::to_string(index), value)) { |
139 | tensor_vec.push_back(std::move(value)); |
140 | value = torch::Tensor(); |
141 | index++; |
142 | } |
143 | } |
144 | } // namespace torch |
145 | |