1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/optim_baseline.h>
7#include <test/cpp/api/support.h>
8
9#include <cmath>
10#include <cstdlib>
11#include <functional>
12#include <iostream>
13#include <memory>
14#include <random>
15#include <vector>
16
17using namespace torch::nn;
18using namespace torch::optim;
19
20template <typename OptimizerClass, typename Options>
21bool test_optimizer_xor(Options options) {
22 torch::manual_seed(0);
23
24 Sequential model(
25 Linear(2, 8),
26 Functional(torch::sigmoid),
27 Linear(8, 1),
28 Functional(torch::sigmoid));
29
30 const int64_t kBatchSize = 200;
31 const int64_t kMaximumNumberOfEpochs = 3000;
32
33 OptimizerClass optimizer(model->parameters(), options);
34
35 float running_loss = 1;
36 int epoch = 0;
37 while (running_loss > 0.1) {
38 auto inputs = torch::empty({kBatchSize, 2});
39 auto labels = torch::empty({kBatchSize});
40 for (const auto i : c10::irange(kBatchSize)) {
41 inputs[i] = torch::randint(2, {2}, torch::kInt64);
42 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
43 }
44
45 inputs.set_requires_grad(true);
46
47 auto step = [&](OptimizerClass& optimizer,
48 Sequential model,
49 torch::Tensor inputs,
50 torch::Tensor labels) {
51 auto closure = [&]() {
52 optimizer.zero_grad();
53 auto x = model->forward(inputs);
54 auto loss = torch::binary_cross_entropy(x, labels);
55 loss.backward();
56 return loss;
57 };
58 return optimizer.step(closure);
59 };
60
61 torch::Tensor loss = step(optimizer, model, inputs, labels);
62
63 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
64 running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
65 if (epoch > kMaximumNumberOfEpochs) {
66 std::cout << "Loss is too high after epoch " << epoch << ": "
67 << running_loss << std::endl;
68 return false;
69 }
70 epoch++;
71 }
72 return true;
73}
74
75template <typename Parameters>
76void assign_parameter(
77 const Parameters& parameters,
78 const char* name,
79 torch::Tensor new_tensor) {
80 auto parameter = parameters[name];
81 parameter.set_requires_grad(false);
82 parameter.flatten().copy_(new_tensor);
83 parameter.set_requires_grad(true);
84}
85
86template <typename OptimizerClass, typename Options>
87void check_exact_values(
88 Options options,
89 std::vector<std::vector<torch::Tensor>> expected_parameters) {
90 const size_t kIterations = 1001;
91 const size_t kSampleEvery = 100;
92
93 torch::manual_seed(0);
94
95 Sequential model(
96 Linear(2, 3),
97 Functional(torch::sigmoid),
98 Linear(3, 1),
99 Functional(torch::sigmoid));
100
101 model->to(torch::kFloat64);
102
103 // Use exact input values because matching random values is hard.
104 auto parameters = model->named_parameters();
105 assign_parameter(
106 parameters,
107 "0.weight",
108 torch::tensor(
109 {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976},
110 torch::kFloat64));
111 assign_parameter(
112 parameters,
113 "0.bias",
114 torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
115 assign_parameter(
116 parameters,
117 "2.weight",
118 torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
119 assign_parameter(
120 parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
121
122 auto optimizer = OptimizerClass(parameters.values(), options);
123 torch::Tensor input =
124 torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
125 .reshape({3, 2});
126
127 for (const auto i : c10::irange(kIterations)) {
128 optimizer.zero_grad();
129 auto output = model->forward(input);
130 auto loss = output.sum();
131 loss.backward();
132
133 auto closure = []() { return torch::tensor({10}); };
134 optimizer.step(closure);
135
136 if (i % kSampleEvery == 0) {
137 ASSERT_TRUE(
138 expected_parameters.at(i / kSampleEvery).size() == parameters.size());
139 for (const auto p : c10::irange(parameters.size())) {
140 ASSERT_TRUE(parameters[p]->defined());
141 // Always compare using double dtype, regardless of the original dtype
142 // of the tensors
143 auto computed = parameters[p]->flatten().to(torch::kFloat64);
144 auto expected =
145 expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
146 if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
147 std::cout << "Iteration " << i << ": " << computed
148 << " != " << expected << " (parameter " << p << ")"
149 << std::endl;
150 ASSERT_TRUE(false);
151 }
152 }
153 }
154 }
155}
156
157TEST(OptimTest, OptimizerAccessors) {
158 auto options = AdagradOptions(1.0);
159 std::vector<torch::Tensor> params;
160 for (const auto i : c10::irange(3)) {
161 (void)i; // Suppress unused variable warning
162 params.push_back(torch::randn(10));
163 }
164 auto optimizer = Adagrad(params, options);
165 // test for defaults() method with non-const reference
166 auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
167 ASSERT_TRUE(options == options_);
168 // test for param_groups() with non-const reference return
169 auto& params_groups = optimizer.param_groups();
170 // NOLINTNEXTLINE(modernize-use-emplace)
171 params_groups.push_back(OptimizerParamGroup(params));
172 auto& params_1 = params_groups[1].params();
173 for (const auto i : c10::irange(params_1.size())) {
174 torch::equal(params[i], params_1[i]);
175 }
176
177 // test for add_param_group() when one or more params existing in another
178 // param_group are passed in the new param group to be added
179 ASSERT_THROWS_WITH(
180 optimizer.add_param_group(OptimizerParamGroup(params)),
181 "some parameters appear in more than one parameter group");
182
183 // test for state() with non-const reference return
184 auto& state_ = static_cast<AdagradParamState&>(
185 *(optimizer
186 .state()[c10::guts::to_string(params_1[0].unsafeGetTensorImpl())]));
187 state_.step(state_.step() + 1);
188
189 const auto& optimizer_ = Adagrad(params, options);
190 optimizer_.defaults();
191 // test for param_groups() with const reference return
192 (void)optimizer_.param_groups();
193 // test for state() with const reference return
194 optimizer_.state();
195}
196
197#define OLD_INTERFACE_WARNING_CHECK(func) \
198 { \
199 torch::test::WarningCapture warnings; \
200 func; \
201 ASSERT_EQ( \
202 torch::test::count_substr_occurrences( \
203 warnings.str(), "will be removed"), \
204 1); \
205 }
206
207struct MyOptimizerOptions
208 : public OptimizerCloneableOptions<MyOptimizerOptions> {
209 MyOptimizerOptions(double lr = 1.0) : lr_(lr){};
210 TORCH_ARG(double, lr) = 1.0;
211};
212
213TEST(OptimTest, OldInterface) {
214 struct MyOptimizer : Optimizer {
215 using Optimizer::Optimizer;
216 torch::Tensor step(LossClosure closure = nullptr) override {
217 return {};
218 }
219 explicit MyOptimizer(
220 std::vector<at::Tensor> params,
221 MyOptimizerOptions defaults = {})
222 : // NOLINTNEXTLINE(performance-move-const-arg)
223 Optimizer(
224 {std::move(OptimizerParamGroup(params))},
225 std::make_unique<MyOptimizerOptions>(defaults)) {}
226 };
227 std::vector<torch::Tensor> parameters = {
228 torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
229 {
230 MyOptimizer optimizer(parameters);
231 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
232 size_t size;
233 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
234 ASSERT_EQ(size, parameters.size());
235 }
236 {
237 std::vector<at::Tensor> params;
238 MyOptimizer optimizer(params);
239
240 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
241 size_t size;
242 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
243 ASSERT_EQ(size, 0);
244
245 OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
246
247 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
248 ASSERT_EQ(size, parameters.size());
249
250 std::vector<torch::Tensor> params_;
251 OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
252 for (const auto p : c10::irange(size)) {
253 ASSERT_TRUE(params_[p].allclose(parameters[p]));
254 }
255 }
256 {
257 Linear linear(3, 4);
258 MyOptimizer optimizer(linear->parameters());
259
260 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
261 size_t size;
262 OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
263 ASSERT_EQ(size, linear->parameters().size());
264 }
265}
266
267TEST(OptimTest, XORConvergence_SGD) {
268 ASSERT_TRUE(test_optimizer_xor<SGD>(
269 SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
270}
271
272TEST(OptimTest, XORConvergence_LBFGS) {
273 ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
274 ASSERT_TRUE(test_optimizer_xor<LBFGS>(
275 LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
276}
277
278TEST(OptimTest, XORConvergence_Adagrad) {
279 ASSERT_TRUE(test_optimizer_xor<Adagrad>(
280 AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
281}
282
283TEST(OptimTest, XORConvergence_RMSprop) {
284 ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
285}
286
287TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
288 ASSERT_TRUE(test_optimizer_xor<RMSprop>(
289 RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
290}
291
292TEST(OptimTest, XORConvergence_Adam) {
293 ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
294}
295
296TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
297 ASSERT_TRUE(test_optimizer_xor<Adam>(
298 AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
299}
300
301TEST(OptimTest, ProducesPyTorchValues_Adam) {
302 check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
303}
304
305TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
306 check_exact_values<Adam>(
307 AdamOptions(1.0).weight_decay(1e-2),
308 expected_parameters::Adam_with_weight_decay());
309}
310
311TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
312 check_exact_values<Adam>(
313 AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
314 expected_parameters::Adam_with_weight_decay_and_amsgrad());
315}
316
317TEST(OptimTest, XORConvergence_AdamW) {
318 ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1)));
319}
320
321TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) {
322 ASSERT_TRUE(test_optimizer_xor<AdamW>(AdamWOptions(0.1).amsgrad(true)));
323}
324
325TEST(OptimTest, ProducesPyTorchValues_AdamW) {
326 check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW());
327}
328
329TEST(OptimTest, ProducesPyTorchValues_AdamWWithoutWeightDecay) {
330 check_exact_values<AdamW>(
331 AdamWOptions(1.0).weight_decay(0),
332 expected_parameters::AdamW_without_weight_decay());
333}
334
335TEST(OptimTest, ProducesPyTorchValues_AdamWWithAMSGrad) {
336 check_exact_values<AdamW>(
337 AdamWOptions(1.0).amsgrad(true),
338 expected_parameters::AdamW_with_amsgrad());
339}
340
341TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
342 check_exact_values<Adagrad>(
343 AdagradOptions(1.0), expected_parameters::Adagrad());
344}
345
346TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
347 check_exact_values<Adagrad>(
348 AdagradOptions(1.0).weight_decay(1e-2),
349 expected_parameters::Adagrad_with_weight_decay());
350}
351
352TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
353 check_exact_values<Adagrad>(
354 AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
355 expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
356}
357
358TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
359 check_exact_values<RMSprop>(
360 RMSpropOptions(0.1), expected_parameters::RMSprop());
361}
362
363TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
364 check_exact_values<RMSprop>(
365 RMSpropOptions(0.1).weight_decay(1e-2),
366 expected_parameters::RMSprop_with_weight_decay());
367}
368
369TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
370 check_exact_values<RMSprop>(
371 RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
372 expected_parameters::RMSprop_with_weight_decay_and_centered());
373}
374
375TEST(
376 OptimTest,
377 ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
378 check_exact_values<RMSprop>(
379 RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
380 expected_parameters::
381 RMSprop_with_weight_decay_and_centered_and_momentum());
382}
383
384TEST(OptimTest, ProducesPyTorchValues_SGD) {
385 check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
386}
387
388TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
389 check_exact_values<SGD>(
390 SGDOptions(0.1).weight_decay(1e-2),
391 expected_parameters::SGD_with_weight_decay());
392}
393
394TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
395 check_exact_values<SGD>(
396 SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
397 expected_parameters::SGD_with_weight_decay_and_momentum());
398}
399
400TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
401 check_exact_values<SGD>(
402 SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
403 expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
404}
405
406TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
407 check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS());
408}
409
410TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
411 check_exact_values<LBFGS>(
412 LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
413 expected_parameters::LBFGS_with_line_search());
414}
415
416TEST(OptimTest, ZeroGrad) {
417 torch::manual_seed(0);
418
419 Linear model(2, 8);
420 SGD optimizer(model->parameters(), 0.1);
421
422 for (const auto& parameter : model->parameters()) {
423 ASSERT_FALSE(parameter.grad().defined());
424 }
425
426 auto output = model->forward(torch::ones({5, 2}));
427 auto loss = output.sum();
428 loss.backward();
429
430 for (const auto& parameter : model->parameters()) {
431 ASSERT_TRUE(parameter.grad().defined());
432 ASSERT_GT(parameter.grad().sum().item<float>(), 0);
433 }
434
435 optimizer.zero_grad();
436
437 for (const auto& parameter : model->parameters()) {
438 ASSERT_FALSE(parameter.grad().defined());
439 }
440}
441
442TEST(OptimTest, ExternalVectorOfParameters) {
443 torch::manual_seed(0);
444
445 std::vector<torch::Tensor> parameters = {
446 torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
447 std::vector<torch::Tensor> original_parameters = {
448 parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
449
450 // Set all gradients to one
451 for (auto& parameter : parameters) {
452 parameter.mutable_grad() = torch::ones_like(parameter);
453 }
454
455 SGD optimizer(parameters, 1.0);
456
457 optimizer.step();
458
459 ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
460 ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
461 ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
462}
463
464TEST(OptimTest, AddParameter_LBFGS) {
465 torch::manual_seed(0);
466
467 std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
468 std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
469
470 // Set all gradients to one
471 for (auto& parameter : parameters) {
472 parameter.mutable_grad() = torch::ones_like(parameter);
473 }
474
475 LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
476 OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
477
478 optimizer.step([]() { return torch::tensor(1); });
479
480 // REQUIRE this doesn't throw
481}
482
483// Check whether the learning rate of the parameter groups in the optimizer are
484// the same as the expected learning rates given in the epoch:learning rate map
485void check_lr_change(
486 Optimizer& optimizer,
487 LRScheduler& lr_scheduler,
488 std::map<unsigned, double> expected_epoch_lrs) {
489 // Find maximum epoch in map
490 unsigned kIterations = std::max_element(
491 expected_epoch_lrs.begin(),
492 expected_epoch_lrs.end(),
493 [](const std::pair<unsigned, double>& a,
494 const std::pair<unsigned, double>& b) -> bool {
495 return a.second > b.second;
496 })
497 ->first;
498
499 for (unsigned i = 0; i <= kIterations; i++) {
500 const auto epoch_iter = expected_epoch_lrs.find(i);
501 if (epoch_iter != expected_epoch_lrs.end()) {
502 // Compare the similarity of the two floating point learning rates
503 ASSERT_TRUE(
504 fabs(
505 epoch_iter->second -
506 optimizer.param_groups()[0].options().get_lr()) <
507 std::numeric_limits<double>::epsilon());
508 }
509 optimizer.step();
510 lr_scheduler.step();
511 }
512}
513
514TEST(OptimTest, CheckLRChange_StepLR_Adam) {
515 torch::Tensor parameters = torch::zeros({1});
516 auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3));
517
518 const unsigned step_size = 20;
519 const double gamma = 0.5;
520 StepLR step_lr_scheduler(optimizer, step_size, gamma);
521
522 // The learning rate should have halved at epoch 20
523 const std::map<unsigned, double> expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}};
524
525 check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs);
526}
527