1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4
5#include <ATen/core/qualified_name.h>
6#include <torch/csrc/jit/frontend/resolver.h>
7#include <torch/csrc/jit/serialization/import.h>
8#include <torch/csrc/jit/serialization/import_source.h>
9#include <torch/torch.h>
10
11namespace torch {
12namespace jit {
13
14static const std::vector<std::string> subMethodSrcs = {R"JIT(
15def one(self, x: Tensor, y: Tensor) -> Tensor:
16 return x + y + 1
17
18def forward(self, x: Tensor) -> Tensor:
19 return x
20)JIT"};
21static const std::string parentForward = R"JIT(
22def forward(self, x: Tensor) -> Tensor:
23 return self.subMod.forward(x)
24)JIT";
25
26static constexpr c10::string_view moduleInterfaceSrc = R"JIT(
27class OneForward(ModuleInterface):
28 def one(self, x: Tensor, y: Tensor) -> Tensor:
29 pass
30 def forward(self, x: Tensor) -> Tensor:
31 pass
32)JIT";
33
34static void import_libs(
35 std::shared_ptr<CompilationUnit> cu,
36 const std::string& class_name,
37 const std::shared_ptr<Source>& src,
38 const std::vector<at::IValue>& tensor_table) {
39 SourceImporter si(
40 cu,
41 &tensor_table,
42 [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
43 /*version=*/2);
44 si.loadType(QualifiedName(class_name));
45}
46
47TEST(InterfaceTest, ModuleInterfaceSerialization) {
48 auto cu = std::make_shared<CompilationUnit>();
49 Module parentMod("parentMod", cu);
50 Module subMod("subMod", cu);
51
52 std::vector<at::IValue> constantTable;
53 import_libs(
54 cu,
55 "__torch__.OneForward",
56 std::make_shared<Source>(moduleInterfaceSrc),
57 constantTable);
58
59 for (const std::string& method : subMethodSrcs) {
60 subMod.define(method, nativeResolver());
61 }
62 parentMod.register_attribute(
63 "subMod",
64 cu->get_interface("__torch__.OneForward"),
65 subMod._ivalue(),
66 // NOLINTNEXTLINE(bugprone-argument-comment)
67 /*is_parameter=*/false);
68 parentMod.define(parentForward, nativeResolver());
69 ASSERT_TRUE(parentMod.hasattr("subMod"));
70 std::stringstream ss;
71 parentMod.save(ss);
72 Module reloaded_mod = jit::load(ss);
73 ASSERT_TRUE(reloaded_mod.hasattr("subMod"));
74 InterfaceTypePtr submodType =
75 reloaded_mod.type()->getAttribute("subMod")->cast<InterfaceType>();
76 ASSERT_TRUE(submodType->is_module());
77}
78
79} // namespace jit
80} // namespace torch
81