1#include <gtest/gtest.h>
2
3#include <c10/util/Exception.h>
4#include <torch/csrc/lazy/core/permutation_util.h>
5
6namespace torch {
7namespace lazy {
8
9TEST(PermutationUtilTest, TestInversePermutation) {
10 EXPECT_EQ(InversePermutation({0}), std::vector<int64_t>({0}));
11 EXPECT_EQ(InversePermutation({0, 1, 2}), std::vector<int64_t>({0, 1, 2}));
12 EXPECT_EQ(
13 InversePermutation({1, 3, 2, 0}), std::vector<int64_t>({3, 0, 2, 1}));
14 // Not a valid permutation
15 EXPECT_THROW(InversePermutation({-1}), c10::Error);
16 EXPECT_THROW(InversePermutation({1, 1}), c10::Error);
17}
18
19TEST(PermutationUtilTest, TestIsPermutation) {
20 EXPECT_TRUE(IsPermutation({0}));
21 EXPECT_TRUE(IsPermutation({0, 1, 2, 3}));
22 EXPECT_FALSE(IsPermutation({-1}));
23 EXPECT_FALSE(IsPermutation({5, 3}));
24 EXPECT_FALSE(IsPermutation({1, 2, 3}));
25}
26
27TEST(PermutationUtilTest, TestPermute) {
28 EXPECT_EQ(
29 PermuteDimensions({0}, std::vector<int64_t>({224})),
30 std::vector<int64_t>({224}));
31 EXPECT_EQ(
32 PermuteDimensions({1, 2, 0}, std::vector<int64_t>({3, 224, 224})),
33 std::vector<int64_t>({224, 224, 3}));
34 // Not a valid permutation
35 EXPECT_THROW(
36 PermuteDimensions({-1}, std::vector<int64_t>({244})), c10::Error);
37 EXPECT_THROW(
38 PermuteDimensions({3, 2}, std::vector<int64_t>({244})), c10::Error);
39 // Permutation size is different from the to-be-permuted vector size
40 EXPECT_THROW(
41 PermuteDimensions({0, 1}, std::vector<int64_t>({244})), c10::Error);
42 EXPECT_THROW(
43 PermuteDimensions({0}, std::vector<int64_t>({3, 244, 244})), c10::Error);
44}
45
46} // namespace lazy
47} // namespace torch
48