1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/core/qualified_name.h> |
4 | #include <test/cpp/jit/test_utils.h> |
5 | #include <torch/csrc/jit/frontend/resolver.h> |
6 | #include <torch/csrc/jit/serialization/import_source.h> |
7 | #include <torch/torch.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | static constexpr c10::string_view classSrcs1 = R"JIT( |
13 | class FooNestedTest: |
14 | def __init__(self, y): |
15 | self.y = y |
16 | |
17 | class FooNestedTest2: |
18 | def __init__(self, y): |
19 | self.y = y |
20 | self.nested = __torch__.FooNestedTest(y) |
21 | |
22 | class FooTest: |
23 | def __init__(self, x): |
24 | self.class_attr = __torch__.FooNestedTest(x) |
25 | self.class_attr2 = __torch__.FooNestedTest2(x) |
26 | self.x = self.class_attr.y + self.class_attr2.y |
27 | )JIT" ; |
28 | |
29 | static constexpr c10::string_view classSrcs2 = R"JIT( |
30 | class FooTest: |
31 | def __init__(self, x): |
32 | self.dx = x |
33 | )JIT" ; |
34 | |
35 | static void import_libs( |
36 | std::shared_ptr<CompilationUnit> cu, |
37 | const std::string& class_name, |
38 | const std::shared_ptr<Source>& src, |
39 | const std::vector<at::IValue>& tensor_table) { |
40 | SourceImporter si( |
41 | cu, |
42 | &tensor_table, |
43 | [&](const std::string& name) -> std::shared_ptr<Source> { return src; }, |
44 | /*version=*/2); |
45 | si.loadType(QualifiedName(class_name)); |
46 | } |
47 | |
48 | TEST(ClassImportTest, Basic) { |
49 | auto cu1 = std::make_shared<CompilationUnit>(); |
50 | auto cu2 = std::make_shared<CompilationUnit>(); |
51 | std::vector<at::IValue> constantTable; |
52 | // Import different versions of FooTest into two namespaces. |
53 | import_libs( |
54 | cu1, |
55 | "__torch__.FooTest" , |
56 | std::make_shared<Source>(classSrcs1), |
57 | constantTable); |
58 | import_libs( |
59 | cu2, |
60 | "__torch__.FooTest" , |
61 | std::make_shared<Source>(classSrcs2), |
62 | constantTable); |
63 | |
64 | // We should get the correct version of `FooTest` for whichever namespace we |
65 | // are referencing |
66 | c10::QualifiedName base("__torch__" ); |
67 | auto classType1 = cu1->get_class(c10::QualifiedName(base, "FooTest" )); |
68 | ASSERT_TRUE(classType1->hasAttribute("x" )); |
69 | ASSERT_FALSE(classType1->hasAttribute("dx" )); |
70 | |
71 | auto classType2 = cu2->get_class(c10::QualifiedName(base, "FooTest" )); |
72 | ASSERT_TRUE(classType2->hasAttribute("dx" )); |
73 | ASSERT_FALSE(classType2->hasAttribute("x" )); |
74 | |
75 | // We should only see FooNestedTest in the first namespace |
76 | auto c = cu1->get_class(c10::QualifiedName(base, "FooNestedTest" )); |
77 | ASSERT_TRUE(c); |
78 | |
79 | c = cu2->get_class(c10::QualifiedName(base, "FooNestedTest" )); |
80 | ASSERT_FALSE(c); |
81 | } |
82 | |
83 | TEST(ClassImportTest, ScriptObject) { |
84 | Module m1("m1" ); |
85 | Module m2("m2" ); |
86 | std::vector<at::IValue> constantTable; |
87 | import_libs( |
88 | m1._ivalue()->compilation_unit(), |
89 | "__torch__.FooTest" , |
90 | std::make_shared<Source>(classSrcs1), |
91 | constantTable); |
92 | import_libs( |
93 | m2._ivalue()->compilation_unit(), |
94 | "__torch__.FooTest" , |
95 | std::make_shared<Source>(classSrcs2), |
96 | constantTable); |
97 | |
98 | // Incorrect arguments for constructor should throw |
99 | c10::QualifiedName base("__torch__" ); |
100 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
101 | ASSERT_ANY_THROW(m1.create_class(c10::QualifiedName(base, "FooTest" ), {1})); |
102 | auto x = torch::ones({2, 3}); |
103 | auto obj = m2.create_class(c10::QualifiedName(base, "FooTest" ), x).toObject(); |
104 | auto dx = obj->getAttr("dx" ); |
105 | ASSERT_TRUE(almostEqual(x, dx.toTensor())); |
106 | |
107 | auto new_x = torch::rand({2, 3}); |
108 | obj->setAttr("dx" , new_x); |
109 | auto new_dx = obj->getAttr("dx" ); |
110 | ASSERT_TRUE(almostEqual(new_x, new_dx.toTensor())); |
111 | } |
112 | |
113 | static const auto methodSrc = R"JIT( |
114 | def __init__(self, x): |
115 | return x |
116 | )JIT" ; |
117 | |
118 | TEST(ClassImportTest, ClassDerive) { |
119 | auto cu = std::make_shared<CompilationUnit>(); |
120 | auto cls = ClassType::create("foo.bar" , cu); |
121 | const auto self = SimpleSelf(cls); |
122 | auto methods = cu->define("foo.bar" , methodSrc, nativeResolver(), &self); |
123 | auto method = methods[0]; |
124 | cls->addAttribute("attr" , TensorType::get()); |
125 | ASSERT_TRUE(cls->findMethod(method->name())); |
126 | |
127 | // Refining a new class should retain attributes and methods |
128 | auto newCls = cls->refine({TensorType::get()}); |
129 | ASSERT_TRUE(newCls->hasAttribute("attr" )); |
130 | ASSERT_TRUE(newCls->findMethod(method->name())); |
131 | |
132 | auto newCls2 = cls->withContained({TensorType::get()})->expect<ClassType>(); |
133 | ASSERT_TRUE(newCls2->hasAttribute("attr" )); |
134 | ASSERT_TRUE(newCls2->findMethod(method->name())); |
135 | } |
136 | |
137 | static constexpr c10::string_view torchbindSrc = R"JIT( |
138 | class FooBar1234(Module): |
139 | __parameters__ = [] |
140 | f : __torch__.torch.classes._TorchScriptTesting._StackString |
141 | training : bool |
142 | def forward(self: __torch__.FooBar1234) -> str: |
143 | return (self.f).top() |
144 | )JIT" ; |
145 | |
146 | TEST(ClassImportTest, CustomClass) { |
147 | auto cu1 = std::make_shared<CompilationUnit>(); |
148 | std::vector<at::IValue> constantTable; |
149 | // Import different versions of FooTest into two namespaces. |
150 | import_libs( |
151 | cu1, |
152 | "__torch__.FooBar1234" , |
153 | std::make_shared<Source>(torchbindSrc), |
154 | constantTable); |
155 | } |
156 | |
157 | } // namespace jit |
158 | } // namespace torch |
159 | |