1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/flat_hash_map.h> |
4 | #include <c10/util/irange.h> |
5 | #include <c10/util/tempfile.h> |
6 | |
7 | #include <torch/torch.h> |
8 | |
9 | #include <test/cpp/api/support.h> |
10 | |
11 | #include <cstdio> |
12 | #include <memory> |
13 | #include <sstream> |
14 | #include <string> |
15 | #include <vector> |
16 | |
17 | using namespace torch::test; |
18 | using namespace torch::nn; |
19 | using namespace torch::optim; |
20 | |
21 | namespace { |
22 | Sequential xor_model() { |
23 | return Sequential( |
24 | Linear(2, 8), |
25 | Functional(at::sigmoid), |
26 | Linear(8, 1), |
27 | Functional(at::sigmoid)); |
28 | } |
29 | |
30 | torch::Tensor save_and_load(torch::Tensor input) { |
31 | std::stringstream stream; |
32 | torch::save(input, stream); |
33 | torch::Tensor tensor; |
34 | torch::load(tensor, stream); |
35 | return tensor; |
36 | } |
37 | } // namespace |
38 | |
39 | template <typename DerivedOptions> |
40 | void is_optimizer_param_group_equal( |
41 | const OptimizerParamGroup& lhs, |
42 | const OptimizerParamGroup& rhs) { |
43 | const auto& lhs_params = lhs.params(); |
44 | const auto& rhs_params = rhs.params(); |
45 | |
46 | ASSERT_TRUE(lhs_params.size() == rhs_params.size()); |
47 | for (const auto j : c10::irange(lhs_params.size())) { |
48 | ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j])); |
49 | } |
50 | ASSERT_TRUE( |
51 | static_cast<const DerivedOptions&>(lhs.options()) == |
52 | static_cast<const DerivedOptions&>(rhs.options())); |
53 | } |
54 | |
55 | template <typename DerivedOptimizerParamState> |
56 | void is_optimizer_state_equal( |
57 | const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>& |
58 | lhs_state, |
59 | const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>& |
60 | rhs_state) { |
61 | ASSERT_TRUE(lhs_state.size() == rhs_state.size()); |
62 | for (const auto& value : lhs_state) { |
63 | auto found = rhs_state.find(value.first); |
64 | ASSERT_TRUE(found != rhs_state.end()); |
65 | const DerivedOptimizerParamState& lhs_curr_state = |
66 | static_cast<const DerivedOptimizerParamState&>(*(value.second.get())); |
67 | const DerivedOptimizerParamState& rhs_curr_state = |
68 | static_cast<const DerivedOptimizerParamState&>(*(found->second.get())); |
69 | ASSERT_TRUE(lhs_curr_state == rhs_curr_state); |
70 | } |
71 | } |
72 | |
73 | template < |
74 | typename OptimizerClass, |
75 | typename DerivedOptimizerOptions, |
76 | typename DerivedOptimizerParamState> |
77 | void test_serialize_optimizer( |
78 | DerivedOptimizerOptions options, |
79 | bool only_has_global_state = false) { |
80 | torch::manual_seed(0); |
81 | auto model1 = Linear(5, 2); |
82 | auto model2 = Linear(5, 2); |
83 | auto model3 = Linear(5, 2); |
84 | |
85 | // Models 1, 2, 3 will have the same parameters. |
86 | auto model_tempfile = c10::make_tempfile(); |
87 | torch::save(model1, model_tempfile.name); |
88 | torch::load(model2, model_tempfile.name); |
89 | torch::load(model3, model_tempfile.name); |
90 | |
91 | auto param1 = model1->named_parameters(); |
92 | auto param2 = model2->named_parameters(); |
93 | auto param3 = model3->named_parameters(); |
94 | for (const auto& p : param1) { |
95 | ASSERT_TRUE(p->allclose(param2[p.key()])); |
96 | ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()])); |
97 | } |
98 | // Make some optimizers |
99 | auto optim1 = OptimizerClass( |
100 | {torch::optim::OptimizerParamGroup(model1->parameters())}, options); |
101 | auto optim2 = OptimizerClass(model2->parameters(), options); |
102 | auto optim2_2 = OptimizerClass(model2->parameters(), options); |
103 | auto optim3 = OptimizerClass(model3->parameters(), options); |
104 | auto optim3_2 = OptimizerClass(model3->parameters(), options); |
105 | |
106 | auto x = torch::ones({10, 5}); |
107 | |
108 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
109 | optimizer.zero_grad(); |
110 | auto y = model->forward(x).sum(); |
111 | y.backward(); |
112 | auto closure = []() { return torch::tensor({10}); }; |
113 | optimizer.step(closure); |
114 | }; |
115 | |
116 | // Do 2 steps of model1 |
117 | step(optim1, model1); |
118 | step(optim1, model1); |
119 | |
120 | // Do 2 steps of model 2 without saving the optimizer |
121 | step(optim2, model2); |
122 | step(optim2_2, model2); |
123 | |
124 | // Do 1 step of model 3 |
125 | step(optim3, model3); |
126 | |
127 | // save the optimizer |
128 | auto optim_tempfile = c10::make_tempfile(); |
129 | torch::save(optim3, optim_tempfile.name); |
130 | torch::load(optim3_2, optim_tempfile.name); |
131 | |
132 | auto& optim3_2_param_groups = optim3_2.param_groups(); |
133 | auto& optim3_param_groups = optim3.param_groups(); |
134 | auto& optim3_2_state = optim3_2.state(); |
135 | auto& optim3_state = optim3.state(); |
136 | |
137 | // optim3_2 and optim1 should have param_groups and state of size 1 and |
138 | // state_size respectively |
139 | ASSERT_TRUE(optim3_2_param_groups.size() == 1); |
140 | // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one |
141 | // global state |
142 | unsigned state_size = only_has_global_state ? 1 : 2; |
143 | ASSERT_TRUE(optim3_2_state.size() == state_size); |
144 | |
145 | // optim3_2 and optim1 should have param_groups and state of same size |
146 | ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size()); |
147 | ASSERT_TRUE(optim3_2_state.size() == optim3_state.size()); |
148 | |
149 | // checking correctness of serialization logic for optimizer.param_groups_ and |
150 | // optimizer.state_ |
151 | for (const auto i : c10::irange(optim3_2_param_groups.size())) { |
152 | is_optimizer_param_group_equal<DerivedOptimizerOptions>( |
153 | optim3_2_param_groups[i], optim3_param_groups[i]); |
154 | is_optimizer_state_equal<DerivedOptimizerParamState>( |
155 | optim3_2_state, optim3_state); |
156 | } |
157 | |
158 | // Do step2 for model 3 |
159 | step(optim3_2, model3); |
160 | |
161 | param1 = model1->named_parameters(); |
162 | param2 = model2->named_parameters(); |
163 | param3 = model3->named_parameters(); |
164 | for (const auto& p : param1) { |
165 | const auto& name = p.key(); |
166 | // Model 1 and 3 should be the same |
167 | ASSERT_TRUE( |
168 | param1[name].norm().item<float>() == param3[name].norm().item<float>()); |
169 | ASSERT_TRUE( |
170 | param1[name].norm().item<float>() != param2[name].norm().item<float>()); |
171 | } |
172 | } |
173 | |
174 | /// Utility function to save a value of `int64_t` type. |
175 | void write_int_value( |
176 | torch::serialize::OutputArchive& archive, |
177 | const std::string& key, |
178 | const int64_t& value) { |
179 | archive.write(key, c10::IValue(value)); |
180 | } |
181 | // Utility function to save a vector of buffers. |
182 | template <typename BufferContainer> |
183 | void write_tensors_to_archive( |
184 | torch::serialize::OutputArchive& archive, |
185 | const std::string& key, |
186 | const BufferContainer& buffers) { |
187 | archive.write( |
188 | key + "/size" , torch::tensor(static_cast<int64_t>(buffers.size()))); |
189 | for (const auto index : c10::irange(buffers.size())) { |
190 | archive.write( |
191 | key + "/" + c10::to_string(index), buffers[index], /*is_buffer=*/true); |
192 | } |
193 | } |
194 | |
195 | // Utility function to save a vector of step buffers. |
196 | void write_step_buffers( |
197 | torch::serialize::OutputArchive& archive, |
198 | const std::string& key, |
199 | const std::vector<int64_t>& steps) { |
200 | std::vector<torch::Tensor> tensors; |
201 | tensors.reserve(steps.size()); |
202 | for (const auto& step : steps) { |
203 | tensors.push_back(torch::tensor(static_cast<int64_t>(step))); |
204 | } |
205 | write_tensors_to_archive(archive, key, tensors); |
206 | } |
207 | |
208 | #define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \ |
209 | { \ |
210 | WarningCapture warnings; \ |
211 | funcname(optimizer, filename); \ |
212 | ASSERT_EQ( \ |
213 | count_substr_occurrences(warnings.str(), "old serialization"), 1); \ |
214 | } |
215 | |
216 | TEST(SerializeTest, KeysFunc) { |
217 | auto tempfile = c10::make_tempfile(); |
218 | torch::serialize::OutputArchive output_archive; |
219 | for (const auto i : c10::irange(3)) { |
220 | output_archive.write( |
221 | "element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i))); |
222 | } |
223 | output_archive.save_to(tempfile.name); |
224 | torch::serialize::InputArchive input_archive; |
225 | input_archive.load_from(tempfile.name); |
226 | std::vector<std::string> keys = input_archive.keys(); |
227 | ASSERT_EQ(keys.size(), 3); |
228 | for (const auto i : c10::irange(keys.size())) { |
229 | ASSERT_EQ(keys[i], "element/" + c10::to_string(i)); |
230 | } |
231 | } |
232 | |
233 | TEST(SerializeTest, TryReadFunc) { |
234 | auto tempfile = c10::make_tempfile(); |
235 | torch::serialize::OutputArchive output_archive; |
236 | for (const auto i : c10::irange(3)) { |
237 | output_archive.write( |
238 | "element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i))); |
239 | } |
240 | output_archive.save_to(tempfile.name); |
241 | torch::serialize::InputArchive input_archive; |
242 | input_archive.load_from(tempfile.name); |
243 | c10::IValue ivalue; |
244 | ASSERT_FALSE(input_archive.try_read("1" , ivalue)); |
245 | ASSERT_TRUE(input_archive.try_read("element/1" , ivalue)); |
246 | ASSERT_EQ(ivalue.toInt(), 1); |
247 | } |
248 | |
249 | TEST(SerializeTest, Basic) { |
250 | torch::manual_seed(0); |
251 | |
252 | auto x = torch::randn({5, 5}); |
253 | auto y = save_and_load(x); |
254 | |
255 | ASSERT_TRUE(y.defined()); |
256 | ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); |
257 | ASSERT_TRUE(x.allclose(y)); |
258 | } |
259 | |
260 | TEST(SerializeTest, MathBits) { |
261 | torch::manual_seed(0); |
262 | |
263 | auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat); |
264 | auto x = torch::randn({5, 5}, options); |
265 | { |
266 | auto expected = torch::conj(x); |
267 | auto actual = save_and_load(expected); |
268 | |
269 | ASSERT_TRUE(actual.defined()); |
270 | ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); |
271 | ASSERT_TRUE(actual.allclose(expected)); |
272 | } |
273 | |
274 | { |
275 | auto expected = torch::_neg_view(x); |
276 | auto actual = save_and_load(expected); |
277 | |
278 | ASSERT_TRUE(actual.defined()); |
279 | ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); |
280 | ASSERT_TRUE(actual.allclose(expected)); |
281 | } |
282 | |
283 | { |
284 | auto expected = torch::conj(torch::_neg_view(x)); |
285 | auto actual = save_and_load(expected); |
286 | |
287 | ASSERT_TRUE(actual.defined()); |
288 | ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); |
289 | ASSERT_TRUE(actual.allclose(expected)); |
290 | } |
291 | |
292 | { |
293 | // We don't support serializing `ZeroTensor` as it is not public facing yet. |
294 | // If in future, `ZeroTensor` serialization is supported, this test should |
295 | // start failing! |
296 | auto t = torch::_efficientzerotensor({5, 5}); |
297 | ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable," ); |
298 | } |
299 | } |
300 | |
301 | TEST(SerializeTest, BasicToFile) { |
302 | torch::manual_seed(0); |
303 | |
304 | auto x = torch::randn({5, 5}); |
305 | |
306 | auto tempfile = c10::make_tempfile(); |
307 | torch::save(x, tempfile.name); |
308 | |
309 | torch::Tensor y; |
310 | torch::load(y, tempfile.name); |
311 | |
312 | ASSERT_TRUE(y.defined()); |
313 | ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); |
314 | ASSERT_TRUE(x.allclose(y)); |
315 | } |
316 | |
317 | TEST(SerializeTest, BasicViaFunc) { |
318 | torch::manual_seed(0); |
319 | |
320 | auto x = torch::randn({5, 5}); |
321 | |
322 | std::string serialized; |
323 | torch::save(x, [&](const void* buf, size_t n) { |
324 | serialized.append(reinterpret_cast<const char*>(buf), n); |
325 | return n; |
326 | }); |
327 | torch::Tensor y; |
328 | torch::load(y, serialized.data(), serialized.size()); |
329 | |
330 | ASSERT_TRUE(y.defined()); |
331 | ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); |
332 | ASSERT_TRUE(x.allclose(y)); |
333 | |
334 | torch::Tensor z; |
335 | torch::load( |
336 | z, |
337 | [&](uint64_t pos, void* buf, size_t n) -> size_t { |
338 | if (pos >= serialized.size()) |
339 | return 0; |
340 | size_t nbytes = |
341 | std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos; |
342 | memcpy(buf, serialized.data() + pos, nbytes); |
343 | return nbytes; |
344 | }, |
345 | [&]() -> size_t { return serialized.size(); }); |
346 | ASSERT_TRUE(z.defined()); |
347 | ASSERT_EQ(x.sizes().vec(), z.sizes().vec()); |
348 | ASSERT_TRUE(x.allclose(z)); |
349 | } |
350 | |
351 | TEST(SerializeTest, Resized) { |
352 | torch::manual_seed(0); |
353 | |
354 | auto x = torch::randn({11, 5}); |
355 | x.resize_({5, 5}); |
356 | auto y = save_and_load(x); |
357 | |
358 | ASSERT_TRUE(y.defined()); |
359 | ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); |
360 | ASSERT_TRUE(x.allclose(y)); |
361 | } |
362 | |
363 | TEST(SerializeTest, Sliced) { |
364 | torch::manual_seed(0); |
365 | |
366 | auto x = torch::randn({11, 5}); |
367 | x = x.slice(0, 1, 5); |
368 | auto y = save_and_load(x); |
369 | |
370 | ASSERT_TRUE(y.defined()); |
371 | ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); |
372 | ASSERT_TRUE(x.allclose(y)); |
373 | } |
374 | |
375 | TEST(SerializeTest, NonContiguous) { |
376 | torch::manual_seed(0); |
377 | |
378 | auto x = torch::randn({11, 5}); |
379 | x = x.slice(1, 1, 4); |
380 | auto y = save_and_load(x); |
381 | |
382 | ASSERT_TRUE(y.defined()); |
383 | ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); |
384 | ASSERT_TRUE(x.allclose(y)); |
385 | } |
386 | |
387 | TEST(SerializeTest, ErrorOnMissingKey) { |
388 | struct B : torch::nn::Module { |
389 | B(const std::string& name_c) { |
390 | register_buffer(name_c, torch::ones(5, torch::kFloat)); |
391 | } |
392 | }; |
393 | struct A : torch::nn::Module { |
394 | A(const std::string& name_b, const std::string& name_c) { |
395 | register_module(name_b, std::make_shared<B>(name_c)); |
396 | } |
397 | }; |
398 | struct M : torch::nn::Module { |
399 | M(const std::string& name_a, |
400 | const std::string& name_b, |
401 | const std::string& name_c) { |
402 | register_module(name_a, std::make_shared<A>(name_b, name_c)); |
403 | } |
404 | }; |
405 | |
406 | // create a hierarchy of models with names differing below the top level |
407 | auto model1 = std::make_shared<M>("a" , "b" , "c" ); |
408 | auto model2 = std::make_shared<M>("a" , "b" , "x" ); |
409 | auto model3 = std::make_shared<M>("a" , "x" , "c" ); |
410 | |
411 | std::stringstream stream; |
412 | torch::save(model1, stream); |
413 | // We want the errors to contain hierarchy information, too. |
414 | ASSERT_THROWS_WITH( |
415 | torch::load(model2, stream), "No such serialized tensor 'a.b.x'" ); |
416 | stream.seekg(0, stream.beg); |
417 | ASSERT_THROWS_WITH( |
418 | torch::load(model3, stream), "No such serialized submodule: 'a.x'" ); |
419 | } |
420 | |
421 | TEST(SerializeTest, XOR) { |
422 | // We better be able to save and load an XOR model! |
423 | auto getLoss = [](Sequential model, uint32_t batch_size) { |
424 | auto inputs = torch::empty({batch_size, 2}); |
425 | auto labels = torch::empty({batch_size}); |
426 | for (const auto i : c10::irange(batch_size)) { |
427 | inputs[i] = torch::randint(2, {2}, torch::kInt64); |
428 | labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>(); |
429 | } |
430 | auto x = model->forward<torch::Tensor>(inputs); |
431 | return torch::binary_cross_entropy(x, labels); |
432 | }; |
433 | |
434 | auto model = xor_model(); |
435 | auto model2 = xor_model(); |
436 | auto model3 = xor_model(); |
437 | auto optimizer = torch::optim::SGD( |
438 | model->parameters(), |
439 | torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay( |
440 | 1e-6)); |
441 | |
442 | float running_loss = 1; |
443 | int epoch = 0; |
444 | while (running_loss > 0.1) { |
445 | torch::Tensor loss = getLoss(model, 4); |
446 | optimizer.zero_grad(); |
447 | loss.backward(); |
448 | optimizer.step(); |
449 | |
450 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
451 | running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01; |
452 | ASSERT_LT(epoch, 3000); |
453 | epoch++; |
454 | } |
455 | |
456 | auto tempfile = c10::make_tempfile(); |
457 | torch::save(model, tempfile.name); |
458 | torch::load(model2, tempfile.name); |
459 | |
460 | auto loss = getLoss(model2, 100); |
461 | ASSERT_LT(loss.item<float>(), 0.1); |
462 | } |
463 | |
464 | TEST(SerializeTest, Optim) { |
465 | auto model1 = Linear(5, 2); |
466 | auto model2 = Linear(5, 2); |
467 | auto model3 = Linear(5, 2); |
468 | |
469 | // Models 1, 2, 3 will have the same parameters. |
470 | auto model_tempfile = c10::make_tempfile(); |
471 | torch::save(model1, model_tempfile.name); |
472 | torch::load(model2, model_tempfile.name); |
473 | torch::load(model3, model_tempfile.name); |
474 | |
475 | auto param1 = model1->named_parameters(); |
476 | auto param2 = model2->named_parameters(); |
477 | auto param3 = model3->named_parameters(); |
478 | for (const auto& p : param1) { |
479 | ASSERT_TRUE(p->allclose(param2[p.key()])); |
480 | ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()])); |
481 | } |
482 | |
483 | // Make some optimizers with momentum (and thus state) |
484 | auto optim1 = torch::optim::SGD( |
485 | model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); |
486 | auto optim2 = torch::optim::SGD( |
487 | model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); |
488 | auto optim2_2 = torch::optim::SGD( |
489 | model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); |
490 | auto optim3 = torch::optim::SGD( |
491 | model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); |
492 | auto optim3_2 = torch::optim::SGD( |
493 | model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); |
494 | |
495 | auto x = torch::ones({10, 5}); |
496 | |
497 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
498 | optimizer.zero_grad(); |
499 | auto y = model->forward(x).sum(); |
500 | y.backward(); |
501 | optimizer.step(); |
502 | }; |
503 | |
504 | // Do 2 steps of model1 |
505 | step(optim1, model1); |
506 | step(optim1, model1); |
507 | |
508 | // Do 2 steps of model 2 without saving the optimizer |
509 | step(optim2, model2); |
510 | step(optim2_2, model2); |
511 | |
512 | // Do 2 steps of model 3 while saving the optimizer |
513 | step(optim3, model3); |
514 | |
515 | auto optim_tempfile = c10::make_tempfile(); |
516 | torch::save(optim3, optim_tempfile.name); |
517 | torch::load(optim3_2, optim_tempfile.name); |
518 | step(optim3_2, model3); |
519 | |
520 | param1 = model1->named_parameters(); |
521 | param2 = model2->named_parameters(); |
522 | param3 = model3->named_parameters(); |
523 | for (const auto& p : param1) { |
524 | const auto& name = p.key(); |
525 | // Model 1 and 3 should be the same |
526 | ASSERT_TRUE( |
527 | param1[name].norm().item<float>() == param3[name].norm().item<float>()); |
528 | ASSERT_TRUE( |
529 | param1[name].norm().item<float>() != param2[name].norm().item<float>()); |
530 | } |
531 | } |
532 | |
533 | TEST(SerializeTest, Optim_Adagrad) { |
534 | test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>( |
535 | AdagradOptions(1e-1)); |
536 | |
537 | // bc compatibility check |
538 | auto model1 = Linear(5, 2); |
539 | auto optim1 = torch::optim::Adagrad( |
540 | model1->parameters(), torch::optim::AdagradOptions(1e-1)); |
541 | |
542 | auto x = torch::ones({10, 5}); |
543 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
544 | optimizer.zero_grad(); |
545 | auto y = model->forward(x).sum(); |
546 | y.backward(); |
547 | optimizer.step(); |
548 | }; |
549 | step(optim1, model1); |
550 | auto optim1_2 = |
551 | Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1)); |
552 | |
553 | // fill up with optim1 sum_buffers |
554 | std::vector<torch::Tensor> sum_buffers; |
555 | // fill up with optim1 state_buffers |
556 | std::vector<int64_t> step_buffers; |
557 | const auto& params_ = optim1.param_groups()[0].params(); |
558 | const auto& optim1_state = optim1.state(); |
559 | for (const auto& param : params_) { |
560 | auto key_ = c10::guts::to_string(param.unsafeGetTensorImpl()); |
561 | const AdagradParamState& curr_state_ = |
562 | static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get())); |
563 | sum_buffers.emplace_back(curr_state_.sum()); |
564 | step_buffers.emplace_back(curr_state_.step()); |
565 | } |
566 | // write sum_buffers and step_buffers to the file |
567 | auto optim_tempfile_old_format = c10::make_tempfile(); |
568 | torch::serialize::OutputArchive output_archive; |
569 | write_tensors_to_archive(output_archive, "sum_buffers" , sum_buffers); |
570 | write_step_buffers(output_archive, "step_buffers" , step_buffers); |
571 | output_archive.save_to(optim_tempfile_old_format.name); |
572 | OLD_SERIALIZATION_LOGIC_WARNING_CHECK( |
573 | torch::load, optim1_2, optim_tempfile_old_format.name); |
574 | is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state()); |
575 | } |
576 | |
577 | TEST(SerializeTest, Optim_SGD) { |
578 | test_serialize_optimizer<SGD, SGDOptions, SGDParamState>( |
579 | SGDOptions(1e-1).momentum(0.9)); |
580 | |
581 | // bc compatibility check |
582 | auto model1 = Linear(5, 2); |
583 | auto model1_params = model1->parameters(); |
584 | // added a tensor for lazy init check - when all params do not have a momentum |
585 | // buffer entry |
586 | model1_params.emplace_back(torch::randn({2, 3})); |
587 | auto optim1 = torch::optim::SGD( |
588 | model1_params, torch::optim::SGDOptions(0.01).momentum(0.9)); |
589 | |
590 | auto x = torch::ones({10, 5}); |
591 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
592 | optimizer.zero_grad(); |
593 | auto y = model->forward(x).sum(); |
594 | y.backward(); |
595 | optimizer.step(); |
596 | }; |
597 | step(optim1, model1); |
598 | |
599 | std::vector<at::Tensor> momentum_buffers; |
600 | int64_t iteration_{0}; |
601 | const auto& params_ = optim1.param_groups()[0].params(); |
602 | const auto& optim1_state = optim1.state(); |
603 | for (const auto i : c10::irange(params_.size())) { |
604 | if (i != (params_.size() - 1)) { |
605 | auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); |
606 | const SGDParamState& curr_state_ = |
607 | static_cast<const SGDParamState&>(*(optim1_state.at(key_).get())); |
608 | momentum_buffers.emplace_back(curr_state_.momentum_buffer()); |
609 | } |
610 | } |
611 | ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1)); |
612 | // write momentum_buffers to the file |
613 | auto optim_tempfile_old_format = c10::make_tempfile(); |
614 | torch::serialize::OutputArchive output_archive; |
615 | write_tensors_to_archive( |
616 | output_archive, "momentum_buffers" , momentum_buffers); |
617 | write_int_value(output_archive, "iteration_" , iteration_); |
618 | output_archive.save_to(optim_tempfile_old_format.name); |
619 | auto optim1_2 = |
620 | SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9)); |
621 | OLD_SERIALIZATION_LOGIC_WARNING_CHECK( |
622 | torch::load, optim1_2, optim_tempfile_old_format.name); |
623 | is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state()); |
624 | } |
625 | |
626 | TEST(SerializeTest, Optim_Adam) { |
627 | test_serialize_optimizer<Adam, AdamOptions, AdamParamState>( |
628 | AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5)); |
629 | |
630 | // bc compatibility check |
631 | auto model1 = Linear(5, 2); |
632 | auto model1_params = model1->parameters(); |
633 | // added a tensor for lazy init check - when all params do not have entry in |
634 | // buffers |
635 | model1_params.emplace_back(torch::randn({2, 3})); |
636 | auto optim1 = torch::optim::Adam( |
637 | model1_params, torch::optim::AdamOptions().weight_decay(0.5)); |
638 | |
639 | auto x = torch::ones({10, 5}); |
640 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
641 | optimizer.zero_grad(); |
642 | auto y = model->forward(x).sum(); |
643 | y.backward(); |
644 | optimizer.step(); |
645 | }; |
646 | step(optim1, model1); |
647 | |
648 | std::vector<int64_t> step_buffers; |
649 | std::vector<at::Tensor> exp_average_buffers; |
650 | std::vector<at::Tensor> exp_average_sq_buffers; |
651 | std::vector<at::Tensor> max_exp_average_sq_buffers; |
652 | const auto& params_ = optim1.param_groups()[0].params(); |
653 | const auto& optim1_state = optim1.state(); |
654 | for (const auto i : c10::irange(params_.size())) { |
655 | if (i != (params_.size() - 1)) { |
656 | auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); |
657 | const AdamParamState& curr_state_ = |
658 | static_cast<const AdamParamState&>(*(optim1_state.at(key_).get())); |
659 | step_buffers.emplace_back(curr_state_.step()); |
660 | exp_average_buffers.emplace_back(curr_state_.exp_avg()); |
661 | exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq()); |
662 | if (curr_state_.max_exp_avg_sq().defined()) { |
663 | max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq()); |
664 | } |
665 | } |
666 | } |
667 | // write buffers to the file |
668 | auto optim_tempfile_old_format = c10::make_tempfile(); |
669 | torch::serialize::OutputArchive output_archive; |
670 | write_step_buffers(output_archive, "step_buffers" , step_buffers); |
671 | write_tensors_to_archive( |
672 | output_archive, "exp_average_buffers" , exp_average_buffers); |
673 | write_tensors_to_archive( |
674 | output_archive, "exp_average_sq_buffers" , exp_average_sq_buffers); |
675 | write_tensors_to_archive( |
676 | output_archive, "max_exp_average_sq_buffers" , max_exp_average_sq_buffers); |
677 | output_archive.save_to(optim_tempfile_old_format.name); |
678 | auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions()); |
679 | OLD_SERIALIZATION_LOGIC_WARNING_CHECK( |
680 | torch::load, optim1_2, optim_tempfile_old_format.name); |
681 | is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state()); |
682 | } |
683 | |
684 | TEST(SerializeTest, Optim_AdamW) { |
685 | test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>( |
686 | AdamWOptions().lr(0.99999).amsgrad(true).betas( |
687 | std::make_tuple(0.999, 0.1))); |
688 | |
689 | // bc compatibility check |
690 | auto model1 = Linear(5, 2); |
691 | auto model1_params = model1->parameters(); |
692 | // added a tensor for lazy init check - when all params do not have entry in |
693 | // buffers |
694 | model1_params.emplace_back(torch::randn({2, 3})); |
695 | auto optim1 = torch::optim::AdamW( |
696 | model1_params, torch::optim::AdamWOptions().weight_decay(0.5)); |
697 | |
698 | auto x = torch::ones({10, 5}); |
699 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
700 | optimizer.zero_grad(); |
701 | auto y = model->forward(x).sum(); |
702 | y.backward(); |
703 | optimizer.step(); |
704 | }; |
705 | step(optim1, model1); |
706 | |
707 | std::vector<int64_t> step_buffers; |
708 | std::vector<at::Tensor> exp_average_buffers; |
709 | std::vector<at::Tensor> exp_average_sq_buffers; |
710 | std::vector<at::Tensor> max_exp_average_sq_buffers; |
711 | const auto& params_ = optim1.param_groups()[0].params(); |
712 | const auto& optim1_state = optim1.state(); |
713 | for (const auto i : c10::irange(params_.size())) { |
714 | if (i != (params_.size() - 1)) { |
715 | auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); |
716 | const AdamWParamState& curr_state_ = |
717 | static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get())); |
718 | step_buffers.emplace_back(curr_state_.step()); |
719 | exp_average_buffers.emplace_back(curr_state_.exp_avg()); |
720 | exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq()); |
721 | if (curr_state_.max_exp_avg_sq().defined()) { |
722 | max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq()); |
723 | } |
724 | } |
725 | } |
726 | // write buffers to the file |
727 | auto optim_tempfile_old_format = c10::make_tempfile(); |
728 | torch::serialize::OutputArchive output_archive; |
729 | write_step_buffers(output_archive, "step_buffers" , step_buffers); |
730 | write_tensors_to_archive( |
731 | output_archive, "exp_average_buffers" , exp_average_buffers); |
732 | write_tensors_to_archive( |
733 | output_archive, "exp_average_sq_buffers" , exp_average_sq_buffers); |
734 | write_tensors_to_archive( |
735 | output_archive, "max_exp_average_sq_buffers" , max_exp_average_sq_buffers); |
736 | output_archive.save_to(optim_tempfile_old_format.name); |
737 | auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions()); |
738 | OLD_SERIALIZATION_LOGIC_WARNING_CHECK( |
739 | torch::load, optim1_2, optim_tempfile_old_format.name); |
740 | is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state()); |
741 | } |
742 | |
743 | TEST(SerializeTest, Optim_RMSprop) { |
744 | auto options = RMSpropOptions(0.1).momentum(0.9).centered(true); |
745 | test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options); |
746 | |
747 | // bc compatibility check |
748 | auto model1 = Linear(5, 2); |
749 | auto model1_params = model1->parameters(); |
750 | |
751 | // added a tensor for lazy init check - when all params do not have a momentum |
752 | // buffer entry |
753 | model1_params.emplace_back(torch::randn({2, 3})); |
754 | auto optim1 = torch::optim::RMSprop(model1_params, options); |
755 | |
756 | auto x = torch::ones({10, 5}); |
757 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
758 | optimizer.zero_grad(); |
759 | auto y = model->forward(x).sum(); |
760 | y.backward(); |
761 | optimizer.step(); |
762 | }; |
763 | step(optim1, model1); |
764 | |
765 | std::vector<at::Tensor> square_average_buffers; |
766 | std::vector<at::Tensor> momentum_buffers; |
767 | std::vector<at::Tensor> grad_average_buffers; |
768 | const auto& params_ = optim1.param_groups()[0].params(); |
769 | const auto& optim1_state = optim1.state(); |
770 | for (const auto i : c10::irange(params_.size())) { |
771 | if (i != (params_.size() - 1)) { |
772 | auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); |
773 | const RMSpropParamState& curr_state_ = |
774 | static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get())); |
775 | square_average_buffers.emplace_back(curr_state_.square_avg()); |
776 | if (curr_state_.momentum_buffer().defined()) { |
777 | momentum_buffers.emplace_back(curr_state_.momentum_buffer()); |
778 | } |
779 | if (curr_state_.grad_avg().defined()) { |
780 | grad_average_buffers.emplace_back(curr_state_.grad_avg()); |
781 | } |
782 | } |
783 | } |
784 | // write buffers to the file |
785 | auto optim_tempfile_old_format = c10::make_tempfile(); |
786 | torch::serialize::OutputArchive output_archive; |
787 | write_tensors_to_archive( |
788 | output_archive, "square_average_buffers" , square_average_buffers); |
789 | write_tensors_to_archive( |
790 | output_archive, "momentum_buffers" , momentum_buffers); |
791 | write_tensors_to_archive( |
792 | output_archive, "grad_average_buffers" , grad_average_buffers); |
793 | output_archive.save_to(optim_tempfile_old_format.name); |
794 | auto optim1_2 = RMSprop(model1_params, options); |
795 | OLD_SERIALIZATION_LOGIC_WARNING_CHECK( |
796 | torch::load, optim1_2, optim_tempfile_old_format.name); |
797 | const auto& params1_2_ = optim1_2.param_groups()[0].params(); |
798 | auto& optim1_2_state = optim1_2.state(); |
799 | // old RMSprop didn't track step value |
800 | for (const auto i : c10::irange(params1_2_.size())) { |
801 | if (i != (params1_2_.size() - 1)) { |
802 | auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); |
803 | auto key1_2_ = c10::guts::to_string(params1_2_[i].unsafeGetTensorImpl()); |
804 | const RMSpropParamState& curr_state_ = |
805 | static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get())); |
806 | RMSpropParamState& curr_state1_2_ = |
807 | static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get())); |
808 | curr_state1_2_.step(curr_state_.step()); |
809 | } |
810 | } |
811 | is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state()); |
812 | } |
813 | |
814 | TEST(SerializeTest, Optim_LBFGS) { |
815 | test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>( |
816 | LBFGSOptions(), true); |
817 | // bc compatibility check |
818 | auto model1 = Linear(5, 2); |
819 | auto model1_params = model1->parameters(); |
820 | // added a tensor for lazy init check - when all params do not have entry in |
821 | // buffers |
822 | model1_params.emplace_back(torch::randn({2, 3})); |
823 | auto optim1 = |
824 | torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions()); |
825 | |
826 | auto x = torch::ones({10, 5}); |
827 | auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { |
828 | optimizer.zero_grad(); |
829 | auto y = model->forward(x).sum(); |
830 | y.backward(); |
831 | auto closure = []() { return torch::tensor({10}); }; |
832 | optimizer.step(closure); |
833 | }; |
834 | |
835 | step(optim1, model1); |
836 | |
837 | at::Tensor d, t, H_diag, prev_flat_grad, prev_loss; |
838 | std::deque<at::Tensor> old_dirs, old_stps; |
839 | |
840 | const auto& params_ = optim1.param_groups()[0].params(); |
841 | auto key_ = c10::guts::to_string(params_[0].unsafeGetTensorImpl()); |
842 | const auto& optim1_state = |
843 | static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get())); |
844 | d = optim1_state.d(); |
845 | t = at::tensor(optim1_state.t()); |
846 | H_diag = optim1_state.H_diag(); |
847 | prev_flat_grad = optim1_state.prev_flat_grad(); |
848 | prev_loss = at::tensor(optim1_state.prev_loss()); |
849 | old_dirs = optim1_state.old_dirs(); |
850 | |
851 | // write buffers to the file |
852 | auto optim_tempfile_old_format = c10::make_tempfile(); |
853 | torch::serialize::OutputArchive output_archive; |
854 | output_archive.write("d" , d, /*is_buffer=*/true); |
855 | output_archive.write("t" , t, /*is_buffer=*/true); |
856 | output_archive.write("H_diag" , H_diag, /*is_buffer=*/true); |
857 | output_archive.write("prev_flat_grad" , prev_flat_grad, /*is_buffer=*/true); |
858 | output_archive.write("prev_loss" , prev_loss, /*is_buffer=*/true); |
859 | write_tensors_to_archive(output_archive, "old_dirs" , old_dirs); |
860 | write_tensors_to_archive(output_archive, "old_stps" , old_stps); |
861 | output_archive.save_to(optim_tempfile_old_format.name); |
862 | |
863 | auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions()); |
864 | OLD_SERIALIZATION_LOGIC_WARNING_CHECK( |
865 | torch::load, optim1_2, optim_tempfile_old_format.name); |
866 | |
867 | const auto& params1_2_ = optim1_2.param_groups()[0].params(); |
868 | auto param_key = c10::guts::to_string(params1_2_[0].unsafeGetTensorImpl()); |
869 | auto& optim1_2_state = |
870 | static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get())); |
871 | |
872 | // old LBFGS didn't track func_evals, n_iter, ro, al values |
873 | optim1_2_state.func_evals(optim1_state.func_evals()); |
874 | optim1_2_state.n_iter(optim1_state.n_iter()); |
875 | optim1_2_state.ro(optim1_state.ro()); |
876 | optim1_2_state.al(optim1_state.al()); |
877 | |
878 | is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state()); |
879 | } |
880 | |
881 | TEST(SerializeTest, XOR_CUDA) { |
882 | torch::manual_seed(0); |
883 | // We better be able to save and load a XOR model! |
884 | auto getLoss = [](Sequential model, |
885 | uint32_t batch_size, |
886 | bool is_cuda = false) { |
887 | auto inputs = torch::empty({batch_size, 2}); |
888 | auto labels = torch::empty({batch_size}); |
889 | if (is_cuda) { |
890 | inputs = inputs.cuda(); |
891 | labels = labels.cuda(); |
892 | } |
893 | for (const auto i : c10::irange(batch_size)) { |
894 | inputs[i] = torch::randint(2, {2}, torch::kInt64); |
895 | labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>(); |
896 | } |
897 | auto x = model->forward<torch::Tensor>(inputs); |
898 | return torch::binary_cross_entropy(x, labels); |
899 | }; |
900 | |
901 | auto model = xor_model(); |
902 | auto model2 = xor_model(); |
903 | auto model3 = xor_model(); |
904 | auto optimizer = torch::optim::SGD( |
905 | model->parameters(), |
906 | torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay( |
907 | 1e-6)); |
908 | |
909 | float running_loss = 1; |
910 | int epoch = 0; |
911 | while (running_loss > 0.1) { |
912 | torch::Tensor loss = getLoss(model, 4); |
913 | optimizer.zero_grad(); |
914 | loss.backward(); |
915 | optimizer.step(); |
916 | |
917 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
918 | running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01; |
919 | ASSERT_LT(epoch, 3000); |
920 | epoch++; |
921 | } |
922 | |
923 | auto tempfile = c10::make_tempfile(); |
924 | torch::save(model, tempfile.name); |
925 | torch::load(model2, tempfile.name); |
926 | |
927 | auto loss = getLoss(model2, 100); |
928 | ASSERT_LT(loss.item<float>(), 0.1); |
929 | |
930 | model2->to(torch::kCUDA); |
931 | loss = getLoss(model2, 100, true); |
932 | ASSERT_LT(loss.item<float>(), 0.1); |
933 | |
934 | auto tempfile2 = c10::make_tempfile(); |
935 | torch::save(model2, tempfile2.name); |
936 | torch::load(model3, tempfile2.name); |
937 | |
938 | loss = getLoss(model3, 100, true); |
939 | ASSERT_LT(loss.item<float>(), 0.1); |
940 | } |
941 | |
942 | TEST( |
943 | SerializeTest, |
944 | CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) { |
945 | struct C : torch::nn::Module { |
946 | C() { |
947 | register_buffer("foo" , torch::ones(5, torch::kInt32)); |
948 | } |
949 | }; |
950 | struct B : torch::nn::Module {}; |
951 | struct A : torch::nn::Module { |
952 | A() { |
953 | register_module("b" , std::make_shared<B>()); |
954 | register_module("c" , std::make_shared<C>()); |
955 | } |
956 | }; |
957 | struct M : torch::nn::Module { |
958 | M() { |
959 | register_module("a" , std::make_shared<A>()); |
960 | } |
961 | }; |
962 | |
963 | auto out = std::make_shared<M>(); |
964 | std::stringstream ss; |
965 | torch::save(out, ss); |
966 | auto in = std::make_shared<M>(); |
967 | torch::load(in, ss); |
968 | |
969 | const int output = in->named_buffers()["a.c.foo" ].sum().item<int>(); |
970 | ASSERT_EQ(output, 5); |
971 | } |
972 | |
973 | TEST(SerializeTest, VectorOfTensors) { |
974 | torch::manual_seed(0); |
975 | |
976 | std::vector<torch::Tensor> x_vec = { |
977 | torch::randn({1, 2}), torch::randn({3, 4})}; |
978 | |
979 | std::stringstream stream; |
980 | torch::save(x_vec, stream); |
981 | |
982 | std::vector<torch::Tensor> y_vec; |
983 | torch::load(y_vec, stream); |
984 | |
985 | for (const auto i : c10::irange(x_vec.size())) { |
986 | auto& x = x_vec[i]; |
987 | auto& y = y_vec[i]; |
988 | ASSERT_TRUE(y.defined()); |
989 | ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); |
990 | ASSERT_TRUE(x.allclose(y)); |
991 | } |
992 | } |
993 | |
994 | TEST(SerializeTest, IValue) { |
995 | c10::IValue ivalue(1); |
996 | auto tempfile = c10::make_tempfile(); |
997 | torch::serialize::OutputArchive output_archive; |
998 | output_archive.write("value" , ivalue); |
999 | output_archive.save_to(tempfile.name); |
1000 | |
1001 | torch::serialize::InputArchive input_archive; |
1002 | input_archive.load_from(tempfile.name); |
1003 | c10::IValue ivalue_out; |
1004 | input_archive.read("value" , ivalue_out); |
1005 | ASSERT_EQ(ivalue_out.toInt(), 1); |
1006 | |
1007 | ASSERT_THROWS_WITH( |
1008 | input_archive.read("bad_key" , ivalue_out), |
1009 | "does not have a field with name" ); |
1010 | } |
1011 | |
1012 | // NOTE: if a `Module` contains unserializable submodules (e.g. |
1013 | // `nn::Functional`), we expect those submodules to be skipped when the `Module` |
1014 | // is being serialized. |
1015 | TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) { |
1016 | struct A : torch::nn::Module { |
1017 | A() { |
1018 | register_module("relu" , torch::nn::Functional(torch::relu)); |
1019 | } |
1020 | }; |
1021 | |
1022 | auto out = std::make_shared<A>(); |
1023 | std::stringstream ss; |
1024 | torch::save(out, ss); |
1025 | |
1026 | torch::serialize::InputArchive archive; |
1027 | archive.load_from(ss); |
1028 | torch::serialize::InputArchive relu_archive; |
1029 | |
1030 | // Submodule with name "relu" should not exist in the `InputArchive`, |
1031 | // because the "relu" submodule is an `nn::Functional` and is not |
1032 | // serializable. |
1033 | ASSERT_FALSE(archive.try_read("relu" , relu_archive)); |
1034 | } |
1035 | |
1036 | // NOTE: If a `Module` contains unserializable submodules (e.g. |
1037 | // `nn::Functional`), we don't check the existence of those submodules in the |
1038 | // `InputArchive` when deserializing. |
1039 | TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) { |
1040 | struct B : torch::nn::Module { |
1041 | B() { |
1042 | register_module("relu1" , torch::nn::Functional(torch::relu)); |
1043 | register_buffer("foo" , torch::zeros(5, torch::kInt32)); |
1044 | } |
1045 | }; |
1046 | struct A : torch::nn::Module { |
1047 | A() { |
1048 | register_module("b" , std::make_shared<B>()); |
1049 | register_module("relu2" , torch::nn::Functional(torch::relu)); |
1050 | } |
1051 | }; |
1052 | |
1053 | auto out = std::make_shared<A>(); |
1054 | // Manually change the values of "b.foo", so that we can check whether the |
1055 | // buffer contains these values after deserialization. |
1056 | out->named_buffers()["b.foo" ].fill_(1); |
1057 | auto tempfile = c10::make_tempfile(); |
1058 | torch::save(out, tempfile.name); |
1059 | |
1060 | torch::serialize::InputArchive archive; |
1061 | archive.load_from(tempfile.name); |
1062 | torch::serialize::InputArchive archive_b; |
1063 | torch::serialize::InputArchive archive_relu; |
1064 | torch::Tensor tensor_foo; |
1065 | |
1066 | ASSERT_TRUE(archive.try_read("b" , archive_b)); |
1067 | ASSERT_TRUE(archive_b.try_read("foo" , tensor_foo, /*is_buffer=*/true)); |
1068 | |
1069 | // Submodule with name "relu1" should not exist in `archive_b`, because the |
1070 | // "relu1" submodule is an `nn::Functional` and is not serializable. |
1071 | ASSERT_FALSE(archive_b.try_read("relu1" , archive_relu)); |
1072 | |
1073 | // Submodule with name "relu2" should not exist in `archive`, because the |
1074 | // "relu2" submodule is an `nn::Functional` and is not serializable. |
1075 | ASSERT_FALSE(archive.try_read("relu2" , archive_relu)); |
1076 | |
1077 | auto in = std::make_shared<A>(); |
1078 | // `torch::load(...)` works without error, even though `A` contains the |
1079 | // `nn::Functional` submodules while the serialized file doesn't, because the |
1080 | // `nn::Functional` submodules are not serializable and thus ignored when |
1081 | // deserializing. |
1082 | torch::load(in, tempfile.name); |
1083 | |
1084 | // Check that the "b.foo" buffer is correctly deserialized from the file. |
1085 | const int output = in->named_buffers()["b.foo" ].sum().item<int>(); |
1086 | // `output` should equal to the sum of the values we manually assigned to |
1087 | // "b.foo" before serialization. |
1088 | ASSERT_EQ(output, 5); |
1089 | } |
1090 | |