1#include <ATen/ATen.h>
2#include <gtest/gtest.h>
3#include <torch/torch.h>
4#include <ATen/core/jit_type.h>
5#include <torch/csrc/jit/frontend/resolver.h>
6#include <torch/csrc/jit/serialization/import_source.h>
7
8namespace c10 {
9
10TEST(TypeCustomPrinter, Basic) {
11 TypePrinter printer =
12 [](const Type& t) -> c10::optional<std::string> {
13 if (auto tensorType = t.cast<TensorType>()) {
14 return "CustomTensor";
15 }
16 return c10::nullopt;
17 };
18
19 // Tensor types should be rewritten
20 torch::Tensor iv = torch::rand({2, 3});
21 const auto type = TensorType::create(iv);
22 EXPECT_EQ(type->annotation_str(), "Tensor");
23 EXPECT_EQ(type->annotation_str(printer), "CustomTensor");
24
25 // Unrelated types shoudl not be affected
26 const auto intType = IntType::get();
27 EXPECT_EQ(intType->annotation_str(printer), intType->annotation_str());
28}
29
30TEST(TypeCustomPrinter, ContainedTypes) {
31 TypePrinter printer =
32 [](const Type& t) -> c10::optional<std::string> {
33 if (auto tensorType = t.cast<TensorType>()) {
34 return "CustomTensor";
35 }
36 return c10::nullopt;
37 };
38 torch::Tensor iv = torch::rand({2, 3});
39 const auto type = TensorType::create(iv);
40
41 // Contained types should work
42 const auto tupleType = TupleType::create({type, IntType::get(), type});
43 EXPECT_EQ(tupleType->annotation_str(), "Tuple[Tensor, int, Tensor]");
44 EXPECT_EQ(
45 tupleType->annotation_str(printer), "Tuple[CustomTensor, int, CustomTensor]");
46 const auto dictType = DictType::create(IntType::get(), type);
47 EXPECT_EQ(dictType->annotation_str(printer), "Dict[int, CustomTensor]");
48 const auto listType = ListType::create(tupleType);
49 EXPECT_EQ(
50 listType->annotation_str(printer),
51 "List[Tuple[CustomTensor, int, CustomTensor]]");
52}
53
54TEST(TypeCustomPrinter, NamedTuples) {
55 TypePrinter printer =
56 [](const Type& t) -> c10::optional<std::string> {
57 if (auto tupleType = t.cast<TupleType>()) {
58 // Rewrite only NamedTuples
59 if (tupleType->name()) {
60 return "Rewritten";
61 }
62 }
63 return c10::nullopt;
64 };
65 torch::Tensor iv = torch::rand({2, 3});
66 const auto type = TensorType::create(iv);
67
68 std::vector<std::string> field_names = {"foo", "bar"};
69 const auto namedTupleType = TupleType::createNamed(
70 "my.named.tuple", field_names, {type, IntType::get()});
71 EXPECT_EQ(namedTupleType->annotation_str(printer), "Rewritten");
72
73 // Put it inside another tuple, should still work
74 const auto outerTupleType = TupleType::create({IntType::get(), namedTupleType});
75 EXPECT_EQ(outerTupleType->annotation_str(printer), "Tuple[int, Rewritten]");
76}
77
78static TypePtr importType(
79 std::shared_ptr<CompilationUnit> cu,
80 const std::string& qual_name,
81 const std::string& src) {
82 std::vector<at::IValue> constantTable;
83 auto source = std::make_shared<torch::jit::Source>(src);
84 torch::jit::SourceImporter si(
85 cu,
86 &constantTable,
87 [&](const std::string& name) -> std::shared_ptr<torch::jit::Source> {
88 return source;
89 },
90 /*version=*/2);
91 return si.loadType(qual_name);
92}
93
94TEST(TypeEquality, ClassBasic) {
95 // Even if classes have the same name across two compilation units, they
96 // should not compare equal.
97 auto cu = std::make_shared<CompilationUnit>();
98 const auto src = R"JIT(
99class First:
100 def one(self, x: Tensor, y: Tensor) -> Tensor:
101 return x
102)JIT";
103
104 auto classType = importType(cu, "__torch__.First", src);
105 auto classType2 = cu->get_type("__torch__.First");
106 // Trivially these should be equal
107 EXPECT_EQ(*classType, *classType2);
108}
109
110TEST(TypeEquality, ClassInequality) {
111 // Even if classes have the same name across two compilation units, they
112 // should not compare equal.
113 auto cu = std::make_shared<CompilationUnit>();
114 const auto src = R"JIT(
115class First:
116 def one(self, x: Tensor, y: Tensor) -> Tensor:
117 return x
118)JIT";
119
120 auto classType = importType(cu, "__torch__.First", src);
121
122 auto cu2 = std::make_shared<CompilationUnit>();
123 const auto src2 = R"JIT(
124class First:
125 def one(self, x: Tensor, y: Tensor) -> Tensor:
126 return y
127)JIT";
128
129 auto classType2 = importType(cu2, "__torch__.First", src2);
130 EXPECT_NE(*classType, *classType2);
131}
132
133TEST(TypeEquality, InterfaceEquality) {
134 // Interfaces defined anywhere should compare equal, provided they share a
135 // name and interface
136 auto cu = std::make_shared<CompilationUnit>();
137 const auto interfaceSrc = R"JIT(
138class OneForward(Interface):
139 def one(self, x: Tensor, y: Tensor) -> Tensor:
140 pass
141 def forward(self, x: Tensor) -> Tensor:
142 pass
143)JIT";
144 auto interfaceType = importType(cu, "__torch__.OneForward", interfaceSrc);
145
146 auto cu2 = std::make_shared<CompilationUnit>();
147 auto interfaceType2 = importType(cu2, "__torch__.OneForward", interfaceSrc);
148
149 EXPECT_EQ(*interfaceType, *interfaceType2);
150}
151
152TEST(TypeEquality, InterfaceInequality) {
153 // Interfaces must match for them to compare equal, even if they share a name
154 auto cu = std::make_shared<CompilationUnit>();
155 const auto interfaceSrc = R"JIT(
156class OneForward(Interface):
157 def one(self, x: Tensor, y: Tensor) -> Tensor:
158 pass
159 def forward(self, x: Tensor) -> Tensor:
160 pass
161)JIT";
162 auto interfaceType = importType(cu, "__torch__.OneForward", interfaceSrc);
163
164 auto cu2 = std::make_shared<CompilationUnit>();
165 const auto interfaceSrc2 = R"JIT(
166class OneForward(Interface):
167 def two(self, x: Tensor, y: Tensor) -> Tensor:
168 pass
169 def forward(self, x: Tensor) -> Tensor:
170 pass
171)JIT";
172 auto interfaceType2 = importType(cu2, "__torch__.OneForward", interfaceSrc2);
173
174 EXPECT_NE(*interfaceType, *interfaceType2);
175}
176
177TEST(TypeEquality, TupleEquality) {
178 // Tuples should be structurally typed
179 auto type = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
180 auto type2 = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
181
182 EXPECT_EQ(*type, *type2);
183}
184
185TEST(TypeEquality, NamedTupleEquality) {
186 // Named tuples should compare equal if they share a name and field names
187 std::vector<std::string> fields = {"a", "b", "c", "d"};
188 std::vector<std::string> otherFields = {"wow", "so", "very", "different"};
189 auto type = TupleType::createNamed(
190 "MyNamedTuple",
191 fields,
192 {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
193 auto type2 = TupleType::createNamed(
194 "MyNamedTuple",
195 fields,
196 {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
197 EXPECT_EQ(*type, *type2);
198
199 auto differentName = TupleType::createNamed(
200 "WowSoDifferent",
201 fields,
202 {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
203 EXPECT_NE(*type, *differentName);
204
205 auto differentField = TupleType::createNamed(
206 "MyNamedTuple",
207 otherFields,
208 {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
209 EXPECT_NE(*type, *differentField);
210}
211} // namespace c10
212