1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_custom_class_registrations.h> |
4 | #include <torch/csrc/jit/passes/freeze_module.h> |
5 | #include <torch/custom_class.h> |
6 | #include <torch/script.h> |
7 | |
8 | #include <iostream> |
9 | #include <string> |
10 | #include <vector> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | TEST(CustomClassTest, TorchbindIValueAPI) { |
16 | script::Module m("m" ); |
17 | |
18 | // test make_custom_class API |
19 | auto custom_class_obj = make_custom_class<MyStackClass<std::string>>( |
20 | std::vector<std::string>{"foo" , "bar" }); |
21 | m.define(R"( |
22 | def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString): |
23 | return s.pop(), s |
24 | )" ); |
25 | |
26 | auto test_with_obj = [&m](IValue obj, std::string expected) { |
27 | auto res = m.run_method("forward" , obj); |
28 | auto tup = res.toTuple(); |
29 | AT_ASSERT(tup->elements().size() == 2); |
30 | auto str = tup->elements()[0].toStringRef(); |
31 | auto other_obj = |
32 | tup->elements()[1].toCustomClass<MyStackClass<std::string>>(); |
33 | AT_ASSERT(str == expected); |
34 | auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>(); |
35 | AT_ASSERT(other_obj.get() == ref_obj.get()); |
36 | }; |
37 | |
38 | test_with_obj(custom_class_obj, "bar" ); |
39 | |
40 | // test IValue() API |
41 | auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>( |
42 | std::vector<std::string>{"baz" , "boo" }); |
43 | auto new_stack_ivalue = c10::IValue(my_new_stack); |
44 | |
45 | test_with_obj(new_stack_ivalue, "boo" ); |
46 | } |
47 | |
48 | TEST(CustomClassTest, ScalarTypeClass) { |
49 | script::Module m("m" ); |
50 | |
51 | // test make_custom_class API |
52 | auto cc = make_custom_class<ScalarTypeClass>(at::kFloat); |
53 | m.register_attribute("s" , cc.type(), cc, false); |
54 | |
55 | std::ostringstream oss; |
56 | m.save(oss); |
57 | std::istringstream iss(oss.str()); |
58 | caffe2::serialize::IStreamAdapter adapter{&iss}; |
59 | auto loaded_module = torch::jit::load(iss, torch::kCPU); |
60 | } |
61 | |
62 | class TorchBindTestClass : public torch::jit::CustomClassHolder { |
63 | public: |
64 | std::string get() { |
65 | return "Hello, I am your test custom class" ; |
66 | } |
67 | }; |
68 | |
69 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
70 | constexpr char class_doc_string[] = R"( |
71 | I am docstring for TorchBindTestClass |
72 | Args: |
73 | What is an argument? Oh never mind, I don't take any. |
74 | |
75 | Return: |
76 | How would I know? I am just a holder of some meaningless test methods. |
77 | )" ; |
78 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
79 | constexpr char method_doc_string[] = |
80 | "I am docstring for TorchBindTestClass get_with_docstring method" ; |
81 | |
82 | namespace { |
83 | static auto reg = |
84 | torch::class_<TorchBindTestClass>( |
85 | "_TorchBindTest" , |
86 | "_TorchBindTestClass" , |
87 | class_doc_string) |
88 | .def("get" , &TorchBindTestClass::get) |
89 | .def("get_with_docstring" , &TorchBindTestClass::get, method_doc_string); |
90 | |
91 | } // namespace |
92 | |
93 | // Tests DocString is properly propagated when defining CustomClasses. |
94 | TEST(CustomClassTest, TestDocString) { |
95 | auto class_type = getCustomClass( |
96 | "__torch__.torch.classes._TorchBindTest._TorchBindTestClass" ); |
97 | AT_ASSERT(class_type); |
98 | AT_ASSERT(class_type->doc_string() == class_doc_string); |
99 | |
100 | AT_ASSERT(class_type->getMethod("get" ).doc_string().empty()); |
101 | AT_ASSERT( |
102 | class_type->getMethod("get_with_docstring" ).doc_string() == |
103 | method_doc_string); |
104 | } |
105 | |
106 | TEST(CustomClassTest, Serialization) { |
107 | script::Module m("m" ); |
108 | |
109 | // test make_custom_class API |
110 | auto custom_class_obj = make_custom_class<MyStackClass<std::string>>( |
111 | std::vector<std::string>{"foo" , "bar" }); |
112 | m.register_attribute( |
113 | "s" , |
114 | custom_class_obj.type(), |
115 | custom_class_obj, |
116 | // NOLINTNEXTLINE(bugprone-argument-comment) |
117 | /*is_parameter=*/false); |
118 | m.define(R"( |
119 | def forward(self): |
120 | return self.s.return_a_tuple() |
121 | )" ); |
122 | |
123 | auto test_with_obj = [](script::Module& mod) { |
124 | auto res = mod.run_method("forward" ); |
125 | auto tup = res.toTuple(); |
126 | AT_ASSERT(tup->elements().size() == 2); |
127 | auto i = tup->elements()[1].toInt(); |
128 | AT_ASSERT(i == 123); |
129 | }; |
130 | |
131 | auto frozen_m = torch::jit::freeze_module(m.clone()); |
132 | |
133 | test_with_obj(m); |
134 | test_with_obj(frozen_m); |
135 | |
136 | std::ostringstream oss; |
137 | m.save(oss); |
138 | std::istringstream iss(oss.str()); |
139 | caffe2::serialize::IStreamAdapter adapter{&iss}; |
140 | auto loaded_module = torch::jit::load(iss, torch::kCPU); |
141 | |
142 | std::ostringstream oss_frozen; |
143 | frozen_m.save(oss_frozen); |
144 | std::istringstream iss_frozen(oss_frozen.str()); |
145 | caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen}; |
146 | auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU); |
147 | } |
148 | |
149 | } // namespace jit |
150 | } // namespace torch |
151 | |