1#include <gtest/gtest.h>
2
3#include <ATen/core/qualified_name.h>
4#include <test/cpp/jit/test_utils.h>
5#include <torch/csrc/jit/frontend/resolver.h>
6#include <torch/csrc/jit/serialization/import_source.h>
7#include <torch/torch.h>
8
9namespace torch {
10namespace jit {
11
12static constexpr c10::string_view classSrcs1 = R"JIT(
13class FooNestedTest:
14 def __init__(self, y):
15 self.y = y
16
17class FooNestedTest2:
18 def __init__(self, y):
19 self.y = y
20 self.nested = __torch__.FooNestedTest(y)
21
22class FooTest:
23 def __init__(self, x):
24 self.class_attr = __torch__.FooNestedTest(x)
25 self.class_attr2 = __torch__.FooNestedTest2(x)
26 self.x = self.class_attr.y + self.class_attr2.y
27)JIT";
28
29static constexpr c10::string_view classSrcs2 = R"JIT(
30class FooTest:
31 def __init__(self, x):
32 self.dx = x
33)JIT";
34
35static void import_libs(
36 std::shared_ptr<CompilationUnit> cu,
37 const std::string& class_name,
38 const std::shared_ptr<Source>& src,
39 const std::vector<at::IValue>& tensor_table) {
40 SourceImporter si(
41 cu,
42 &tensor_table,
43 [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
44 /*version=*/2);
45 si.loadType(QualifiedName(class_name));
46}
47
48TEST(ClassImportTest, Basic) {
49 auto cu1 = std::make_shared<CompilationUnit>();
50 auto cu2 = std::make_shared<CompilationUnit>();
51 std::vector<at::IValue> constantTable;
52 // Import different versions of FooTest into two namespaces.
53 import_libs(
54 cu1,
55 "__torch__.FooTest",
56 std::make_shared<Source>(classSrcs1),
57 constantTable);
58 import_libs(
59 cu2,
60 "__torch__.FooTest",
61 std::make_shared<Source>(classSrcs2),
62 constantTable);
63
64 // We should get the correct version of `FooTest` for whichever namespace we
65 // are referencing
66 c10::QualifiedName base("__torch__");
67 auto classType1 = cu1->get_class(c10::QualifiedName(base, "FooTest"));
68 ASSERT_TRUE(classType1->hasAttribute("x"));
69 ASSERT_FALSE(classType1->hasAttribute("dx"));
70
71 auto classType2 = cu2->get_class(c10::QualifiedName(base, "FooTest"));
72 ASSERT_TRUE(classType2->hasAttribute("dx"));
73 ASSERT_FALSE(classType2->hasAttribute("x"));
74
75 // We should only see FooNestedTest in the first namespace
76 auto c = cu1->get_class(c10::QualifiedName(base, "FooNestedTest"));
77 ASSERT_TRUE(c);
78
79 c = cu2->get_class(c10::QualifiedName(base, "FooNestedTest"));
80 ASSERT_FALSE(c);
81}
82
83TEST(ClassImportTest, ScriptObject) {
84 Module m1("m1");
85 Module m2("m2");
86 std::vector<at::IValue> constantTable;
87 import_libs(
88 m1._ivalue()->compilation_unit(),
89 "__torch__.FooTest",
90 std::make_shared<Source>(classSrcs1),
91 constantTable);
92 import_libs(
93 m2._ivalue()->compilation_unit(),
94 "__torch__.FooTest",
95 std::make_shared<Source>(classSrcs2),
96 constantTable);
97
98 // Incorrect arguments for constructor should throw
99 c10::QualifiedName base("__torch__");
100 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
101 ASSERT_ANY_THROW(m1.create_class(c10::QualifiedName(base, "FooTest"), {1}));
102 auto x = torch::ones({2, 3});
103 auto obj = m2.create_class(c10::QualifiedName(base, "FooTest"), x).toObject();
104 auto dx = obj->getAttr("dx");
105 ASSERT_TRUE(almostEqual(x, dx.toTensor()));
106
107 auto new_x = torch::rand({2, 3});
108 obj->setAttr("dx", new_x);
109 auto new_dx = obj->getAttr("dx");
110 ASSERT_TRUE(almostEqual(new_x, new_dx.toTensor()));
111}
112
113static const auto methodSrc = R"JIT(
114def __init__(self, x):
115 return x
116)JIT";
117
118TEST(ClassImportTest, ClassDerive) {
119 auto cu = std::make_shared<CompilationUnit>();
120 auto cls = ClassType::create("foo.bar", cu);
121 const auto self = SimpleSelf(cls);
122 auto methods = cu->define("foo.bar", methodSrc, nativeResolver(), &self);
123 auto method = methods[0];
124 cls->addAttribute("attr", TensorType::get());
125 ASSERT_TRUE(cls->findMethod(method->name()));
126
127 // Refining a new class should retain attributes and methods
128 auto newCls = cls->refine({TensorType::get()});
129 ASSERT_TRUE(newCls->hasAttribute("attr"));
130 ASSERT_TRUE(newCls->findMethod(method->name()));
131
132 auto newCls2 = cls->withContained({TensorType::get()})->expect<ClassType>();
133 ASSERT_TRUE(newCls2->hasAttribute("attr"));
134 ASSERT_TRUE(newCls2->findMethod(method->name()));
135}
136
137static constexpr c10::string_view torchbindSrc = R"JIT(
138class FooBar1234(Module):
139 __parameters__ = []
140 f : __torch__.torch.classes._TorchScriptTesting._StackString
141 training : bool
142 def forward(self: __torch__.FooBar1234) -> str:
143 return (self.f).top()
144)JIT";
145
146TEST(ClassImportTest, CustomClass) {
147 auto cu1 = std::make_shared<CompilationUnit>();
148 std::vector<at::IValue> constantTable;
149 // Import different versions of FooTest into two namespaces.
150 import_libs(
151 cu1,
152 "__torch__.FooBar1234",
153 std::make_shared<Source>(torchbindSrc),
154 constantTable);
155}
156
157} // namespace jit
158} // namespace torch
159