1 | #include <ATen/ATen.h> |
2 | #include <gtest/gtest.h> |
3 | #include <torch/torch.h> |
4 | #include <ATen/core/jit_type.h> |
5 | #include <torch/csrc/jit/frontend/resolver.h> |
6 | #include <torch/csrc/jit/serialization/import_source.h> |
7 | |
8 | namespace c10 { |
9 | |
10 | TEST(TypeCustomPrinter, Basic) { |
11 | TypePrinter printer = |
12 | [](const Type& t) -> c10::optional<std::string> { |
13 | if (auto tensorType = t.cast<TensorType>()) { |
14 | return "CustomTensor" ; |
15 | } |
16 | return c10::nullopt; |
17 | }; |
18 | |
19 | // Tensor types should be rewritten |
20 | torch::Tensor iv = torch::rand({2, 3}); |
21 | const auto type = TensorType::create(iv); |
22 | EXPECT_EQ(type->annotation_str(), "Tensor" ); |
23 | EXPECT_EQ(type->annotation_str(printer), "CustomTensor" ); |
24 | |
25 | // Unrelated types shoudl not be affected |
26 | const auto intType = IntType::get(); |
27 | EXPECT_EQ(intType->annotation_str(printer), intType->annotation_str()); |
28 | } |
29 | |
30 | TEST(TypeCustomPrinter, ContainedTypes) { |
31 | TypePrinter printer = |
32 | [](const Type& t) -> c10::optional<std::string> { |
33 | if (auto tensorType = t.cast<TensorType>()) { |
34 | return "CustomTensor" ; |
35 | } |
36 | return c10::nullopt; |
37 | }; |
38 | torch::Tensor iv = torch::rand({2, 3}); |
39 | const auto type = TensorType::create(iv); |
40 | |
41 | // Contained types should work |
42 | const auto tupleType = TupleType::create({type, IntType::get(), type}); |
43 | EXPECT_EQ(tupleType->annotation_str(), "Tuple[Tensor, int, Tensor]" ); |
44 | EXPECT_EQ( |
45 | tupleType->annotation_str(printer), "Tuple[CustomTensor, int, CustomTensor]" ); |
46 | const auto dictType = DictType::create(IntType::get(), type); |
47 | EXPECT_EQ(dictType->annotation_str(printer), "Dict[int, CustomTensor]" ); |
48 | const auto listType = ListType::create(tupleType); |
49 | EXPECT_EQ( |
50 | listType->annotation_str(printer), |
51 | "List[Tuple[CustomTensor, int, CustomTensor]]" ); |
52 | } |
53 | |
54 | TEST(TypeCustomPrinter, NamedTuples) { |
55 | TypePrinter printer = |
56 | [](const Type& t) -> c10::optional<std::string> { |
57 | if (auto tupleType = t.cast<TupleType>()) { |
58 | // Rewrite only NamedTuples |
59 | if (tupleType->name()) { |
60 | return "Rewritten" ; |
61 | } |
62 | } |
63 | return c10::nullopt; |
64 | }; |
65 | torch::Tensor iv = torch::rand({2, 3}); |
66 | const auto type = TensorType::create(iv); |
67 | |
68 | std::vector<std::string> field_names = {"foo" , "bar" }; |
69 | const auto namedTupleType = TupleType::createNamed( |
70 | "my.named.tuple" , field_names, {type, IntType::get()}); |
71 | EXPECT_EQ(namedTupleType->annotation_str(printer), "Rewritten" ); |
72 | |
73 | // Put it inside another tuple, should still work |
74 | const auto outerTupleType = TupleType::create({IntType::get(), namedTupleType}); |
75 | EXPECT_EQ(outerTupleType->annotation_str(printer), "Tuple[int, Rewritten]" ); |
76 | } |
77 | |
78 | static TypePtr importType( |
79 | std::shared_ptr<CompilationUnit> cu, |
80 | const std::string& qual_name, |
81 | const std::string& src) { |
82 | std::vector<at::IValue> constantTable; |
83 | auto source = std::make_shared<torch::jit::Source>(src); |
84 | torch::jit::SourceImporter si( |
85 | cu, |
86 | &constantTable, |
87 | [&](const std::string& name) -> std::shared_ptr<torch::jit::Source> { |
88 | return source; |
89 | }, |
90 | /*version=*/2); |
91 | return si.loadType(qual_name); |
92 | } |
93 | |
94 | TEST(TypeEquality, ClassBasic) { |
95 | // Even if classes have the same name across two compilation units, they |
96 | // should not compare equal. |
97 | auto cu = std::make_shared<CompilationUnit>(); |
98 | const auto src = R"JIT( |
99 | class First: |
100 | def one(self, x: Tensor, y: Tensor) -> Tensor: |
101 | return x |
102 | )JIT" ; |
103 | |
104 | auto classType = importType(cu, "__torch__.First" , src); |
105 | auto classType2 = cu->get_type("__torch__.First" ); |
106 | // Trivially these should be equal |
107 | EXPECT_EQ(*classType, *classType2); |
108 | } |
109 | |
110 | TEST(TypeEquality, ClassInequality) { |
111 | // Even if classes have the same name across two compilation units, they |
112 | // should not compare equal. |
113 | auto cu = std::make_shared<CompilationUnit>(); |
114 | const auto src = R"JIT( |
115 | class First: |
116 | def one(self, x: Tensor, y: Tensor) -> Tensor: |
117 | return x |
118 | )JIT" ; |
119 | |
120 | auto classType = importType(cu, "__torch__.First" , src); |
121 | |
122 | auto cu2 = std::make_shared<CompilationUnit>(); |
123 | const auto src2 = R"JIT( |
124 | class First: |
125 | def one(self, x: Tensor, y: Tensor) -> Tensor: |
126 | return y |
127 | )JIT" ; |
128 | |
129 | auto classType2 = importType(cu2, "__torch__.First" , src2); |
130 | EXPECT_NE(*classType, *classType2); |
131 | } |
132 | |
133 | TEST(TypeEquality, InterfaceEquality) { |
134 | // Interfaces defined anywhere should compare equal, provided they share a |
135 | // name and interface |
136 | auto cu = std::make_shared<CompilationUnit>(); |
137 | const auto interfaceSrc = R"JIT( |
138 | class OneForward(Interface): |
139 | def one(self, x: Tensor, y: Tensor) -> Tensor: |
140 | pass |
141 | def forward(self, x: Tensor) -> Tensor: |
142 | pass |
143 | )JIT" ; |
144 | auto interfaceType = importType(cu, "__torch__.OneForward" , interfaceSrc); |
145 | |
146 | auto cu2 = std::make_shared<CompilationUnit>(); |
147 | auto interfaceType2 = importType(cu2, "__torch__.OneForward" , interfaceSrc); |
148 | |
149 | EXPECT_EQ(*interfaceType, *interfaceType2); |
150 | } |
151 | |
152 | TEST(TypeEquality, InterfaceInequality) { |
153 | // Interfaces must match for them to compare equal, even if they share a name |
154 | auto cu = std::make_shared<CompilationUnit>(); |
155 | const auto interfaceSrc = R"JIT( |
156 | class OneForward(Interface): |
157 | def one(self, x: Tensor, y: Tensor) -> Tensor: |
158 | pass |
159 | def forward(self, x: Tensor) -> Tensor: |
160 | pass |
161 | )JIT" ; |
162 | auto interfaceType = importType(cu, "__torch__.OneForward" , interfaceSrc); |
163 | |
164 | auto cu2 = std::make_shared<CompilationUnit>(); |
165 | const auto interfaceSrc2 = R"JIT( |
166 | class OneForward(Interface): |
167 | def two(self, x: Tensor, y: Tensor) -> Tensor: |
168 | pass |
169 | def forward(self, x: Tensor) -> Tensor: |
170 | pass |
171 | )JIT" ; |
172 | auto interfaceType2 = importType(cu2, "__torch__.OneForward" , interfaceSrc2); |
173 | |
174 | EXPECT_NE(*interfaceType, *interfaceType2); |
175 | } |
176 | |
177 | TEST(TypeEquality, TupleEquality) { |
178 | // Tuples should be structurally typed |
179 | auto type = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()}); |
180 | auto type2 = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()}); |
181 | |
182 | EXPECT_EQ(*type, *type2); |
183 | } |
184 | |
185 | TEST(TypeEquality, NamedTupleEquality) { |
186 | // Named tuples should compare equal if they share a name and field names |
187 | std::vector<std::string> fields = {"a" , "b" , "c" , "d" }; |
188 | std::vector<std::string> otherFields = {"wow" , "so" , "very" , "different" }; |
189 | auto type = TupleType::createNamed( |
190 | "MyNamedTuple" , |
191 | fields, |
192 | {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()}); |
193 | auto type2 = TupleType::createNamed( |
194 | "MyNamedTuple" , |
195 | fields, |
196 | {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()}); |
197 | EXPECT_EQ(*type, *type2); |
198 | |
199 | auto differentName = TupleType::createNamed( |
200 | "WowSoDifferent" , |
201 | fields, |
202 | {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()}); |
203 | EXPECT_NE(*type, *differentName); |
204 | |
205 | auto differentField = TupleType::createNamed( |
206 | "MyNamedTuple" , |
207 | otherFields, |
208 | {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()}); |
209 | EXPECT_NE(*type, *differentField); |
210 | } |
211 | } // namespace c10 |
212 | |