1#include <gtest/gtest.h>
2
3#include <ATen/Dimname.h>
4#include <c10/util/Exception.h>
5#include <c10/util/Optional.h>
6
7using at::NameType;
8using at::Symbol;
9using at::Dimname;
10
11TEST(DimnameTest, isValidIdentifier) {
12 ASSERT_TRUE(Dimname::isValidName("a"));
13 ASSERT_TRUE(Dimname::isValidName("batch"));
14 ASSERT_TRUE(Dimname::isValidName("N"));
15 ASSERT_TRUE(Dimname::isValidName("CHANNELS"));
16 ASSERT_TRUE(Dimname::isValidName("foo_bar_baz"));
17 ASSERT_TRUE(Dimname::isValidName("batch1"));
18 ASSERT_TRUE(Dimname::isValidName("batch_9"));
19 ASSERT_TRUE(Dimname::isValidName("_"));
20 ASSERT_TRUE(Dimname::isValidName("_1"));
21
22 ASSERT_FALSE(Dimname::isValidName(""));
23 ASSERT_FALSE(Dimname::isValidName(" "));
24 ASSERT_FALSE(Dimname::isValidName(" a "));
25 ASSERT_FALSE(Dimname::isValidName("1batch"));
26 ASSERT_FALSE(Dimname::isValidName("?"));
27 ASSERT_FALSE(Dimname::isValidName("-"));
28 ASSERT_FALSE(Dimname::isValidName("1"));
29 ASSERT_FALSE(Dimname::isValidName("01"));
30}
31
32TEST(DimnameTest, wildcardName) {
33 Dimname wildcard = Dimname::wildcard();
34 ASSERT_EQ(wildcard.type(), NameType::WILDCARD);
35 ASSERT_EQ(wildcard.symbol(), Symbol::dimname("*"));
36}
37
38TEST(DimnameTest, createNormalName) {
39 auto foo = Symbol::dimname("foo");
40 auto dimname = Dimname::fromSymbol(foo);
41 ASSERT_EQ(dimname.type(), NameType::BASIC);
42 ASSERT_EQ(dimname.symbol(), foo);
43 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
44 ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("inva.lid")), c10::Error);
45 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
46 ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("1invalid")), c10::Error);
47}
48
49static void check_unify_and_match(
50 const std::string& dimname,
51 const std::string& other,
52 at::optional<const std::string> expected) {
53 auto dimname1 = Dimname::fromSymbol(Symbol::dimname(dimname));
54 auto dimname2 = Dimname::fromSymbol(Symbol::dimname(other));
55 auto result = dimname1.unify(dimname2);
56 if (expected) {
57 auto expected_result = Dimname::fromSymbol(Symbol::dimname(*expected));
58 ASSERT_EQ(result->symbol(), expected_result.symbol());
59 ASSERT_EQ(result->type(), expected_result.type());
60 ASSERT_TRUE(dimname1.matches(dimname2));
61 } else {
62 ASSERT_FALSE(result);
63 ASSERT_FALSE(dimname1.matches(dimname2));
64 }
65}
66
67TEST(DimnameTest, unifyAndMatch) {
68 check_unify_and_match("a", "a", "a");
69 check_unify_and_match("a", "*", "a");
70 check_unify_and_match("*", "a", "a");
71 check_unify_and_match("*", "*", "*");
72 check_unify_and_match("a", "b", c10::nullopt);
73}
74