1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_custom_class_registrations.h>
4#include <torch/csrc/jit/passes/freeze_module.h>
5#include <torch/custom_class.h>
6#include <torch/script.h>
7
8#include <iostream>
9#include <string>
10#include <vector>
11
12namespace torch {
13namespace jit {
14
15TEST(CustomClassTest, TorchbindIValueAPI) {
16 script::Module m("m");
17
18 // test make_custom_class API
19 auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
20 std::vector<std::string>{"foo", "bar"});
21 m.define(R"(
22 def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
23 return s.pop(), s
24 )");
25
26 auto test_with_obj = [&m](IValue obj, std::string expected) {
27 auto res = m.run_method("forward", obj);
28 auto tup = res.toTuple();
29 AT_ASSERT(tup->elements().size() == 2);
30 auto str = tup->elements()[0].toStringRef();
31 auto other_obj =
32 tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
33 AT_ASSERT(str == expected);
34 auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
35 AT_ASSERT(other_obj.get() == ref_obj.get());
36 };
37
38 test_with_obj(custom_class_obj, "bar");
39
40 // test IValue() API
41 auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
42 std::vector<std::string>{"baz", "boo"});
43 auto new_stack_ivalue = c10::IValue(my_new_stack);
44
45 test_with_obj(new_stack_ivalue, "boo");
46}
47
48TEST(CustomClassTest, ScalarTypeClass) {
49 script::Module m("m");
50
51 // test make_custom_class API
52 auto cc = make_custom_class<ScalarTypeClass>(at::kFloat);
53 m.register_attribute("s", cc.type(), cc, false);
54
55 std::ostringstream oss;
56 m.save(oss);
57 std::istringstream iss(oss.str());
58 caffe2::serialize::IStreamAdapter adapter{&iss};
59 auto loaded_module = torch::jit::load(iss, torch::kCPU);
60}
61
62class TorchBindTestClass : public torch::jit::CustomClassHolder {
63 public:
64 std::string get() {
65 return "Hello, I am your test custom class";
66 }
67};
68
69// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
70constexpr char class_doc_string[] = R"(
71 I am docstring for TorchBindTestClass
72 Args:
73 What is an argument? Oh never mind, I don't take any.
74
75 Return:
76 How would I know? I am just a holder of some meaningless test methods.
77 )";
78// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
79constexpr char method_doc_string[] =
80 "I am docstring for TorchBindTestClass get_with_docstring method";
81
82namespace {
83static auto reg =
84 torch::class_<TorchBindTestClass>(
85 "_TorchBindTest",
86 "_TorchBindTestClass",
87 class_doc_string)
88 .def("get", &TorchBindTestClass::get)
89 .def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);
90
91} // namespace
92
93// Tests DocString is properly propagated when defining CustomClasses.
94TEST(CustomClassTest, TestDocString) {
95 auto class_type = getCustomClass(
96 "__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
97 AT_ASSERT(class_type);
98 AT_ASSERT(class_type->doc_string() == class_doc_string);
99
100 AT_ASSERT(class_type->getMethod("get").doc_string().empty());
101 AT_ASSERT(
102 class_type->getMethod("get_with_docstring").doc_string() ==
103 method_doc_string);
104}
105
106TEST(CustomClassTest, Serialization) {
107 script::Module m("m");
108
109 // test make_custom_class API
110 auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
111 std::vector<std::string>{"foo", "bar"});
112 m.register_attribute(
113 "s",
114 custom_class_obj.type(),
115 custom_class_obj,
116 // NOLINTNEXTLINE(bugprone-argument-comment)
117 /*is_parameter=*/false);
118 m.define(R"(
119 def forward(self):
120 return self.s.return_a_tuple()
121 )");
122
123 auto test_with_obj = [](script::Module& mod) {
124 auto res = mod.run_method("forward");
125 auto tup = res.toTuple();
126 AT_ASSERT(tup->elements().size() == 2);
127 auto i = tup->elements()[1].toInt();
128 AT_ASSERT(i == 123);
129 };
130
131 auto frozen_m = torch::jit::freeze_module(m.clone());
132
133 test_with_obj(m);
134 test_with_obj(frozen_m);
135
136 std::ostringstream oss;
137 m.save(oss);
138 std::istringstream iss(oss.str());
139 caffe2::serialize::IStreamAdapter adapter{&iss};
140 auto loaded_module = torch::jit::load(iss, torch::kCPU);
141
142 std::ostringstream oss_frozen;
143 frozen_m.save(oss_frozen);
144 std::istringstream iss_frozen(oss_frozen.str());
145 caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
146 auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
147}
148
149} // namespace jit
150} // namespace torch
151