1#include <gtest/gtest.h>
2
3#include <ATen/core/jit_type.h>
4#include <test/cpp/jit/test_utils.h>
5#include <torch/csrc/jit/ir/ir.h>
6
7namespace torch {
8namespace jit {
9
10class UnionTypeTest : public ::testing::Test {
11 public:
12 // None
13 const TypePtr none = NoneType::get();
14
15 // List[str]
16 const TypePtr l1 = ListType::ofStrings();
17
18 // Optional[int]
19 const TypePtr opt1 = OptionalType::create(IntType::get());
20
21 // Optional[float]
22 const TypePtr opt2 = OptionalType::create(FloatType::get());
23
24 // Optional[List[str]]
25 const TypePtr opt3 = OptionalType::create(ListType::ofStrings());
26
27 // Tuple[Optional[int], int]
28 const TypePtr tup1 =
29 TupleType::create({OptionalType::create(IntType::get()), IntType::get()});
30
31 // Tuple[int, int]
32 const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()});
33
34 bool hasType(UnionTypePtr u, TypePtr t) {
35 auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t);
36 return res != u->getTypes().end();
37 }
38};
39
40TEST_F(UnionTypeTest, UnionOperatorEquals) {
41 const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()});
42
43 // Same thing, but using different TypePtrs
44 const TypePtr l1_ = ListType::ofStrings();
45 const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()});
46 const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()});
47
48 ASSERT_TRUE(*u1 == *u2);
49}
50
51TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) {
52 // Goal: Union[int, float, None]
53 const UnionTypePtr u = UnionType::create({opt1, opt2});
54
55 ASSERT_EQ(u->getTypes().size(), 3);
56 ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
57 ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get()));
58 ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
59}
60
61TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) {
62 // Goal: Union[int, None]
63 const UnionTypePtr u = UnionType::create({opt1, IntType::get()});
64
65 ASSERT_EQ(u->getTypes().size(), 2);
66 ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
67 ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
68}
69
70TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) {
71 // Goal: Union[Tuple[Optional[int], int], str]
72 const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2});
73
74 ASSERT_EQ(u->getTypes().size(), 2);
75 ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
76 ASSERT_TRUE(UnionTypeTest::hasType(u, tup1));
77}
78
79TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) {
80 // Goal: Union[List[str], str]
81 const UnionTypePtr u = UnionType::create({l1, StringType::get()});
82
83 ASSERT_EQ(u->getTypes().size(), 2);
84 ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
85 ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
86}
87
88TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) {
89 // Goal: Union[List[str], None, str]
90 const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()});
91
92 ASSERT_EQ(u->getTypes().size(), 3);
93 ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
94 ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
95}
96
97TEST_F(UnionTypeTest, Subtyping_NumberType) {
98 // Union[int, float, Complex]
99 const UnionTypePtr union1 =
100 UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()});
101
102 // Union[int, float, Complex, None]
103 const UnionTypePtr union2 = UnionType::create(
104 {IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()});
105
106 const NumberTypePtr num = NumberType::get();
107
108 ASSERT_TRUE(num->isSubtypeOf(*union1));
109 ASSERT_TRUE(union1->isSubtypeOf(*num));
110 ASSERT_TRUE(*num == *union1);
111
112 ASSERT_TRUE(num->isSubtypeOf(*union2));
113 ASSERT_FALSE(union2->isSubtypeOf(*num));
114 ASSERT_FALSE(*num == *union2);
115}
116
117TEST_F(UnionTypeTest, Subtyping_OptionalType) {
118 // Union[int, None]
119 const UnionTypePtr union1 =
120 UnionType::create({IntType::get(), NoneType::get()});
121
122 // Union[int, str, None]
123 const UnionTypePtr union2 =
124 UnionType::create({IntType::get(), StringType::get(), NoneType::get()});
125
126 // Union[int, str, List[str]]
127 const UnionTypePtr union3 = UnionType::create(
128 {IntType::get(), StringType::get(), ListType::ofStrings()});
129
130 ASSERT_TRUE(none->isSubtypeOf(opt1));
131 ASSERT_TRUE(none->isSubtypeOf(union1));
132 ASSERT_TRUE(none->isSubtypeOf(union2));
133 ASSERT_FALSE(none->isSubtypeOf(union3));
134
135 ASSERT_FALSE(opt1->isSubtypeOf(none));
136 ASSERT_TRUE(opt1->isSubtypeOf(union1));
137 ASSERT_TRUE(opt1->isSubtypeOf(union2));
138 ASSERT_FALSE(opt1->isSubtypeOf(union3));
139
140 ASSERT_FALSE(union1->isSubtypeOf(none));
141 ASSERT_TRUE(union1->isSubtypeOf(opt1));
142 ASSERT_TRUE(union1->isSubtypeOf(union2));
143 ASSERT_FALSE(union1->isSubtypeOf(union3));
144
145 ASSERT_FALSE(union2->isSubtypeOf(union1));
146}
147
148} // namespace jit
149} // namespace torch
150