1#include <gtest/gtest.h>
2
3#include <c10/util/flat_hash_map.h>
4#include <c10/util/irange.h>
5#include <c10/util/tempfile.h>
6
7#include <torch/torch.h>
8
9#include <test/cpp/api/support.h>
10
11#include <cstdio>
12#include <memory>
13#include <sstream>
14#include <string>
15#include <vector>
16
17using namespace torch::test;
18using namespace torch::nn;
19using namespace torch::optim;
20
21namespace {
22Sequential xor_model() {
23 return Sequential(
24 Linear(2, 8),
25 Functional(at::sigmoid),
26 Linear(8, 1),
27 Functional(at::sigmoid));
28}
29
30torch::Tensor save_and_load(torch::Tensor input) {
31 std::stringstream stream;
32 torch::save(input, stream);
33 torch::Tensor tensor;
34 torch::load(tensor, stream);
35 return tensor;
36}
37} // namespace
38
39template <typename DerivedOptions>
40void is_optimizer_param_group_equal(
41 const OptimizerParamGroup& lhs,
42 const OptimizerParamGroup& rhs) {
43 const auto& lhs_params = lhs.params();
44 const auto& rhs_params = rhs.params();
45
46 ASSERT_TRUE(lhs_params.size() == rhs_params.size());
47 for (const auto j : c10::irange(lhs_params.size())) {
48 ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
49 }
50 ASSERT_TRUE(
51 static_cast<const DerivedOptions&>(lhs.options()) ==
52 static_cast<const DerivedOptions&>(rhs.options()));
53}
54
55template <typename DerivedOptimizerParamState>
56void is_optimizer_state_equal(
57 const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
58 lhs_state,
59 const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
60 rhs_state) {
61 ASSERT_TRUE(lhs_state.size() == rhs_state.size());
62 for (const auto& value : lhs_state) {
63 auto found = rhs_state.find(value.first);
64 ASSERT_TRUE(found != rhs_state.end());
65 const DerivedOptimizerParamState& lhs_curr_state =
66 static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
67 const DerivedOptimizerParamState& rhs_curr_state =
68 static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
69 ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
70 }
71}
72
73template <
74 typename OptimizerClass,
75 typename DerivedOptimizerOptions,
76 typename DerivedOptimizerParamState>
77void test_serialize_optimizer(
78 DerivedOptimizerOptions options,
79 bool only_has_global_state = false) {
80 torch::manual_seed(0);
81 auto model1 = Linear(5, 2);
82 auto model2 = Linear(5, 2);
83 auto model3 = Linear(5, 2);
84
85 // Models 1, 2, 3 will have the same parameters.
86 auto model_tempfile = c10::make_tempfile();
87 torch::save(model1, model_tempfile.name);
88 torch::load(model2, model_tempfile.name);
89 torch::load(model3, model_tempfile.name);
90
91 auto param1 = model1->named_parameters();
92 auto param2 = model2->named_parameters();
93 auto param3 = model3->named_parameters();
94 for (const auto& p : param1) {
95 ASSERT_TRUE(p->allclose(param2[p.key()]));
96 ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
97 }
98 // Make some optimizers
99 auto optim1 = OptimizerClass(
100 {torch::optim::OptimizerParamGroup(model1->parameters())}, options);
101 auto optim2 = OptimizerClass(model2->parameters(), options);
102 auto optim2_2 = OptimizerClass(model2->parameters(), options);
103 auto optim3 = OptimizerClass(model3->parameters(), options);
104 auto optim3_2 = OptimizerClass(model3->parameters(), options);
105
106 auto x = torch::ones({10, 5});
107
108 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
109 optimizer.zero_grad();
110 auto y = model->forward(x).sum();
111 y.backward();
112 auto closure = []() { return torch::tensor({10}); };
113 optimizer.step(closure);
114 };
115
116 // Do 2 steps of model1
117 step(optim1, model1);
118 step(optim1, model1);
119
120 // Do 2 steps of model 2 without saving the optimizer
121 step(optim2, model2);
122 step(optim2_2, model2);
123
124 // Do 1 step of model 3
125 step(optim3, model3);
126
127 // save the optimizer
128 auto optim_tempfile = c10::make_tempfile();
129 torch::save(optim3, optim_tempfile.name);
130 torch::load(optim3_2, optim_tempfile.name);
131
132 auto& optim3_2_param_groups = optim3_2.param_groups();
133 auto& optim3_param_groups = optim3.param_groups();
134 auto& optim3_2_state = optim3_2.state();
135 auto& optim3_state = optim3.state();
136
137 // optim3_2 and optim1 should have param_groups and state of size 1 and
138 // state_size respectively
139 ASSERT_TRUE(optim3_2_param_groups.size() == 1);
140 // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one
141 // global state
142 unsigned state_size = only_has_global_state ? 1 : 2;
143 ASSERT_TRUE(optim3_2_state.size() == state_size);
144
145 // optim3_2 and optim1 should have param_groups and state of same size
146 ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
147 ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());
148
149 // checking correctness of serialization logic for optimizer.param_groups_ and
150 // optimizer.state_
151 for (const auto i : c10::irange(optim3_2_param_groups.size())) {
152 is_optimizer_param_group_equal<DerivedOptimizerOptions>(
153 optim3_2_param_groups[i], optim3_param_groups[i]);
154 is_optimizer_state_equal<DerivedOptimizerParamState>(
155 optim3_2_state, optim3_state);
156 }
157
158 // Do step2 for model 3
159 step(optim3_2, model3);
160
161 param1 = model1->named_parameters();
162 param2 = model2->named_parameters();
163 param3 = model3->named_parameters();
164 for (const auto& p : param1) {
165 const auto& name = p.key();
166 // Model 1 and 3 should be the same
167 ASSERT_TRUE(
168 param1[name].norm().item<float>() == param3[name].norm().item<float>());
169 ASSERT_TRUE(
170 param1[name].norm().item<float>() != param2[name].norm().item<float>());
171 }
172}
173
174/// Utility function to save a value of `int64_t` type.
175void write_int_value(
176 torch::serialize::OutputArchive& archive,
177 const std::string& key,
178 const int64_t& value) {
179 archive.write(key, c10::IValue(value));
180}
181// Utility function to save a vector of buffers.
182template <typename BufferContainer>
183void write_tensors_to_archive(
184 torch::serialize::OutputArchive& archive,
185 const std::string& key,
186 const BufferContainer& buffers) {
187 archive.write(
188 key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
189 for (const auto index : c10::irange(buffers.size())) {
190 archive.write(
191 key + "/" + c10::to_string(index), buffers[index], /*is_buffer=*/true);
192 }
193}
194
195// Utility function to save a vector of step buffers.
196void write_step_buffers(
197 torch::serialize::OutputArchive& archive,
198 const std::string& key,
199 const std::vector<int64_t>& steps) {
200 std::vector<torch::Tensor> tensors;
201 tensors.reserve(steps.size());
202 for (const auto& step : steps) {
203 tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
204 }
205 write_tensors_to_archive(archive, key, tensors);
206}
207
208#define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
209 { \
210 WarningCapture warnings; \
211 funcname(optimizer, filename); \
212 ASSERT_EQ( \
213 count_substr_occurrences(warnings.str(), "old serialization"), 1); \
214 }
215
216TEST(SerializeTest, KeysFunc) {
217 auto tempfile = c10::make_tempfile();
218 torch::serialize::OutputArchive output_archive;
219 for (const auto i : c10::irange(3)) {
220 output_archive.write(
221 "element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i)));
222 }
223 output_archive.save_to(tempfile.name);
224 torch::serialize::InputArchive input_archive;
225 input_archive.load_from(tempfile.name);
226 std::vector<std::string> keys = input_archive.keys();
227 ASSERT_EQ(keys.size(), 3);
228 for (const auto i : c10::irange(keys.size())) {
229 ASSERT_EQ(keys[i], "element/" + c10::to_string(i));
230 }
231}
232
233TEST(SerializeTest, TryReadFunc) {
234 auto tempfile = c10::make_tempfile();
235 torch::serialize::OutputArchive output_archive;
236 for (const auto i : c10::irange(3)) {
237 output_archive.write(
238 "element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i)));
239 }
240 output_archive.save_to(tempfile.name);
241 torch::serialize::InputArchive input_archive;
242 input_archive.load_from(tempfile.name);
243 c10::IValue ivalue;
244 ASSERT_FALSE(input_archive.try_read("1", ivalue));
245 ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
246 ASSERT_EQ(ivalue.toInt(), 1);
247}
248
249TEST(SerializeTest, Basic) {
250 torch::manual_seed(0);
251
252 auto x = torch::randn({5, 5});
253 auto y = save_and_load(x);
254
255 ASSERT_TRUE(y.defined());
256 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
257 ASSERT_TRUE(x.allclose(y));
258}
259
260TEST(SerializeTest, MathBits) {
261 torch::manual_seed(0);
262
263 auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
264 auto x = torch::randn({5, 5}, options);
265 {
266 auto expected = torch::conj(x);
267 auto actual = save_and_load(expected);
268
269 ASSERT_TRUE(actual.defined());
270 ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
271 ASSERT_TRUE(actual.allclose(expected));
272 }
273
274 {
275 auto expected = torch::_neg_view(x);
276 auto actual = save_and_load(expected);
277
278 ASSERT_TRUE(actual.defined());
279 ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
280 ASSERT_TRUE(actual.allclose(expected));
281 }
282
283 {
284 auto expected = torch::conj(torch::_neg_view(x));
285 auto actual = save_and_load(expected);
286
287 ASSERT_TRUE(actual.defined());
288 ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
289 ASSERT_TRUE(actual.allclose(expected));
290 }
291
292 {
293 // We don't support serializing `ZeroTensor` as it is not public facing yet.
294 // If in future, `ZeroTensor` serialization is supported, this test should
295 // start failing!
296 auto t = torch::_efficientzerotensor({5, 5});
297 ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
298 }
299}
300
301TEST(SerializeTest, BasicToFile) {
302 torch::manual_seed(0);
303
304 auto x = torch::randn({5, 5});
305
306 auto tempfile = c10::make_tempfile();
307 torch::save(x, tempfile.name);
308
309 torch::Tensor y;
310 torch::load(y, tempfile.name);
311
312 ASSERT_TRUE(y.defined());
313 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
314 ASSERT_TRUE(x.allclose(y));
315}
316
317TEST(SerializeTest, BasicViaFunc) {
318 torch::manual_seed(0);
319
320 auto x = torch::randn({5, 5});
321
322 std::string serialized;
323 torch::save(x, [&](const void* buf, size_t n) {
324 serialized.append(reinterpret_cast<const char*>(buf), n);
325 return n;
326 });
327 torch::Tensor y;
328 torch::load(y, serialized.data(), serialized.size());
329
330 ASSERT_TRUE(y.defined());
331 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
332 ASSERT_TRUE(x.allclose(y));
333
334 torch::Tensor z;
335 torch::load(
336 z,
337 [&](uint64_t pos, void* buf, size_t n) -> size_t {
338 if (pos >= serialized.size())
339 return 0;
340 size_t nbytes =
341 std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
342 memcpy(buf, serialized.data() + pos, nbytes);
343 return nbytes;
344 },
345 [&]() -> size_t { return serialized.size(); });
346 ASSERT_TRUE(z.defined());
347 ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
348 ASSERT_TRUE(x.allclose(z));
349}
350
351TEST(SerializeTest, Resized) {
352 torch::manual_seed(0);
353
354 auto x = torch::randn({11, 5});
355 x.resize_({5, 5});
356 auto y = save_and_load(x);
357
358 ASSERT_TRUE(y.defined());
359 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
360 ASSERT_TRUE(x.allclose(y));
361}
362
363TEST(SerializeTest, Sliced) {
364 torch::manual_seed(0);
365
366 auto x = torch::randn({11, 5});
367 x = x.slice(0, 1, 5);
368 auto y = save_and_load(x);
369
370 ASSERT_TRUE(y.defined());
371 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
372 ASSERT_TRUE(x.allclose(y));
373}
374
375TEST(SerializeTest, NonContiguous) {
376 torch::manual_seed(0);
377
378 auto x = torch::randn({11, 5});
379 x = x.slice(1, 1, 4);
380 auto y = save_and_load(x);
381
382 ASSERT_TRUE(y.defined());
383 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
384 ASSERT_TRUE(x.allclose(y));
385}
386
387TEST(SerializeTest, ErrorOnMissingKey) {
388 struct B : torch::nn::Module {
389 B(const std::string& name_c) {
390 register_buffer(name_c, torch::ones(5, torch::kFloat));
391 }
392 };
393 struct A : torch::nn::Module {
394 A(const std::string& name_b, const std::string& name_c) {
395 register_module(name_b, std::make_shared<B>(name_c));
396 }
397 };
398 struct M : torch::nn::Module {
399 M(const std::string& name_a,
400 const std::string& name_b,
401 const std::string& name_c) {
402 register_module(name_a, std::make_shared<A>(name_b, name_c));
403 }
404 };
405
406 // create a hierarchy of models with names differing below the top level
407 auto model1 = std::make_shared<M>("a", "b", "c");
408 auto model2 = std::make_shared<M>("a", "b", "x");
409 auto model3 = std::make_shared<M>("a", "x", "c");
410
411 std::stringstream stream;
412 torch::save(model1, stream);
413 // We want the errors to contain hierarchy information, too.
414 ASSERT_THROWS_WITH(
415 torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
416 stream.seekg(0, stream.beg);
417 ASSERT_THROWS_WITH(
418 torch::load(model3, stream), "No such serialized submodule: 'a.x'");
419}
420
421TEST(SerializeTest, XOR) {
422 // We better be able to save and load an XOR model!
423 auto getLoss = [](Sequential model, uint32_t batch_size) {
424 auto inputs = torch::empty({batch_size, 2});
425 auto labels = torch::empty({batch_size});
426 for (const auto i : c10::irange(batch_size)) {
427 inputs[i] = torch::randint(2, {2}, torch::kInt64);
428 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
429 }
430 auto x = model->forward<torch::Tensor>(inputs);
431 return torch::binary_cross_entropy(x, labels);
432 };
433
434 auto model = xor_model();
435 auto model2 = xor_model();
436 auto model3 = xor_model();
437 auto optimizer = torch::optim::SGD(
438 model->parameters(),
439 torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
440 1e-6));
441
442 float running_loss = 1;
443 int epoch = 0;
444 while (running_loss > 0.1) {
445 torch::Tensor loss = getLoss(model, 4);
446 optimizer.zero_grad();
447 loss.backward();
448 optimizer.step();
449
450 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
451 running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
452 ASSERT_LT(epoch, 3000);
453 epoch++;
454 }
455
456 auto tempfile = c10::make_tempfile();
457 torch::save(model, tempfile.name);
458 torch::load(model2, tempfile.name);
459
460 auto loss = getLoss(model2, 100);
461 ASSERT_LT(loss.item<float>(), 0.1);
462}
463
464TEST(SerializeTest, Optim) {
465 auto model1 = Linear(5, 2);
466 auto model2 = Linear(5, 2);
467 auto model3 = Linear(5, 2);
468
469 // Models 1, 2, 3 will have the same parameters.
470 auto model_tempfile = c10::make_tempfile();
471 torch::save(model1, model_tempfile.name);
472 torch::load(model2, model_tempfile.name);
473 torch::load(model3, model_tempfile.name);
474
475 auto param1 = model1->named_parameters();
476 auto param2 = model2->named_parameters();
477 auto param3 = model3->named_parameters();
478 for (const auto& p : param1) {
479 ASSERT_TRUE(p->allclose(param2[p.key()]));
480 ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
481 }
482
483 // Make some optimizers with momentum (and thus state)
484 auto optim1 = torch::optim::SGD(
485 model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
486 auto optim2 = torch::optim::SGD(
487 model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
488 auto optim2_2 = torch::optim::SGD(
489 model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
490 auto optim3 = torch::optim::SGD(
491 model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
492 auto optim3_2 = torch::optim::SGD(
493 model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
494
495 auto x = torch::ones({10, 5});
496
497 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
498 optimizer.zero_grad();
499 auto y = model->forward(x).sum();
500 y.backward();
501 optimizer.step();
502 };
503
504 // Do 2 steps of model1
505 step(optim1, model1);
506 step(optim1, model1);
507
508 // Do 2 steps of model 2 without saving the optimizer
509 step(optim2, model2);
510 step(optim2_2, model2);
511
512 // Do 2 steps of model 3 while saving the optimizer
513 step(optim3, model3);
514
515 auto optim_tempfile = c10::make_tempfile();
516 torch::save(optim3, optim_tempfile.name);
517 torch::load(optim3_2, optim_tempfile.name);
518 step(optim3_2, model3);
519
520 param1 = model1->named_parameters();
521 param2 = model2->named_parameters();
522 param3 = model3->named_parameters();
523 for (const auto& p : param1) {
524 const auto& name = p.key();
525 // Model 1 and 3 should be the same
526 ASSERT_TRUE(
527 param1[name].norm().item<float>() == param3[name].norm().item<float>());
528 ASSERT_TRUE(
529 param1[name].norm().item<float>() != param2[name].norm().item<float>());
530 }
531}
532
533TEST(SerializeTest, Optim_Adagrad) {
534 test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
535 AdagradOptions(1e-1));
536
537 // bc compatibility check
538 auto model1 = Linear(5, 2);
539 auto optim1 = torch::optim::Adagrad(
540 model1->parameters(), torch::optim::AdagradOptions(1e-1));
541
542 auto x = torch::ones({10, 5});
543 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
544 optimizer.zero_grad();
545 auto y = model->forward(x).sum();
546 y.backward();
547 optimizer.step();
548 };
549 step(optim1, model1);
550 auto optim1_2 =
551 Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));
552
553 // fill up with optim1 sum_buffers
554 std::vector<torch::Tensor> sum_buffers;
555 // fill up with optim1 state_buffers
556 std::vector<int64_t> step_buffers;
557 const auto& params_ = optim1.param_groups()[0].params();
558 const auto& optim1_state = optim1.state();
559 for (const auto& param : params_) {
560 auto key_ = c10::guts::to_string(param.unsafeGetTensorImpl());
561 const AdagradParamState& curr_state_ =
562 static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
563 sum_buffers.emplace_back(curr_state_.sum());
564 step_buffers.emplace_back(curr_state_.step());
565 }
566 // write sum_buffers and step_buffers to the file
567 auto optim_tempfile_old_format = c10::make_tempfile();
568 torch::serialize::OutputArchive output_archive;
569 write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
570 write_step_buffers(output_archive, "step_buffers", step_buffers);
571 output_archive.save_to(optim_tempfile_old_format.name);
572 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
573 torch::load, optim1_2, optim_tempfile_old_format.name);
574 is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
575}
576
577TEST(SerializeTest, Optim_SGD) {
578 test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
579 SGDOptions(1e-1).momentum(0.9));
580
581 // bc compatibility check
582 auto model1 = Linear(5, 2);
583 auto model1_params = model1->parameters();
584 // added a tensor for lazy init check - when all params do not have a momentum
585 // buffer entry
586 model1_params.emplace_back(torch::randn({2, 3}));
587 auto optim1 = torch::optim::SGD(
588 model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));
589
590 auto x = torch::ones({10, 5});
591 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
592 optimizer.zero_grad();
593 auto y = model->forward(x).sum();
594 y.backward();
595 optimizer.step();
596 };
597 step(optim1, model1);
598
599 std::vector<at::Tensor> momentum_buffers;
600 int64_t iteration_{0};
601 const auto& params_ = optim1.param_groups()[0].params();
602 const auto& optim1_state = optim1.state();
603 for (const auto i : c10::irange(params_.size())) {
604 if (i != (params_.size() - 1)) {
605 auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
606 const SGDParamState& curr_state_ =
607 static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
608 momentum_buffers.emplace_back(curr_state_.momentum_buffer());
609 }
610 }
611 ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
612 // write momentum_buffers to the file
613 auto optim_tempfile_old_format = c10::make_tempfile();
614 torch::serialize::OutputArchive output_archive;
615 write_tensors_to_archive(
616 output_archive, "momentum_buffers", momentum_buffers);
617 write_int_value(output_archive, "iteration_", iteration_);
618 output_archive.save_to(optim_tempfile_old_format.name);
619 auto optim1_2 =
620 SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
621 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
622 torch::load, optim1_2, optim_tempfile_old_format.name);
623 is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
624}
625
626TEST(SerializeTest, Optim_Adam) {
627 test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
628 AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));
629
630 // bc compatibility check
631 auto model1 = Linear(5, 2);
632 auto model1_params = model1->parameters();
633 // added a tensor for lazy init check - when all params do not have entry in
634 // buffers
635 model1_params.emplace_back(torch::randn({2, 3}));
636 auto optim1 = torch::optim::Adam(
637 model1_params, torch::optim::AdamOptions().weight_decay(0.5));
638
639 auto x = torch::ones({10, 5});
640 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
641 optimizer.zero_grad();
642 auto y = model->forward(x).sum();
643 y.backward();
644 optimizer.step();
645 };
646 step(optim1, model1);
647
648 std::vector<int64_t> step_buffers;
649 std::vector<at::Tensor> exp_average_buffers;
650 std::vector<at::Tensor> exp_average_sq_buffers;
651 std::vector<at::Tensor> max_exp_average_sq_buffers;
652 const auto& params_ = optim1.param_groups()[0].params();
653 const auto& optim1_state = optim1.state();
654 for (const auto i : c10::irange(params_.size())) {
655 if (i != (params_.size() - 1)) {
656 auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
657 const AdamParamState& curr_state_ =
658 static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
659 step_buffers.emplace_back(curr_state_.step());
660 exp_average_buffers.emplace_back(curr_state_.exp_avg());
661 exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
662 if (curr_state_.max_exp_avg_sq().defined()) {
663 max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
664 }
665 }
666 }
667 // write buffers to the file
668 auto optim_tempfile_old_format = c10::make_tempfile();
669 torch::serialize::OutputArchive output_archive;
670 write_step_buffers(output_archive, "step_buffers", step_buffers);
671 write_tensors_to_archive(
672 output_archive, "exp_average_buffers", exp_average_buffers);
673 write_tensors_to_archive(
674 output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
675 write_tensors_to_archive(
676 output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
677 output_archive.save_to(optim_tempfile_old_format.name);
678 auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
679 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
680 torch::load, optim1_2, optim_tempfile_old_format.name);
681 is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
682}
683
684TEST(SerializeTest, Optim_AdamW) {
685 test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
686 AdamWOptions().lr(0.99999).amsgrad(true).betas(
687 std::make_tuple(0.999, 0.1)));
688
689 // bc compatibility check
690 auto model1 = Linear(5, 2);
691 auto model1_params = model1->parameters();
692 // added a tensor for lazy init check - when all params do not have entry in
693 // buffers
694 model1_params.emplace_back(torch::randn({2, 3}));
695 auto optim1 = torch::optim::AdamW(
696 model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
697
698 auto x = torch::ones({10, 5});
699 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
700 optimizer.zero_grad();
701 auto y = model->forward(x).sum();
702 y.backward();
703 optimizer.step();
704 };
705 step(optim1, model1);
706
707 std::vector<int64_t> step_buffers;
708 std::vector<at::Tensor> exp_average_buffers;
709 std::vector<at::Tensor> exp_average_sq_buffers;
710 std::vector<at::Tensor> max_exp_average_sq_buffers;
711 const auto& params_ = optim1.param_groups()[0].params();
712 const auto& optim1_state = optim1.state();
713 for (const auto i : c10::irange(params_.size())) {
714 if (i != (params_.size() - 1)) {
715 auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
716 const AdamWParamState& curr_state_ =
717 static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
718 step_buffers.emplace_back(curr_state_.step());
719 exp_average_buffers.emplace_back(curr_state_.exp_avg());
720 exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
721 if (curr_state_.max_exp_avg_sq().defined()) {
722 max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
723 }
724 }
725 }
726 // write buffers to the file
727 auto optim_tempfile_old_format = c10::make_tempfile();
728 torch::serialize::OutputArchive output_archive;
729 write_step_buffers(output_archive, "step_buffers", step_buffers);
730 write_tensors_to_archive(
731 output_archive, "exp_average_buffers", exp_average_buffers);
732 write_tensors_to_archive(
733 output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
734 write_tensors_to_archive(
735 output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
736 output_archive.save_to(optim_tempfile_old_format.name);
737 auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
738 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
739 torch::load, optim1_2, optim_tempfile_old_format.name);
740 is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
741}
742
743TEST(SerializeTest, Optim_RMSprop) {
744 auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
745 test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
746
747 // bc compatibility check
748 auto model1 = Linear(5, 2);
749 auto model1_params = model1->parameters();
750
751 // added a tensor for lazy init check - when all params do not have a momentum
752 // buffer entry
753 model1_params.emplace_back(torch::randn({2, 3}));
754 auto optim1 = torch::optim::RMSprop(model1_params, options);
755
756 auto x = torch::ones({10, 5});
757 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
758 optimizer.zero_grad();
759 auto y = model->forward(x).sum();
760 y.backward();
761 optimizer.step();
762 };
763 step(optim1, model1);
764
765 std::vector<at::Tensor> square_average_buffers;
766 std::vector<at::Tensor> momentum_buffers;
767 std::vector<at::Tensor> grad_average_buffers;
768 const auto& params_ = optim1.param_groups()[0].params();
769 const auto& optim1_state = optim1.state();
770 for (const auto i : c10::irange(params_.size())) {
771 if (i != (params_.size() - 1)) {
772 auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
773 const RMSpropParamState& curr_state_ =
774 static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
775 square_average_buffers.emplace_back(curr_state_.square_avg());
776 if (curr_state_.momentum_buffer().defined()) {
777 momentum_buffers.emplace_back(curr_state_.momentum_buffer());
778 }
779 if (curr_state_.grad_avg().defined()) {
780 grad_average_buffers.emplace_back(curr_state_.grad_avg());
781 }
782 }
783 }
784 // write buffers to the file
785 auto optim_tempfile_old_format = c10::make_tempfile();
786 torch::serialize::OutputArchive output_archive;
787 write_tensors_to_archive(
788 output_archive, "square_average_buffers", square_average_buffers);
789 write_tensors_to_archive(
790 output_archive, "momentum_buffers", momentum_buffers);
791 write_tensors_to_archive(
792 output_archive, "grad_average_buffers", grad_average_buffers);
793 output_archive.save_to(optim_tempfile_old_format.name);
794 auto optim1_2 = RMSprop(model1_params, options);
795 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
796 torch::load, optim1_2, optim_tempfile_old_format.name);
797 const auto& params1_2_ = optim1_2.param_groups()[0].params();
798 auto& optim1_2_state = optim1_2.state();
799 // old RMSprop didn't track step value
800 for (const auto i : c10::irange(params1_2_.size())) {
801 if (i != (params1_2_.size() - 1)) {
802 auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
803 auto key1_2_ = c10::guts::to_string(params1_2_[i].unsafeGetTensorImpl());
804 const RMSpropParamState& curr_state_ =
805 static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
806 RMSpropParamState& curr_state1_2_ =
807 static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
808 curr_state1_2_.step(curr_state_.step());
809 }
810 }
811 is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
812}
813
814TEST(SerializeTest, Optim_LBFGS) {
815 test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
816 LBFGSOptions(), true);
817 // bc compatibility check
818 auto model1 = Linear(5, 2);
819 auto model1_params = model1->parameters();
820 // added a tensor for lazy init check - when all params do not have entry in
821 // buffers
822 model1_params.emplace_back(torch::randn({2, 3}));
823 auto optim1 =
824 torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());
825
826 auto x = torch::ones({10, 5});
827 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
828 optimizer.zero_grad();
829 auto y = model->forward(x).sum();
830 y.backward();
831 auto closure = []() { return torch::tensor({10}); };
832 optimizer.step(closure);
833 };
834
835 step(optim1, model1);
836
837 at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
838 std::deque<at::Tensor> old_dirs, old_stps;
839
840 const auto& params_ = optim1.param_groups()[0].params();
841 auto key_ = c10::guts::to_string(params_[0].unsafeGetTensorImpl());
842 const auto& optim1_state =
843 static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
844 d = optim1_state.d();
845 t = at::tensor(optim1_state.t());
846 H_diag = optim1_state.H_diag();
847 prev_flat_grad = optim1_state.prev_flat_grad();
848 prev_loss = at::tensor(optim1_state.prev_loss());
849 old_dirs = optim1_state.old_dirs();
850
851 // write buffers to the file
852 auto optim_tempfile_old_format = c10::make_tempfile();
853 torch::serialize::OutputArchive output_archive;
854 output_archive.write("d", d, /*is_buffer=*/true);
855 output_archive.write("t", t, /*is_buffer=*/true);
856 output_archive.write("H_diag", H_diag, /*is_buffer=*/true);
857 output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
858 output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true);
859 write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
860 write_tensors_to_archive(output_archive, "old_stps", old_stps);
861 output_archive.save_to(optim_tempfile_old_format.name);
862
863 auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
864 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
865 torch::load, optim1_2, optim_tempfile_old_format.name);
866
867 const auto& params1_2_ = optim1_2.param_groups()[0].params();
868 auto param_key = c10::guts::to_string(params1_2_[0].unsafeGetTensorImpl());
869 auto& optim1_2_state =
870 static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
871
872 // old LBFGS didn't track func_evals, n_iter, ro, al values
873 optim1_2_state.func_evals(optim1_state.func_evals());
874 optim1_2_state.n_iter(optim1_state.n_iter());
875 optim1_2_state.ro(optim1_state.ro());
876 optim1_2_state.al(optim1_state.al());
877
878 is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
879}
880
881TEST(SerializeTest, XOR_CUDA) {
882 torch::manual_seed(0);
883 // We better be able to save and load a XOR model!
884 auto getLoss = [](Sequential model,
885 uint32_t batch_size,
886 bool is_cuda = false) {
887 auto inputs = torch::empty({batch_size, 2});
888 auto labels = torch::empty({batch_size});
889 if (is_cuda) {
890 inputs = inputs.cuda();
891 labels = labels.cuda();
892 }
893 for (const auto i : c10::irange(batch_size)) {
894 inputs[i] = torch::randint(2, {2}, torch::kInt64);
895 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
896 }
897 auto x = model->forward<torch::Tensor>(inputs);
898 return torch::binary_cross_entropy(x, labels);
899 };
900
901 auto model = xor_model();
902 auto model2 = xor_model();
903 auto model3 = xor_model();
904 auto optimizer = torch::optim::SGD(
905 model->parameters(),
906 torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
907 1e-6));
908
909 float running_loss = 1;
910 int epoch = 0;
911 while (running_loss > 0.1) {
912 torch::Tensor loss = getLoss(model, 4);
913 optimizer.zero_grad();
914 loss.backward();
915 optimizer.step();
916
917 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
918 running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
919 ASSERT_LT(epoch, 3000);
920 epoch++;
921 }
922
923 auto tempfile = c10::make_tempfile();
924 torch::save(model, tempfile.name);
925 torch::load(model2, tempfile.name);
926
927 auto loss = getLoss(model2, 100);
928 ASSERT_LT(loss.item<float>(), 0.1);
929
930 model2->to(torch::kCUDA);
931 loss = getLoss(model2, 100, true);
932 ASSERT_LT(loss.item<float>(), 0.1);
933
934 auto tempfile2 = c10::make_tempfile();
935 torch::save(model2, tempfile2.name);
936 torch::load(model3, tempfile2.name);
937
938 loss = getLoss(model3, 100, true);
939 ASSERT_LT(loss.item<float>(), 0.1);
940}
941
942TEST(
943 SerializeTest,
944 CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
945 struct C : torch::nn::Module {
946 C() {
947 register_buffer("foo", torch::ones(5, torch::kInt32));
948 }
949 };
950 struct B : torch::nn::Module {};
951 struct A : torch::nn::Module {
952 A() {
953 register_module("b", std::make_shared<B>());
954 register_module("c", std::make_shared<C>());
955 }
956 };
957 struct M : torch::nn::Module {
958 M() {
959 register_module("a", std::make_shared<A>());
960 }
961 };
962
963 auto out = std::make_shared<M>();
964 std::stringstream ss;
965 torch::save(out, ss);
966 auto in = std::make_shared<M>();
967 torch::load(in, ss);
968
969 const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
970 ASSERT_EQ(output, 5);
971}
972
973TEST(SerializeTest, VectorOfTensors) {
974 torch::manual_seed(0);
975
976 std::vector<torch::Tensor> x_vec = {
977 torch::randn({1, 2}), torch::randn({3, 4})};
978
979 std::stringstream stream;
980 torch::save(x_vec, stream);
981
982 std::vector<torch::Tensor> y_vec;
983 torch::load(y_vec, stream);
984
985 for (const auto i : c10::irange(x_vec.size())) {
986 auto& x = x_vec[i];
987 auto& y = y_vec[i];
988 ASSERT_TRUE(y.defined());
989 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
990 ASSERT_TRUE(x.allclose(y));
991 }
992}
993
994TEST(SerializeTest, IValue) {
995 c10::IValue ivalue(1);
996 auto tempfile = c10::make_tempfile();
997 torch::serialize::OutputArchive output_archive;
998 output_archive.write("value", ivalue);
999 output_archive.save_to(tempfile.name);
1000
1001 torch::serialize::InputArchive input_archive;
1002 input_archive.load_from(tempfile.name);
1003 c10::IValue ivalue_out;
1004 input_archive.read("value", ivalue_out);
1005 ASSERT_EQ(ivalue_out.toInt(), 1);
1006
1007 ASSERT_THROWS_WITH(
1008 input_archive.read("bad_key", ivalue_out),
1009 "does not have a field with name");
1010}
1011
1012// NOTE: if a `Module` contains unserializable submodules (e.g.
1013// `nn::Functional`), we expect those submodules to be skipped when the `Module`
1014// is being serialized.
1015TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
1016 struct A : torch::nn::Module {
1017 A() {
1018 register_module("relu", torch::nn::Functional(torch::relu));
1019 }
1020 };
1021
1022 auto out = std::make_shared<A>();
1023 std::stringstream ss;
1024 torch::save(out, ss);
1025
1026 torch::serialize::InputArchive archive;
1027 archive.load_from(ss);
1028 torch::serialize::InputArchive relu_archive;
1029
1030 // Submodule with name "relu" should not exist in the `InputArchive`,
1031 // because the "relu" submodule is an `nn::Functional` and is not
1032 // serializable.
1033 ASSERT_FALSE(archive.try_read("relu", relu_archive));
1034}
1035
1036// NOTE: If a `Module` contains unserializable submodules (e.g.
1037// `nn::Functional`), we don't check the existence of those submodules in the
1038// `InputArchive` when deserializing.
1039TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
1040 struct B : torch::nn::Module {
1041 B() {
1042 register_module("relu1", torch::nn::Functional(torch::relu));
1043 register_buffer("foo", torch::zeros(5, torch::kInt32));
1044 }
1045 };
1046 struct A : torch::nn::Module {
1047 A() {
1048 register_module("b", std::make_shared<B>());
1049 register_module("relu2", torch::nn::Functional(torch::relu));
1050 }
1051 };
1052
1053 auto out = std::make_shared<A>();
1054 // Manually change the values of "b.foo", so that we can check whether the
1055 // buffer contains these values after deserialization.
1056 out->named_buffers()["b.foo"].fill_(1);
1057 auto tempfile = c10::make_tempfile();
1058 torch::save(out, tempfile.name);
1059
1060 torch::serialize::InputArchive archive;
1061 archive.load_from(tempfile.name);
1062 torch::serialize::InputArchive archive_b;
1063 torch::serialize::InputArchive archive_relu;
1064 torch::Tensor tensor_foo;
1065
1066 ASSERT_TRUE(archive.try_read("b", archive_b));
1067 ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));
1068
1069 // Submodule with name "relu1" should not exist in `archive_b`, because the
1070 // "relu1" submodule is an `nn::Functional` and is not serializable.
1071 ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
1072
1073 // Submodule with name "relu2" should not exist in `archive`, because the
1074 // "relu2" submodule is an `nn::Functional` and is not serializable.
1075 ASSERT_FALSE(archive.try_read("relu2", archive_relu));
1076
1077 auto in = std::make_shared<A>();
1078 // `torch::load(...)` works without error, even though `A` contains the
1079 // `nn::Functional` submodules while the serialized file doesn't, because the
1080 // `nn::Functional` submodules are not serializable and thus ignored when
1081 // deserializing.
1082 torch::load(in, tempfile.name);
1083
1084 // Check that the "b.foo" buffer is correctly deserialized from the file.
1085 const int output = in->named_buffers()["b.foo"].sum().item<int>();
1086 // `output` should equal to the sum of the values we manually assigned to
1087 // "b.foo" before serialization.
1088 ASSERT_EQ(output, 5);
1089}
1090