1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | |
5 | #include <ATen/core/qualified_name.h> |
6 | #include <torch/csrc/jit/frontend/resolver.h> |
7 | #include <torch/csrc/jit/serialization/import.h> |
8 | #include <torch/csrc/jit/serialization/import_source.h> |
9 | #include <torch/torch.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | static const std::vector<std::string> subMethodSrcs = {R"JIT( |
15 | def one(self, x: Tensor, y: Tensor) -> Tensor: |
16 | return x + y + 1 |
17 | |
18 | def forward(self, x: Tensor) -> Tensor: |
19 | return x |
20 | )JIT" }; |
21 | static const std::string parentForward = R"JIT( |
22 | def forward(self, x: Tensor) -> Tensor: |
23 | return self.subMod.forward(x) |
24 | )JIT" ; |
25 | |
26 | static constexpr c10::string_view moduleInterfaceSrc = R"JIT( |
27 | class OneForward(ModuleInterface): |
28 | def one(self, x: Tensor, y: Tensor) -> Tensor: |
29 | pass |
30 | def forward(self, x: Tensor) -> Tensor: |
31 | pass |
32 | )JIT" ; |
33 | |
34 | static void import_libs( |
35 | std::shared_ptr<CompilationUnit> cu, |
36 | const std::string& class_name, |
37 | const std::shared_ptr<Source>& src, |
38 | const std::vector<at::IValue>& tensor_table) { |
39 | SourceImporter si( |
40 | cu, |
41 | &tensor_table, |
42 | [&](const std::string& name) -> std::shared_ptr<Source> { return src; }, |
43 | /*version=*/2); |
44 | si.loadType(QualifiedName(class_name)); |
45 | } |
46 | |
47 | TEST(InterfaceTest, ModuleInterfaceSerialization) { |
48 | auto cu = std::make_shared<CompilationUnit>(); |
49 | Module parentMod("parentMod" , cu); |
50 | Module subMod("subMod" , cu); |
51 | |
52 | std::vector<at::IValue> constantTable; |
53 | import_libs( |
54 | cu, |
55 | "__torch__.OneForward" , |
56 | std::make_shared<Source>(moduleInterfaceSrc), |
57 | constantTable); |
58 | |
59 | for (const std::string& method : subMethodSrcs) { |
60 | subMod.define(method, nativeResolver()); |
61 | } |
62 | parentMod.register_attribute( |
63 | "subMod" , |
64 | cu->get_interface("__torch__.OneForward" ), |
65 | subMod._ivalue(), |
66 | // NOLINTNEXTLINE(bugprone-argument-comment) |
67 | /*is_parameter=*/false); |
68 | parentMod.define(parentForward, nativeResolver()); |
69 | ASSERT_TRUE(parentMod.hasattr("subMod" )); |
70 | std::stringstream ss; |
71 | parentMod.save(ss); |
72 | Module reloaded_mod = jit::load(ss); |
73 | ASSERT_TRUE(reloaded_mod.hasattr("subMod" )); |
74 | InterfaceTypePtr submodType = |
75 | reloaded_mod.type()->getAttribute("subMod" )->cast<InterfaceType>(); |
76 | ASSERT_TRUE(submodType->is_module()); |
77 | } |
78 | |
79 | } // namespace jit |
80 | } // namespace torch |
81 | |