1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <test/cpp/api/support.h> |
7 | |
8 | #include <algorithm> |
9 | #include <random> |
10 | #include <sstream> |
11 | #include <string> |
12 | |
13 | using namespace torch::nn; |
14 | |
15 | namespace rnn_utils = torch::nn::utils::rnn; |
16 | |
17 | struct NNUtilsTest : torch::test::SeedingFixture {}; |
18 | struct PackedSequenceTest : torch::test::SeedingFixture {}; |
19 | |
20 | TEST_F(NNUtilsTest, ClipGradNorm) { |
21 | auto l = Linear(10, 10); |
22 | float max_norm = 2; |
23 | auto compute_norm = [&](float norm_type) -> float { |
24 | float total_norm = 0.0; |
25 | if (norm_type != std::numeric_limits<float>::infinity()) { |
26 | for (const auto& p : l->parameters()) { |
27 | total_norm += |
28 | p.grad().data().abs().pow(norm_type).sum().item().toFloat(); |
29 | } |
30 | return std::pow(total_norm, 1.0 / norm_type); |
31 | } else { |
32 | for (const auto& p : l->parameters()) { |
33 | auto param_max = p.grad().data().abs().max().item().toFloat(); |
34 | if (param_max > total_norm) { |
35 | total_norm = param_max; |
36 | } |
37 | } |
38 | return total_norm; |
39 | } |
40 | }; |
41 | auto compare_scaling = |
42 | [&](const std::vector<torch::Tensor>& grads) -> torch::Tensor { |
43 | std::vector<torch::Tensor> p_scale; |
44 | for (const auto i : c10::irange(grads.size())) { |
45 | auto param = l->parameters()[i]; |
46 | auto grad = grads[i]; |
47 | p_scale.push_back(param.grad().data().div(grad).view(-1)); |
48 | } |
49 | auto scale = torch::cat(p_scale); |
50 | return scale; // need to assert std is 0. |
51 | }; |
52 | |
53 | std::vector<torch::Tensor> grads = { |
54 | torch::arange(1.0, 101).view({10, 10}), |
55 | torch::ones({10}).div(1000), |
56 | }; |
57 | std::vector<float> norm_types = { |
58 | 0.5, |
59 | 1.5, |
60 | 2.0, |
61 | 4.0, |
62 | std::numeric_limits<float>::infinity(), |
63 | }; |
64 | for (auto norm_type : norm_types) { |
65 | for (const auto i : c10::irange(grads.size())) { |
66 | l->parameters()[i].mutable_grad() = |
67 | grads[i].clone().view_as(l->parameters()[i].data()); |
68 | } |
69 | auto norm_before = compute_norm(norm_type); |
70 | auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type); |
71 | auto norm_after = compute_norm(norm_type); |
72 | ASSERT_FLOAT_EQ(norm, norm_before); |
73 | ASSERT_NEAR(norm_after, max_norm, 1e-6); |
74 | ASSERT_LE(norm_after, max_norm); |
75 | auto scaled = compare_scaling(grads); |
76 | ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7); |
77 | } |
78 | // Small gradients should be left unchanged |
79 | grads = { |
80 | torch::rand({10, 10}).div(10000), |
81 | torch::ones(10).div(500), |
82 | }; |
83 | for (auto norm_type : norm_types) { |
84 | for (const auto i : c10::irange(grads.size())) { |
85 | l->parameters()[i].grad().data().copy_(grads[i]); |
86 | } |
87 | auto norm_before = compute_norm(norm_type); |
88 | auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type); |
89 | auto norm_after = compute_norm(norm_type); |
90 | ASSERT_FLOAT_EQ(norm, norm_before); |
91 | ASSERT_FLOAT_EQ(norm_before, norm_after); |
92 | ASSERT_LE(norm_after, max_norm); |
93 | auto scaled = compare_scaling(grads); |
94 | ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7); |
95 | ASSERT_FLOAT_EQ(scaled[0].item().toFloat(), 1); |
96 | } |
97 | // should accept a single tensor as input |
98 | auto p1 = torch::randn({10, 10}); |
99 | auto p2 = torch::randn({10, 10}); |
100 | auto g = torch::arange(1., 101).view({10, 10}); |
101 | p1.mutable_grad() = g.clone(); |
102 | p2.mutable_grad() = g.clone(); |
103 | for (const auto norm_type : norm_types) { |
104 | utils::clip_grad_norm_(p1, max_norm, norm_type); |
105 | utils::clip_grad_norm_({p2}, max_norm, norm_type); |
106 | ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad())); |
107 | } |
108 | } |
109 | |
110 | // Check that clip_grad_norm_ raises an error if the norm of a gradient |
111 | // is non-finite |
112 | TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) { |
113 | double inf = std::numeric_limits<double>::infinity(); |
114 | double nan = std::numeric_limits<double>::quiet_NaN(); |
115 | |
116 | using Vector = std::vector<double>; |
117 | |
118 | Vector norms_pos = {0.1, 1, 2, 3.5, inf}; |
119 | Vector norms_neg = {-0.1, -1, -2, -3.5}; |
120 | Vector norms_neg_plus_0 = {0, -0.1, -1, -2, -3.5}; |
121 | Vector norms_except_0 = {0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5}; |
122 | Vector norms_all = {0, 0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5}; |
123 | |
124 | // Each entry in test_cases has the following values, in this order: |
125 | // |
126 | // grad_only_one_elem If True, only one element of the parameter's |
127 | // gradient is set to the scalar grad, and the |
128 | // rest of the elements are 0. If False, all grad |
129 | // elements are equal to the scalar. |
130 | // |
131 | // prefix_finite_grad_param If True, prefix a parameter that has a grad |
132 | // of 1. |
133 | // |
134 | // scalars Scalars to use as the parameter's grad, through |
135 | // multiplication |
136 | // |
137 | // norms_nonfinite Norm types that should produce nonfinite total norm |
138 | // |
139 | // norms_finite Norm types that should produce finite total norm |
140 | std::vector<std::tuple<bool, bool, Vector, Vector, Vector>> test_cases({ |
141 | // Test errors from an infinite grad |
142 | std::make_tuple( |
143 | false, false, Vector({inf, -inf}), norms_except_0, Vector({0})), |
144 | std::make_tuple( |
145 | false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), |
146 | std::make_tuple( |
147 | true, false, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), |
148 | std::make_tuple( |
149 | false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), |
150 | |
151 | // Test errors from a NaN grad |
152 | std::make_tuple(false, false, Vector({nan}), norms_except_0, Vector({0})), |
153 | std::make_tuple(false, true, Vector({nan}), norms_except_0, Vector({0})), |
154 | std::make_tuple(true, false, Vector({nan}), norms_except_0, Vector({0})), |
155 | std::make_tuple(true, true, Vector({nan}), norms_except_0, Vector({0})), |
156 | |
157 | // Test a grad that should never error |
158 | std::make_tuple(false, false, Vector({2e22, -2e22}), Vector(), norms_all), |
159 | std::make_tuple(false, true, Vector({2e22, -2e22}), Vector(), norms_all), |
160 | std::make_tuple(true, false, Vector({2e22, -2e22}), Vector(), norms_all), |
161 | std::make_tuple(true, true, Vector({2e22, -2e22}), Vector(), norms_all), |
162 | |
163 | // Test a grad that will overflow to inf for only some norm orders |
164 | std::make_tuple( |
165 | false, |
166 | false, |
167 | Vector({2e200, -2e200}), |
168 | Vector({3.5, 2, -2, -3.5}), |
169 | Vector({inf, 1, 0.1, 0, -1, -0.1})), |
170 | std::make_tuple( |
171 | false, |
172 | true, |
173 | Vector({2e200, -2e200}), |
174 | Vector({3.5, 2}), |
175 | Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), |
176 | std::make_tuple( |
177 | true, |
178 | false, |
179 | Vector({2e200, -2e200}), |
180 | Vector({3.5, 2}), |
181 | Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), |
182 | std::make_tuple( |
183 | false, |
184 | true, |
185 | Vector({2e200, -2e200}), |
186 | Vector({3.5, 2}), |
187 | Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), |
188 | }); |
189 | |
190 | auto gen_parameters = [](double scalar, |
191 | bool grad_only_one_elem, |
192 | bool prefix_finite_grad_param, |
193 | torch::DeviceType device_type) { |
194 | auto param = torch::ones( |
195 | 10, |
196 | torch::TensorOptions() |
197 | .dtype(torch::kDouble) |
198 | .device(device_type) |
199 | .requires_grad(true)); |
200 | if (grad_only_one_elem) { |
201 | param[1].mul(scalar).sum().backward(); |
202 | } else { |
203 | param.mul(scalar).sum().backward(); |
204 | } |
205 | |
206 | std::vector<torch::Tensor> parameters; |
207 | if (prefix_finite_grad_param) { |
208 | auto prefix_param = torch::ones( |
209 | 1, |
210 | torch::TensorOptions() |
211 | .dtype(torch::kDouble) |
212 | .device(device_type) |
213 | .requires_grad(true)); |
214 | prefix_param.mul(1).sum().backward(); |
215 | parameters.push_back(prefix_param); |
216 | } |
217 | parameters.push_back(param); |
218 | |
219 | return parameters; |
220 | }; |
221 | |
222 | auto run_test_case = [&gen_parameters]( |
223 | double norm_type, |
224 | bool error_if_nonfinite, |
225 | double scalar, |
226 | bool grad_only_one_elem, |
227 | bool prefix_finite_grad_param, |
228 | bool is_norm_nonfinite, |
229 | torch::DeviceType device_type) { |
230 | std::stringstream ss; |
231 | ss << "device: " << device_type << ", norm_type: " << norm_type |
232 | << ", error_if_nonfinite: " << error_if_nonfinite |
233 | << ", scalar: " << scalar |
234 | << ", grad_only_one_elem: " << grad_only_one_elem |
235 | << ", prefix_finite_grad_param: " << prefix_finite_grad_param |
236 | << ", is_norm_nonfinite: " << is_norm_nonfinite; |
237 | std::string msg = ss.str(); |
238 | |
239 | auto parameters = gen_parameters( |
240 | scalar, grad_only_one_elem, prefix_finite_grad_param, device_type); |
241 | |
242 | if (is_norm_nonfinite && error_if_nonfinite) { |
243 | std::vector<torch::Tensor> grads_before; |
244 | // NOLINTNEXTLINE(performance-for-range-copy) |
245 | for (auto p : parameters) { |
246 | // NOLINTNEXTLINE(performance-inefficient-vector-operation) |
247 | grads_before.push_back(p.grad().clone()); |
248 | } |
249 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
250 | EXPECT_THROW( |
251 | utils::clip_grad_norm_(parameters, 1., norm_type, true), |
252 | std::exception) |
253 | << msg; |
254 | // Grads should not change if error is thrown |
255 | for (const auto p_idx : c10::irange(parameters.size())) { |
256 | ASSERT_TRUE(torch::allclose( |
257 | parameters[p_idx].grad(), |
258 | grads_before[p_idx], |
259 | 1.0, |
260 | 0.0, |
261 | /*equal_nan*/ true)) |
262 | << msg; |
263 | } |
264 | } else { |
265 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
266 | EXPECT_NO_THROW( |
267 | utils::clip_grad_norm_(parameters, 1., norm_type, error_if_nonfinite)) |
268 | << msg; |
269 | } |
270 | }; |
271 | |
272 | for (auto device_type : {torch::kCPU, torch::kCUDA}) { |
273 | if (device_type == torch::kCUDA && !torch::cuda::is_available()) { |
274 | continue; |
275 | } |
276 | for (auto test_case : test_cases) { |
277 | auto grad_only_one_elem = std::get<0>(test_case); |
278 | auto prefix_finite_grad_param = std::get<1>(test_case); |
279 | auto scalars = std::get<2>(test_case); |
280 | auto norms_nonfinite = std::get<3>(test_case); |
281 | auto norms_finite = std::get<4>(test_case); |
282 | |
283 | for (auto error_if_nonfinite : {false, true}) { |
284 | for (auto scalar : scalars) { |
285 | for (auto norm_type : norms_nonfinite) { |
286 | run_test_case( |
287 | norm_type, |
288 | error_if_nonfinite, |
289 | scalar, |
290 | grad_only_one_elem, |
291 | prefix_finite_grad_param, |
292 | true, |
293 | device_type); |
294 | } |
295 | |
296 | for (auto norm_type : norms_finite) { |
297 | run_test_case( |
298 | norm_type, |
299 | error_if_nonfinite, |
300 | scalar, |
301 | grad_only_one_elem, |
302 | prefix_finite_grad_param, |
303 | false, |
304 | device_type); |
305 | } |
306 | } |
307 | } |
308 | } |
309 | } |
310 | } |
311 | |
312 | TEST_F(NNUtilsTest, ClipGradValue) { |
313 | auto l = Linear(10, 10); |
314 | float clip_value = 2.5; |
315 | |
316 | torch::Tensor grad_w = torch::arange(-50., 50).view({10, 10}).div_(5); |
317 | torch::Tensor grad_b = torch::ones({10}).mul_(2); |
318 | std::vector<std::vector<torch::Tensor>> grad_lists = { |
319 | {grad_w, grad_b}, {grad_w, torch::Tensor()}}; |
320 | for (auto grad_list : grad_lists) { |
321 | for (const auto i : c10::irange(grad_list.size())) { |
322 | auto p = l->parameters()[i]; |
323 | auto g = grad_list[i]; |
324 | p.mutable_grad() = g.defined() ? g.clone().view_as(p.data()) : g; |
325 | } |
326 | |
327 | utils::clip_grad_value_(l->parameters(), clip_value); |
328 | for (const auto& p : l->parameters()) { |
329 | if (p.grad().defined()) { |
330 | ASSERT_LE(p.grad().data().max().item().toFloat(), clip_value); |
331 | ASSERT_GE(p.grad().data().min().item().toFloat(), -clip_value); |
332 | } |
333 | } |
334 | } |
335 | |
336 | // Should accept a single Tensor as input |
337 | auto p1 = torch::randn({10, 10}); |
338 | auto p2 = torch::randn({10, 10}); |
339 | auto g = torch::arange(-50., 50).view({10, 10}).div_(5); |
340 | p1.mutable_grad() = g.clone(); |
341 | p2.mutable_grad() = g.clone(); |
342 | utils::clip_grad_value_(p1, clip_value); |
343 | utils::clip_grad_value_({p2}, clip_value); |
344 | ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad())); |
345 | } |
346 | |
347 | TEST_F(NNUtilsTest, ConvertParameters) { |
348 | std::vector<torch::Tensor> parameters{ |
349 | torch::arange(9, torch::kFloat32), |
350 | torch::arange(9, torch::kFloat32).view({3, 3}), |
351 | torch::arange(8, torch::kFloat32).view({2, 2, 2})}; |
352 | |
353 | auto expected = torch::cat( |
354 | {torch::arange(9, torch::kFloat32), |
355 | torch::arange(9, torch::kFloat32).view(-1), |
356 | torch::arange(8, torch::kFloat32).view(-1)}); |
357 | auto vector = utils::parameters_to_vector(parameters); |
358 | ASSERT_TRUE(vector.allclose(expected)); |
359 | |
360 | std::vector<torch::Tensor> zero_parameters{ |
361 | torch::zeros({9}, torch::kFloat32), |
362 | torch::zeros({9}, torch::kFloat32).view({3, 3}), |
363 | torch::zeros({8}, torch::kFloat32).view({2, 2, 2})}; |
364 | |
365 | utils::vector_to_parameters(vector, zero_parameters); |
366 | for (const auto i : c10::irange(zero_parameters.size())) { |
367 | ASSERT_TRUE(zero_parameters[i].allclose(parameters[i])); |
368 | } |
369 | |
370 | { |
371 | auto conv1 = Conv2d(3, 10, 5); |
372 | auto fc1 = Linear(10, 20); |
373 | auto model = Sequential(conv1, fc1); |
374 | |
375 | auto vec = utils::parameters_to_vector(model->parameters()); |
376 | ASSERT_EQ(vec.size(0), 980); |
377 | } |
378 | { |
379 | auto conv1 = Conv2d(3, 10, 5); |
380 | auto fc1 = Linear(10, 20); |
381 | auto model = Sequential(conv1, fc1); |
382 | |
383 | auto vec = torch::arange(0., 980); |
384 | utils::vector_to_parameters(vec, model->parameters()); |
385 | |
386 | auto sample = model->parameters()[0][0][0][0]; |
387 | ASSERT_TRUE(torch::equal(sample.data(), vec.data().slice(0, 0, 5))); |
388 | } |
389 | } |
390 | |
391 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables) |
392 | int64_t PackedSequenceTest_batch_size = 5; |
393 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables) |
394 | int64_t PackedSequenceTest_max_length = 6; |
395 | |
396 | std::vector<torch::Tensor> PackedSequenceTest_ordered_sequence( |
397 | torch::ScalarType tensor_type) { |
398 | std::vector<torch::Tensor> seqs; |
399 | seqs.reserve(PackedSequenceTest_batch_size); |
400 | for (const auto i : c10::irange(PackedSequenceTest_batch_size)) { |
401 | (void)i; // Suppress unused variable warning |
402 | seqs.emplace_back(torch::empty( |
403 | {torch::randint(1, PackedSequenceTest_max_length, {1}).item<int64_t>()}, |
404 | tensor_type)); |
405 | } |
406 | for (auto& s : seqs) { |
407 | s.random_(-128, 128); |
408 | } |
409 | sort( |
410 | seqs.begin(), |
411 | seqs.end(), |
412 | [&](const torch::Tensor& t1, const torch::Tensor& t2) { |
413 | return t1.size(0) > t2.size(0); |
414 | }); |
415 | return seqs; |
416 | } |
417 | |
418 | std::tuple<torch::Tensor, torch::Tensor> PackedSequenceTest_padded_sequence( |
419 | torch::ScalarType tensor_type) { |
420 | // Create Tensor of random padded sequences |
421 | auto ordered = PackedSequenceTest_ordered_sequence(tensor_type); |
422 | auto lengths = torch::empty({(int64_t)ordered.size()}, torch::kInt64); |
423 | for (const auto i : c10::irange(ordered.size())) { |
424 | lengths[i] = ordered[i].size(0); |
425 | } |
426 | auto padded_tensor = rnn_utils::pad_sequence(ordered); |
427 | return std::make_tuple(padded_tensor, lengths); |
428 | } |
429 | |
430 | void assert_is_equal_packed_sequence( |
431 | const rnn_utils::PackedSequence& a, |
432 | const rnn_utils::PackedSequence& b) { |
433 | ASSERT_TRUE(torch::allclose(a.data(), b.data())); |
434 | ASSERT_TRUE(torch::allclose(a.batch_sizes(), b.batch_sizes())); |
435 | ASSERT_TRUE( |
436 | (!a.sorted_indices().defined() && !b.sorted_indices().defined()) || |
437 | torch::allclose(a.sorted_indices(), b.sorted_indices())); |
438 | ASSERT_TRUE( |
439 | (!a.unsorted_indices().defined() && !b.unsorted_indices().defined()) || |
440 | torch::allclose(a.unsorted_indices(), b.unsorted_indices())); |
441 | } |
442 | |
443 | void assert_is_same_packed_sequence( |
444 | const rnn_utils::PackedSequence& a, |
445 | const rnn_utils::PackedSequence& b) { |
446 | ASSERT_TRUE(a.data().is_same(b.data())); |
447 | ASSERT_TRUE(a.batch_sizes().is_same(b.batch_sizes())); |
448 | ASSERT_TRUE(a.sorted_indices().is_same(b.sorted_indices())); |
449 | ASSERT_TRUE(a.unsorted_indices().is_same(b.unsorted_indices())); |
450 | } |
451 | |
452 | TEST_F(PackedSequenceTest, WrongOrder) { |
453 | auto a = torch::ones({25, 300}); |
454 | auto b = torch::ones({22, 300}); |
455 | auto b_a = rnn_utils::pad_sequence({b, a}); |
456 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
457 | ASSERT_THROW( |
458 | rnn_utils::pack_padded_sequence( |
459 | b_a, |
460 | torch::tensor({22, 25}), |
461 | /*batch_first=*/false, |
462 | /*enforce_sorted=*/true), |
463 | c10::Error); |
464 | } |
465 | |
466 | TEST_F(PackedSequenceTest, TotalLength) { |
467 | torch::Tensor padded, lengths; |
468 | std::tie(padded, lengths) = PackedSequenceTest_padded_sequence(torch::kFloat); |
469 | int64_t max_length = torch::max(lengths).item<int64_t>(); |
470 | rnn_utils::PackedSequence packed = |
471 | rnn_utils::pack_padded_sequence(padded, lengths); |
472 | |
473 | // test ValueError if total_length < max_length |
474 | for (int64_t total_length : std::vector<int64_t>{-1, 0, max_length - 1}) { |
475 | for (bool batch_first : std::vector<bool>{true, false}) { |
476 | auto err_fn = [&]() { |
477 | rnn_utils::pad_packed_sequence( |
478 | packed, |
479 | /*batch_first=*/batch_first, |
480 | /*padding_value=*/0.0, |
481 | /*total_length=*/total_length); |
482 | }; |
483 | ASSERT_THROWS_WITH( |
484 | err_fn(), |
485 | "Expected total_length to be at least the length of the longest sequence in input" ); |
486 | } |
487 | } |
488 | |
489 | // test that pad_packed_sequence returns results of correct length |
490 | for (bool batch_first : std::vector<bool>{true, false}) { |
491 | torch::Tensor , ignored; |
492 | std::tie(no_extra_pad, ignored) = |
493 | rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first); |
494 | for (int64_t total_length_delta : std::vector<int64_t>{0, 1, 8}) { |
495 | int64_t total_length = max_length + total_length_delta; |
496 | torch::Tensor unpacked, lengths_out; |
497 | std::tie(unpacked, lengths_out) = rnn_utils::pad_packed_sequence( |
498 | packed, |
499 | /*batch_first=*/batch_first, |
500 | /*padding_value=*/0.0, |
501 | /*total_length=*/total_length); |
502 | ASSERT_TRUE(torch::allclose(lengths, lengths_out)); |
503 | ASSERT_EQ(unpacked.size(batch_first ? 1 : 0), total_length); |
504 | torch::Tensor ref_output, ; |
505 | if (total_length_delta == 0) { |
506 | ref_output = no_extra_pad; |
507 | } else if (batch_first) { |
508 | extra_pad = torch::zeros( |
509 | {PackedSequenceTest_batch_size, total_length_delta}, |
510 | no_extra_pad.options()); |
511 | ref_output = torch::cat({no_extra_pad, extra_pad}, 1); |
512 | } else { |
513 | extra_pad = torch::zeros( |
514 | {total_length_delta, PackedSequenceTest_batch_size}, |
515 | no_extra_pad.options()); |
516 | ref_output = torch::cat({no_extra_pad, extra_pad}, 0); |
517 | } |
518 | ASSERT_TRUE(torch::allclose(unpacked, ref_output)); |
519 | } |
520 | } |
521 | } |
522 | |
523 | TEST_F(PackedSequenceTest, To) { |
524 | for (bool enforce_sorted : std::vector<bool>{true, false}) { |
525 | torch::Tensor padded, lengths; |
526 | std::tie(padded, lengths) = PackedSequenceTest_padded_sequence(torch::kInt); |
527 | rnn_utils::PackedSequence a = rnn_utils::pack_padded_sequence( |
528 | padded, |
529 | lengths, |
530 | /*batch_first=*/false, |
531 | /*enforce_sorted=*/enforce_sorted) |
532 | .cpu(); |
533 | |
534 | assert_is_same_packed_sequence(a, a.to(torch::kCPU)); |
535 | assert_is_same_packed_sequence(a, a.cpu()); |
536 | assert_is_same_packed_sequence( |
537 | a, a.to(torch::device(torch::kCPU).dtype(torch::kInt32))); |
538 | |
539 | if (torch::cuda::is_available()) { |
540 | auto b = a.cuda(); |
541 | assert_is_same_packed_sequence(b, b.to(torch::kCUDA)); |
542 | assert_is_same_packed_sequence(b, b.cuda()); |
543 | assert_is_equal_packed_sequence(a, b.to(torch::kCPU)); |
544 | assert_is_equal_packed_sequence(b, a.to(torch::kCUDA)); |
545 | assert_is_equal_packed_sequence( |
546 | a, b.to(torch::device(torch::kCPU).dtype(torch::kInt32))); |
547 | assert_is_same_packed_sequence(b, b.to(torch::kInt32)); |
548 | } |
549 | } |
550 | } |
551 | |
552 | TEST_F(NNUtilsTest, PackSequence) { |
553 | auto _compatibility_test = [&](torch::ArrayRef<torch::Tensor> sequences, |
554 | torch::Tensor lengths, |
555 | bool batch_first, |
556 | bool enforce_sorted = false) { |
557 | torch::Tensor padded = rnn_utils::pad_sequence(sequences, batch_first); |
558 | rnn_utils::PackedSequence packed = |
559 | rnn_utils::pack_sequence(sequences, enforce_sorted); |
560 | std::tuple<torch::Tensor, torch::Tensor> unpacked = |
561 | rnn_utils::pad_packed_sequence(packed, batch_first); |
562 | ASSERT_TRUE(torch::allclose(padded, std::get<0>(unpacked))); |
563 | rnn_utils::PackedSequence pack_padded = rnn_utils::pack_padded_sequence( |
564 | padded, lengths, batch_first, enforce_sorted); |
565 | assert_is_equal_packed_sequence(packed, pack_padded); |
566 | }; |
567 | |
568 | // single dimensional |
569 | auto a = torch::tensor({1, 2, 3}); |
570 | auto b = torch::tensor({4, 5}); |
571 | auto c = torch::tensor({6}); |
572 | rnn_utils::PackedSequence packed = |
573 | rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/false); |
574 | auto expected = torch::tensor({1, 4, 6, 2, 5, 3}); |
575 | ASSERT_TRUE(torch::allclose(packed.batch_sizes(), torch::tensor({3, 2, 1}))); |
576 | ASSERT_TRUE(torch::allclose(packed.data(), expected)); |
577 | ASSERT_TRUE( |
578 | torch::allclose(packed.sorted_indices(), torch::tensor({0, 1, 2}))); |
579 | ASSERT_TRUE( |
580 | torch::allclose(packed.unsorted_indices(), torch::tensor({0, 1, 2}))); |
581 | |
582 | rnn_utils::PackedSequence packed_unsorted = |
583 | rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/false); |
584 | ASSERT_TRUE( |
585 | torch::allclose(packed_unsorted.batch_sizes(), torch::tensor({3, 2, 1}))); |
586 | ASSERT_TRUE(torch::allclose(packed_unsorted.data(), expected)); |
587 | ASSERT_TRUE(torch::allclose( |
588 | packed_unsorted.sorted_indices(), torch::tensor({2, 0, 1}))); |
589 | ASSERT_TRUE(torch::allclose( |
590 | packed_unsorted.unsorted_indices(), torch::tensor({1, 2, 0}))); |
591 | |
592 | // single dimensional, enforce_sorted = True |
593 | rnn_utils::PackedSequence packed_enforce_sorted = |
594 | rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/true); |
595 | ASSERT_TRUE(torch::allclose( |
596 | packed_enforce_sorted.batch_sizes(), torch::tensor({3, 2, 1}))); |
597 | ASSERT_TRUE(torch::allclose(packed_enforce_sorted.data(), expected)); |
598 | ASSERT_FALSE(packed_enforce_sorted.sorted_indices().defined()); |
599 | ASSERT_FALSE(packed_enforce_sorted.unsorted_indices().defined()); |
600 | |
601 | ASSERT_THROWS_WITH( |
602 | rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true), |
603 | "must be sorted in decreasing order" ); |
604 | |
605 | ASSERT_THROWS_WITH( |
606 | rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true), |
607 | "You can pass `enforce_sorted=False`" ); |
608 | |
609 | // more dimensions |
610 | int64_t maxlen = 9; |
611 | for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) { |
612 | std::vector<torch::Tensor> sequences; |
613 | std::vector<int64_t> lengths_vec; |
614 | std::vector<int64_t> trailing_dims(num_dim, 4); |
615 | for (int64_t i = maxlen; i > 0; i--) { |
616 | int64_t seq_len = i * i; |
617 | lengths_vec.emplace_back(seq_len); |
618 | std::vector<int64_t> tensor_sizes{seq_len, 5}; |
619 | tensor_sizes.insert( |
620 | tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end()); |
621 | sequences.emplace_back(torch::rand(tensor_sizes)); |
622 | } |
623 | std::vector<torch::Tensor> unsorted_sequences; |
624 | for (const auto& s : sequences) { |
625 | // NOLINTNEXTLINE(performance-inefficient-vector-operation) |
626 | unsorted_sequences.emplace_back(s.clone()); |
627 | } |
628 | std::shuffle( |
629 | std::begin(unsorted_sequences), |
630 | std::end(unsorted_sequences), |
631 | std::default_random_engine{}); |
632 | |
633 | std::vector<int64_t> unsorted_sequences_lengths_vec; |
634 | for (const auto& t : unsorted_sequences) { |
635 | // NOLINTNEXTLINE(performance-inefficient-vector-operation) |
636 | unsorted_sequences_lengths_vec.emplace_back(t.size(0)); |
637 | } |
638 | |
639 | // compatibility with other utilities |
640 | for (bool batch_first : std::vector<bool>{true, false}) { |
641 | for (bool enforce_sorted : std::vector<bool>{true, false}) { |
642 | _compatibility_test( |
643 | sequences, torch::tensor(lengths_vec), batch_first, enforce_sorted); |
644 | } |
645 | _compatibility_test( |
646 | unsorted_sequences, |
647 | torch::tensor(unsorted_sequences_lengths_vec), |
648 | batch_first); |
649 | } |
650 | } |
651 | } |
652 | |
653 | TEST_F(NNUtilsTest, PackPaddedSequence) { |
654 | auto generate_test_case = [&](torch::ArrayRef<int64_t> sorted_lengths, |
655 | bool should_shuffle) { |
656 | auto pad = [&](torch::Tensor tensor, int64_t length) { |
657 | std::vector<int64_t> tensor_sizes{length - tensor.size(0)}; |
658 | tensor_sizes.insert( |
659 | tensor_sizes.end(), |
660 | tensor.sizes().slice(1).begin(), |
661 | tensor.sizes().slice(1).end()); |
662 | return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())}); |
663 | }; |
664 | int64_t max_length = sorted_lengths[0]; |
665 | torch::Tensor batch_sizes = torch::empty({max_length}, torch::kInt64); |
666 | for (int64_t i = 1; i < max_length + 1; i++) { |
667 | int64_t total = 0; |
668 | for (const auto& x : sorted_lengths) { |
669 | if (x >= i) { |
670 | total++; |
671 | } |
672 | } |
673 | batch_sizes[i - 1] = total; |
674 | } |
675 | std::vector<torch::Tensor> tensors_to_be_cat; |
676 | for (int64_t i = 1; i < static_cast<int64_t>(sorted_lengths.size() + 1); |
677 | i++) { |
678 | int64_t l = sorted_lengths.at(i - 1); |
679 | tensors_to_be_cat.emplace_back(pad( |
680 | i * 100 + torch::arange(1., 5 * l + 1).view({l, 1, 5}), max_length)); |
681 | } |
682 | auto padded = torch::cat(tensors_to_be_cat, 1); |
683 | std::vector<torch::Tensor> expected_data_vec; |
684 | for (const auto n : c10::irange(batch_sizes.size(0))) { |
685 | int64_t batch_size = batch_sizes[n].item<int64_t>(); |
686 | for (const auto i : c10::irange(batch_size)) { |
687 | expected_data_vec.emplace_back( |
688 | torch::arange(1., 6) + (i + 1) * 100 + 5 * n); |
689 | } |
690 | } |
691 | auto expected_data = torch::stack(expected_data_vec, /*dim=*/0); |
692 | |
693 | torch::Tensor unsorted_indices, lengths; |
694 | if (should_shuffle) { |
695 | // Shuffle the padded sequence to create an unsorted sequence |
696 | std::vector<int64_t> permutation; |
697 | for (const auto i : c10::irange(sorted_lengths.size())) { |
698 | permutation.emplace_back(i); |
699 | } |
700 | std::shuffle( |
701 | std::begin(permutation), |
702 | std::end(permutation), |
703 | std::default_random_engine{}); |
704 | |
705 | unsorted_indices = torch::tensor(permutation); |
706 | padded = padded.index_select(1, unsorted_indices); |
707 | lengths = torch::tensor(sorted_lengths).index_select(0, unsorted_indices); |
708 | } else { |
709 | unsorted_indices = torch::Tensor(); |
710 | lengths = torch::tensor(sorted_lengths); |
711 | } |
712 | |
713 | return std::make_tuple( |
714 | padded.requires_grad_(), |
715 | lengths, |
716 | expected_data, |
717 | batch_sizes, |
718 | unsorted_indices); |
719 | }; |
720 | |
721 | std::vector<std::pair<std::vector<int64_t>, bool>> test_cases = { |
722 | // sorted_lengths, should_shuffle |
723 | {{10, 8, 4, 2, 2, 2, 1}, false}, |
724 | {{11, 10, 8, 6, 4, 3, 1}, false}, |
725 | {{11, 10, 8, 6, 4, 3, 1}, true}}; |
726 | |
727 | for (const auto& test_case : test_cases) { |
728 | for (bool batch_first : std::vector<bool>{true, false}) { |
729 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
730 | std::vector<int64_t> sorted_lengths = std::get<0>(test_case); |
731 | bool should_shuffle = std::get<1>(test_case); |
732 | |
733 | torch::Tensor padded, lengths, expected_data, batch_sizes, |
734 | unsorted_indices; |
735 | std::tie(padded, lengths, expected_data, batch_sizes, unsorted_indices) = |
736 | generate_test_case(sorted_lengths, should_shuffle); |
737 | |
738 | auto src = padded; |
739 | if (batch_first) { |
740 | src = src.transpose(0, 1); |
741 | } |
742 | |
743 | // check output |
744 | rnn_utils::PackedSequence packed = rnn_utils::pack_padded_sequence( |
745 | src, |
746 | lengths, |
747 | /*batch_first=*/batch_first, |
748 | /*enforce_sorted=*/!should_shuffle); |
749 | ASSERT_TRUE(torch::allclose(packed.data(), expected_data)); |
750 | ASSERT_TRUE(torch::allclose(packed.batch_sizes(), batch_sizes)); |
751 | ASSERT_TRUE( |
752 | (!packed.unsorted_indices().defined() && |
753 | !unsorted_indices.defined()) || |
754 | torch::allclose(packed.unsorted_indices(), unsorted_indices)); |
755 | |
756 | // test inverse |
757 | torch::Tensor unpacked, unpacked_len; |
758 | std::tie(unpacked, unpacked_len) = |
759 | rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first); |
760 | ASSERT_TRUE(torch::allclose(unpacked, src)); |
761 | ASSERT_TRUE(torch::allclose(unpacked_len, lengths)); |
762 | |
763 | // check grad |
764 | if (padded.grad().defined()) { |
765 | torch::NoGradGuard no_grad; |
766 | padded.grad().zero_(); |
767 | } |
768 | torch::Tensor grad_output; |
769 | { |
770 | torch::NoGradGuard no_grad; |
771 | grad_output = unpacked.clone().normal_(); |
772 | } |
773 | unpacked.backward(grad_output); |
774 | if (batch_first) { |
775 | grad_output.transpose_(0, 1); |
776 | } |
777 | for (const auto i : c10::irange(lengths.size(0))) { |
778 | int64_t l = lengths[i].item<int64_t>(); |
779 | ASSERT_TRUE(torch::allclose( |
780 | padded.grad().narrow(0, 0, l).select(1, i), |
781 | grad_output.narrow(0, 0, l).select(1, i))); |
782 | if (l < 10) { |
783 | ASSERT_EQ( |
784 | padded.grad() |
785 | .narrow(0, l, padded.grad().size(0) - l) |
786 | .select(1, i) |
787 | .abs() |
788 | .sum() |
789 | .item<double>(), |
790 | 0); |
791 | } |
792 | } |
793 | } |
794 | } |
795 | |
796 | // test error messages |
797 | ASSERT_THROWS_WITH( |
798 | rnn_utils::pack_padded_sequence( |
799 | torch::randn({3, 3}), torch::tensor({1, 3, 2})), |
800 | "You can pass `enforce_sorted=False`" ); |
801 | ASSERT_THROWS_WITH( |
802 | rnn_utils::pack_padded_sequence(torch::randn({0, 0}), torch::tensor({})), |
803 | "empty tensor" ); |
804 | } |
805 | |
806 | TEST_F(NNUtilsTest, PadSequence) { |
807 | auto pad = [&](const torch::Tensor& tensor, int64_t length) { |
808 | torch::NoGradGuard no_grad; |
809 | std::vector<int64_t> tensor_sizes{length - tensor.size(0)}; |
810 | tensor_sizes.insert( |
811 | tensor_sizes.end(), |
812 | tensor.sizes().slice(1).begin(), |
813 | tensor.sizes().slice(1).end()); |
814 | return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())}); |
815 | }; |
816 | |
817 | // single dimensional |
818 | auto a = torch::tensor({1, 2, 3}); |
819 | auto b = torch::tensor({4, 5}); |
820 | auto c = torch::tensor({6}); |
821 | |
822 | torch::Tensor expected, padded; |
823 | |
824 | // batch_first = true |
825 | expected = torch::tensor({{4, 5, 0}, {1, 2, 3}, {6, 0, 0}}); |
826 | padded = rnn_utils::pad_sequence({b, a, c}, true); |
827 | ASSERT_TRUE(padded.allclose(expected)); |
828 | |
829 | // batch_first = false |
830 | padded = rnn_utils::pad_sequence({b, a, c}); |
831 | ASSERT_TRUE(padded.allclose(expected.transpose(0, 1))); |
832 | |
833 | // pad with non-zero value |
834 | expected = torch::tensor({{4, 5, 1}, {1, 2, 3}, {6, 1, 1}}); |
835 | padded = rnn_utils::pad_sequence({b, a, c}, true, 1); |
836 | ASSERT_TRUE(padded.allclose(expected)); |
837 | |
838 | // Test pad sorted sequence |
839 | expected = torch::tensor({{1, 2, 3}, {4, 5, 0}, {6, 0, 0}}); |
840 | padded = rnn_utils::pad_sequence({a, b, c}, true); |
841 | ASSERT_TRUE(padded.allclose(expected)); |
842 | |
843 | // more dimensions |
844 | int64_t maxlen = 9; |
845 | for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) { |
846 | std::vector<torch::Tensor> sequences; |
847 | std::vector<int64_t> trailing_dims(num_dim, 4); |
848 | for (int64_t i = 1; i < maxlen + 1; i++) { |
849 | int64_t seq_len = i * i; |
850 | std::vector<int64_t> tensor_sizes{seq_len, 5}; |
851 | tensor_sizes.insert( |
852 | tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end()); |
853 | sequences.emplace_back(torch::rand(tensor_sizes)); |
854 | } |
855 | std::shuffle( |
856 | std::begin(sequences), |
857 | std::end(sequences), |
858 | std::default_random_engine{}); |
859 | std::vector<torch::Tensor> expected_tensors; |
860 | for (const torch::Tensor& seq : sequences) { |
861 | // NOLINTNEXTLINE(performance-inefficient-vector-operation) |
862 | expected_tensors.emplace_back(pad(seq, maxlen * maxlen)); |
863 | } |
864 | |
865 | // batch first = true |
866 | auto expected = torch::stack(expected_tensors); |
867 | auto padded = rnn_utils::pad_sequence(sequences, true); |
868 | ASSERT_TRUE(padded.allclose(expected)); |
869 | |
870 | // batch first = false |
871 | padded = rnn_utils::pad_sequence(sequences); |
872 | ASSERT_TRUE(padded.allclose(expected.transpose(0, 1))); |
873 | } |
874 | } |
875 | |