1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <test/cpp/jit/test_utils.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | class UnionTypeTest : public ::testing::Test { |
11 | public: |
12 | // None |
13 | const TypePtr none = NoneType::get(); |
14 | |
15 | // List[str] |
16 | const TypePtr l1 = ListType::ofStrings(); |
17 | |
18 | // Optional[int] |
19 | const TypePtr opt1 = OptionalType::create(IntType::get()); |
20 | |
21 | // Optional[float] |
22 | const TypePtr opt2 = OptionalType::create(FloatType::get()); |
23 | |
24 | // Optional[List[str]] |
25 | const TypePtr opt3 = OptionalType::create(ListType::ofStrings()); |
26 | |
27 | // Tuple[Optional[int], int] |
28 | const TypePtr tup1 = |
29 | TupleType::create({OptionalType::create(IntType::get()), IntType::get()}); |
30 | |
31 | // Tuple[int, int] |
32 | const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()}); |
33 | |
34 | bool hasType(UnionTypePtr u, TypePtr t) { |
35 | auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t); |
36 | return res != u->getTypes().end(); |
37 | } |
38 | }; |
39 | |
40 | TEST_F(UnionTypeTest, UnionOperatorEquals) { |
41 | const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()}); |
42 | |
43 | // Same thing, but using different TypePtrs |
44 | const TypePtr l1_ = ListType::ofStrings(); |
45 | const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()}); |
46 | const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()}); |
47 | |
48 | ASSERT_TRUE(*u1 == *u2); |
49 | } |
50 | |
51 | TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) { |
52 | // Goal: Union[int, float, None] |
53 | const UnionTypePtr u = UnionType::create({opt1, opt2}); |
54 | |
55 | ASSERT_EQ(u->getTypes().size(), 3); |
56 | ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get())); |
57 | ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get())); |
58 | ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get())); |
59 | } |
60 | |
61 | TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) { |
62 | // Goal: Union[int, None] |
63 | const UnionTypePtr u = UnionType::create({opt1, IntType::get()}); |
64 | |
65 | ASSERT_EQ(u->getTypes().size(), 2); |
66 | ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get())); |
67 | ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get())); |
68 | } |
69 | |
70 | TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) { |
71 | // Goal: Union[Tuple[Optional[int], int], str] |
72 | const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2}); |
73 | |
74 | ASSERT_EQ(u->getTypes().size(), 2); |
75 | ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); |
76 | ASSERT_TRUE(UnionTypeTest::hasType(u, tup1)); |
77 | } |
78 | |
79 | TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) { |
80 | // Goal: Union[List[str], str] |
81 | const UnionTypePtr u = UnionType::create({l1, StringType::get()}); |
82 | |
83 | ASSERT_EQ(u->getTypes().size(), 2); |
84 | ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); |
85 | ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings())); |
86 | } |
87 | |
88 | TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) { |
89 | // Goal: Union[List[str], None, str] |
90 | const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()}); |
91 | |
92 | ASSERT_EQ(u->getTypes().size(), 3); |
93 | ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); |
94 | ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings())); |
95 | } |
96 | |
97 | TEST_F(UnionTypeTest, Subtyping_NumberType) { |
98 | // Union[int, float, Complex] |
99 | const UnionTypePtr union1 = |
100 | UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()}); |
101 | |
102 | // Union[int, float, Complex, None] |
103 | const UnionTypePtr union2 = UnionType::create( |
104 | {IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()}); |
105 | |
106 | const NumberTypePtr num = NumberType::get(); |
107 | |
108 | ASSERT_TRUE(num->isSubtypeOf(*union1)); |
109 | ASSERT_TRUE(union1->isSubtypeOf(*num)); |
110 | ASSERT_TRUE(*num == *union1); |
111 | |
112 | ASSERT_TRUE(num->isSubtypeOf(*union2)); |
113 | ASSERT_FALSE(union2->isSubtypeOf(*num)); |
114 | ASSERT_FALSE(*num == *union2); |
115 | } |
116 | |
117 | TEST_F(UnionTypeTest, Subtyping_OptionalType) { |
118 | // Union[int, None] |
119 | const UnionTypePtr union1 = |
120 | UnionType::create({IntType::get(), NoneType::get()}); |
121 | |
122 | // Union[int, str, None] |
123 | const UnionTypePtr union2 = |
124 | UnionType::create({IntType::get(), StringType::get(), NoneType::get()}); |
125 | |
126 | // Union[int, str, List[str]] |
127 | const UnionTypePtr union3 = UnionType::create( |
128 | {IntType::get(), StringType::get(), ListType::ofStrings()}); |
129 | |
130 | ASSERT_TRUE(none->isSubtypeOf(opt1)); |
131 | ASSERT_TRUE(none->isSubtypeOf(union1)); |
132 | ASSERT_TRUE(none->isSubtypeOf(union2)); |
133 | ASSERT_FALSE(none->isSubtypeOf(union3)); |
134 | |
135 | ASSERT_FALSE(opt1->isSubtypeOf(none)); |
136 | ASSERT_TRUE(opt1->isSubtypeOf(union1)); |
137 | ASSERT_TRUE(opt1->isSubtypeOf(union2)); |
138 | ASSERT_FALSE(opt1->isSubtypeOf(union3)); |
139 | |
140 | ASSERT_FALSE(union1->isSubtypeOf(none)); |
141 | ASSERT_TRUE(union1->isSubtypeOf(opt1)); |
142 | ASSERT_TRUE(union1->isSubtypeOf(union2)); |
143 | ASSERT_FALSE(union1->isSubtypeOf(union3)); |
144 | |
145 | ASSERT_FALSE(union2->isSubtypeOf(union1)); |
146 | } |
147 | |
148 | } // namespace jit |
149 | } // namespace torch |
150 | |