1#include <gtest/gtest.h>
2
3#include <torch/torch.h>
4
5#include <test/cpp/api/support.h>
6
7using namespace torch::nn;
8using namespace torch::test;
9
10template <typename R, typename Func>
11bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
12 torch::manual_seed(0);
13
14 auto nhid = 32;
15 auto model = std::make_shared<SimpleContainer>();
16 auto l1 = model->add(Linear(1, nhid), "l1");
17 auto rnn_model = model_maker(nhid);
18 auto rnn = model->add(rnn_model, "rnn");
19 auto nout = nhid;
20 if (rnn_model.get()->options_base.proj_size() > 0) {
21 nout = rnn_model.get()->options_base.proj_size();
22 }
23 auto lo = model->add(Linear(nout, 1), "lo");
24
25 torch::optim::Adam optimizer(model->parameters(), 1e-2);
26 auto forward_op = [&](torch::Tensor x) {
27 auto T = x.size(0);
28 auto B = x.size(1);
29 x = x.view({T * B, 1});
30 x = l1->forward(x).view({T, B, nhid}).tanh_();
31 x = std::get<0>(rnn->forward(x))[T - 1];
32 x = lo->forward(x);
33 return x;
34 };
35
36 if (cuda) {
37 model->to(torch::kCUDA);
38 }
39
40 float running_loss = 1;
41 int epoch = 0;
42 auto max_epoch = 1500;
43 while (running_loss > 1e-2) {
44 auto bs = 16U;
45 auto nlen = 5U;
46
47 const auto backend = cuda ? torch::kCUDA : torch::kCPU;
48 auto inputs =
49 torch::rand({nlen, bs, 1}, backend).round().to(torch::kFloat32);
50 auto labels = inputs.sum(0).detach();
51 inputs.set_requires_grad(true);
52 auto outputs = forward_op(inputs);
53 torch::Tensor loss = torch::mse_loss(outputs, labels);
54
55 optimizer.zero_grad();
56 loss.backward();
57 optimizer.step();
58
59 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
60 running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
61 if (epoch > max_epoch) {
62 return false;
63 }
64 epoch++;
65 }
66 return true;
67};
68
69void check_lstm_sizes(
70 std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
71 lstm_output) {
72 // Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
73 // 10 and 16 time steps (10 x 16 x n)
74
75 torch::Tensor output = std::get<0>(lstm_output);
76 std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
77 torch::Tensor hx = std::get<0>(state);
78 torch::Tensor cx = std::get<1>(state);
79
80 ASSERT_EQ(output.ndimension(), 3);
81 ASSERT_EQ(output.size(0), 10);
82 ASSERT_EQ(output.size(1), 16);
83 ASSERT_EQ(output.size(2), 64);
84
85 ASSERT_EQ(hx.ndimension(), 3);
86 ASSERT_EQ(hx.size(0), 3); // layers
87 ASSERT_EQ(hx.size(1), 16); // Batchsize
88 ASSERT_EQ(hx.size(2), 64); // 64 hidden dims
89
90 ASSERT_EQ(cx.ndimension(), 3);
91 ASSERT_EQ(cx.size(0), 3); // layers
92 ASSERT_EQ(cx.size(1), 16); // Batchsize
93 ASSERT_EQ(cx.size(2), 64); // 64 hidden dims
94
95 // Something is in the hiddens
96 ASSERT_GT(hx.norm().item<float>(), 0);
97 ASSERT_GT(cx.norm().item<float>(), 0);
98}
99
100void check_lstm_sizes_proj(
101 std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
102 lstm_output) {
103 // Expect the LSTM to have 32 outputs and 3 layers, with an input of batch
104 // 10 and 16 time steps (10 x 16 x n)
105
106 torch::Tensor output = std::get<0>(lstm_output);
107 std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
108 torch::Tensor hx = std::get<0>(state);
109 torch::Tensor cx = std::get<1>(state);
110
111 ASSERT_EQ(output.ndimension(), 3);
112 ASSERT_EQ(output.size(0), 10);
113 ASSERT_EQ(output.size(1), 16);
114 ASSERT_EQ(output.size(2), 32);
115
116 ASSERT_EQ(hx.ndimension(), 3);
117 ASSERT_EQ(hx.size(0), 3); // layers
118 ASSERT_EQ(hx.size(1), 16); // Batchsize
119 ASSERT_EQ(hx.size(2), 32); // 32 hidden dims
120
121 ASSERT_EQ(cx.ndimension(), 3);
122 ASSERT_EQ(cx.size(0), 3); // layers
123 ASSERT_EQ(cx.size(1), 16); // Batchsize
124 ASSERT_EQ(cx.size(2), 64); // 64 cell dims
125
126 // Something is in the hiddens
127 ASSERT_GT(hx.norm().item<float>(), 0);
128 ASSERT_GT(cx.norm().item<float>(), 0);
129}
130
131struct RNNTest : torch::test::SeedingFixture {};
132
133TEST_F(RNNTest, CheckOutputSizes) {
134 LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
135 // Input size is: sequence length, batch size, input size
136 auto x = torch::randn({10, 16, 128}, torch::requires_grad());
137 auto output = model->forward(x);
138 auto y = x.mean();
139
140 y.backward();
141 check_lstm_sizes(output);
142
143 auto next = model->forward(x, std::get<1>(output));
144
145 check_lstm_sizes(next);
146
147 auto output_hx = std::get<0>(std::get<1>(output));
148 auto output_cx = std::get<1>(std::get<1>(output));
149
150 auto next_hx = std::get<0>(std::get<1>(next));
151 auto next_cx = std::get<1>(std::get<1>(next));
152
153 torch::Tensor diff =
154 torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
155
156 // Hiddens changed
157 ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
158}
159
160TEST_F(RNNTest, CheckOutputSizesProj) {
161 LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32));
162 // Input size is: sequence length, batch size, input size
163 auto x = torch::randn({10, 16, 128}, torch::requires_grad());
164 auto output = model->forward(x);
165 auto y = x.mean();
166
167 y.backward();
168 check_lstm_sizes_proj(output);
169
170 auto next = model->forward(x, std::get<1>(output));
171
172 check_lstm_sizes_proj(next);
173
174 auto output_hx = std::get<0>(std::get<1>(output));
175 auto output_cx = std::get<1>(std::get<1>(output));
176
177 auto next_hx = std::get<0>(std::get<1>(next));
178 auto next_cx = std::get<1>(std::get<1>(next));
179
180 torch::Tensor diff = next_hx - output_hx;
181 // Hiddens changed
182 ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
183 diff = next_cx - output_cx;
184 ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
185}
186
187TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
188 torch::manual_seed(0);
189 // Make sure the outputs match pytorch outputs
190 LSTM model(2, 2);
191 for (auto& v : model->parameters()) {
192 float size = v.numel();
193 auto p = static_cast<float*>(v.storage().data());
194 for (size_t i = 0; i < size; i++) {
195 p[i] = i / size;
196 }
197 }
198
199 auto x = torch::empty({3, 4, 2}, torch::requires_grad());
200 float size = x.numel();
201 auto p = static_cast<float*>(x.storage().data());
202 for (size_t i = 0; i < size; i++) {
203 p[i] = (size - i) / size;
204 }
205
206 auto out = model->forward(x);
207 ASSERT_EQ(std::get<0>(out).ndimension(), 3);
208 ASSERT_EQ(std::get<0>(out).size(0), 3);
209 ASSERT_EQ(std::get<0>(out).size(1), 4);
210 ASSERT_EQ(std::get<0>(out).size(2), 2);
211
212 auto flat = std::get<0>(out).view(3 * 4 * 2);
213 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
214 float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
215 0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
216 0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
217 0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
218 for (size_t i = 0; i < 3 * 4 * 2; i++) {
219 ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3);
220 }
221
222 auto hx = std::get<0>(std::get<1>(out));
223 auto cx = std::get<1>(std::get<1>(out));
224
225 ASSERT_EQ(hx.ndimension(), 3); // layers x B x 2
226 ASSERT_EQ(hx.size(0), 1);
227 ASSERT_EQ(hx.size(1), 4);
228 ASSERT_EQ(hx.size(2), 2);
229
230 ASSERT_EQ(cx.ndimension(), 3); // layers x B x 2
231 ASSERT_EQ(cx.size(0), 1);
232 ASSERT_EQ(cx.size(1), 4);
233 ASSERT_EQ(cx.size(2), 2);
234
235 flat = torch::cat({hx, cx}, 0).view(16);
236 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
237 float h_out[] = {
238 0.7889,
239 0.9003,
240 0.7769,
241 0.8905,
242 0.7635,
243 0.8794,
244 0.7484,
245 0.8666,
246 1.1647,
247 1.6106,
248 1.1425,
249 1.5726,
250 1.1187,
251 1.5329,
252 1.0931,
253 1.4911};
254 for (size_t i = 0; i < 16; i++) {
255 ASSERT_LT(std::abs(flat[i].item<float>() - h_out[i]), 1e-3);
256 }
257}
258
259TEST_F(RNNTest, EndToEndLSTM) {
260 ASSERT_TRUE(test_RNN_xor<LSTM>(
261 [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }));
262}
263
264TEST_F(RNNTest, EndToEndLSTMProj) {
265 ASSERT_TRUE(test_RNN_xor<LSTM>([](int s) {
266 return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2));
267 }));
268}
269
270TEST_F(RNNTest, EndToEndGRU) {
271 ASSERT_TRUE(test_RNN_xor<GRU>(
272 [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }));
273}
274
275TEST_F(RNNTest, EndToEndRNNRelu) {
276 ASSERT_TRUE(test_RNN_xor<RNN>([](int s) {
277 return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2));
278 }));
279}
280
281TEST_F(RNNTest, EndToEndRNNTanh) {
282 ASSERT_TRUE(test_RNN_xor<RNN>([](int s) {
283 return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2));
284 }));
285}
286
287TEST_F(RNNTest, Sizes_CUDA) {
288 torch::manual_seed(0);
289 LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
290 model->to(torch::kCUDA);
291 auto x =
292 torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
293 auto output = model->forward(x);
294 auto y = x.mean();
295
296 y.backward();
297 check_lstm_sizes(output);
298
299 auto next = model->forward(x, std::get<1>(output));
300
301 check_lstm_sizes(next);
302
303 auto output_hx = std::get<0>(std::get<1>(output));
304 auto output_cx = std::get<1>(std::get<1>(output));
305
306 auto next_hx = std::get<0>(std::get<1>(next));
307 auto next_cx = std::get<1>(std::get<1>(next));
308
309 torch::Tensor diff =
310 torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
311
312 // Hiddens changed
313 ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
314}
315
316TEST_F(RNNTest, SizesProj_CUDA) {
317 torch::manual_seed(0);
318 LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32));
319 model->to(torch::kCUDA);
320 auto x =
321 torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
322 auto output = model->forward(x);
323 auto y = x.mean();
324
325 y.backward();
326 check_lstm_sizes_proj(output);
327
328 auto next = model->forward(x, std::get<1>(output));
329
330 check_lstm_sizes_proj(next);
331
332 auto output_hx = std::get<0>(std::get<1>(output));
333 auto output_cx = std::get<1>(std::get<1>(output));
334
335 auto next_hx = std::get<0>(std::get<1>(next));
336 auto next_cx = std::get<1>(std::get<1>(next));
337
338 torch::Tensor diff = next_hx - output_hx;
339 // Hiddens changed
340 ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
341 diff = next_cx - output_cx;
342 ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
343}
344
345TEST_F(RNNTest, EndToEndLSTM_CUDA) {
346 ASSERT_TRUE(test_RNN_xor<LSTM>(
347 [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true));
348}
349
350TEST_F(RNNTest, EndToEndLSTMProj_CUDA) {
351 ASSERT_TRUE(test_RNN_xor<LSTM>(
352 [](int s) {
353 return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2));
354 },
355 true));
356}
357
358TEST_F(RNNTest, EndToEndGRU_CUDA) {
359 ASSERT_TRUE(test_RNN_xor<GRU>(
360 [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true));
361}
362
363TEST_F(RNNTest, EndToEndRNNRelu_CUDA) {
364 ASSERT_TRUE(test_RNN_xor<RNN>(
365 [](int s) {
366 return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2));
367 },
368 true));
369}
370TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
371 ASSERT_TRUE(test_RNN_xor<RNN>(
372 [](int s) {
373 return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2));
374 },
375 true));
376}
377
378TEST_F(RNNTest, PrettyPrintRNNs) {
379 ASSERT_EQ(
380 c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))),
381 "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
382 ASSERT_EQ(
383 c10::str(
384 LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32))),
385 "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false, proj_size=32)");
386 ASSERT_EQ(
387 c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))),
388 "torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)");
389 ASSERT_EQ(
390 c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(
391 torch::kTanh))),
392 "torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
393}
394
395// This test assures that flatten_parameters does not crash,
396// when bidirectional is set to true
397// https://github.com/pytorch/pytorch/issues/19545
398TEST_F(RNNTest, BidirectionalFlattenParameters) {
399 GRU gru(GRUOptions(100, 256).num_layers(2).bidirectional(true));
400 gru->flatten_parameters();
401}
402
403template <typename Impl>
404void copyParameters(
405 torch::nn::ModuleHolder<Impl>& target,
406 std::string t_suffix,
407 const torch::nn::ModuleHolder<Impl>& source,
408 std::string s_suffix) {
409 at::NoGradGuard guard;
410 target->named_parameters()["weight_ih_l" + t_suffix].copy_(
411 source->named_parameters()["weight_ih_l" + s_suffix]);
412 target->named_parameters()["weight_hh_l" + t_suffix].copy_(
413 source->named_parameters()["weight_hh_l" + s_suffix]);
414 target->named_parameters()["bias_ih_l" + t_suffix].copy_(
415 source->named_parameters()["bias_ih_l" + s_suffix]);
416 target->named_parameters()["bias_hh_l" + t_suffix].copy_(
417 source->named_parameters()["bias_hh_l" + s_suffix]);
418}
419
420std::tuple<torch::Tensor, torch::Tensor> gru_output_to_device(
421 std::tuple<torch::Tensor, torch::Tensor> gru_output,
422 torch::Device device) {
423 return std::make_tuple(
424 std::get<0>(gru_output).to(device), std::get<1>(gru_output).to(device));
425}
426
427std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
428lstm_output_to_device(
429 std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
430 lstm_output,
431 torch::Device device) {
432 auto hidden_states = std::get<1>(lstm_output);
433 return std::make_tuple(
434 std::get<0>(lstm_output).to(device),
435 std::make_tuple(
436 std::get<0>(hidden_states).to(device),
437 std::get<1>(hidden_states).to(device)));
438}
439
440// This test is a port of python code introduced here:
441// https://towardsdatascience.com/understanding-bidirectional-rnn-in-pytorch-5bd25a5dd66
442// Reverse forward of bidirectional GRU should act
443// as regular forward of unidirectional GRU
444void BidirectionalGRUReverseForward(bool cuda) {
445 auto opt = torch::TensorOptions()
446 .dtype(torch::kFloat32)
447 .requires_grad(false)
448 .device(cuda ? torch::kCUDA : torch::kCPU);
449 auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
450 auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
451
452 auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false);
453 GRU bi_grus{gru_options.bidirectional(true)};
454 GRU reverse_gru{gru_options.bidirectional(false)};
455
456 if (cuda) {
457 bi_grus->to(torch::kCUDA);
458 reverse_gru->to(torch::kCUDA);
459 }
460
461 // Now make sure the weights of the reverse gru layer match
462 // ones of the (reversed) bidirectional's:
463 copyParameters(reverse_gru, "0", bi_grus, "0_reverse");
464
465 auto bi_output = bi_grus->forward(input);
466 auto reverse_output = reverse_gru->forward(input_reversed);
467
468 if (cuda) {
469 bi_output = gru_output_to_device(bi_output, torch::kCPU);
470 reverse_output = gru_output_to_device(reverse_output, torch::kCPU);
471 }
472
473 ASSERT_EQ(
474 std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
475 auto size = std::get<0>(bi_output).size(0);
476 for (int i = 0; i < size; i++) {
477 ASSERT_EQ(
478 std::get<0>(bi_output)[i][0][1].item<float>(),
479 std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
480 }
481 // The hidden states of the reversed GRUs sits
482 // in the odd indices in the first dimension.
483 ASSERT_EQ(
484 std::get<1>(bi_output)[1][0][0].item<float>(),
485 std::get<1>(reverse_output)[0][0][0].item<float>());
486}
487
488TEST_F(RNNTest, BidirectionalGRUReverseForward) {
489 BidirectionalGRUReverseForward(false);
490}
491
492TEST_F(RNNTest, BidirectionalGRUReverseForward_CUDA) {
493 BidirectionalGRUReverseForward(true);
494}
495
496// Reverse forward of bidirectional LSTM should act
497// as regular forward of unidirectional LSTM
498void BidirectionalLSTMReverseForwardTest(bool cuda) {
499 auto opt = torch::TensorOptions()
500 .dtype(torch::kFloat32)
501 .requires_grad(false)
502 .device(cuda ? torch::kCUDA : torch::kCPU);
503 auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
504 auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
505
506 auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false);
507
508 LSTM bi_lstm{lstm_opt.bidirectional(true)};
509 LSTM reverse_lstm{lstm_opt.bidirectional(false)};
510
511 if (cuda) {
512 bi_lstm->to(torch::kCUDA);
513 reverse_lstm->to(torch::kCUDA);
514 }
515
516 // Now make sure the weights of the reverse lstm layer match
517 // ones of the (reversed) bidirectional's:
518 copyParameters(reverse_lstm, "0", bi_lstm, "0_reverse");
519
520 auto bi_output = bi_lstm->forward(input);
521 auto reverse_output = reverse_lstm->forward(input_reversed);
522
523 if (cuda) {
524 bi_output = lstm_output_to_device(bi_output, torch::kCPU);
525 reverse_output = lstm_output_to_device(reverse_output, torch::kCPU);
526 }
527
528 ASSERT_EQ(
529 std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
530 auto size = std::get<0>(bi_output).size(0);
531 for (int i = 0; i < size; i++) {
532 ASSERT_EQ(
533 std::get<0>(bi_output)[i][0][1].item<float>(),
534 std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
535 }
536 // The hidden states of the reversed LSTM sits
537 // in the odd indices in the first dimension.
538 ASSERT_EQ(
539 std::get<0>(std::get<1>(bi_output))[1][0][0].item<float>(),
540 std::get<0>(std::get<1>(reverse_output))[0][0][0].item<float>());
541 ASSERT_EQ(
542 std::get<1>(std::get<1>(bi_output))[1][0][0].item<float>(),
543 std::get<1>(std::get<1>(reverse_output))[0][0][0].item<float>());
544}
545
546TEST_F(RNNTest, BidirectionalLSTMReverseForward) {
547 BidirectionalLSTMReverseForwardTest(false);
548}
549
550TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) {
551 BidirectionalLSTMReverseForwardTest(true);
552}
553
554TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
555 // Create two GRUs with the same options
556 auto opt =
557 GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
558 GRU gru_cpu{opt};
559 GRU gru_cuda{opt};
560
561 // Copy weights and biases from CPU GRU to CUDA GRU
562 {
563 at::NoGradGuard guard;
564 for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) {
565 gru_cuda->named_parameters()[param.key()].copy_(
566 gru_cpu->named_parameters()[param.key()]);
567 }
568 }
569
570 gru_cpu->flatten_parameters();
571 gru_cuda->flatten_parameters();
572
573 // Move GRU to CUDA
574 gru_cuda->to(torch::kCUDA);
575
576 // Create the same inputs
577 auto input_opt =
578 torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
579 auto input_cpu =
580 torch::tensor({1, 2, 3, 4, 5, 6}, input_opt).reshape({3, 1, 2});
581 auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, input_opt)
582 .reshape({3, 1, 2})
583 .to(torch::kCUDA);
584
585 // Call forward on both GRUs
586 auto output_cpu = gru_cpu->forward(input_cpu);
587 auto output_cuda = gru_cuda->forward(input_cuda);
588
589 output_cpu = gru_output_to_device(output_cpu, torch::kCPU);
590
591 // Assert that the output and state are equal on CPU and CUDA
592 ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
593 for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
594 ASSERT_EQ(
595 std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
596 }
597 for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
598 for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
599 for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
600 ASSERT_NEAR(
601 std::get<0>(output_cpu)[i][j][k].item<float>(),
602 std::get<0>(output_cuda)[i][j][k].item<float>(),
603 1e-5);
604 }
605 }
606 }
607}
608
609TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
610 // Create two LSTMs with the same options
611 auto opt =
612 LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
613 LSTM lstm_cpu{opt};
614 LSTM lstm_cuda{opt};
615
616 // Copy weights and biases from CPU LSTM to CUDA LSTM
617 {
618 at::NoGradGuard guard;
619 for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
620 lstm_cuda->named_parameters()[param.key()].copy_(
621 lstm_cpu->named_parameters()[param.key()]);
622 }
623 }
624
625 lstm_cpu->flatten_parameters();
626 lstm_cuda->flatten_parameters();
627
628 // Move LSTM to CUDA
629 lstm_cuda->to(torch::kCUDA);
630
631 auto options =
632 torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
633 auto input_cpu =
634 torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2});
635 auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options)
636 .reshape({3, 1, 2})
637 .to(torch::kCUDA);
638
639 // Call forward on both LSTMs
640 auto output_cpu = lstm_cpu->forward(input_cpu);
641 auto output_cuda = lstm_cuda->forward(input_cuda);
642
643 output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
644
645 // Assert that the output and state are equal on CPU and CUDA
646 ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
647 for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
648 ASSERT_EQ(
649 std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
650 }
651 for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
652 for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
653 for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
654 ASSERT_NEAR(
655 std::get<0>(output_cpu)[i][j][k].item<float>(),
656 std::get<0>(output_cuda)[i][j][k].item<float>(),
657 1e-5);
658 }
659 }
660 }
661}
662
663TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) {
664 // Create two LSTMs with the same options
665 auto opt = LSTMOptions(2, 4)
666 .num_layers(3)
667 .batch_first(false)
668 .bidirectional(true)
669 .proj_size(2);
670 LSTM lstm_cpu{opt};
671 LSTM lstm_cuda{opt};
672
673 // Copy weights and biases from CPU LSTM to CUDA LSTM
674 {
675 at::NoGradGuard guard;
676 for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
677 lstm_cuda->named_parameters()[param.key()].copy_(
678 lstm_cpu->named_parameters()[param.key()]);
679 }
680 }
681
682 lstm_cpu->flatten_parameters();
683 lstm_cuda->flatten_parameters();
684
685 // Move LSTM to CUDA
686 lstm_cuda->to(torch::kCUDA);
687
688 auto options =
689 torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
690 auto input_cpu =
691 torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2});
692 auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options)
693 .reshape({3, 1, 2})
694 .to(torch::kCUDA);
695
696 // Call forward on both LSTMs
697 auto output_cpu = lstm_cpu->forward(input_cpu);
698 auto output_cuda = lstm_cuda->forward(input_cuda);
699
700 output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
701
702 // Assert that the output and state are equal on CPU and CUDA
703 ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
704 for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
705 ASSERT_EQ(
706 std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
707 }
708 for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
709 for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
710 for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
711 ASSERT_NEAR(
712 std::get<0>(output_cpu)[i][j][k].item<float>(),
713 std::get<0>(output_cuda)[i][j][k].item<float>(),
714 1e-5);
715 }
716 }
717 }
718}
719
720TEST_F(RNNTest, UsePackedSequenceAsInput) {
721 {
722 torch::manual_seed(0);
723 auto m = RNN(2, 3);
724 torch::nn::utils::rnn::PackedSequence packed_input =
725 torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
726 auto rnn_output = m->forward_with_packed_input(packed_input);
727 auto expected_output = torch::tensor(
728 {{-0.0645, -0.7274, 0.4531},
729 {-0.3970, -0.6950, 0.6009},
730 {-0.3877, -0.7310, 0.6806}});
731 ASSERT_TRUE(torch::allclose(
732 std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
733
734 // Test passing optional argument to `RNN::forward_with_packed_input`
735 rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
736 ASSERT_TRUE(torch::allclose(
737 std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
738 }
739 {
740 torch::manual_seed(0);
741 auto m = LSTM(2, 3);
742 torch::nn::utils::rnn::PackedSequence packed_input =
743 torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
744 auto rnn_output = m->forward_with_packed_input(packed_input);
745 auto expected_output = torch::tensor(
746 {{-0.2693, -0.1240, 0.0744},
747 {-0.3889, -0.1919, 0.1183},
748 {-0.4425, -0.2314, 0.1386}});
749 ASSERT_TRUE(torch::allclose(
750 std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
751
752 // Test passing optional argument to `LSTM::forward_with_packed_input`
753 rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt);
754 ASSERT_TRUE(torch::allclose(
755 std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
756 }
757 {
758 torch::manual_seed(0);
759 auto m = GRU(2, 3);
760 torch::nn::utils::rnn::PackedSequence packed_input =
761 torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
762 auto rnn_output = m->forward_with_packed_input(packed_input);
763 auto expected_output = torch::tensor(
764 {{-0.1134, 0.0467, 0.2336},
765 {-0.1189, 0.0502, 0.2960},
766 {-0.1138, 0.0484, 0.3110}});
767 ASSERT_TRUE(torch::allclose(
768 std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
769
770 // Test passing optional argument to `GRU::forward_with_packed_input`
771 rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
772 ASSERT_TRUE(torch::allclose(
773 std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
774 }
775}
776