1#include <gtest/gtest.h>
2#include <test/cpp/jit/test_utils.h>
3
4#include <ATen/core/jit_type.h>
5#include <torch/csrc/jit/mobile/type_parser.h>
6
7namespace torch {
8namespace jit {
9
10// Parse Success cases
11TEST(MobileTypeParserTest, Int) {
12 std::string int_ps("int");
13 auto int_tp = c10::parseType(int_ps);
14 EXPECT_EQ(*int_tp, *IntType::get());
15}
16
17TEST(MobileTypeParserTest, NestedContainersAnnotationStr) {
18 std::string tuple_ps(
19 "Tuple[str, Optional[float], Dict[str, List[Tensor]], int]");
20 auto tuple_tp = c10::parseType(tuple_ps);
21 std::vector<TypePtr> args = {
22 c10::StringType::get(),
23 c10::OptionalType::create(c10::FloatType::get()),
24 c10::DictType::create(
25 StringType::get(), ListType::create(TensorType::get())),
26 IntType::get()};
27 auto tp = TupleType::create(std::move(args));
28 ASSERT_EQ(*tuple_tp, *tp);
29}
30
31TEST(MobileTypeParserTest, TorchBindClass) {
32 std::string tuple_ps("__torch__.torch.classes.rnn.CellParamsBase");
33 auto tuple_tp = c10::parseType(tuple_ps);
34 std::string tuple_tps = tuple_tp->annotation_str();
35 ASSERT_EQ(tuple_ps, tuple_tps);
36}
37
38TEST(MobileTypeParserTest, ListOfTorchBindClass) {
39 std::string tuple_ps("List[__torch__.torch.classes.rnn.CellParamsBase]");
40 auto tuple_tp = c10::parseType(tuple_ps);
41 EXPECT_TRUE(tuple_tp->isSubtypeOf(AnyListType::get()));
42 EXPECT_EQ(
43 "__torch__.torch.classes.rnn.CellParamsBase",
44 tuple_tp->containedType(0)->annotation_str());
45}
46
47TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) {
48 std::string tuple_space_ps(
49 "Tuple[ str, Optional[float], Dict[str, List[Tensor ]] , int]");
50 auto tuple_space_tp = c10::parseType(tuple_space_ps);
51 // tuple_space_tps should not have weird white spaces
52 std::string tuple_space_tps = tuple_space_tp->annotation_str();
53 ASSERT_TRUE(tuple_space_tps.find("[ ") == std::string::npos);
54 ASSERT_TRUE(tuple_space_tps.find(" ]") == std::string::npos);
55 ASSERT_TRUE(tuple_space_tps.find(" ,") == std::string::npos);
56}
57
58TEST(MobileTypeParserTest, NamedTuple) {
59 std::string named_tuple_ps(
60 "__torch__.base_models.preproc_types.PreprocOutputType["
61 " NamedTuple, ["
62 " [float_features, Tensor],"
63 " [id_list_features, List[Tensor]],"
64 " [label, Tensor],"
65 " [weight, Tensor],"
66 " [prod_prediction, Tuple[Tensor, Tensor]],"
67 " [id_score_list_features, List[Tensor]],"
68 " [embedding_features, List[Tensor]],"
69 " [teacher_label, Tensor]"
70 " ]"
71 " ]");
72
73 c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps);
74 std::string named_tuple_annotation_str = named_tuple_tp->annotation_str();
75 ASSERT_EQ(
76 named_tuple_annotation_str,
77 "__torch__.base_models.preproc_types.PreprocOutputType");
78}
79
80TEST(MobileTypeParserTest, DictNestedNamedTupleTypeList) {
81 std::string type_str_1(
82 "__torch__.base_models.preproc_types.PreprocOutputType["
83 " NamedTuple, ["
84 " [float_features, Tensor],"
85 " [id_list_features, List[Tensor]],"
86 " [label, Tensor],"
87 " [weight, Tensor],"
88 " [prod_prediction, Tuple[Tensor, Tensor]],"
89 " [id_score_list_features, List[Tensor]],"
90 " [embedding_features, List[Tensor]],"
91 " [teacher_label, Tensor]"
92 " ]");
93 std::string type_str_2(
94 "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]");
95 std::vector<std::string> type_strs = {type_str_1, type_str_2};
96 std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs);
97 EXPECT_EQ(*named_tuple_tps[1]->containedType(0), *c10::StringType::get());
98 EXPECT_EQ(*named_tuple_tps[0], *named_tuple_tps[1]->containedType(1));
99}
100
101TEST(MobileTypeParserTest, NamedTupleNestedNamedTupleTypeList) {
102 std::string type_str_1(
103 " __torch__.ccc.xxx ["
104 " NamedTuple, ["
105 " [field_name_c_1, Tensor],"
106 " [field_name_c_2, Tuple[Tensor, Tensor]]"
107 " ]"
108 "]");
109 std::string type_str_2(
110 "__torch__.bbb.xxx ["
111 " NamedTuple,["
112 " [field_name_b, __torch__.ccc.xxx]]"
113 " ]"
114 "]");
115
116 std::string type_str_3(
117 "__torch__.aaa.xxx["
118 " NamedTuple, ["
119 " [field_name_a, __torch__.bbb.xxx]"
120 " ]"
121 "]");
122
123 std::vector<std::string> type_strs = {type_str_1, type_str_2, type_str_3};
124 std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs);
125 std::string named_tuple_annotation_str = named_tuple_tps[2]->annotation_str();
126 ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx");
127}
128
129TEST(MobileTypeParserTest, NamedTupleNestedNamedTuple) {
130 std::string named_tuple_ps(
131 "__torch__.aaa.xxx["
132 " NamedTuple, ["
133 " [field_name_a, __torch__.bbb.xxx ["
134 " NamedTuple, ["
135 " [field_name_b, __torch__.ccc.xxx ["
136 " NamedTuple, ["
137 " [field_name_c_1, Tensor],"
138 " [field_name_c_2, Tuple[Tensor, Tensor]]"
139 " ]"
140 " ]"
141 " ]"
142 " ]"
143 " ]"
144 " ]"
145 " ] "
146 "]");
147
148 c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps);
149 std::string named_tuple_annotation_str = named_tuple_tp->str();
150 ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx");
151}
152
153// Parse throw cases
154TEST(MobileTypeParserTest, Empty) {
155 std::string empty_ps("");
156 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
157 ASSERT_ANY_THROW(c10::parseType(empty_ps));
158}
159
160TEST(MobileTypeParserTest, TypoRaises) {
161 std::string typo_token("List[tensor]");
162 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
163 ASSERT_ANY_THROW(c10::parseType(typo_token));
164}
165
166TEST(MobileTypeParserTest, MismatchBracketRaises) {
167 std::string mismatch1("List[Tensor");
168 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
169 ASSERT_ANY_THROW(c10::parseType(mismatch1));
170}
171
172TEST(MobileTypeParserTest, MismatchBracketRaises2) {
173 std::string mismatch2("List[[Tensor]");
174 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
175 ASSERT_ANY_THROW(c10::parseType(mismatch2));
176}
177
178TEST(MobileTypeParserTest, DictWithoutValueRaises) {
179 std::string mismatch3("Dict[Tensor]");
180 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
181 ASSERT_ANY_THROW(c10::parseType(mismatch3));
182}
183
184TEST(MobileTypeParserTest, ListArgCountMismatchRaises) {
185 // arg count mismatch
186 std::string mismatch4("List[int, str]");
187 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
188 ASSERT_ANY_THROW(c10::parseType(mismatch4));
189}
190
191TEST(MobileTypeParserTest, DictArgCountMismatchRaises) {
192 std::string trailing_commm("Dict[str,]");
193 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
194 ASSERT_ANY_THROW(c10::parseType(trailing_commm));
195}
196
197TEST(MobileTypeParserTest, ValidTypeWithExtraStuffRaises) {
198 std::string extra_stuff("int int");
199 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
200 ASSERT_ANY_THROW(c10::parseType(extra_stuff));
201}
202
203TEST(MobileTypeParserTest, NonIdentifierRaises) {
204 std::string non_id("(int)");
205 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
206 ASSERT_ANY_THROW(c10::parseType(non_id));
207}
208
209TEST(MobileTypeParserTest, DictNestedNamedTupleTypeListRaises) {
210 std::string type_str_1(
211 "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]");
212 std::string type_str_2(
213 "__torch__.base_models.preproc_types.PreprocOutputType["
214 " NamedTuple, ["
215 " [float_features, Tensor],"
216 " [id_list_features, List[Tensor]],"
217 " [label, Tensor],"
218 " [weight, Tensor],"
219 " [prod_prediction, Tuple[Tensor, Tensor]],"
220 " [id_score_list_features, List[Tensor]],"
221 " [embedding_features, List[Tensor]],"
222 " [teacher_label, Tensor]"
223 " ]");
224 std::vector<std::string> type_strs = {type_str_1, type_str_2};
225 std::string error_message =
226 R"(Can't find definition for the type: __torch__.base_models.preproc_types.PreprocOutputType)";
227 ASSERT_THROWS_WITH_MESSAGE(c10::parseType(type_strs), error_message);
228}
229
230} // namespace jit
231} // namespace torch
232