1 | #include <gtest/gtest.h> |
2 | #include <test/cpp/jit/test_utils.h> |
3 | |
4 | #include <ATen/core/jit_type.h> |
5 | #include <torch/csrc/jit/mobile/type_parser.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | // Parse Success cases |
11 | TEST(MobileTypeParserTest, Int) { |
12 | std::string int_ps("int" ); |
13 | auto int_tp = c10::parseType(int_ps); |
14 | EXPECT_EQ(*int_tp, *IntType::get()); |
15 | } |
16 | |
17 | TEST(MobileTypeParserTest, NestedContainersAnnotationStr) { |
18 | std::string tuple_ps( |
19 | "Tuple[str, Optional[float], Dict[str, List[Tensor]], int]" ); |
20 | auto tuple_tp = c10::parseType(tuple_ps); |
21 | std::vector<TypePtr> args = { |
22 | c10::StringType::get(), |
23 | c10::OptionalType::create(c10::FloatType::get()), |
24 | c10::DictType::create( |
25 | StringType::get(), ListType::create(TensorType::get())), |
26 | IntType::get()}; |
27 | auto tp = TupleType::create(std::move(args)); |
28 | ASSERT_EQ(*tuple_tp, *tp); |
29 | } |
30 | |
31 | TEST(MobileTypeParserTest, TorchBindClass) { |
32 | std::string tuple_ps("__torch__.torch.classes.rnn.CellParamsBase" ); |
33 | auto tuple_tp = c10::parseType(tuple_ps); |
34 | std::string tuple_tps = tuple_tp->annotation_str(); |
35 | ASSERT_EQ(tuple_ps, tuple_tps); |
36 | } |
37 | |
38 | TEST(MobileTypeParserTest, ListOfTorchBindClass) { |
39 | std::string tuple_ps("List[__torch__.torch.classes.rnn.CellParamsBase]" ); |
40 | auto tuple_tp = c10::parseType(tuple_ps); |
41 | EXPECT_TRUE(tuple_tp->isSubtypeOf(AnyListType::get())); |
42 | EXPECT_EQ( |
43 | "__torch__.torch.classes.rnn.CellParamsBase" , |
44 | tuple_tp->containedType(0)->annotation_str()); |
45 | } |
46 | |
47 | TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) { |
48 | std::string tuple_space_ps( |
49 | "Tuple[ str, Optional[float], Dict[str, List[Tensor ]] , int]" ); |
50 | auto tuple_space_tp = c10::parseType(tuple_space_ps); |
51 | // tuple_space_tps should not have weird white spaces |
52 | std::string tuple_space_tps = tuple_space_tp->annotation_str(); |
53 | ASSERT_TRUE(tuple_space_tps.find("[ " ) == std::string::npos); |
54 | ASSERT_TRUE(tuple_space_tps.find(" ]" ) == std::string::npos); |
55 | ASSERT_TRUE(tuple_space_tps.find(" ," ) == std::string::npos); |
56 | } |
57 | |
58 | TEST(MobileTypeParserTest, NamedTuple) { |
59 | std::string named_tuple_ps( |
60 | "__torch__.base_models.preproc_types.PreprocOutputType[" |
61 | " NamedTuple, [" |
62 | " [float_features, Tensor]," |
63 | " [id_list_features, List[Tensor]]," |
64 | " [label, Tensor]," |
65 | " [weight, Tensor]," |
66 | " [prod_prediction, Tuple[Tensor, Tensor]]," |
67 | " [id_score_list_features, List[Tensor]]," |
68 | " [embedding_features, List[Tensor]]," |
69 | " [teacher_label, Tensor]" |
70 | " ]" |
71 | " ]" ); |
72 | |
73 | c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps); |
74 | std::string named_tuple_annotation_str = named_tuple_tp->annotation_str(); |
75 | ASSERT_EQ( |
76 | named_tuple_annotation_str, |
77 | "__torch__.base_models.preproc_types.PreprocOutputType" ); |
78 | } |
79 | |
80 | TEST(MobileTypeParserTest, DictNestedNamedTupleTypeList) { |
81 | std::string type_str_1( |
82 | "__torch__.base_models.preproc_types.PreprocOutputType[" |
83 | " NamedTuple, [" |
84 | " [float_features, Tensor]," |
85 | " [id_list_features, List[Tensor]]," |
86 | " [label, Tensor]," |
87 | " [weight, Tensor]," |
88 | " [prod_prediction, Tuple[Tensor, Tensor]]," |
89 | " [id_score_list_features, List[Tensor]]," |
90 | " [embedding_features, List[Tensor]]," |
91 | " [teacher_label, Tensor]" |
92 | " ]" ); |
93 | std::string type_str_2( |
94 | "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]" ); |
95 | std::vector<std::string> type_strs = {type_str_1, type_str_2}; |
96 | std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs); |
97 | EXPECT_EQ(*named_tuple_tps[1]->containedType(0), *c10::StringType::get()); |
98 | EXPECT_EQ(*named_tuple_tps[0], *named_tuple_tps[1]->containedType(1)); |
99 | } |
100 | |
101 | TEST(MobileTypeParserTest, NamedTupleNestedNamedTupleTypeList) { |
102 | std::string type_str_1( |
103 | " __torch__.ccc.xxx [" |
104 | " NamedTuple, [" |
105 | " [field_name_c_1, Tensor]," |
106 | " [field_name_c_2, Tuple[Tensor, Tensor]]" |
107 | " ]" |
108 | "]" ); |
109 | std::string type_str_2( |
110 | "__torch__.bbb.xxx [" |
111 | " NamedTuple,[" |
112 | " [field_name_b, __torch__.ccc.xxx]]" |
113 | " ]" |
114 | "]" ); |
115 | |
116 | std::string type_str_3( |
117 | "__torch__.aaa.xxx[" |
118 | " NamedTuple, [" |
119 | " [field_name_a, __torch__.bbb.xxx]" |
120 | " ]" |
121 | "]" ); |
122 | |
123 | std::vector<std::string> type_strs = {type_str_1, type_str_2, type_str_3}; |
124 | std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs); |
125 | std::string named_tuple_annotation_str = named_tuple_tps[2]->annotation_str(); |
126 | ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx" ); |
127 | } |
128 | |
129 | TEST(MobileTypeParserTest, NamedTupleNestedNamedTuple) { |
130 | std::string named_tuple_ps( |
131 | "__torch__.aaa.xxx[" |
132 | " NamedTuple, [" |
133 | " [field_name_a, __torch__.bbb.xxx [" |
134 | " NamedTuple, [" |
135 | " [field_name_b, __torch__.ccc.xxx [" |
136 | " NamedTuple, [" |
137 | " [field_name_c_1, Tensor]," |
138 | " [field_name_c_2, Tuple[Tensor, Tensor]]" |
139 | " ]" |
140 | " ]" |
141 | " ]" |
142 | " ]" |
143 | " ]" |
144 | " ]" |
145 | " ] " |
146 | "]" ); |
147 | |
148 | c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps); |
149 | std::string named_tuple_annotation_str = named_tuple_tp->str(); |
150 | ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx" ); |
151 | } |
152 | |
153 | // Parse throw cases |
154 | TEST(MobileTypeParserTest, Empty) { |
155 | std::string empty_ps("" ); |
156 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
157 | ASSERT_ANY_THROW(c10::parseType(empty_ps)); |
158 | } |
159 | |
160 | TEST(MobileTypeParserTest, TypoRaises) { |
161 | std::string typo_token("List[tensor]" ); |
162 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
163 | ASSERT_ANY_THROW(c10::parseType(typo_token)); |
164 | } |
165 | |
166 | TEST(MobileTypeParserTest, MismatchBracketRaises) { |
167 | std::string mismatch1("List[Tensor" ); |
168 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
169 | ASSERT_ANY_THROW(c10::parseType(mismatch1)); |
170 | } |
171 | |
172 | TEST(MobileTypeParserTest, MismatchBracketRaises2) { |
173 | std::string mismatch2("List[[Tensor]" ); |
174 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
175 | ASSERT_ANY_THROW(c10::parseType(mismatch2)); |
176 | } |
177 | |
178 | TEST(MobileTypeParserTest, DictWithoutValueRaises) { |
179 | std::string mismatch3("Dict[Tensor]" ); |
180 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
181 | ASSERT_ANY_THROW(c10::parseType(mismatch3)); |
182 | } |
183 | |
184 | TEST(MobileTypeParserTest, ListArgCountMismatchRaises) { |
185 | // arg count mismatch |
186 | std::string mismatch4("List[int, str]" ); |
187 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
188 | ASSERT_ANY_THROW(c10::parseType(mismatch4)); |
189 | } |
190 | |
191 | TEST(MobileTypeParserTest, DictArgCountMismatchRaises) { |
192 | std::string trailing_commm("Dict[str,]" ); |
193 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
194 | ASSERT_ANY_THROW(c10::parseType(trailing_commm)); |
195 | } |
196 | |
197 | TEST(MobileTypeParserTest, ValidTypeWithExtraStuffRaises) { |
198 | std::string ("int int" ); |
199 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
200 | ASSERT_ANY_THROW(c10::parseType(extra_stuff)); |
201 | } |
202 | |
203 | TEST(MobileTypeParserTest, NonIdentifierRaises) { |
204 | std::string non_id("(int)" ); |
205 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
206 | ASSERT_ANY_THROW(c10::parseType(non_id)); |
207 | } |
208 | |
209 | TEST(MobileTypeParserTest, DictNestedNamedTupleTypeListRaises) { |
210 | std::string type_str_1( |
211 | "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]" ); |
212 | std::string type_str_2( |
213 | "__torch__.base_models.preproc_types.PreprocOutputType[" |
214 | " NamedTuple, [" |
215 | " [float_features, Tensor]," |
216 | " [id_list_features, List[Tensor]]," |
217 | " [label, Tensor]," |
218 | " [weight, Tensor]," |
219 | " [prod_prediction, Tuple[Tensor, Tensor]]," |
220 | " [id_score_list_features, List[Tensor]]," |
221 | " [embedding_features, List[Tensor]]," |
222 | " [teacher_label, Tensor]" |
223 | " ]" ); |
224 | std::vector<std::string> type_strs = {type_str_1, type_str_2}; |
225 | std::string error_message = |
226 | R"(Can't find definition for the type: __torch__.base_models.preproc_types.PreprocOutputType)" ; |
227 | ASSERT_THROWS_WITH_MESSAGE(c10::parseType(type_strs), error_message); |
228 | } |
229 | |
230 | } // namespace jit |
231 | } // namespace torch |
232 | |