1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/support.h>
7
8#include <algorithm>
9#include <random>
10#include <sstream>
11#include <string>
12
13using namespace torch::nn;
14
15namespace rnn_utils = torch::nn::utils::rnn;
16
17struct NNUtilsTest : torch::test::SeedingFixture {};
18struct PackedSequenceTest : torch::test::SeedingFixture {};
19
20TEST_F(NNUtilsTest, ClipGradNorm) {
21 auto l = Linear(10, 10);
22 float max_norm = 2;
23 auto compute_norm = [&](float norm_type) -> float {
24 float total_norm = 0.0;
25 if (norm_type != std::numeric_limits<float>::infinity()) {
26 for (const auto& p : l->parameters()) {
27 total_norm +=
28 p.grad().data().abs().pow(norm_type).sum().item().toFloat();
29 }
30 return std::pow(total_norm, 1.0 / norm_type);
31 } else {
32 for (const auto& p : l->parameters()) {
33 auto param_max = p.grad().data().abs().max().item().toFloat();
34 if (param_max > total_norm) {
35 total_norm = param_max;
36 }
37 }
38 return total_norm;
39 }
40 };
41 auto compare_scaling =
42 [&](const std::vector<torch::Tensor>& grads) -> torch::Tensor {
43 std::vector<torch::Tensor> p_scale;
44 for (const auto i : c10::irange(grads.size())) {
45 auto param = l->parameters()[i];
46 auto grad = grads[i];
47 p_scale.push_back(param.grad().data().div(grad).view(-1));
48 }
49 auto scale = torch::cat(p_scale);
50 return scale; // need to assert std is 0.
51 };
52
53 std::vector<torch::Tensor> grads = {
54 torch::arange(1.0, 101).view({10, 10}),
55 torch::ones({10}).div(1000),
56 };
57 std::vector<float> norm_types = {
58 0.5,
59 1.5,
60 2.0,
61 4.0,
62 std::numeric_limits<float>::infinity(),
63 };
64 for (auto norm_type : norm_types) {
65 for (const auto i : c10::irange(grads.size())) {
66 l->parameters()[i].mutable_grad() =
67 grads[i].clone().view_as(l->parameters()[i].data());
68 }
69 auto norm_before = compute_norm(norm_type);
70 auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type);
71 auto norm_after = compute_norm(norm_type);
72 ASSERT_FLOAT_EQ(norm, norm_before);
73 ASSERT_NEAR(norm_after, max_norm, 1e-6);
74 ASSERT_LE(norm_after, max_norm);
75 auto scaled = compare_scaling(grads);
76 ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
77 }
78 // Small gradients should be left unchanged
79 grads = {
80 torch::rand({10, 10}).div(10000),
81 torch::ones(10).div(500),
82 };
83 for (auto norm_type : norm_types) {
84 for (const auto i : c10::irange(grads.size())) {
85 l->parameters()[i].grad().data().copy_(grads[i]);
86 }
87 auto norm_before = compute_norm(norm_type);
88 auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type);
89 auto norm_after = compute_norm(norm_type);
90 ASSERT_FLOAT_EQ(norm, norm_before);
91 ASSERT_FLOAT_EQ(norm_before, norm_after);
92 ASSERT_LE(norm_after, max_norm);
93 auto scaled = compare_scaling(grads);
94 ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
95 ASSERT_FLOAT_EQ(scaled[0].item().toFloat(), 1);
96 }
97 // should accept a single tensor as input
98 auto p1 = torch::randn({10, 10});
99 auto p2 = torch::randn({10, 10});
100 auto g = torch::arange(1., 101).view({10, 10});
101 p1.mutable_grad() = g.clone();
102 p2.mutable_grad() = g.clone();
103 for (const auto norm_type : norm_types) {
104 utils::clip_grad_norm_(p1, max_norm, norm_type);
105 utils::clip_grad_norm_({p2}, max_norm, norm_type);
106 ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
107 }
108}
109
110// Check that clip_grad_norm_ raises an error if the norm of a gradient
111// is non-finite
112TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) {
113 double inf = std::numeric_limits<double>::infinity();
114 double nan = std::numeric_limits<double>::quiet_NaN();
115
116 using Vector = std::vector<double>;
117
118 Vector norms_pos = {0.1, 1, 2, 3.5, inf};
119 Vector norms_neg = {-0.1, -1, -2, -3.5};
120 Vector norms_neg_plus_0 = {0, -0.1, -1, -2, -3.5};
121 Vector norms_except_0 = {0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5};
122 Vector norms_all = {0, 0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5};
123
124 // Each entry in test_cases has the following values, in this order:
125 //
126 // grad_only_one_elem If True, only one element of the parameter's
127 // gradient is set to the scalar grad, and the
128 // rest of the elements are 0. If False, all grad
129 // elements are equal to the scalar.
130 //
131 // prefix_finite_grad_param If True, prefix a parameter that has a grad
132 // of 1.
133 //
134 // scalars Scalars to use as the parameter's grad, through
135 // multiplication
136 //
137 // norms_nonfinite Norm types that should produce nonfinite total norm
138 //
139 // norms_finite Norm types that should produce finite total norm
140 std::vector<std::tuple<bool, bool, Vector, Vector, Vector>> test_cases({
141 // Test errors from an infinite grad
142 std::make_tuple(
143 false, false, Vector({inf, -inf}), norms_except_0, Vector({0})),
144 std::make_tuple(
145 false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
146 std::make_tuple(
147 true, false, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
148 std::make_tuple(
149 false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
150
151 // Test errors from a NaN grad
152 std::make_tuple(false, false, Vector({nan}), norms_except_0, Vector({0})),
153 std::make_tuple(false, true, Vector({nan}), norms_except_0, Vector({0})),
154 std::make_tuple(true, false, Vector({nan}), norms_except_0, Vector({0})),
155 std::make_tuple(true, true, Vector({nan}), norms_except_0, Vector({0})),
156
157 // Test a grad that should never error
158 std::make_tuple(false, false, Vector({2e22, -2e22}), Vector(), norms_all),
159 std::make_tuple(false, true, Vector({2e22, -2e22}), Vector(), norms_all),
160 std::make_tuple(true, false, Vector({2e22, -2e22}), Vector(), norms_all),
161 std::make_tuple(true, true, Vector({2e22, -2e22}), Vector(), norms_all),
162
163 // Test a grad that will overflow to inf for only some norm orders
164 std::make_tuple(
165 false,
166 false,
167 Vector({2e200, -2e200}),
168 Vector({3.5, 2, -2, -3.5}),
169 Vector({inf, 1, 0.1, 0, -1, -0.1})),
170 std::make_tuple(
171 false,
172 true,
173 Vector({2e200, -2e200}),
174 Vector({3.5, 2}),
175 Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
176 std::make_tuple(
177 true,
178 false,
179 Vector({2e200, -2e200}),
180 Vector({3.5, 2}),
181 Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
182 std::make_tuple(
183 false,
184 true,
185 Vector({2e200, -2e200}),
186 Vector({3.5, 2}),
187 Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
188 });
189
190 auto gen_parameters = [](double scalar,
191 bool grad_only_one_elem,
192 bool prefix_finite_grad_param,
193 torch::DeviceType device_type) {
194 auto param = torch::ones(
195 10,
196 torch::TensorOptions()
197 .dtype(torch::kDouble)
198 .device(device_type)
199 .requires_grad(true));
200 if (grad_only_one_elem) {
201 param[1].mul(scalar).sum().backward();
202 } else {
203 param.mul(scalar).sum().backward();
204 }
205
206 std::vector<torch::Tensor> parameters;
207 if (prefix_finite_grad_param) {
208 auto prefix_param = torch::ones(
209 1,
210 torch::TensorOptions()
211 .dtype(torch::kDouble)
212 .device(device_type)
213 .requires_grad(true));
214 prefix_param.mul(1).sum().backward();
215 parameters.push_back(prefix_param);
216 }
217 parameters.push_back(param);
218
219 return parameters;
220 };
221
222 auto run_test_case = [&gen_parameters](
223 double norm_type,
224 bool error_if_nonfinite,
225 double scalar,
226 bool grad_only_one_elem,
227 bool prefix_finite_grad_param,
228 bool is_norm_nonfinite,
229 torch::DeviceType device_type) {
230 std::stringstream ss;
231 ss << "device: " << device_type << ", norm_type: " << norm_type
232 << ", error_if_nonfinite: " << error_if_nonfinite
233 << ", scalar: " << scalar
234 << ", grad_only_one_elem: " << grad_only_one_elem
235 << ", prefix_finite_grad_param: " << prefix_finite_grad_param
236 << ", is_norm_nonfinite: " << is_norm_nonfinite;
237 std::string msg = ss.str();
238
239 auto parameters = gen_parameters(
240 scalar, grad_only_one_elem, prefix_finite_grad_param, device_type);
241
242 if (is_norm_nonfinite && error_if_nonfinite) {
243 std::vector<torch::Tensor> grads_before;
244 // NOLINTNEXTLINE(performance-for-range-copy)
245 for (auto p : parameters) {
246 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
247 grads_before.push_back(p.grad().clone());
248 }
249 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
250 EXPECT_THROW(
251 utils::clip_grad_norm_(parameters, 1., norm_type, true),
252 std::exception)
253 << msg;
254 // Grads should not change if error is thrown
255 for (const auto p_idx : c10::irange(parameters.size())) {
256 ASSERT_TRUE(torch::allclose(
257 parameters[p_idx].grad(),
258 grads_before[p_idx],
259 1.0,
260 0.0,
261 /*equal_nan*/ true))
262 << msg;
263 }
264 } else {
265 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
266 EXPECT_NO_THROW(
267 utils::clip_grad_norm_(parameters, 1., norm_type, error_if_nonfinite))
268 << msg;
269 }
270 };
271
272 for (auto device_type : {torch::kCPU, torch::kCUDA}) {
273 if (device_type == torch::kCUDA && !torch::cuda::is_available()) {
274 continue;
275 }
276 for (auto test_case : test_cases) {
277 auto grad_only_one_elem = std::get<0>(test_case);
278 auto prefix_finite_grad_param = std::get<1>(test_case);
279 auto scalars = std::get<2>(test_case);
280 auto norms_nonfinite = std::get<3>(test_case);
281 auto norms_finite = std::get<4>(test_case);
282
283 for (auto error_if_nonfinite : {false, true}) {
284 for (auto scalar : scalars) {
285 for (auto norm_type : norms_nonfinite) {
286 run_test_case(
287 norm_type,
288 error_if_nonfinite,
289 scalar,
290 grad_only_one_elem,
291 prefix_finite_grad_param,
292 true,
293 device_type);
294 }
295
296 for (auto norm_type : norms_finite) {
297 run_test_case(
298 norm_type,
299 error_if_nonfinite,
300 scalar,
301 grad_only_one_elem,
302 prefix_finite_grad_param,
303 false,
304 device_type);
305 }
306 }
307 }
308 }
309 }
310}
311
312TEST_F(NNUtilsTest, ClipGradValue) {
313 auto l = Linear(10, 10);
314 float clip_value = 2.5;
315
316 torch::Tensor grad_w = torch::arange(-50., 50).view({10, 10}).div_(5);
317 torch::Tensor grad_b = torch::ones({10}).mul_(2);
318 std::vector<std::vector<torch::Tensor>> grad_lists = {
319 {grad_w, grad_b}, {grad_w, torch::Tensor()}};
320 for (auto grad_list : grad_lists) {
321 for (const auto i : c10::irange(grad_list.size())) {
322 auto p = l->parameters()[i];
323 auto g = grad_list[i];
324 p.mutable_grad() = g.defined() ? g.clone().view_as(p.data()) : g;
325 }
326
327 utils::clip_grad_value_(l->parameters(), clip_value);
328 for (const auto& p : l->parameters()) {
329 if (p.grad().defined()) {
330 ASSERT_LE(p.grad().data().max().item().toFloat(), clip_value);
331 ASSERT_GE(p.grad().data().min().item().toFloat(), -clip_value);
332 }
333 }
334 }
335
336 // Should accept a single Tensor as input
337 auto p1 = torch::randn({10, 10});
338 auto p2 = torch::randn({10, 10});
339 auto g = torch::arange(-50., 50).view({10, 10}).div_(5);
340 p1.mutable_grad() = g.clone();
341 p2.mutable_grad() = g.clone();
342 utils::clip_grad_value_(p1, clip_value);
343 utils::clip_grad_value_({p2}, clip_value);
344 ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
345}
346
347TEST_F(NNUtilsTest, ConvertParameters) {
348 std::vector<torch::Tensor> parameters{
349 torch::arange(9, torch::kFloat32),
350 torch::arange(9, torch::kFloat32).view({3, 3}),
351 torch::arange(8, torch::kFloat32).view({2, 2, 2})};
352
353 auto expected = torch::cat(
354 {torch::arange(9, torch::kFloat32),
355 torch::arange(9, torch::kFloat32).view(-1),
356 torch::arange(8, torch::kFloat32).view(-1)});
357 auto vector = utils::parameters_to_vector(parameters);
358 ASSERT_TRUE(vector.allclose(expected));
359
360 std::vector<torch::Tensor> zero_parameters{
361 torch::zeros({9}, torch::kFloat32),
362 torch::zeros({9}, torch::kFloat32).view({3, 3}),
363 torch::zeros({8}, torch::kFloat32).view({2, 2, 2})};
364
365 utils::vector_to_parameters(vector, zero_parameters);
366 for (const auto i : c10::irange(zero_parameters.size())) {
367 ASSERT_TRUE(zero_parameters[i].allclose(parameters[i]));
368 }
369
370 {
371 auto conv1 = Conv2d(3, 10, 5);
372 auto fc1 = Linear(10, 20);
373 auto model = Sequential(conv1, fc1);
374
375 auto vec = utils::parameters_to_vector(model->parameters());
376 ASSERT_EQ(vec.size(0), 980);
377 }
378 {
379 auto conv1 = Conv2d(3, 10, 5);
380 auto fc1 = Linear(10, 20);
381 auto model = Sequential(conv1, fc1);
382
383 auto vec = torch::arange(0., 980);
384 utils::vector_to_parameters(vec, model->parameters());
385
386 auto sample = model->parameters()[0][0][0][0];
387 ASSERT_TRUE(torch::equal(sample.data(), vec.data().slice(0, 0, 5)));
388 }
389}
390
391// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables)
392int64_t PackedSequenceTest_batch_size = 5;
393// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables)
394int64_t PackedSequenceTest_max_length = 6;
395
396std::vector<torch::Tensor> PackedSequenceTest_ordered_sequence(
397 torch::ScalarType tensor_type) {
398 std::vector<torch::Tensor> seqs;
399 seqs.reserve(PackedSequenceTest_batch_size);
400 for (const auto i : c10::irange(PackedSequenceTest_batch_size)) {
401 (void)i; // Suppress unused variable warning
402 seqs.emplace_back(torch::empty(
403 {torch::randint(1, PackedSequenceTest_max_length, {1}).item<int64_t>()},
404 tensor_type));
405 }
406 for (auto& s : seqs) {
407 s.random_(-128, 128);
408 }
409 sort(
410 seqs.begin(),
411 seqs.end(),
412 [&](const torch::Tensor& t1, const torch::Tensor& t2) {
413 return t1.size(0) > t2.size(0);
414 });
415 return seqs;
416}
417
418std::tuple<torch::Tensor, torch::Tensor> PackedSequenceTest_padded_sequence(
419 torch::ScalarType tensor_type) {
420 // Create Tensor of random padded sequences
421 auto ordered = PackedSequenceTest_ordered_sequence(tensor_type);
422 auto lengths = torch::empty({(int64_t)ordered.size()}, torch::kInt64);
423 for (const auto i : c10::irange(ordered.size())) {
424 lengths[i] = ordered[i].size(0);
425 }
426 auto padded_tensor = rnn_utils::pad_sequence(ordered);
427 return std::make_tuple(padded_tensor, lengths);
428}
429
430void assert_is_equal_packed_sequence(
431 const rnn_utils::PackedSequence& a,
432 const rnn_utils::PackedSequence& b) {
433 ASSERT_TRUE(torch::allclose(a.data(), b.data()));
434 ASSERT_TRUE(torch::allclose(a.batch_sizes(), b.batch_sizes()));
435 ASSERT_TRUE(
436 (!a.sorted_indices().defined() && !b.sorted_indices().defined()) ||
437 torch::allclose(a.sorted_indices(), b.sorted_indices()));
438 ASSERT_TRUE(
439 (!a.unsorted_indices().defined() && !b.unsorted_indices().defined()) ||
440 torch::allclose(a.unsorted_indices(), b.unsorted_indices()));
441}
442
443void assert_is_same_packed_sequence(
444 const rnn_utils::PackedSequence& a,
445 const rnn_utils::PackedSequence& b) {
446 ASSERT_TRUE(a.data().is_same(b.data()));
447 ASSERT_TRUE(a.batch_sizes().is_same(b.batch_sizes()));
448 ASSERT_TRUE(a.sorted_indices().is_same(b.sorted_indices()));
449 ASSERT_TRUE(a.unsorted_indices().is_same(b.unsorted_indices()));
450}
451
452TEST_F(PackedSequenceTest, WrongOrder) {
453 auto a = torch::ones({25, 300});
454 auto b = torch::ones({22, 300});
455 auto b_a = rnn_utils::pad_sequence({b, a});
456 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
457 ASSERT_THROW(
458 rnn_utils::pack_padded_sequence(
459 b_a,
460 torch::tensor({22, 25}),
461 /*batch_first=*/false,
462 /*enforce_sorted=*/true),
463 c10::Error);
464}
465
466TEST_F(PackedSequenceTest, TotalLength) {
467 torch::Tensor padded, lengths;
468 std::tie(padded, lengths) = PackedSequenceTest_padded_sequence(torch::kFloat);
469 int64_t max_length = torch::max(lengths).item<int64_t>();
470 rnn_utils::PackedSequence packed =
471 rnn_utils::pack_padded_sequence(padded, lengths);
472
473 // test ValueError if total_length < max_length
474 for (int64_t total_length : std::vector<int64_t>{-1, 0, max_length - 1}) {
475 for (bool batch_first : std::vector<bool>{true, false}) {
476 auto err_fn = [&]() {
477 rnn_utils::pad_packed_sequence(
478 packed,
479 /*batch_first=*/batch_first,
480 /*padding_value=*/0.0,
481 /*total_length=*/total_length);
482 };
483 ASSERT_THROWS_WITH(
484 err_fn(),
485 "Expected total_length to be at least the length of the longest sequence in input");
486 }
487 }
488
489 // test that pad_packed_sequence returns results of correct length
490 for (bool batch_first : std::vector<bool>{true, false}) {
491 torch::Tensor no_extra_pad, ignored;
492 std::tie(no_extra_pad, ignored) =
493 rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first);
494 for (int64_t total_length_delta : std::vector<int64_t>{0, 1, 8}) {
495 int64_t total_length = max_length + total_length_delta;
496 torch::Tensor unpacked, lengths_out;
497 std::tie(unpacked, lengths_out) = rnn_utils::pad_packed_sequence(
498 packed,
499 /*batch_first=*/batch_first,
500 /*padding_value=*/0.0,
501 /*total_length=*/total_length);
502 ASSERT_TRUE(torch::allclose(lengths, lengths_out));
503 ASSERT_EQ(unpacked.size(batch_first ? 1 : 0), total_length);
504 torch::Tensor ref_output, extra_pad;
505 if (total_length_delta == 0) {
506 ref_output = no_extra_pad;
507 } else if (batch_first) {
508 extra_pad = torch::zeros(
509 {PackedSequenceTest_batch_size, total_length_delta},
510 no_extra_pad.options());
511 ref_output = torch::cat({no_extra_pad, extra_pad}, 1);
512 } else {
513 extra_pad = torch::zeros(
514 {total_length_delta, PackedSequenceTest_batch_size},
515 no_extra_pad.options());
516 ref_output = torch::cat({no_extra_pad, extra_pad}, 0);
517 }
518 ASSERT_TRUE(torch::allclose(unpacked, ref_output));
519 }
520 }
521}
522
523TEST_F(PackedSequenceTest, To) {
524 for (bool enforce_sorted : std::vector<bool>{true, false}) {
525 torch::Tensor padded, lengths;
526 std::tie(padded, lengths) = PackedSequenceTest_padded_sequence(torch::kInt);
527 rnn_utils::PackedSequence a = rnn_utils::pack_padded_sequence(
528 padded,
529 lengths,
530 /*batch_first=*/false,
531 /*enforce_sorted=*/enforce_sorted)
532 .cpu();
533
534 assert_is_same_packed_sequence(a, a.to(torch::kCPU));
535 assert_is_same_packed_sequence(a, a.cpu());
536 assert_is_same_packed_sequence(
537 a, a.to(torch::device(torch::kCPU).dtype(torch::kInt32)));
538
539 if (torch::cuda::is_available()) {
540 auto b = a.cuda();
541 assert_is_same_packed_sequence(b, b.to(torch::kCUDA));
542 assert_is_same_packed_sequence(b, b.cuda());
543 assert_is_equal_packed_sequence(a, b.to(torch::kCPU));
544 assert_is_equal_packed_sequence(b, a.to(torch::kCUDA));
545 assert_is_equal_packed_sequence(
546 a, b.to(torch::device(torch::kCPU).dtype(torch::kInt32)));
547 assert_is_same_packed_sequence(b, b.to(torch::kInt32));
548 }
549 }
550}
551
552TEST_F(NNUtilsTest, PackSequence) {
553 auto _compatibility_test = [&](torch::ArrayRef<torch::Tensor> sequences,
554 torch::Tensor lengths,
555 bool batch_first,
556 bool enforce_sorted = false) {
557 torch::Tensor padded = rnn_utils::pad_sequence(sequences, batch_first);
558 rnn_utils::PackedSequence packed =
559 rnn_utils::pack_sequence(sequences, enforce_sorted);
560 std::tuple<torch::Tensor, torch::Tensor> unpacked =
561 rnn_utils::pad_packed_sequence(packed, batch_first);
562 ASSERT_TRUE(torch::allclose(padded, std::get<0>(unpacked)));
563 rnn_utils::PackedSequence pack_padded = rnn_utils::pack_padded_sequence(
564 padded, lengths, batch_first, enforce_sorted);
565 assert_is_equal_packed_sequence(packed, pack_padded);
566 };
567
568 // single dimensional
569 auto a = torch::tensor({1, 2, 3});
570 auto b = torch::tensor({4, 5});
571 auto c = torch::tensor({6});
572 rnn_utils::PackedSequence packed =
573 rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/false);
574 auto expected = torch::tensor({1, 4, 6, 2, 5, 3});
575 ASSERT_TRUE(torch::allclose(packed.batch_sizes(), torch::tensor({3, 2, 1})));
576 ASSERT_TRUE(torch::allclose(packed.data(), expected));
577 ASSERT_TRUE(
578 torch::allclose(packed.sorted_indices(), torch::tensor({0, 1, 2})));
579 ASSERT_TRUE(
580 torch::allclose(packed.unsorted_indices(), torch::tensor({0, 1, 2})));
581
582 rnn_utils::PackedSequence packed_unsorted =
583 rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/false);
584 ASSERT_TRUE(
585 torch::allclose(packed_unsorted.batch_sizes(), torch::tensor({3, 2, 1})));
586 ASSERT_TRUE(torch::allclose(packed_unsorted.data(), expected));
587 ASSERT_TRUE(torch::allclose(
588 packed_unsorted.sorted_indices(), torch::tensor({2, 0, 1})));
589 ASSERT_TRUE(torch::allclose(
590 packed_unsorted.unsorted_indices(), torch::tensor({1, 2, 0})));
591
592 // single dimensional, enforce_sorted = True
593 rnn_utils::PackedSequence packed_enforce_sorted =
594 rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/true);
595 ASSERT_TRUE(torch::allclose(
596 packed_enforce_sorted.batch_sizes(), torch::tensor({3, 2, 1})));
597 ASSERT_TRUE(torch::allclose(packed_enforce_sorted.data(), expected));
598 ASSERT_FALSE(packed_enforce_sorted.sorted_indices().defined());
599 ASSERT_FALSE(packed_enforce_sorted.unsorted_indices().defined());
600
601 ASSERT_THROWS_WITH(
602 rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true),
603 "must be sorted in decreasing order");
604
605 ASSERT_THROWS_WITH(
606 rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true),
607 "You can pass `enforce_sorted=False`");
608
609 // more dimensions
610 int64_t maxlen = 9;
611 for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) {
612 std::vector<torch::Tensor> sequences;
613 std::vector<int64_t> lengths_vec;
614 std::vector<int64_t> trailing_dims(num_dim, 4);
615 for (int64_t i = maxlen; i > 0; i--) {
616 int64_t seq_len = i * i;
617 lengths_vec.emplace_back(seq_len);
618 std::vector<int64_t> tensor_sizes{seq_len, 5};
619 tensor_sizes.insert(
620 tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end());
621 sequences.emplace_back(torch::rand(tensor_sizes));
622 }
623 std::vector<torch::Tensor> unsorted_sequences;
624 for (const auto& s : sequences) {
625 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
626 unsorted_sequences.emplace_back(s.clone());
627 }
628 std::shuffle(
629 std::begin(unsorted_sequences),
630 std::end(unsorted_sequences),
631 std::default_random_engine{});
632
633 std::vector<int64_t> unsorted_sequences_lengths_vec;
634 for (const auto& t : unsorted_sequences) {
635 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
636 unsorted_sequences_lengths_vec.emplace_back(t.size(0));
637 }
638
639 // compatibility with other utilities
640 for (bool batch_first : std::vector<bool>{true, false}) {
641 for (bool enforce_sorted : std::vector<bool>{true, false}) {
642 _compatibility_test(
643 sequences, torch::tensor(lengths_vec), batch_first, enforce_sorted);
644 }
645 _compatibility_test(
646 unsorted_sequences,
647 torch::tensor(unsorted_sequences_lengths_vec),
648 batch_first);
649 }
650 }
651}
652
653TEST_F(NNUtilsTest, PackPaddedSequence) {
654 auto generate_test_case = [&](torch::ArrayRef<int64_t> sorted_lengths,
655 bool should_shuffle) {
656 auto pad = [&](torch::Tensor tensor, int64_t length) {
657 std::vector<int64_t> tensor_sizes{length - tensor.size(0)};
658 tensor_sizes.insert(
659 tensor_sizes.end(),
660 tensor.sizes().slice(1).begin(),
661 tensor.sizes().slice(1).end());
662 return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())});
663 };
664 int64_t max_length = sorted_lengths[0];
665 torch::Tensor batch_sizes = torch::empty({max_length}, torch::kInt64);
666 for (int64_t i = 1; i < max_length + 1; i++) {
667 int64_t total = 0;
668 for (const auto& x : sorted_lengths) {
669 if (x >= i) {
670 total++;
671 }
672 }
673 batch_sizes[i - 1] = total;
674 }
675 std::vector<torch::Tensor> tensors_to_be_cat;
676 for (int64_t i = 1; i < static_cast<int64_t>(sorted_lengths.size() + 1);
677 i++) {
678 int64_t l = sorted_lengths.at(i - 1);
679 tensors_to_be_cat.emplace_back(pad(
680 i * 100 + torch::arange(1., 5 * l + 1).view({l, 1, 5}), max_length));
681 }
682 auto padded = torch::cat(tensors_to_be_cat, 1);
683 std::vector<torch::Tensor> expected_data_vec;
684 for (const auto n : c10::irange(batch_sizes.size(0))) {
685 int64_t batch_size = batch_sizes[n].item<int64_t>();
686 for (const auto i : c10::irange(batch_size)) {
687 expected_data_vec.emplace_back(
688 torch::arange(1., 6) + (i + 1) * 100 + 5 * n);
689 }
690 }
691 auto expected_data = torch::stack(expected_data_vec, /*dim=*/0);
692
693 torch::Tensor unsorted_indices, lengths;
694 if (should_shuffle) {
695 // Shuffle the padded sequence to create an unsorted sequence
696 std::vector<int64_t> permutation;
697 for (const auto i : c10::irange(sorted_lengths.size())) {
698 permutation.emplace_back(i);
699 }
700 std::shuffle(
701 std::begin(permutation),
702 std::end(permutation),
703 std::default_random_engine{});
704
705 unsorted_indices = torch::tensor(permutation);
706 padded = padded.index_select(1, unsorted_indices);
707 lengths = torch::tensor(sorted_lengths).index_select(0, unsorted_indices);
708 } else {
709 unsorted_indices = torch::Tensor();
710 lengths = torch::tensor(sorted_lengths);
711 }
712
713 return std::make_tuple(
714 padded.requires_grad_(),
715 lengths,
716 expected_data,
717 batch_sizes,
718 unsorted_indices);
719 };
720
721 std::vector<std::pair<std::vector<int64_t>, bool>> test_cases = {
722 // sorted_lengths, should_shuffle
723 {{10, 8, 4, 2, 2, 2, 1}, false},
724 {{11, 10, 8, 6, 4, 3, 1}, false},
725 {{11, 10, 8, 6, 4, 3, 1}, true}};
726
727 for (const auto& test_case : test_cases) {
728 for (bool batch_first : std::vector<bool>{true, false}) {
729 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
730 std::vector<int64_t> sorted_lengths = std::get<0>(test_case);
731 bool should_shuffle = std::get<1>(test_case);
732
733 torch::Tensor padded, lengths, expected_data, batch_sizes,
734 unsorted_indices;
735 std::tie(padded, lengths, expected_data, batch_sizes, unsorted_indices) =
736 generate_test_case(sorted_lengths, should_shuffle);
737
738 auto src = padded;
739 if (batch_first) {
740 src = src.transpose(0, 1);
741 }
742
743 // check output
744 rnn_utils::PackedSequence packed = rnn_utils::pack_padded_sequence(
745 src,
746 lengths,
747 /*batch_first=*/batch_first,
748 /*enforce_sorted=*/!should_shuffle);
749 ASSERT_TRUE(torch::allclose(packed.data(), expected_data));
750 ASSERT_TRUE(torch::allclose(packed.batch_sizes(), batch_sizes));
751 ASSERT_TRUE(
752 (!packed.unsorted_indices().defined() &&
753 !unsorted_indices.defined()) ||
754 torch::allclose(packed.unsorted_indices(), unsorted_indices));
755
756 // test inverse
757 torch::Tensor unpacked, unpacked_len;
758 std::tie(unpacked, unpacked_len) =
759 rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first);
760 ASSERT_TRUE(torch::allclose(unpacked, src));
761 ASSERT_TRUE(torch::allclose(unpacked_len, lengths));
762
763 // check grad
764 if (padded.grad().defined()) {
765 torch::NoGradGuard no_grad;
766 padded.grad().zero_();
767 }
768 torch::Tensor grad_output;
769 {
770 torch::NoGradGuard no_grad;
771 grad_output = unpacked.clone().normal_();
772 }
773 unpacked.backward(grad_output);
774 if (batch_first) {
775 grad_output.transpose_(0, 1);
776 }
777 for (const auto i : c10::irange(lengths.size(0))) {
778 int64_t l = lengths[i].item<int64_t>();
779 ASSERT_TRUE(torch::allclose(
780 padded.grad().narrow(0, 0, l).select(1, i),
781 grad_output.narrow(0, 0, l).select(1, i)));
782 if (l < 10) {
783 ASSERT_EQ(
784 padded.grad()
785 .narrow(0, l, padded.grad().size(0) - l)
786 .select(1, i)
787 .abs()
788 .sum()
789 .item<double>(),
790 0);
791 }
792 }
793 }
794 }
795
796 // test error messages
797 ASSERT_THROWS_WITH(
798 rnn_utils::pack_padded_sequence(
799 torch::randn({3, 3}), torch::tensor({1, 3, 2})),
800 "You can pass `enforce_sorted=False`");
801 ASSERT_THROWS_WITH(
802 rnn_utils::pack_padded_sequence(torch::randn({0, 0}), torch::tensor({})),
803 "empty tensor");
804}
805
806TEST_F(NNUtilsTest, PadSequence) {
807 auto pad = [&](const torch::Tensor& tensor, int64_t length) {
808 torch::NoGradGuard no_grad;
809 std::vector<int64_t> tensor_sizes{length - tensor.size(0)};
810 tensor_sizes.insert(
811 tensor_sizes.end(),
812 tensor.sizes().slice(1).begin(),
813 tensor.sizes().slice(1).end());
814 return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())});
815 };
816
817 // single dimensional
818 auto a = torch::tensor({1, 2, 3});
819 auto b = torch::tensor({4, 5});
820 auto c = torch::tensor({6});
821
822 torch::Tensor expected, padded;
823
824 // batch_first = true
825 expected = torch::tensor({{4, 5, 0}, {1, 2, 3}, {6, 0, 0}});
826 padded = rnn_utils::pad_sequence({b, a, c}, true);
827 ASSERT_TRUE(padded.allclose(expected));
828
829 // batch_first = false
830 padded = rnn_utils::pad_sequence({b, a, c});
831 ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
832
833 // pad with non-zero value
834 expected = torch::tensor({{4, 5, 1}, {1, 2, 3}, {6, 1, 1}});
835 padded = rnn_utils::pad_sequence({b, a, c}, true, 1);
836 ASSERT_TRUE(padded.allclose(expected));
837
838 // Test pad sorted sequence
839 expected = torch::tensor({{1, 2, 3}, {4, 5, 0}, {6, 0, 0}});
840 padded = rnn_utils::pad_sequence({a, b, c}, true);
841 ASSERT_TRUE(padded.allclose(expected));
842
843 // more dimensions
844 int64_t maxlen = 9;
845 for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) {
846 std::vector<torch::Tensor> sequences;
847 std::vector<int64_t> trailing_dims(num_dim, 4);
848 for (int64_t i = 1; i < maxlen + 1; i++) {
849 int64_t seq_len = i * i;
850 std::vector<int64_t> tensor_sizes{seq_len, 5};
851 tensor_sizes.insert(
852 tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end());
853 sequences.emplace_back(torch::rand(tensor_sizes));
854 }
855 std::shuffle(
856 std::begin(sequences),
857 std::end(sequences),
858 std::default_random_engine{});
859 std::vector<torch::Tensor> expected_tensors;
860 for (const torch::Tensor& seq : sequences) {
861 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
862 expected_tensors.emplace_back(pad(seq, maxlen * maxlen));
863 }
864
865 // batch first = true
866 auto expected = torch::stack(expected_tensors);
867 auto padded = rnn_utils::pad_sequence(sequences, true);
868 ASSERT_TRUE(padded.allclose(expected));
869
870 // batch first = false
871 padded = rnn_utils::pad_sequence(sequences);
872 ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
873 }
874}
875