1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <iostream>
5#include <sstream>
6
7#include <caffe2/serialize/inline_container.h>
8#include <torch/csrc/jit/mobile/module.h>
9#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
10#include <torch/csrc/jit/serialization/export.h>
11#include <torch/csrc/jit/serialization/export_bytecode.h>
12#include <torch/csrc/jit/serialization/import.h>
13#include <torch/csrc/jit/serialization/import_source.h>
14#include <torch/torch.h>
15
16#include "caffe2/serialize/istream_adapter.h"
17
18namespace torch {
19namespace jit {
20
21namespace {
22
23Module roundtripThroughMobile(const Module& m) {
24 ExtraFilesMap files;
25 std::vector<IValue> constants;
26 jitModuleToPythonCodeAndConstants(m, &files, &constants);
27 CompilationOptions options;
28 mobile::Module mobilem = jitModuleToMobile(m, options);
29 return jitModuleFromSourceAndConstants(
30 mobilem._ivalue(), files, constants, 8);
31}
32
33template <class Functor>
34inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
35 try {
36 std::forward<Functor>(functor)();
37 } catch (const Error& e) {
38 EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
39 return;
40 }
41 ADD_FAILURE() << "Expected to throw exception with message \""
42 << expectedMessage << "\" but didn't throw";
43}
44
45} // namespace
46
47TEST(SerializationTest, ExtraFilesHookPreference) {
48 // Tests that an extra file written explicitly has precedence over
49 // extra files written by a hook
50 // TODO: test for the warning, too
51 const auto script = R"JIT(
52 def forward(self):
53 x = torch.rand(5, 5)
54 x = x.mm(x)
55 return x
56 )JIT";
57
58 auto module =
59 std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
60 module->define(script);
61 std::ostringstream oss;
62 std::unordered_map<std::string, std::string> extra_files;
63 extra_files["metadata.json"] = "abc";
64 SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
65 return {{"metadata.json", "def"}};
66 });
67 module->save(oss, extra_files);
68 SetExportModuleExtraFilesHook(nullptr);
69
70 std::istringstream iss(oss.str());
71 caffe2::serialize::IStreamAdapter adapter{&iss};
72 std::unordered_map<std::string, std::string> loaded_extra_files;
73 loaded_extra_files["metadata.json"] = "";
74 auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files);
75 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
76}
77
78TEST(SerializationTest, ExtraFileHooksNoSecret) {
79 // no secrets
80 std::stringstream ss;
81 {
82 Module m("__torch__.m");
83 ExtraFilesMap extra;
84 extra["metadata.json"] = "abc";
85 m.save(ss, extra);
86 }
87 ss.seekg(0);
88 {
89 ExtraFilesMap extra;
90 extra["metadata.json"] = "";
91 extra["secret.json"] = "";
92 jit::load(ss, c10::nullopt, extra);
93 ASSERT_EQ(extra["metadata.json"], "abc");
94 ASSERT_EQ(extra["secret.json"], "");
95 }
96}
97
98TEST(SerializationTest, ExtraFileHooksWithSecret) {
99 std::stringstream ss;
100 {
101 SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
102 return {{"secret.json", "topsecret"}};
103 });
104 Module m("__torch__.m");
105 ExtraFilesMap extra;
106 extra["metadata.json"] = "abc";
107 m.save(ss, extra);
108 SetExportModuleExtraFilesHook(nullptr);
109 }
110 ss.seekg(0);
111 {
112 ExtraFilesMap extra;
113 extra["metadata.json"] = "";
114 extra["secret.json"] = "";
115 jit::load(ss, c10::nullopt, extra);
116 ASSERT_EQ(extra["metadata.json"], "abc");
117 ASSERT_EQ(extra["secret.json"], "topsecret");
118 }
119}
120
121TEST(SerializationTest, TypeTags) {
122 auto list = c10::List<c10::List<int64_t>>();
123 list.push_back(c10::List<int64_t>({1, 2, 3}));
124 list.push_back(c10::List<int64_t>({4, 5, 6}));
125 auto dict = c10::Dict<std::string, at::Tensor>();
126 dict.insert("Hello", torch::ones({2, 2}));
127 auto dict_list = c10::List<c10::Dict<std::string, at::Tensor>>();
128 for (size_t i = 0; i < 5; i++) {
129 auto another_dict = c10::Dict<std::string, at::Tensor>();
130 another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2}));
131 dict_list.push_back(another_dict);
132 }
133 auto tuple = std::tuple<int, std::string>(2, "hi");
134 struct TestItem {
135 IValue value;
136 TypePtr expected_type;
137 };
138 std::vector<TestItem> items = {
139 {list, ListType::create(ListType::create(IntType::get()))},
140 {2, IntType::get()},
141 {dict, DictType::create(StringType::get(), TensorType::get())},
142 {dict_list,
143 ListType::create(
144 DictType::create(StringType::get(), TensorType::get()))},
145 {tuple, TupleType::create({IntType::get(), StringType::get()})}};
146 // NOLINTNEXTLINE(performance-for-range-copy)
147 for (auto item : items) {
148 auto bytes = torch::pickle_save(item.value);
149 auto loaded = torch::pickle_load(bytes);
150 ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type));
151 ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type()));
152 }
153}
154
155TEST(SerializationTest, TestJitStream_CUDA) {
156 torch::jit::Module model;
157 std::vector<torch::jit::IValue> inputs;
158 // Deserialize the ScriptModule from a file using torch::jit::load().
159 // Load the scripted model. This should have been generated by tests_setup.py
160 // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py
161 model = torch::jit::load("saved_stream_model.pt");
162
163 auto output = model.forward(inputs);
164 const auto& list_of_elements = output.toTupleRef().elements();
165 auto is_stream_s = list_of_elements[0].toBool();
166
167 // a,b: These are the two input tensors
168 // c: This is output tensor generated by the operation torch.cat(a,b)
169 auto a = list_of_elements[1].toTensor();
170 auto b = list_of_elements[2].toTensor();
171 auto c = list_of_elements[3].toTensor();
172 // op: this is used to verify if the cat operation produced the same results
173 // as that on the GPU with torch.cat
174 auto op = at::cat({a, b}, 0);
175
176 // Check if the stream is set
177 ASSERT_TRUE(is_stream_s);
178 // Check if the sizes of the outputs (op and c) is same on the GPU and CPU
179 ASSERT_EQ(op.sizes(), c.sizes());
180 // Check if both the output tensors are equal
181 ASSERT_TRUE(op.equal(c));
182}
183
184TEST(TestSourceRoundTrip, UpsampleNearest2d) {
185 Module m("m");
186 m.define(R"(
187 def forward(self, input: Tensor, scale:float):
188 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
189 )");
190
191 std::vector<IValue> inputs;
192 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
193 inputs.emplace_back(at::Scalar(2.0));
194 auto ref = m.forward(inputs);
195
196 Module m2 = roundtripThroughMobile(m);
197 auto res = m2.forward(inputs);
198
199 auto resd = res.toTensor();
200 auto refd = ref.toTensor();
201 ASSERT_TRUE(resd.equal(refd));
202}
203
204TEST(TestSourceRoundTrip, CheckAttrAccess) {
205 Module m("m");
206 m.register_attribute("mobile_optimized", BoolType::get(), true);
207 Module m2 = roundtripThroughMobile(m);
208 bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
209 AT_ASSERT(mobile_optimized);
210}
211
212TEST(TestSourceRoundTrip,
213 MethodInvocation) { // NOLINT (use =delete in gtest)
214 const std::vector<std::string> test_programs{
215 // test invoking a method with default parameter
216 R"(
217 def test_func(self, x, b : int = 4):
218 return self.foo + x + b
219 )",
220 // inner method call with default parameter (gets inlined)
221 R"(
222 def add_with_default_arg(self, x, b : int = 4):
223 return self.foo + x + b
224 def test_func(self, x):
225 return self.add_with_default_arg(x) # invoke method w/ default arg
226 )",
227 // simple method call
228 R"(
229 def test_func(self, x):
230 b = 4
231 return self.foo + x + b
232 )",
233 };
234 for (const auto& test_program : test_programs) {
235 Module m("m");
236 m.register_parameter("foo", torch::ones({}), false);
237 m.define(test_program);
238
239 const int fortyTwo = 42; // (keep linter happy)
240 auto minput = fortyTwo * torch::ones({});
241 auto ref = m.run_method("test_func", minput);
242
243 Module m2 = roundtripThroughMobile(m);
244 const auto& test_func = m2.get_method("test_func");
245 IValue res;
246 for (int i = 0; i < 3; ++i) {
247 res = test_func({minput});
248 }
249
250 auto resd = res.toTensor().item<float>();
251 auto refd = ref.toTensor().item<float>();
252 AT_ASSERT(resd == refd);
253 }
254}
255
256TEST(SerializationTest, ParentDirNotExist) {
257 expectThrowsEq(
258 []() {
259 auto t = torch::nn::Linear(5, 5);
260 torch::save(t, "./doesnotexist/file.pt");
261 },
262 "Parent directory ./doesnotexist does not exist.");
263}
264
265TEST(SerializationTest, CalculateNecessaryArgsTest) {
266 auto schema = torch::schema(
267 "sync_stream(int stream_id = -1) -> ()",
268 c10::AliasAnalysisKind::CONSERVATIVE);
269
270 auto graph = std::make_shared<Graph>();
271 auto one_val = graph->insertConstant(-1);
272 auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true);
273 EXPECT_EQ(0, necessary.first);
274 EXPECT_EQ(0, necessary.second);
275}
276
277TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest)
278 Module m("m");
279 m.register_parameter("foo", torch::ones({}), false);
280 m.define(
281 R"(
282 def test_func(self, x):
283 b = 4
284 return self.foo + x + b
285 )");
286 m.define(
287 R"(
288 def exception(self):
289 assert False, "message"
290 )");
291 std::stringstream ss;
292 m.save(ss);
293 ss.seekg(0);
294 caffe2::serialize::PyTorchStreamReader reader(&ss);
295 reader.setShouldLoadDebugSymbol(true);
296 EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl"));
297 reader.setShouldLoadDebugSymbol(false);
298 EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl"));
299 ss.seekg(0);
300 Module m2 = torch::jit::load(ss);
301 std::string error_msg = R"(
302 def exception(self):
303 assert False, "message"
304 ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)";
305 ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception"), error_msg);
306
307 ss.seekg(0);
308 // NO DEBUG trace so error message points to torchscript generated
309 // source instead of original python source.
310 std::string error2 = R"(
311 def exception(self: __torch__.m) -> NoneType:
312 _0 = uninitialized(NoneType)
313 ops.prim.RaiseException("AssertionError: message")
314 ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
315 return _0
316 )";
317 Module m3 = torch::jit::load(ss, c10::nullopt, false);
318 ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2);
319}
320
321} // namespace jit
322} // namespace torch
323