1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/Dimname.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/Optional.h> |
6 | |
7 | using at::NameType; |
8 | using at::Symbol; |
9 | using at::Dimname; |
10 | |
11 | TEST(DimnameTest, isValidIdentifier) { |
12 | ASSERT_TRUE(Dimname::isValidName("a" )); |
13 | ASSERT_TRUE(Dimname::isValidName("batch" )); |
14 | ASSERT_TRUE(Dimname::isValidName("N" )); |
15 | ASSERT_TRUE(Dimname::isValidName("CHANNELS" )); |
16 | ASSERT_TRUE(Dimname::isValidName("foo_bar_baz" )); |
17 | ASSERT_TRUE(Dimname::isValidName("batch1" )); |
18 | ASSERT_TRUE(Dimname::isValidName("batch_9" )); |
19 | ASSERT_TRUE(Dimname::isValidName("_" )); |
20 | ASSERT_TRUE(Dimname::isValidName("_1" )); |
21 | |
22 | ASSERT_FALSE(Dimname::isValidName("" )); |
23 | ASSERT_FALSE(Dimname::isValidName(" " )); |
24 | ASSERT_FALSE(Dimname::isValidName(" a " )); |
25 | ASSERT_FALSE(Dimname::isValidName("1batch" )); |
26 | ASSERT_FALSE(Dimname::isValidName("?" )); |
27 | ASSERT_FALSE(Dimname::isValidName("-" )); |
28 | ASSERT_FALSE(Dimname::isValidName("1" )); |
29 | ASSERT_FALSE(Dimname::isValidName("01" )); |
30 | } |
31 | |
32 | TEST(DimnameTest, wildcardName) { |
33 | Dimname wildcard = Dimname::wildcard(); |
34 | ASSERT_EQ(wildcard.type(), NameType::WILDCARD); |
35 | ASSERT_EQ(wildcard.symbol(), Symbol::dimname("*" )); |
36 | } |
37 | |
38 | TEST(DimnameTest, createNormalName) { |
39 | auto foo = Symbol::dimname("foo" ); |
40 | auto dimname = Dimname::fromSymbol(foo); |
41 | ASSERT_EQ(dimname.type(), NameType::BASIC); |
42 | ASSERT_EQ(dimname.symbol(), foo); |
43 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
44 | ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("inva.lid" )), c10::Error); |
45 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
46 | ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("1invalid" )), c10::Error); |
47 | } |
48 | |
49 | static void check_unify_and_match( |
50 | const std::string& dimname, |
51 | const std::string& other, |
52 | at::optional<const std::string> expected) { |
53 | auto dimname1 = Dimname::fromSymbol(Symbol::dimname(dimname)); |
54 | auto dimname2 = Dimname::fromSymbol(Symbol::dimname(other)); |
55 | auto result = dimname1.unify(dimname2); |
56 | if (expected) { |
57 | auto expected_result = Dimname::fromSymbol(Symbol::dimname(*expected)); |
58 | ASSERT_EQ(result->symbol(), expected_result.symbol()); |
59 | ASSERT_EQ(result->type(), expected_result.type()); |
60 | ASSERT_TRUE(dimname1.matches(dimname2)); |
61 | } else { |
62 | ASSERT_FALSE(result); |
63 | ASSERT_FALSE(dimname1.matches(dimname2)); |
64 | } |
65 | } |
66 | |
67 | TEST(DimnameTest, unifyAndMatch) { |
68 | check_unify_and_match("a" , "a" , "a" ); |
69 | check_unify_and_match("a" , "*" , "a" ); |
70 | check_unify_and_match("*" , "a" , "a" ); |
71 | check_unify_and_match("*" , "*" , "*" ); |
72 | check_unify_and_match("a" , "b" , c10::nullopt); |
73 | } |
74 | |