1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <iostream> |
5 | #include <sstream> |
6 | |
7 | #include <caffe2/serialize/inline_container.h> |
8 | #include <torch/csrc/jit/mobile/module.h> |
9 | #include <torch/csrc/jit/runtime/calculate_necessary_args.h> |
10 | #include <torch/csrc/jit/serialization/export.h> |
11 | #include <torch/csrc/jit/serialization/export_bytecode.h> |
12 | #include <torch/csrc/jit/serialization/import.h> |
13 | #include <torch/csrc/jit/serialization/import_source.h> |
14 | #include <torch/torch.h> |
15 | |
16 | #include "caffe2/serialize/istream_adapter.h" |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | |
21 | namespace { |
22 | |
23 | Module roundtripThroughMobile(const Module& m) { |
24 | ExtraFilesMap files; |
25 | std::vector<IValue> constants; |
26 | jitModuleToPythonCodeAndConstants(m, &files, &constants); |
27 | CompilationOptions options; |
28 | mobile::Module mobilem = jitModuleToMobile(m, options); |
29 | return jitModuleFromSourceAndConstants( |
30 | mobilem._ivalue(), files, constants, 8); |
31 | } |
32 | |
33 | template <class Functor> |
34 | inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) { |
35 | try { |
36 | std::forward<Functor>(functor)(); |
37 | } catch (const Error& e) { |
38 | EXPECT_STREQ(e.what_without_backtrace(), expectedMessage); |
39 | return; |
40 | } |
41 | ADD_FAILURE() << "Expected to throw exception with message \"" |
42 | << expectedMessage << "\" but didn't throw" ; |
43 | } |
44 | |
45 | } // namespace |
46 | |
47 | TEST(SerializationTest, ExtraFilesHookPreference) { |
48 | // Tests that an extra file written explicitly has precedence over |
49 | // extra files written by a hook |
50 | // TODO: test for the warning, too |
51 | const auto script = R"JIT( |
52 | def forward(self): |
53 | x = torch.rand(5, 5) |
54 | x = x.mm(x) |
55 | return x |
56 | )JIT" ; |
57 | |
58 | auto module = |
59 | std::make_shared<Module>("Module" , std::make_shared<CompilationUnit>()); |
60 | module->define(script); |
61 | std::ostringstream oss; |
62 | std::unordered_map<std::string, std::string> ; |
63 | extra_files["metadata.json" ] = "abc" ; |
64 | SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { |
65 | return {{"metadata.json" , "def" }}; |
66 | }); |
67 | module->save(oss, extra_files); |
68 | SetExportModuleExtraFilesHook(nullptr); |
69 | |
70 | std::istringstream iss(oss.str()); |
71 | caffe2::serialize::IStreamAdapter adapter{&iss}; |
72 | std::unordered_map<std::string, std::string> ; |
73 | loaded_extra_files["metadata.json" ] = "" ; |
74 | auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files); |
75 | ASSERT_EQ(loaded_extra_files["metadata.json" ], "abc" ); |
76 | } |
77 | |
78 | TEST(SerializationTest, ExtraFileHooksNoSecret) { |
79 | // no secrets |
80 | std::stringstream ss; |
81 | { |
82 | Module m("__torch__.m" ); |
83 | ExtraFilesMap ; |
84 | extra["metadata.json" ] = "abc" ; |
85 | m.save(ss, extra); |
86 | } |
87 | ss.seekg(0); |
88 | { |
89 | ExtraFilesMap ; |
90 | extra["metadata.json" ] = "" ; |
91 | extra["secret.json" ] = "" ; |
92 | jit::load(ss, c10::nullopt, extra); |
93 | ASSERT_EQ(extra["metadata.json" ], "abc" ); |
94 | ASSERT_EQ(extra["secret.json" ], "" ); |
95 | } |
96 | } |
97 | |
98 | TEST(SerializationTest, ExtraFileHooksWithSecret) { |
99 | std::stringstream ss; |
100 | { |
101 | SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { |
102 | return {{"secret.json" , "topsecret" }}; |
103 | }); |
104 | Module m("__torch__.m" ); |
105 | ExtraFilesMap ; |
106 | extra["metadata.json" ] = "abc" ; |
107 | m.save(ss, extra); |
108 | SetExportModuleExtraFilesHook(nullptr); |
109 | } |
110 | ss.seekg(0); |
111 | { |
112 | ExtraFilesMap ; |
113 | extra["metadata.json" ] = "" ; |
114 | extra["secret.json" ] = "" ; |
115 | jit::load(ss, c10::nullopt, extra); |
116 | ASSERT_EQ(extra["metadata.json" ], "abc" ); |
117 | ASSERT_EQ(extra["secret.json" ], "topsecret" ); |
118 | } |
119 | } |
120 | |
121 | TEST(SerializationTest, TypeTags) { |
122 | auto list = c10::List<c10::List<int64_t>>(); |
123 | list.push_back(c10::List<int64_t>({1, 2, 3})); |
124 | list.push_back(c10::List<int64_t>({4, 5, 6})); |
125 | auto dict = c10::Dict<std::string, at::Tensor>(); |
126 | dict.insert("Hello" , torch::ones({2, 2})); |
127 | auto dict_list = c10::List<c10::Dict<std::string, at::Tensor>>(); |
128 | for (size_t i = 0; i < 5; i++) { |
129 | auto another_dict = c10::Dict<std::string, at::Tensor>(); |
130 | another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2})); |
131 | dict_list.push_back(another_dict); |
132 | } |
133 | auto tuple = std::tuple<int, std::string>(2, "hi" ); |
134 | struct TestItem { |
135 | IValue value; |
136 | TypePtr expected_type; |
137 | }; |
138 | std::vector<TestItem> items = { |
139 | {list, ListType::create(ListType::create(IntType::get()))}, |
140 | {2, IntType::get()}, |
141 | {dict, DictType::create(StringType::get(), TensorType::get())}, |
142 | {dict_list, |
143 | ListType::create( |
144 | DictType::create(StringType::get(), TensorType::get()))}, |
145 | {tuple, TupleType::create({IntType::get(), StringType::get()})}}; |
146 | // NOLINTNEXTLINE(performance-for-range-copy) |
147 | for (auto item : items) { |
148 | auto bytes = torch::pickle_save(item.value); |
149 | auto loaded = torch::pickle_load(bytes); |
150 | ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type)); |
151 | ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type())); |
152 | } |
153 | } |
154 | |
155 | TEST(SerializationTest, TestJitStream_CUDA) { |
156 | torch::jit::Module model; |
157 | std::vector<torch::jit::IValue> inputs; |
158 | // Deserialize the ScriptModule from a file using torch::jit::load(). |
159 | // Load the scripted model. This should have been generated by tests_setup.py |
160 | // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py |
161 | model = torch::jit::load("saved_stream_model.pt" ); |
162 | |
163 | auto output = model.forward(inputs); |
164 | const auto& list_of_elements = output.toTupleRef().elements(); |
165 | auto is_stream_s = list_of_elements[0].toBool(); |
166 | |
167 | // a,b: These are the two input tensors |
168 | // c: This is output tensor generated by the operation torch.cat(a,b) |
169 | auto a = list_of_elements[1].toTensor(); |
170 | auto b = list_of_elements[2].toTensor(); |
171 | auto c = list_of_elements[3].toTensor(); |
172 | // op: this is used to verify if the cat operation produced the same results |
173 | // as that on the GPU with torch.cat |
174 | auto op = at::cat({a, b}, 0); |
175 | |
176 | // Check if the stream is set |
177 | ASSERT_TRUE(is_stream_s); |
178 | // Check if the sizes of the outputs (op and c) is same on the GPU and CPU |
179 | ASSERT_EQ(op.sizes(), c.sizes()); |
180 | // Check if both the output tensors are equal |
181 | ASSERT_TRUE(op.equal(c)); |
182 | } |
183 | |
184 | TEST(TestSourceRoundTrip, UpsampleNearest2d) { |
185 | Module m("m" ); |
186 | m.define(R"( |
187 | def forward(self, input: Tensor, scale:float): |
188 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
189 | )" ); |
190 | |
191 | std::vector<IValue> inputs; |
192 | inputs.emplace_back(torch::rand({1, 3, 128, 128})); |
193 | inputs.emplace_back(at::Scalar(2.0)); |
194 | auto ref = m.forward(inputs); |
195 | |
196 | Module m2 = roundtripThroughMobile(m); |
197 | auto res = m2.forward(inputs); |
198 | |
199 | auto resd = res.toTensor(); |
200 | auto refd = ref.toTensor(); |
201 | ASSERT_TRUE(resd.equal(refd)); |
202 | } |
203 | |
204 | TEST(TestSourceRoundTrip, CheckAttrAccess) { |
205 | Module m("m" ); |
206 | m.register_attribute("mobile_optimized" , BoolType::get(), true); |
207 | Module m2 = roundtripThroughMobile(m); |
208 | bool mobile_optimized = m2.attr("mobile_optimized" , false).toBool(); |
209 | AT_ASSERT(mobile_optimized); |
210 | } |
211 | |
212 | TEST(TestSourceRoundTrip, |
213 | MethodInvocation) { // NOLINT (use =delete in gtest) |
214 | const std::vector<std::string> test_programs{ |
215 | // test invoking a method with default parameter |
216 | R"( |
217 | def test_func(self, x, b : int = 4): |
218 | return self.foo + x + b |
219 | )" , |
220 | // inner method call with default parameter (gets inlined) |
221 | R"( |
222 | def add_with_default_arg(self, x, b : int = 4): |
223 | return self.foo + x + b |
224 | def test_func(self, x): |
225 | return self.add_with_default_arg(x) # invoke method w/ default arg |
226 | )" , |
227 | // simple method call |
228 | R"( |
229 | def test_func(self, x): |
230 | b = 4 |
231 | return self.foo + x + b |
232 | )" , |
233 | }; |
234 | for (const auto& test_program : test_programs) { |
235 | Module m("m" ); |
236 | m.register_parameter("foo" , torch::ones({}), false); |
237 | m.define(test_program); |
238 | |
239 | const int fortyTwo = 42; // (keep linter happy) |
240 | auto minput = fortyTwo * torch::ones({}); |
241 | auto ref = m.run_method("test_func" , minput); |
242 | |
243 | Module m2 = roundtripThroughMobile(m); |
244 | const auto& test_func = m2.get_method("test_func" ); |
245 | IValue res; |
246 | for (int i = 0; i < 3; ++i) { |
247 | res = test_func({minput}); |
248 | } |
249 | |
250 | auto resd = res.toTensor().item<float>(); |
251 | auto refd = ref.toTensor().item<float>(); |
252 | AT_ASSERT(resd == refd); |
253 | } |
254 | } |
255 | |
256 | TEST(SerializationTest, ParentDirNotExist) { |
257 | expectThrowsEq( |
258 | []() { |
259 | auto t = torch::nn::Linear(5, 5); |
260 | torch::save(t, "./doesnotexist/file.pt" ); |
261 | }, |
262 | "Parent directory ./doesnotexist does not exist." ); |
263 | } |
264 | |
265 | TEST(SerializationTest, CalculateNecessaryArgsTest) { |
266 | auto schema = torch::schema( |
267 | "sync_stream(int stream_id = -1) -> ()" , |
268 | c10::AliasAnalysisKind::CONSERVATIVE); |
269 | |
270 | auto graph = std::make_shared<Graph>(); |
271 | auto one_val = graph->insertConstant(-1); |
272 | auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true); |
273 | EXPECT_EQ(0, necessary.first); |
274 | EXPECT_EQ(0, necessary.second); |
275 | } |
276 | |
277 | TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest) |
278 | Module m("m" ); |
279 | m.register_parameter("foo" , torch::ones({}), false); |
280 | m.define( |
281 | R"( |
282 | def test_func(self, x): |
283 | b = 4 |
284 | return self.foo + x + b |
285 | )" ); |
286 | m.define( |
287 | R"( |
288 | def exception(self): |
289 | assert False, "message" |
290 | )" ); |
291 | std::stringstream ss; |
292 | m.save(ss); |
293 | ss.seekg(0); |
294 | caffe2::serialize::PyTorchStreamReader reader(&ss); |
295 | reader.setShouldLoadDebugSymbol(true); |
296 | EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl" )); |
297 | reader.setShouldLoadDebugSymbol(false); |
298 | EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl" )); |
299 | ss.seekg(0); |
300 | Module m2 = torch::jit::load(ss); |
301 | std::string error_msg = R"( |
302 | def exception(self): |
303 | assert False, "message" |
304 | ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)" ; |
305 | ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception" ), error_msg); |
306 | |
307 | ss.seekg(0); |
308 | // NO DEBUG trace so error message points to torchscript generated |
309 | // source instead of original python source. |
310 | std::string error2 = R"( |
311 | def exception(self: __torch__.m) -> NoneType: |
312 | _0 = uninitialized(NoneType) |
313 | ops.prim.RaiseException("AssertionError: message") |
314 | ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
315 | return _0 |
316 | )" ; |
317 | Module m3 = torch::jit::load(ss, c10::nullopt, false); |
318 | ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception" ), error2); |
319 | } |
320 | |
321 | } // namespace jit |
322 | } // namespace torch |
323 | |