1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/lazy/core/permutation_util.h> |
5 | |
6 | namespace torch { |
7 | namespace lazy { |
8 | |
9 | TEST(PermutationUtilTest, TestInversePermutation) { |
10 | EXPECT_EQ(InversePermutation({0}), std::vector<int64_t>({0})); |
11 | EXPECT_EQ(InversePermutation({0, 1, 2}), std::vector<int64_t>({0, 1, 2})); |
12 | EXPECT_EQ( |
13 | InversePermutation({1, 3, 2, 0}), std::vector<int64_t>({3, 0, 2, 1})); |
14 | // Not a valid permutation |
15 | EXPECT_THROW(InversePermutation({-1}), c10::Error); |
16 | EXPECT_THROW(InversePermutation({1, 1}), c10::Error); |
17 | } |
18 | |
19 | TEST(PermutationUtilTest, TestIsPermutation) { |
20 | EXPECT_TRUE(IsPermutation({0})); |
21 | EXPECT_TRUE(IsPermutation({0, 1, 2, 3})); |
22 | EXPECT_FALSE(IsPermutation({-1})); |
23 | EXPECT_FALSE(IsPermutation({5, 3})); |
24 | EXPECT_FALSE(IsPermutation({1, 2, 3})); |
25 | } |
26 | |
27 | TEST(PermutationUtilTest, TestPermute) { |
28 | EXPECT_EQ( |
29 | PermuteDimensions({0}, std::vector<int64_t>({224})), |
30 | std::vector<int64_t>({224})); |
31 | EXPECT_EQ( |
32 | PermuteDimensions({1, 2, 0}, std::vector<int64_t>({3, 224, 224})), |
33 | std::vector<int64_t>({224, 224, 3})); |
34 | // Not a valid permutation |
35 | EXPECT_THROW( |
36 | PermuteDimensions({-1}, std::vector<int64_t>({244})), c10::Error); |
37 | EXPECT_THROW( |
38 | PermuteDimensions({3, 2}, std::vector<int64_t>({244})), c10::Error); |
39 | // Permutation size is different from the to-be-permuted vector size |
40 | EXPECT_THROW( |
41 | PermuteDimensions({0, 1}, std::vector<int64_t>({244})), c10::Error); |
42 | EXPECT_THROW( |
43 | PermuteDimensions({0}, std::vector<int64_t>({3, 244, 244})), c10::Error); |
44 | } |
45 | |
46 | } // namespace lazy |
47 | } // namespace torch |
48 | |