1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <test/cpp/api/support.h> |
6 | |
7 | using namespace torch::nn; |
8 | using namespace torch::test; |
9 | |
10 | template <typename R, typename Func> |
11 | bool test_RNN_xor(Func&& model_maker, bool cuda = false) { |
12 | torch::manual_seed(0); |
13 | |
14 | auto nhid = 32; |
15 | auto model = std::make_shared<SimpleContainer>(); |
16 | auto l1 = model->add(Linear(1, nhid), "l1" ); |
17 | auto rnn_model = model_maker(nhid); |
18 | auto rnn = model->add(rnn_model, "rnn" ); |
19 | auto nout = nhid; |
20 | if (rnn_model.get()->options_base.proj_size() > 0) { |
21 | nout = rnn_model.get()->options_base.proj_size(); |
22 | } |
23 | auto lo = model->add(Linear(nout, 1), "lo" ); |
24 | |
25 | torch::optim::Adam optimizer(model->parameters(), 1e-2); |
26 | auto forward_op = [&](torch::Tensor x) { |
27 | auto T = x.size(0); |
28 | auto B = x.size(1); |
29 | x = x.view({T * B, 1}); |
30 | x = l1->forward(x).view({T, B, nhid}).tanh_(); |
31 | x = std::get<0>(rnn->forward(x))[T - 1]; |
32 | x = lo->forward(x); |
33 | return x; |
34 | }; |
35 | |
36 | if (cuda) { |
37 | model->to(torch::kCUDA); |
38 | } |
39 | |
40 | float running_loss = 1; |
41 | int epoch = 0; |
42 | auto max_epoch = 1500; |
43 | while (running_loss > 1e-2) { |
44 | auto bs = 16U; |
45 | auto nlen = 5U; |
46 | |
47 | const auto backend = cuda ? torch::kCUDA : torch::kCPU; |
48 | auto inputs = |
49 | torch::rand({nlen, bs, 1}, backend).round().to(torch::kFloat32); |
50 | auto labels = inputs.sum(0).detach(); |
51 | inputs.set_requires_grad(true); |
52 | auto outputs = forward_op(inputs); |
53 | torch::Tensor loss = torch::mse_loss(outputs, labels); |
54 | |
55 | optimizer.zero_grad(); |
56 | loss.backward(); |
57 | optimizer.step(); |
58 | |
59 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions) |
60 | running_loss = running_loss * 0.99 + loss.item<float>() * 0.01; |
61 | if (epoch > max_epoch) { |
62 | return false; |
63 | } |
64 | epoch++; |
65 | } |
66 | return true; |
67 | }; |
68 | |
69 | void check_lstm_sizes( |
70 | std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> |
71 | lstm_output) { |
72 | // Expect the LSTM to have 64 outputs and 3 layers, with an input of batch |
73 | // 10 and 16 time steps (10 x 16 x n) |
74 | |
75 | torch::Tensor output = std::get<0>(lstm_output); |
76 | std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output); |
77 | torch::Tensor hx = std::get<0>(state); |
78 | torch::Tensor cx = std::get<1>(state); |
79 | |
80 | ASSERT_EQ(output.ndimension(), 3); |
81 | ASSERT_EQ(output.size(0), 10); |
82 | ASSERT_EQ(output.size(1), 16); |
83 | ASSERT_EQ(output.size(2), 64); |
84 | |
85 | ASSERT_EQ(hx.ndimension(), 3); |
86 | ASSERT_EQ(hx.size(0), 3); // layers |
87 | ASSERT_EQ(hx.size(1), 16); // Batchsize |
88 | ASSERT_EQ(hx.size(2), 64); // 64 hidden dims |
89 | |
90 | ASSERT_EQ(cx.ndimension(), 3); |
91 | ASSERT_EQ(cx.size(0), 3); // layers |
92 | ASSERT_EQ(cx.size(1), 16); // Batchsize |
93 | ASSERT_EQ(cx.size(2), 64); // 64 hidden dims |
94 | |
95 | // Something is in the hiddens |
96 | ASSERT_GT(hx.norm().item<float>(), 0); |
97 | ASSERT_GT(cx.norm().item<float>(), 0); |
98 | } |
99 | |
100 | void check_lstm_sizes_proj( |
101 | std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> |
102 | lstm_output) { |
103 | // Expect the LSTM to have 32 outputs and 3 layers, with an input of batch |
104 | // 10 and 16 time steps (10 x 16 x n) |
105 | |
106 | torch::Tensor output = std::get<0>(lstm_output); |
107 | std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output); |
108 | torch::Tensor hx = std::get<0>(state); |
109 | torch::Tensor cx = std::get<1>(state); |
110 | |
111 | ASSERT_EQ(output.ndimension(), 3); |
112 | ASSERT_EQ(output.size(0), 10); |
113 | ASSERT_EQ(output.size(1), 16); |
114 | ASSERT_EQ(output.size(2), 32); |
115 | |
116 | ASSERT_EQ(hx.ndimension(), 3); |
117 | ASSERT_EQ(hx.size(0), 3); // layers |
118 | ASSERT_EQ(hx.size(1), 16); // Batchsize |
119 | ASSERT_EQ(hx.size(2), 32); // 32 hidden dims |
120 | |
121 | ASSERT_EQ(cx.ndimension(), 3); |
122 | ASSERT_EQ(cx.size(0), 3); // layers |
123 | ASSERT_EQ(cx.size(1), 16); // Batchsize |
124 | ASSERT_EQ(cx.size(2), 64); // 64 cell dims |
125 | |
126 | // Something is in the hiddens |
127 | ASSERT_GT(hx.norm().item<float>(), 0); |
128 | ASSERT_GT(cx.norm().item<float>(), 0); |
129 | } |
130 | |
131 | struct RNNTest : torch::test::SeedingFixture {}; |
132 | |
133 | TEST_F(RNNTest, CheckOutputSizes) { |
134 | LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2)); |
135 | // Input size is: sequence length, batch size, input size |
136 | auto x = torch::randn({10, 16, 128}, torch::requires_grad()); |
137 | auto output = model->forward(x); |
138 | auto y = x.mean(); |
139 | |
140 | y.backward(); |
141 | check_lstm_sizes(output); |
142 | |
143 | auto next = model->forward(x, std::get<1>(output)); |
144 | |
145 | check_lstm_sizes(next); |
146 | |
147 | auto output_hx = std::get<0>(std::get<1>(output)); |
148 | auto output_cx = std::get<1>(std::get<1>(output)); |
149 | |
150 | auto next_hx = std::get<0>(std::get<1>(next)); |
151 | auto next_cx = std::get<1>(std::get<1>(next)); |
152 | |
153 | torch::Tensor diff = |
154 | torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0); |
155 | |
156 | // Hiddens changed |
157 | ASSERT_GT(diff.abs().sum().item<float>(), 1e-3); |
158 | } |
159 | |
160 | TEST_F(RNNTest, CheckOutputSizesProj) { |
161 | LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32)); |
162 | // Input size is: sequence length, batch size, input size |
163 | auto x = torch::randn({10, 16, 128}, torch::requires_grad()); |
164 | auto output = model->forward(x); |
165 | auto y = x.mean(); |
166 | |
167 | y.backward(); |
168 | check_lstm_sizes_proj(output); |
169 | |
170 | auto next = model->forward(x, std::get<1>(output)); |
171 | |
172 | check_lstm_sizes_proj(next); |
173 | |
174 | auto output_hx = std::get<0>(std::get<1>(output)); |
175 | auto output_cx = std::get<1>(std::get<1>(output)); |
176 | |
177 | auto next_hx = std::get<0>(std::get<1>(next)); |
178 | auto next_cx = std::get<1>(std::get<1>(next)); |
179 | |
180 | torch::Tensor diff = next_hx - output_hx; |
181 | // Hiddens changed |
182 | ASSERT_GT(diff.abs().sum().item<float>(), 1e-3); |
183 | diff = next_cx - output_cx; |
184 | ASSERT_GT(diff.abs().sum().item<float>(), 1e-3); |
185 | } |
186 | |
187 | TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) { |
188 | torch::manual_seed(0); |
189 | // Make sure the outputs match pytorch outputs |
190 | LSTM model(2, 2); |
191 | for (auto& v : model->parameters()) { |
192 | float size = v.numel(); |
193 | auto p = static_cast<float*>(v.storage().data()); |
194 | for (size_t i = 0; i < size; i++) { |
195 | p[i] = i / size; |
196 | } |
197 | } |
198 | |
199 | auto x = torch::empty({3, 4, 2}, torch::requires_grad()); |
200 | float size = x.numel(); |
201 | auto p = static_cast<float*>(x.storage().data()); |
202 | for (size_t i = 0; i < size; i++) { |
203 | p[i] = (size - i) / size; |
204 | } |
205 | |
206 | auto out = model->forward(x); |
207 | ASSERT_EQ(std::get<0>(out).ndimension(), 3); |
208 | ASSERT_EQ(std::get<0>(out).size(0), 3); |
209 | ASSERT_EQ(std::get<0>(out).size(1), 4); |
210 | ASSERT_EQ(std::get<0>(out).size(2), 2); |
211 | |
212 | auto flat = std::get<0>(out).view(3 * 4 * 2); |
213 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
214 | float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239, |
215 | 0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968, |
216 | 0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003, |
217 | 0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666}; |
218 | for (size_t i = 0; i < 3 * 4 * 2; i++) { |
219 | ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3); |
220 | } |
221 | |
222 | auto hx = std::get<0>(std::get<1>(out)); |
223 | auto cx = std::get<1>(std::get<1>(out)); |
224 | |
225 | ASSERT_EQ(hx.ndimension(), 3); // layers x B x 2 |
226 | ASSERT_EQ(hx.size(0), 1); |
227 | ASSERT_EQ(hx.size(1), 4); |
228 | ASSERT_EQ(hx.size(2), 2); |
229 | |
230 | ASSERT_EQ(cx.ndimension(), 3); // layers x B x 2 |
231 | ASSERT_EQ(cx.size(0), 1); |
232 | ASSERT_EQ(cx.size(1), 4); |
233 | ASSERT_EQ(cx.size(2), 2); |
234 | |
235 | flat = torch::cat({hx, cx}, 0).view(16); |
236 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
237 | float h_out[] = { |
238 | 0.7889, |
239 | 0.9003, |
240 | 0.7769, |
241 | 0.8905, |
242 | 0.7635, |
243 | 0.8794, |
244 | 0.7484, |
245 | 0.8666, |
246 | 1.1647, |
247 | 1.6106, |
248 | 1.1425, |
249 | 1.5726, |
250 | 1.1187, |
251 | 1.5329, |
252 | 1.0931, |
253 | 1.4911}; |
254 | for (size_t i = 0; i < 16; i++) { |
255 | ASSERT_LT(std::abs(flat[i].item<float>() - h_out[i]), 1e-3); |
256 | } |
257 | } |
258 | |
259 | TEST_F(RNNTest, EndToEndLSTM) { |
260 | ASSERT_TRUE(test_RNN_xor<LSTM>( |
261 | [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); })); |
262 | } |
263 | |
264 | TEST_F(RNNTest, EndToEndLSTMProj) { |
265 | ASSERT_TRUE(test_RNN_xor<LSTM>([](int s) { |
266 | return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); |
267 | })); |
268 | } |
269 | |
270 | TEST_F(RNNTest, EndToEndGRU) { |
271 | ASSERT_TRUE(test_RNN_xor<GRU>( |
272 | [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); })); |
273 | } |
274 | |
275 | TEST_F(RNNTest, EndToEndRNNRelu) { |
276 | ASSERT_TRUE(test_RNN_xor<RNN>([](int s) { |
277 | return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); |
278 | })); |
279 | } |
280 | |
281 | TEST_F(RNNTest, EndToEndRNNTanh) { |
282 | ASSERT_TRUE(test_RNN_xor<RNN>([](int s) { |
283 | return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); |
284 | })); |
285 | } |
286 | |
287 | TEST_F(RNNTest, Sizes_CUDA) { |
288 | torch::manual_seed(0); |
289 | LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2)); |
290 | model->to(torch::kCUDA); |
291 | auto x = |
292 | torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA)); |
293 | auto output = model->forward(x); |
294 | auto y = x.mean(); |
295 | |
296 | y.backward(); |
297 | check_lstm_sizes(output); |
298 | |
299 | auto next = model->forward(x, std::get<1>(output)); |
300 | |
301 | check_lstm_sizes(next); |
302 | |
303 | auto output_hx = std::get<0>(std::get<1>(output)); |
304 | auto output_cx = std::get<1>(std::get<1>(output)); |
305 | |
306 | auto next_hx = std::get<0>(std::get<1>(next)); |
307 | auto next_cx = std::get<1>(std::get<1>(next)); |
308 | |
309 | torch::Tensor diff = |
310 | torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0); |
311 | |
312 | // Hiddens changed |
313 | ASSERT_GT(diff.abs().sum().item<float>(), 1e-3); |
314 | } |
315 | |
316 | TEST_F(RNNTest, SizesProj_CUDA) { |
317 | torch::manual_seed(0); |
318 | LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32)); |
319 | model->to(torch::kCUDA); |
320 | auto x = |
321 | torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA)); |
322 | auto output = model->forward(x); |
323 | auto y = x.mean(); |
324 | |
325 | y.backward(); |
326 | check_lstm_sizes_proj(output); |
327 | |
328 | auto next = model->forward(x, std::get<1>(output)); |
329 | |
330 | check_lstm_sizes_proj(next); |
331 | |
332 | auto output_hx = std::get<0>(std::get<1>(output)); |
333 | auto output_cx = std::get<1>(std::get<1>(output)); |
334 | |
335 | auto next_hx = std::get<0>(std::get<1>(next)); |
336 | auto next_cx = std::get<1>(std::get<1>(next)); |
337 | |
338 | torch::Tensor diff = next_hx - output_hx; |
339 | // Hiddens changed |
340 | ASSERT_GT(diff.abs().sum().item<float>(), 1e-3); |
341 | diff = next_cx - output_cx; |
342 | ASSERT_GT(diff.abs().sum().item<float>(), 1e-3); |
343 | } |
344 | |
345 | TEST_F(RNNTest, EndToEndLSTM_CUDA) { |
346 | ASSERT_TRUE(test_RNN_xor<LSTM>( |
347 | [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true)); |
348 | } |
349 | |
350 | TEST_F(RNNTest, EndToEndLSTMProj_CUDA) { |
351 | ASSERT_TRUE(test_RNN_xor<LSTM>( |
352 | [](int s) { |
353 | return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); |
354 | }, |
355 | true)); |
356 | } |
357 | |
358 | TEST_F(RNNTest, EndToEndGRU_CUDA) { |
359 | ASSERT_TRUE(test_RNN_xor<GRU>( |
360 | [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true)); |
361 | } |
362 | |
363 | TEST_F(RNNTest, EndToEndRNNRelu_CUDA) { |
364 | ASSERT_TRUE(test_RNN_xor<RNN>( |
365 | [](int s) { |
366 | return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); |
367 | }, |
368 | true)); |
369 | } |
370 | TEST_F(RNNTest, EndToEndRNNTanh_CUDA) { |
371 | ASSERT_TRUE(test_RNN_xor<RNN>( |
372 | [](int s) { |
373 | return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); |
374 | }, |
375 | true)); |
376 | } |
377 | |
378 | TEST_F(RNNTest, PrettyPrintRNNs) { |
379 | ASSERT_EQ( |
380 | c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))), |
381 | "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)" ); |
382 | ASSERT_EQ( |
383 | c10::str( |
384 | LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32))), |
385 | "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false, proj_size=32)" ); |
386 | ASSERT_EQ( |
387 | c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))), |
388 | "torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)" ); |
389 | ASSERT_EQ( |
390 | c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity( |
391 | torch::kTanh))), |
392 | "torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)" ); |
393 | } |
394 | |
395 | // This test assures that flatten_parameters does not crash, |
396 | // when bidirectional is set to true |
397 | // https://github.com/pytorch/pytorch/issues/19545 |
398 | TEST_F(RNNTest, BidirectionalFlattenParameters) { |
399 | GRU gru(GRUOptions(100, 256).num_layers(2).bidirectional(true)); |
400 | gru->flatten_parameters(); |
401 | } |
402 | |
403 | template <typename Impl> |
404 | void copyParameters( |
405 | torch::nn::ModuleHolder<Impl>& target, |
406 | std::string t_suffix, |
407 | const torch::nn::ModuleHolder<Impl>& source, |
408 | std::string s_suffix) { |
409 | at::NoGradGuard guard; |
410 | target->named_parameters()["weight_ih_l" + t_suffix].copy_( |
411 | source->named_parameters()["weight_ih_l" + s_suffix]); |
412 | target->named_parameters()["weight_hh_l" + t_suffix].copy_( |
413 | source->named_parameters()["weight_hh_l" + s_suffix]); |
414 | target->named_parameters()["bias_ih_l" + t_suffix].copy_( |
415 | source->named_parameters()["bias_ih_l" + s_suffix]); |
416 | target->named_parameters()["bias_hh_l" + t_suffix].copy_( |
417 | source->named_parameters()["bias_hh_l" + s_suffix]); |
418 | } |
419 | |
420 | std::tuple<torch::Tensor, torch::Tensor> gru_output_to_device( |
421 | std::tuple<torch::Tensor, torch::Tensor> gru_output, |
422 | torch::Device device) { |
423 | return std::make_tuple( |
424 | std::get<0>(gru_output).to(device), std::get<1>(gru_output).to(device)); |
425 | } |
426 | |
427 | std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> |
428 | lstm_output_to_device( |
429 | std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> |
430 | lstm_output, |
431 | torch::Device device) { |
432 | auto hidden_states = std::get<1>(lstm_output); |
433 | return std::make_tuple( |
434 | std::get<0>(lstm_output).to(device), |
435 | std::make_tuple( |
436 | std::get<0>(hidden_states).to(device), |
437 | std::get<1>(hidden_states).to(device))); |
438 | } |
439 | |
440 | // This test is a port of python code introduced here: |
441 | // https://towardsdatascience.com/understanding-bidirectional-rnn-in-pytorch-5bd25a5dd66 |
442 | // Reverse forward of bidirectional GRU should act |
443 | // as regular forward of unidirectional GRU |
444 | void BidirectionalGRUReverseForward(bool cuda) { |
445 | auto opt = torch::TensorOptions() |
446 | .dtype(torch::kFloat32) |
447 | .requires_grad(false) |
448 | .device(cuda ? torch::kCUDA : torch::kCPU); |
449 | auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1}); |
450 | auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1}); |
451 | |
452 | auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false); |
453 | GRU bi_grus{gru_options.bidirectional(true)}; |
454 | GRU reverse_gru{gru_options.bidirectional(false)}; |
455 | |
456 | if (cuda) { |
457 | bi_grus->to(torch::kCUDA); |
458 | reverse_gru->to(torch::kCUDA); |
459 | } |
460 | |
461 | // Now make sure the weights of the reverse gru layer match |
462 | // ones of the (reversed) bidirectional's: |
463 | copyParameters(reverse_gru, "0" , bi_grus, "0_reverse" ); |
464 | |
465 | auto bi_output = bi_grus->forward(input); |
466 | auto reverse_output = reverse_gru->forward(input_reversed); |
467 | |
468 | if (cuda) { |
469 | bi_output = gru_output_to_device(bi_output, torch::kCPU); |
470 | reverse_output = gru_output_to_device(reverse_output, torch::kCPU); |
471 | } |
472 | |
473 | ASSERT_EQ( |
474 | std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0)); |
475 | auto size = std::get<0>(bi_output).size(0); |
476 | for (int i = 0; i < size; i++) { |
477 | ASSERT_EQ( |
478 | std::get<0>(bi_output)[i][0][1].item<float>(), |
479 | std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>()); |
480 | } |
481 | // The hidden states of the reversed GRUs sits |
482 | // in the odd indices in the first dimension. |
483 | ASSERT_EQ( |
484 | std::get<1>(bi_output)[1][0][0].item<float>(), |
485 | std::get<1>(reverse_output)[0][0][0].item<float>()); |
486 | } |
487 | |
488 | TEST_F(RNNTest, BidirectionalGRUReverseForward) { |
489 | BidirectionalGRUReverseForward(false); |
490 | } |
491 | |
492 | TEST_F(RNNTest, BidirectionalGRUReverseForward_CUDA) { |
493 | BidirectionalGRUReverseForward(true); |
494 | } |
495 | |
496 | // Reverse forward of bidirectional LSTM should act |
497 | // as regular forward of unidirectional LSTM |
498 | void BidirectionalLSTMReverseForwardTest(bool cuda) { |
499 | auto opt = torch::TensorOptions() |
500 | .dtype(torch::kFloat32) |
501 | .requires_grad(false) |
502 | .device(cuda ? torch::kCUDA : torch::kCPU); |
503 | auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1}); |
504 | auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1}); |
505 | |
506 | auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false); |
507 | |
508 | LSTM bi_lstm{lstm_opt.bidirectional(true)}; |
509 | LSTM reverse_lstm{lstm_opt.bidirectional(false)}; |
510 | |
511 | if (cuda) { |
512 | bi_lstm->to(torch::kCUDA); |
513 | reverse_lstm->to(torch::kCUDA); |
514 | } |
515 | |
516 | // Now make sure the weights of the reverse lstm layer match |
517 | // ones of the (reversed) bidirectional's: |
518 | copyParameters(reverse_lstm, "0" , bi_lstm, "0_reverse" ); |
519 | |
520 | auto bi_output = bi_lstm->forward(input); |
521 | auto reverse_output = reverse_lstm->forward(input_reversed); |
522 | |
523 | if (cuda) { |
524 | bi_output = lstm_output_to_device(bi_output, torch::kCPU); |
525 | reverse_output = lstm_output_to_device(reverse_output, torch::kCPU); |
526 | } |
527 | |
528 | ASSERT_EQ( |
529 | std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0)); |
530 | auto size = std::get<0>(bi_output).size(0); |
531 | for (int i = 0; i < size; i++) { |
532 | ASSERT_EQ( |
533 | std::get<0>(bi_output)[i][0][1].item<float>(), |
534 | std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>()); |
535 | } |
536 | // The hidden states of the reversed LSTM sits |
537 | // in the odd indices in the first dimension. |
538 | ASSERT_EQ( |
539 | std::get<0>(std::get<1>(bi_output))[1][0][0].item<float>(), |
540 | std::get<0>(std::get<1>(reverse_output))[0][0][0].item<float>()); |
541 | ASSERT_EQ( |
542 | std::get<1>(std::get<1>(bi_output))[1][0][0].item<float>(), |
543 | std::get<1>(std::get<1>(reverse_output))[0][0][0].item<float>()); |
544 | } |
545 | |
546 | TEST_F(RNNTest, BidirectionalLSTMReverseForward) { |
547 | BidirectionalLSTMReverseForwardTest(false); |
548 | } |
549 | |
550 | TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) { |
551 | BidirectionalLSTMReverseForwardTest(true); |
552 | } |
553 | |
554 | TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) { |
555 | // Create two GRUs with the same options |
556 | auto opt = |
557 | GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true); |
558 | GRU gru_cpu{opt}; |
559 | GRU gru_cuda{opt}; |
560 | |
561 | // Copy weights and biases from CPU GRU to CUDA GRU |
562 | { |
563 | at::NoGradGuard guard; |
564 | for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) { |
565 | gru_cuda->named_parameters()[param.key()].copy_( |
566 | gru_cpu->named_parameters()[param.key()]); |
567 | } |
568 | } |
569 | |
570 | gru_cpu->flatten_parameters(); |
571 | gru_cuda->flatten_parameters(); |
572 | |
573 | // Move GRU to CUDA |
574 | gru_cuda->to(torch::kCUDA); |
575 | |
576 | // Create the same inputs |
577 | auto input_opt = |
578 | torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); |
579 | auto input_cpu = |
580 | torch::tensor({1, 2, 3, 4, 5, 6}, input_opt).reshape({3, 1, 2}); |
581 | auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, input_opt) |
582 | .reshape({3, 1, 2}) |
583 | .to(torch::kCUDA); |
584 | |
585 | // Call forward on both GRUs |
586 | auto output_cpu = gru_cpu->forward(input_cpu); |
587 | auto output_cuda = gru_cuda->forward(input_cuda); |
588 | |
589 | output_cpu = gru_output_to_device(output_cpu, torch::kCPU); |
590 | |
591 | // Assert that the output and state are equal on CPU and CUDA |
592 | ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim()); |
593 | for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) { |
594 | ASSERT_EQ( |
595 | std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); |
596 | } |
597 | for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) { |
598 | for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) { |
599 | for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) { |
600 | ASSERT_NEAR( |
601 | std::get<0>(output_cpu)[i][j][k].item<float>(), |
602 | std::get<0>(output_cuda)[i][j][k].item<float>(), |
603 | 1e-5); |
604 | } |
605 | } |
606 | } |
607 | } |
608 | |
609 | TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) { |
610 | // Create two LSTMs with the same options |
611 | auto opt = |
612 | LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true); |
613 | LSTM lstm_cpu{opt}; |
614 | LSTM lstm_cuda{opt}; |
615 | |
616 | // Copy weights and biases from CPU LSTM to CUDA LSTM |
617 | { |
618 | at::NoGradGuard guard; |
619 | for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) { |
620 | lstm_cuda->named_parameters()[param.key()].copy_( |
621 | lstm_cpu->named_parameters()[param.key()]); |
622 | } |
623 | } |
624 | |
625 | lstm_cpu->flatten_parameters(); |
626 | lstm_cuda->flatten_parameters(); |
627 | |
628 | // Move LSTM to CUDA |
629 | lstm_cuda->to(torch::kCUDA); |
630 | |
631 | auto options = |
632 | torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); |
633 | auto input_cpu = |
634 | torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2}); |
635 | auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options) |
636 | .reshape({3, 1, 2}) |
637 | .to(torch::kCUDA); |
638 | |
639 | // Call forward on both LSTMs |
640 | auto output_cpu = lstm_cpu->forward(input_cpu); |
641 | auto output_cuda = lstm_cuda->forward(input_cuda); |
642 | |
643 | output_cpu = lstm_output_to_device(output_cpu, torch::kCPU); |
644 | |
645 | // Assert that the output and state are equal on CPU and CUDA |
646 | ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim()); |
647 | for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) { |
648 | ASSERT_EQ( |
649 | std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); |
650 | } |
651 | for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) { |
652 | for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) { |
653 | for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) { |
654 | ASSERT_NEAR( |
655 | std::get<0>(output_cpu)[i][j][k].item<float>(), |
656 | std::get<0>(output_cuda)[i][j][k].item<float>(), |
657 | 1e-5); |
658 | } |
659 | } |
660 | } |
661 | } |
662 | |
663 | TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) { |
664 | // Create two LSTMs with the same options |
665 | auto opt = LSTMOptions(2, 4) |
666 | .num_layers(3) |
667 | .batch_first(false) |
668 | .bidirectional(true) |
669 | .proj_size(2); |
670 | LSTM lstm_cpu{opt}; |
671 | LSTM lstm_cuda{opt}; |
672 | |
673 | // Copy weights and biases from CPU LSTM to CUDA LSTM |
674 | { |
675 | at::NoGradGuard guard; |
676 | for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) { |
677 | lstm_cuda->named_parameters()[param.key()].copy_( |
678 | lstm_cpu->named_parameters()[param.key()]); |
679 | } |
680 | } |
681 | |
682 | lstm_cpu->flatten_parameters(); |
683 | lstm_cuda->flatten_parameters(); |
684 | |
685 | // Move LSTM to CUDA |
686 | lstm_cuda->to(torch::kCUDA); |
687 | |
688 | auto options = |
689 | torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); |
690 | auto input_cpu = |
691 | torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2}); |
692 | auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options) |
693 | .reshape({3, 1, 2}) |
694 | .to(torch::kCUDA); |
695 | |
696 | // Call forward on both LSTMs |
697 | auto output_cpu = lstm_cpu->forward(input_cpu); |
698 | auto output_cuda = lstm_cuda->forward(input_cuda); |
699 | |
700 | output_cpu = lstm_output_to_device(output_cpu, torch::kCPU); |
701 | |
702 | // Assert that the output and state are equal on CPU and CUDA |
703 | ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim()); |
704 | for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) { |
705 | ASSERT_EQ( |
706 | std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); |
707 | } |
708 | for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) { |
709 | for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) { |
710 | for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) { |
711 | ASSERT_NEAR( |
712 | std::get<0>(output_cpu)[i][j][k].item<float>(), |
713 | std::get<0>(output_cuda)[i][j][k].item<float>(), |
714 | 1e-5); |
715 | } |
716 | } |
717 | } |
718 | } |
719 | |
720 | TEST_F(RNNTest, UsePackedSequenceAsInput) { |
721 | { |
722 | torch::manual_seed(0); |
723 | auto m = RNN(2, 3); |
724 | torch::nn::utils::rnn::PackedSequence packed_input = |
725 | torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); |
726 | auto rnn_output = m->forward_with_packed_input(packed_input); |
727 | auto expected_output = torch::tensor( |
728 | {{-0.0645, -0.7274, 0.4531}, |
729 | {-0.3970, -0.6950, 0.6009}, |
730 | {-0.3877, -0.7310, 0.6806}}); |
731 | ASSERT_TRUE(torch::allclose( |
732 | std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); |
733 | |
734 | // Test passing optional argument to `RNN::forward_with_packed_input` |
735 | rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor()); |
736 | ASSERT_TRUE(torch::allclose( |
737 | std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); |
738 | } |
739 | { |
740 | torch::manual_seed(0); |
741 | auto m = LSTM(2, 3); |
742 | torch::nn::utils::rnn::PackedSequence packed_input = |
743 | torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); |
744 | auto rnn_output = m->forward_with_packed_input(packed_input); |
745 | auto expected_output = torch::tensor( |
746 | {{-0.2693, -0.1240, 0.0744}, |
747 | {-0.3889, -0.1919, 0.1183}, |
748 | {-0.4425, -0.2314, 0.1386}}); |
749 | ASSERT_TRUE(torch::allclose( |
750 | std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); |
751 | |
752 | // Test passing optional argument to `LSTM::forward_with_packed_input` |
753 | rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt); |
754 | ASSERT_TRUE(torch::allclose( |
755 | std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); |
756 | } |
757 | { |
758 | torch::manual_seed(0); |
759 | auto m = GRU(2, 3); |
760 | torch::nn::utils::rnn::PackedSequence packed_input = |
761 | torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); |
762 | auto rnn_output = m->forward_with_packed_input(packed_input); |
763 | auto expected_output = torch::tensor( |
764 | {{-0.1134, 0.0467, 0.2336}, |
765 | {-0.1189, 0.0502, 0.2960}, |
766 | {-0.1138, 0.0484, 0.3110}}); |
767 | ASSERT_TRUE(torch::allclose( |
768 | std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); |
769 | |
770 | // Test passing optional argument to `GRU::forward_with_packed_input` |
771 | rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor()); |
772 | ASSERT_TRUE(torch::allclose( |
773 | std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); |
774 | } |
775 | } |
776 | |