1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <test/cpp/api/optim_baseline.h> |
7 | #include <test/cpp/api/support.h> |
8 | |
9 | #include <cmath> |
10 | #include <cstdlib> |
11 | #include <functional> |
12 | #include <iostream> |
13 | #include <memory> |
14 | #include <random> |
15 | #include <vector> |
16 | |
17 | using namespace torch::nn; |
18 | using namespace torch::optim; |
19 | |
20 | template <typename OptimizerClass, typename Options> |
21 | bool test_optimizer_xor(Options options) { |
22 | torch::manual_seed(0); |
23 | |
24 | Sequential model( |
25 | Linear(2, 8), |
26 | Functional(torch::sigmoid), |
27 | Linear(8, 1), |
28 | Functional(torch::sigmoid)); |
29 | |
30 | const int64_t kBatchSize = 200; |
31 | const int64_t kMaximumNumberOfEpochs = 3000; |
32 | |
33 | OptimizerClass optimizer(model->parameters(), options); |
34 | |
35 | float running_loss = 1; |
36 | int epoch = 0; |
37 | while (running_loss > 0.1) { |
38 | auto inputs = torch::empty({kBatchSize, 2}); |
39 | auto labels = torch::empty({kBatchSize}); |
40 | for (const auto i : c10::irange(kBatchSize)) { |
41 | inputs[i] = torch::randint(2, {2}, torch::kInt64); |
42 | labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>(); |
43 | } |
44 | |
45 | inputs.set_requires_grad(true); |
46 | |
47 | auto step = [&](OptimizerClass& optimizer, |
48 | Sequential model, |
49 | torch::Tensor inputs, |
50 | torch::Tensor labels) { |
51 | auto closure = [&]() { |
52 | optimizer.zero_grad(); |
53 | auto x = model->forward(inputs); |
54 | auto loss = torch::binary_cross_entropy(x, labels); |
55 | loss.backward(); |
56 | return loss; |
57 | }; |
58 | return optimizer.step(closure); |
59 | }; |
60 | |
61 | torch::Tensor loss = step(optimizer, model, inputs, labels); |
62 | |
63 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions) |
64 | running_loss = running_loss * 0.99 + loss.item<float>() * 0.01; |
65 | if (epoch > kMaximumNumberOfEpochs) { |
66 | std::cout << "Loss is too high after epoch " << epoch << ": " |
67 | << running_loss << std::endl; |
68 | return false; |
69 | } |
70 | epoch++; |
71 | } |
72 | return true; |
73 | } |
74 | |
75 | template <typename Parameters> |
76 | void assign_parameter( |
77 | const Parameters& parameters, |
78 | const char* name, |
79 | torch::Tensor new_tensor) { |
80 | auto parameter = parameters[name]; |
81 | parameter.set_requires_grad(false); |
82 | parameter.flatten().copy_(new_tensor); |
83 | parameter.set_requires_grad(true); |
84 | } |
85 | |
86 | template <typename OptimizerClass, typename Options> |
87 | void check_exact_values( |
88 | Options options, |
89 | std::vector<std::vector<torch::Tensor>> expected_parameters) { |
90 | const size_t kIterations = 1001; |
91 | const size_t kSampleEvery = 100; |
92 | |
93 | torch::manual_seed(0); |
94 | |
95 | Sequential model( |
96 | Linear(2, 3), |
97 | Functional(torch::sigmoid), |
98 | Linear(3, 1), |
99 | Functional(torch::sigmoid)); |
100 | |
101 | model->to(torch::kFloat64); |
102 | |
103 | // Use exact input values because matching random values is hard. |
104 | auto parameters = model->named_parameters(); |
105 | assign_parameter( |
106 | parameters, |
107 | "0.weight" , |
108 | torch::tensor( |
109 | {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, |
110 | torch::kFloat64)); |
111 | assign_parameter( |
112 | parameters, |
113 | "0.bias" , |
114 | torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64)); |
115 | assign_parameter( |
116 | parameters, |
117 | "2.weight" , |
118 | torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64)); |
119 | assign_parameter( |
120 | parameters, "2.bias" , torch::tensor({-0.0711}, torch::kFloat64)); |
121 | |
122 | auto optimizer = OptimizerClass(parameters.values(), options); |
123 | torch::Tensor input = |
124 | torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64) |
125 | .reshape({3, 2}); |
126 | |
127 | for (const auto i : c10::irange(kIterations)) { |
128 | optimizer.zero_grad(); |
129 | auto output = model->forward(input); |
130 | auto loss = output.sum(); |
131 | loss.backward(); |
132 | |
133 | auto closure = []() { return torch::tensor({10}); }; |
134 | optimizer.step(closure); |
135 | |
136 | if (i % kSampleEvery == 0) { |
137 | ASSERT_TRUE( |
138 | expected_parameters.at(i / kSampleEvery).size() == parameters.size()); |
139 | for (const auto p : c10::irange(parameters.size())) { |
140 | ASSERT_TRUE(parameters[p]->defined()); |
141 | // Always compare using double dtype, regardless of the original dtype |
142 | // of the tensors |
143 | auto computed = parameters[p]->flatten().to(torch::kFloat64); |
144 | auto expected = |
145 | expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64); |
146 | if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) { |
147 | std::cout << "Iteration " << i << ": " << computed |
148 | << " != " << expected << " (parameter " << p << ")" |
149 | << std::endl; |
150 | ASSERT_TRUE(false); |
151 | } |
152 | } |
153 | } |
154 | } |
155 | } |
156 | |
157 | TEST(OptimTest, OptimizerAccessors) { |
158 | auto options = AdagradOptions(1.0); |
159 | std::vector<torch::Tensor> params; |
160 | for (const auto i : c10::irange(3)) { |
161 | (void)i; // Suppress unused variable warning |
162 | params.push_back(torch::randn(10)); |
163 | } |
164 | auto optimizer = Adagrad(params, options); |
165 | // test for defaults() method with non-const reference |
166 | auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults()); |
167 | ASSERT_TRUE(options == options_); |
168 | // test for param_groups() with non-const reference return |
169 | auto& params_groups = optimizer.param_groups(); |
170 | // NOLINTNEXTLINE(modernize-use-emplace) |
171 | params_groups.push_back(OptimizerParamGroup(params)); |
172 | auto& params_1 = params_groups[1].params(); |
173 | for (const auto i : c10::irange(params_1.size())) { |
174 | torch::equal(params[i], params_1[i]); |
175 | } |
176 | |
177 | // test for add_param_group() when one or more params existing in another |
178 | // param_group are passed in the new param group to be added |
179 | ASSERT_THROWS_WITH( |
180 | optimizer.add_param_group(OptimizerParamGroup(params)), |
181 | "some parameters appear in more than one parameter group" ); |
182 | |
183 | // test for state() with non-const reference return |
184 | auto& state_ = static_cast<AdagradParamState&>( |
185 | *(optimizer |
186 | .state()[c10::guts::to_string(params_1[0].unsafeGetTensorImpl())])); |
187 | state_.step(state_.step() + 1); |
188 | |
189 | const auto& optimizer_ = Adagrad(params, options); |
190 | optimizer_.defaults(); |
191 | // test for param_groups() with const reference return |
192 | (void)optimizer_.param_groups(); |
193 | // test for state() with const reference return |
194 | optimizer_.state(); |
195 | } |
196 | |
197 | #define OLD_INTERFACE_WARNING_CHECK(func) \ |
198 | { \ |
199 | torch::test::WarningCapture warnings; \ |
200 | func; \ |
201 | ASSERT_EQ( \ |
202 | torch::test::count_substr_occurrences( \ |
203 | warnings.str(), "will be removed"), \ |
204 | 1); \ |
205 | } |
206 | |
207 | struct MyOptimizerOptions |
208 | : public OptimizerCloneableOptions<MyOptimizerOptions> { |
209 | MyOptimizerOptions(double lr = 1.0) : lr_(lr){}; |
210 | TORCH_ARG(double, lr) = 1.0; |
211 | }; |
212 | |
213 | TEST(OptimTest, OldInterface) { |
214 | struct MyOptimizer : Optimizer { |
215 | using Optimizer::Optimizer; |
216 | torch::Tensor step(LossClosure closure = nullptr) override { |
217 | return {}; |
218 | } |
219 | explicit MyOptimizer( |
220 | std::vector<at::Tensor> params, |
221 | MyOptimizerOptions defaults = {}) |
222 | : // NOLINTNEXTLINE(performance-move-const-arg) |
223 | Optimizer( |
224 | {std::move(OptimizerParamGroup(params))}, |
225 | std::make_unique<MyOptimizerOptions>(defaults)) {} |
226 | }; |
227 | std::vector<torch::Tensor> parameters = { |
228 | torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})}; |
229 | { |
230 | MyOptimizer optimizer(parameters); |
231 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
232 | size_t size; |
233 | OLD_INTERFACE_WARNING_CHECK(size = optimizer.size()); |
234 | ASSERT_EQ(size, parameters.size()); |
235 | } |
236 | { |
237 | std::vector<at::Tensor> params; |
238 | MyOptimizer optimizer(params); |
239 | |
240 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
241 | size_t size; |
242 | OLD_INTERFACE_WARNING_CHECK(size = optimizer.size()); |
243 | ASSERT_EQ(size, 0); |
244 | |
245 | OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters)); |
246 | |
247 | OLD_INTERFACE_WARNING_CHECK(size = optimizer.size()); |
248 | ASSERT_EQ(size, parameters.size()); |
249 | |
250 | std::vector<torch::Tensor> params_; |
251 | OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters()); |
252 | for (const auto p : c10::irange(size)) { |
253 | ASSERT_TRUE(params_[p].allclose(parameters[p])); |
254 | } |
255 | } |
256 | { |
257 | Linear linear(3, 4); |
258 | MyOptimizer optimizer(linear->parameters()); |
259 | |
260 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
261 | size_t size; |
262 | OLD_INTERFACE_WARNING_CHECK(size = optimizer.size()); |
263 | ASSERT_EQ(size, linear->parameters().size()); |
264 | } |
265 | } |
266 | |
267 | TEST(OptimTest, XORConvergence_SGD) { |
268 | ASSERT_TRUE(test_optimizer_xor<SGD>( |
269 | SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6))); |
270 | } |
271 | |
272 | TEST(OptimTest, XORConvergence_LBFGS) { |
273 | ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0))); |
274 | ASSERT_TRUE(test_optimizer_xor<LBFGS>( |
275 | LBFGSOptions(1.0).line_search_fn("strong_wolfe" ))); |
276 | } |
277 | |
278 | TEST(OptimTest, XORConvergence_Adagrad) { |
279 | ASSERT_TRUE(test_optimizer_xor<Adagrad>( |
280 | AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3))); |
281 | } |
282 | |
283 | TEST(OptimTest, XORConvergence_RMSprop) { |
284 | ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true))); |
285 | } |
286 | |
287 | TEST(OptimTest, XORConvergence_RMSpropWithMomentum) { |
288 | ASSERT_TRUE(test_optimizer_xor<RMSprop>( |
289 | RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6))); |
290 | } |
291 | |
292 | TEST(OptimTest, XORConvergence_Adam) { |
293 | ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6))); |
294 | } |
295 | |
296 | TEST(OptimTest, XORConvergence_AdamWithAmsgrad) { |
297 | ASSERT_TRUE(test_optimizer_xor<Adam>( |
298 | AdamOptions(0.1).weight_decay(1e-6).amsgrad(true))); |
299 | } |
300 | |
301 | TEST(OptimTest, ProducesPyTorchValues_Adam) { |
302 | check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam()); |
303 | } |
304 | |
305 | TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) { |
306 | check_exact_values<Adam>( |
307 | AdamOptions(1.0).weight_decay(1e-2), |
308 | expected_parameters::Adam_with_weight_decay()); |
309 | } |
310 | |
311 | TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) { |
312 | check_exact_values<Adam>( |
313 | AdamOptions(1.0).weight_decay(1e-6).amsgrad(true), |
314 | expected_parameters::Adam_with_weight_decay_and_amsgrad()); |
315 | } |
316 | |
317 | TEST(OptimTest, XORConvergence_AdamW) { |
318 | ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1))); |
319 | } |
320 | |
321 | TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) { |
322 | ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true))); |
323 | } |
324 | |
325 | TEST(OptimTest, ProducesPyTorchValues_AdamW) { |
326 | check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW()); |
327 | } |
328 | |
329 | TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) { |
330 | check_exact_values<AdamW>( |
331 | AdamWOptions(1.0).weight_decay(0), |
332 | expected_parameters::AdamW_without_weight_decay()); |
333 | } |
334 | |
335 | TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) { |
336 | check_exact_values<AdamW>( |
337 | AdamWOptions(1.0).amsgrad(true), |
338 | expected_parameters::AdamW_with_amsgrad()); |
339 | } |
340 | |
341 | TEST(OptimTest, ProducesPyTorchValues_Adagrad) { |
342 | check_exact_values<Adagrad>( |
343 | AdagradOptions(1.0), expected_parameters::Adagrad()); |
344 | } |
345 | |
346 | TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) { |
347 | check_exact_values<Adagrad>( |
348 | AdagradOptions(1.0).weight_decay(1e-2), |
349 | expected_parameters::Adagrad_with_weight_decay()); |
350 | } |
351 | |
352 | TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) { |
353 | check_exact_values<Adagrad>( |
354 | AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3), |
355 | expected_parameters::Adagrad_with_weight_decay_and_lr_decay()); |
356 | } |
357 | |
358 | TEST(OptimTest, ProducesPyTorchValues_RMSprop) { |
359 | check_exact_values<RMSprop>( |
360 | RMSpropOptions(0.1), expected_parameters::RMSprop()); |
361 | } |
362 | |
363 | TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) { |
364 | check_exact_values<RMSprop>( |
365 | RMSpropOptions(0.1).weight_decay(1e-2), |
366 | expected_parameters::RMSprop_with_weight_decay()); |
367 | } |
368 | |
369 | TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) { |
370 | check_exact_values<RMSprop>( |
371 | RMSpropOptions(0.1).weight_decay(1e-6).centered(true), |
372 | expected_parameters::RMSprop_with_weight_decay_and_centered()); |
373 | } |
374 | |
375 | TEST( |
376 | OptimTest, |
377 | ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) { |
378 | check_exact_values<RMSprop>( |
379 | RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9), |
380 | expected_parameters:: |
381 | RMSprop_with_weight_decay_and_centered_and_momentum()); |
382 | } |
383 | |
384 | TEST(OptimTest, ProducesPyTorchValues_SGD) { |
385 | check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD()); |
386 | } |
387 | |
388 | TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) { |
389 | check_exact_values<SGD>( |
390 | SGDOptions(0.1).weight_decay(1e-2), |
391 | expected_parameters::SGD_with_weight_decay()); |
392 | } |
393 | |
394 | TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) { |
395 | check_exact_values<SGD>( |
396 | SGDOptions(0.1).weight_decay(1e-2).momentum(0.9), |
397 | expected_parameters::SGD_with_weight_decay_and_momentum()); |
398 | } |
399 | |
400 | TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) { |
401 | check_exact_values<SGD>( |
402 | SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true), |
403 | expected_parameters::SGD_with_weight_decay_and_nesterov_momentum()); |
404 | } |
405 | |
406 | TEST(OptimTest, ProducesPyTorchValues_LBFGS) { |
407 | check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS()); |
408 | } |
409 | |
410 | TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) { |
411 | check_exact_values<LBFGS>( |
412 | LBFGSOptions(1.0).line_search_fn("strong_wolfe" ), |
413 | expected_parameters::LBFGS_with_line_search()); |
414 | } |
415 | |
416 | TEST(OptimTest, ZeroGrad) { |
417 | torch::manual_seed(0); |
418 | |
419 | Linear model(2, 8); |
420 | SGD optimizer(model->parameters(), 0.1); |
421 | |
422 | for (const auto& parameter : model->parameters()) { |
423 | ASSERT_FALSE(parameter.grad().defined()); |
424 | } |
425 | |
426 | auto output = model->forward(torch::ones({5, 2})); |
427 | auto loss = output.sum(); |
428 | loss.backward(); |
429 | |
430 | for (const auto& parameter : model->parameters()) { |
431 | ASSERT_TRUE(parameter.grad().defined()); |
432 | ASSERT_GT(parameter.grad().sum().item<float>(), 0); |
433 | } |
434 | |
435 | optimizer.zero_grad(); |
436 | |
437 | for (const auto& parameter : model->parameters()) { |
438 | ASSERT_FALSE(parameter.grad().defined()); |
439 | } |
440 | } |
441 | |
442 | TEST(OptimTest, ExternalVectorOfParameters) { |
443 | torch::manual_seed(0); |
444 | |
445 | std::vector<torch::Tensor> parameters = { |
446 | torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})}; |
447 | std::vector<torch::Tensor> original_parameters = { |
448 | parameters[0].clone(), parameters[1].clone(), parameters[2].clone()}; |
449 | |
450 | // Set all gradients to one |
451 | for (auto& parameter : parameters) { |
452 | parameter.mutable_grad() = torch::ones_like(parameter); |
453 | } |
454 | |
455 | SGD optimizer(parameters, 1.0); |
456 | |
457 | optimizer.step(); |
458 | |
459 | ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0)); |
460 | ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0)); |
461 | ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0)); |
462 | } |
463 | |
464 | TEST(OptimTest, AddParameter_LBFGS) { |
465 | torch::manual_seed(0); |
466 | |
467 | std::vector<torch::Tensor> parameters = {torch::randn({5, 5})}; |
468 | std::vector<torch::Tensor> original_parameters = {parameters[0].clone()}; |
469 | |
470 | // Set all gradients to one |
471 | for (auto& parameter : parameters) { |
472 | parameter.mutable_grad() = torch::ones_like(parameter); |
473 | } |
474 | |
475 | LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0); |
476 | OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters)); |
477 | |
478 | optimizer.step([]() { return torch::tensor(1); }); |
479 | |
480 | // REQUIRE this doesn't throw |
481 | } |
482 | |
483 | // Check whether the learning rate of the parameter groups in the optimizer are |
484 | // the same as the expected learning rates given in the epoch:learning rate map |
485 | void check_lr_change( |
486 | Optimizer& optimizer, |
487 | LRScheduler& lr_scheduler, |
488 | std::map<unsigned, double> expected_epoch_lrs) { |
489 | // Find maximum epoch in map |
490 | unsigned kIterations = std::max_element( |
491 | expected_epoch_lrs.begin(), |
492 | expected_epoch_lrs.end(), |
493 | [](const std::pair<unsigned, double>& a, |
494 | const std::pair<unsigned, double>& b) -> bool { |
495 | return a.second > b.second; |
496 | }) |
497 | ->first; |
498 | |
499 | for (unsigned i = 0; i <= kIterations; i++) { |
500 | const auto epoch_iter = expected_epoch_lrs.find(i); |
501 | if (epoch_iter != expected_epoch_lrs.end()) { |
502 | // Compare the similarity of the two floating point learning rates |
503 | ASSERT_TRUE( |
504 | fabs( |
505 | epoch_iter->second - |
506 | optimizer.param_groups()[0].options().get_lr()) < |
507 | std::numeric_limits<double>::epsilon()); |
508 | } |
509 | optimizer.step(); |
510 | lr_scheduler.step(); |
511 | } |
512 | } |
513 | |
514 | TEST(OptimTest, CheckLRChange_StepLR_Adam) { |
515 | torch::Tensor parameters = torch::zeros({1}); |
516 | auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3)); |
517 | |
518 | const unsigned step_size = 20; |
519 | const double gamma = 0.5; |
520 | StepLR step_lr_scheduler(optimizer, step_size, gamma); |
521 | |
522 | // The learning rate should have halved at epoch 20 |
523 | const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}}; |
524 | |
525 | check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs); |
526 | } |
527 | |