1 | #include <gtest/gtest.h> |
2 | #include <test/cpp/jit/test_utils.h> |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/backends/backend_detail.h> |
5 | #include <torch/csrc/jit/mobile/import.h> |
6 | #include <torch/csrc/jit/serialization/import.h> |
7 | #include <torch/torch.h> |
8 | |
9 | // Tests go in torch::jit |
10 | namespace torch { |
11 | namespace jit { |
12 | TEST(BackendTest, ToBackend) { |
13 | Module m("m" ); |
14 | m.define(R"( |
15 | def forward(self, x, h): |
16 | return self.accum(x, h), self.sub_accum(x, h) |
17 | |
18 | def accum(self, x, h): |
19 | return x + h |
20 | |
21 | def sub_accum(self, x, h): |
22 | return x - h |
23 | )" ); |
24 | |
25 | std::vector<IValue> inputs; |
26 | inputs.emplace_back(2.0 * torch::ones({})); |
27 | inputs.emplace_back(1.0 * torch::ones({})); |
28 | auto ref = m.forward(inputs).toTupleRef().elements().vec(); |
29 | |
30 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
31 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
32 | fake_dict.insert("" , "" ); |
33 | compile_spec.insert("forward" , fake_dict); |
34 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
35 | // lowered module |
36 | auto lm = torch::jit::detail::codegen_backend_module( |
37 | "test_backend" , m, compile_spec, any_dict_ty); |
38 | // lowered module code: |
39 | /* |
40 | class test_backendLoweredModule(Module): |
41 | __parameters__ = [] |
42 | __buffers__ = [] |
43 | __processed_module : Any |
44 | __method_compile_spec : Dict[str, Any] |
45 | __backend : __torch__.torch.classes.__backends__.test_backend |
46 | __handles : Dict[str, Any] |
47 | def __create_backend(self: torch.jit.test_backendLoweredModule) -> None: |
48 | _0 = |
49 | __torch__.torch.classes.__backends__.test_backend.__new__(__torch__.torch.classes.__backends__.test_backend) |
50 | _1 = (_0).__init__() |
51 | self.__backend = _0 |
52 | return None |
53 | def __getstate__(self: torch.jit.test_backendLoweredModule) -> |
54 | Tuple[Dict[str, Any], Any]: _2 = (self.__method_compile_spec, |
55 | self.__processed_module) return _2 def __setstate__(self: |
56 | torch.jit.test_backendLoweredModule, state: Tuple[Dict[str, Any], Any]) -> |
57 | None: self.__method_compile_spec = (state)[0] self.__processed_module = |
58 | (state)[1] _3 = (self).__create_backend() _4 = |
59 | (self.__backend).compile(self.__processed_module, |
60 | self.__method_compile_spec, ) self.__handles = _4 return None def |
61 | forward(self: torch.jit.test_backendLoweredModule, x: Tensor, h: Tensor) -> |
62 | Tuple[Tensor, Tensor]: _5 = uninitialized(Tensor) typed_inputs = |
63 | annotate(List[Any], [x, h]) _6 = |
64 | (self.__backend).execute((self.__handles)["forward"], typed_inputs, ) _7, |
65 | _8, = _6 _9 = isinstance(_7, Tensor) if _9: _10 = unchecked_cast(Tensor, _7) |
66 | else: |
67 | ops.prim.RaiseException("AssertionError: ") |
68 | _10 = _5 |
69 | _11 = isinstance(_8, Tensor) |
70 | if _11: |
71 | _12 = unchecked_cast(Tensor, _8) |
72 | else: |
73 | ops.prim.RaiseException("AssertionError: ") |
74 | _12 = _5 |
75 | return (_10, _12) |
76 | |
77 | */ |
78 | auto res = lm.forward(inputs).toTupleRef().elements().vec(); |
79 | AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor())); |
80 | AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor())); |
81 | } |
82 | |
83 | TEST(BackendTest, ToBackendNotAvailable) { |
84 | Module m("m" ); |
85 | m.define(R"( |
86 | def forward(self, x, h): |
87 | return self.accum(x, h), self.sub_accum(x, h) |
88 | |
89 | def accum(self, x, h): |
90 | return x + h |
91 | |
92 | def sub_accum(self, x, h): |
93 | return x - h |
94 | )" ); |
95 | |
96 | std::vector<IValue> inputs; |
97 | inputs.emplace_back(2.0 * torch::ones({})); |
98 | inputs.emplace_back(1.0 * torch::ones({})); |
99 | auto ref = m.forward(inputs).toTupleRef().elements().vec(); |
100 | |
101 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
102 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
103 | fake_dict.insert("" , "" ); |
104 | compile_spec.insert("forward" , fake_dict); |
105 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
106 | // Produce lowered module (backend not available). |
107 | // Exception is not thrown at this point. |
108 | auto lm = torch::jit::detail::codegen_backend_module( |
109 | "test_backend_unavailable" , m, compile_spec, any_dict_ty); |
110 | // Validate exception is thrown when trying to execute and |
111 | // the backend is not available. |
112 | ASSERT_THROWS_WITH_MESSAGE( |
113 | lm.forward(inputs).toTupleRef().elements(), "Backend is not available." ); |
114 | } |
115 | |
116 | TEST(BackendTest, TestCompiler) { |
117 | Module m("m" ); |
118 | m.define(R"( |
119 | def forward(self, x, h): |
120 | return x + h |
121 | )" ); |
122 | |
123 | std::vector<IValue> inputs; |
124 | inputs.emplace_back(2.0 * torch::ones({})); |
125 | inputs.emplace_back(1.0 * torch::ones({})); |
126 | auto ref = m.forward(inputs); |
127 | |
128 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
129 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
130 | fake_dict.insert("" , "" ); |
131 | compile_spec.insert("forward" , fake_dict); |
132 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
133 | // lowered module |
134 | auto lm = torch::jit::detail::codegen_backend_module( |
135 | "backend_with_compiler_demo" , m, compile_spec, any_dict_ty); |
136 | auto res = lm.forward(inputs); |
137 | AT_ASSERT(res.toTensor().equal(ref.toTensor())); |
138 | |
139 | std::stringstream ss; |
140 | lm._save_for_mobile(ss); |
141 | auto mlm = _load_for_mobile(ss); |
142 | auto mres = mlm.forward(inputs); |
143 | AT_ASSERT(mres.toTensor().equal(ref.toTensor())); |
144 | } |
145 | |
146 | TEST(BackendTest, TestCompilerWithStringTable) { |
147 | setShouldUseFormatWithStringTable(true); |
148 | Module m("m" ); |
149 | m.define(R"( |
150 | def forward(self, x, h): |
151 | return x + h |
152 | )" ); |
153 | |
154 | std::vector<IValue> inputs; |
155 | inputs.emplace_back(2.0 * torch::ones({})); |
156 | inputs.emplace_back(1.0 * torch::ones({})); |
157 | auto ref = m.forward(inputs); |
158 | |
159 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
160 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
161 | fake_dict.insert("" , "" ); |
162 | compile_spec.insert("forward" , fake_dict); |
163 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
164 | // lowered module |
165 | auto lm = torch::jit::detail::codegen_backend_module( |
166 | "backend_with_compiler_demo" , m, compile_spec, any_dict_ty); |
167 | auto res = lm.forward(inputs); |
168 | AT_ASSERT(res.toTensor().equal(ref.toTensor())); |
169 | |
170 | std::stringstream ss; |
171 | lm._save_for_mobile(ss); |
172 | auto mlm = _load_for_mobile(ss); |
173 | auto mres = mlm.forward(inputs); |
174 | setShouldUseFormatWithStringTable(false); |
175 | AT_ASSERT(mres.toTensor().equal(ref.toTensor())); |
176 | } |
177 | |
178 | TEST(BackendTest, TestComposite) { |
179 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
180 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
181 | fake_dict.insert("" , "" ); |
182 | compile_spec.insert("forward" , fake_dict); |
183 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
184 | |
185 | Module m_add("m_add" ); |
186 | m_add.define(R"( |
187 | def forward(self, x, y): |
188 | return x + y |
189 | )" ); |
190 | auto lm_add = torch::jit::detail::codegen_backend_module( |
191 | "backend_with_compiler_demo" , m_add, compile_spec, any_dict_ty); |
192 | |
193 | Module m_sub("m_sub" ); |
194 | m_sub.define(R"( |
195 | def forward(self, x, y): |
196 | return x - y |
197 | )" ); |
198 | auto lm_sub = torch::jit::detail::codegen_backend_module( |
199 | "backend_with_compiler_demo" , m_sub, compile_spec, any_dict_ty); |
200 | |
201 | Module c("C" ); |
202 | c.register_module("Add" , lm_add); |
203 | c.register_module("Sub" , lm_sub); |
204 | c.define(R"( |
205 | def forward(self, x, y): |
206 | return self.Add.forward(x, y) * self.Sub.forward(x, y) |
207 | )" ); |
208 | |
209 | std::vector<IValue> inputs; |
210 | inputs.emplace_back(3.0 * torch::ones({})); |
211 | inputs.emplace_back(1.0 * torch::ones({})); |
212 | auto res_jit = c.forward(inputs); |
213 | |
214 | std::stringstream ss; |
215 | c._save_for_mobile(ss); |
216 | auto mc = _load_for_mobile(ss); |
217 | auto res_mobile = mc.forward(inputs); |
218 | |
219 | AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor())); |
220 | } |
221 | |
222 | TEST(BackendTest, TestPrimDtype) { |
223 | Module c("name" ); |
224 | c.define(R"( |
225 | def forward(self, x, y): |
226 | c = y.dtype |
227 | return c |
228 | )" ); |
229 | |
230 | std::vector<IValue> inputs; |
231 | inputs.emplace_back(3.0 * torch::ones({})); |
232 | inputs.emplace_back(1.0 * torch::ones({})); |
233 | auto res_jit = c.forward(inputs); |
234 | |
235 | std::stringstream ss; |
236 | c._save_for_mobile(ss); |
237 | auto mc = _load_for_mobile(ss); |
238 | auto res_mobile = mc.forward(inputs); |
239 | |
240 | ASSERT_EQ(res_jit.toInt(), res_mobile.toInt()); |
241 | } |
242 | |
243 | Module getCompositeModuleWithSameNameSubModules() { |
244 | // Two submodules with same module name but different forward and other |
245 | // functions should be serialized and loaded correctly. |
246 | |
247 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
248 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
249 | fake_dict.insert("" , "" ); |
250 | compile_spec.insert("forward" , fake_dict); |
251 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
252 | |
253 | Module sub1("m_add" ); |
254 | sub1.define(R"( |
255 | def forward(self, x, y): |
256 | return x + y |
257 | )" ); |
258 | auto lowered_sub1 = torch::jit::detail::codegen_backend_module( |
259 | "backend_with_compiler_demo" , sub1, compile_spec, any_dict_ty); |
260 | |
261 | Module sub2("m_add" ); |
262 | sub2.define(R"( |
263 | def forward(self, x, y): |
264 | return x - y |
265 | )" ); |
266 | auto lowered_sub2 = torch::jit::detail::codegen_backend_module( |
267 | "backend_with_compiler_demo" , sub2, compile_spec, any_dict_ty); |
268 | |
269 | Module c("C" ); |
270 | c.register_module("Add" , lowered_sub1); |
271 | c.register_module("Sub" , lowered_sub2); |
272 | c.define(R"( |
273 | def forward(self, a, b, s:int): |
274 | c = self.Add.forward(a, b) |
275 | d = self.Sub.forward(a, b) |
276 | y = s * (c * d) |
277 | return y |
278 | )" ); |
279 | |
280 | return c; |
281 | } |
282 | |
283 | TEST(BackendTest, TestCompositeWithSetStates) { |
284 | Module c = getCompositeModuleWithSameNameSubModules(); |
285 | |
286 | std::vector<IValue> inputs; |
287 | inputs.emplace_back(torch::ones({})); |
288 | inputs.emplace_back(3.0 * torch::ones({})); |
289 | inputs.emplace_back(3); |
290 | auto res_jit = c.forward(inputs); |
291 | |
292 | std::stringstream ss; |
293 | c._save_for_mobile(ss); |
294 | auto mc = _load_for_mobile(ss); |
295 | auto res_mobile = mc.forward(inputs); |
296 | AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor())); |
297 | } |
298 | |
299 | TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) { |
300 | Module c = getCompositeModuleWithSameNameSubModules(); |
301 | |
302 | std::vector<IValue> inputs; |
303 | inputs.emplace_back(torch::ones({})); |
304 | inputs.emplace_back(3.0 * torch::ones({})); |
305 | inputs.emplace_back(3); |
306 | |
307 | std::stringstream ss, ss_resave; |
308 | c._save_for_mobile(ss); |
309 | auto mc = _load_for_mobile(ss); |
310 | auto res_mobile = mc.forward(inputs); |
311 | ss.seekg(0, ss.beg); |
312 | |
313 | // check if the methods names are always the same |
314 | // by reloading the script module and saving it back as mobile |
315 | // The below checks ensure that the names of Methods |
316 | // and numerical outputs of mobile and reloaded mobile |
317 | // modules are same. |
318 | auto script_module_load = torch::jit::load(ss); |
319 | script_module_load._save_for_mobile(ss_resave); |
320 | auto mc_reload = _load_for_mobile(ss_resave); |
321 | auto res_mobile_reload = mc_reload.forward(inputs); |
322 | |
323 | AT_ASSERT(res_mobile_reload.toTensor().equal(res_mobile.toTensor())); |
324 | |
325 | auto mc_methods = mc.get_methods(); |
326 | auto mc_reload_methods = mc_reload.get_methods(); |
327 | |
328 | std::vector<std::string> mc_method_qns, mc_reload_method_qns; |
329 | |
330 | auto get_qual_name = [](mobile::Method method) -> std::string { |
331 | return method.function().qualname().qualifiedName(); |
332 | }; |
333 | |
334 | std::transform( |
335 | mc_methods.begin(), |
336 | mc_methods.end(), |
337 | std::back_inserter(mc_method_qns), |
338 | get_qual_name); |
339 | |
340 | std::transform( |
341 | mc_reload_methods.begin(), |
342 | mc_reload_methods.end(), |
343 | std::back_inserter(mc_reload_method_qns), |
344 | get_qual_name); |
345 | |
346 | AT_ASSERT(std::equal( |
347 | mc_method_qns.begin(), |
348 | mc_method_qns.end(), |
349 | mc_reload_method_qns.begin())); |
350 | } |
351 | |
352 | TEST(BackendTest, TestCompilerNotSupport) { |
353 | Module m("m" ); |
354 | m.define(R"( |
355 | def forward(self, x, h): |
356 | return x * h |
357 | )" ); |
358 | |
359 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
360 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
361 | fake_dict.insert("" , "" ); |
362 | compile_spec.insert("forward" , fake_dict); |
363 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
364 | // lowered module |
365 | ASSERT_THROWS_WITH_MESSAGE( |
366 | torch::jit::detail::codegen_backend_module( |
367 | "backend_with_compiler_demo" , m, compile_spec, any_dict_ty), |
368 | "The node of aten::mul is not supported in this compiler. Source code:" ); |
369 | } |
370 | |
371 | TEST(BackendTestDebugInfo, TestCompiler) { |
372 | Module m("m" ); |
373 | m.define(R"( |
374 | def forward(self, x, h): |
375 | return x + h |
376 | )" ); |
377 | |
378 | std::vector<IValue> inputs; |
379 | inputs.emplace_back(torch::rand({2, 4})); |
380 | inputs.emplace_back(torch::rand({13, 9})); |
381 | |
382 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
383 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
384 | fake_dict.insert("" , "" ); |
385 | compile_spec.insert("forward" , fake_dict); |
386 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
387 | // lowered module |
388 | auto lm = torch::jit::detail::codegen_backend_module( |
389 | "backend_with_compiler_demo" , m, compile_spec, any_dict_ty); |
390 | |
391 | std::stringstream ss; |
392 | lm._save_for_mobile(ss, ExtraFilesMap(), true); |
393 | auto mlm = _load_for_mobile(ss); |
394 | std::string error_pattern = R"( |
395 | Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add |
396 | Traceback of TorchScript (most recent call last): |
397 | File "<string>", line 3, in <unknown> |
398 | |
399 | def forward(self, x: Tensor, h: Tensor): |
400 | return self.__loweredModule__.forward(x, h) |
401 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
402 | |
403 | File "<string>", line 5, in forward |
404 | typed_inputs: List[Any] = [x, h, ] |
405 | if self.__backend.is_available() : |
406 | _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) |
407 | ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
408 | assert isinstance(_0, Tensor) |
409 | return _0 |
410 | File "<string>", line 3, in <unknown> |
411 | |
412 | def forward(self, x, h): |
413 | return x + h |
414 | ~~~~~ <--- HERE |
415 | )" ; |
416 | ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); |
417 | } |
418 | |
419 | TEST(BackendTestDebugInfo, TestCompilerWithStringTable) { |
420 | setShouldUseFormatWithStringTable(true); |
421 | Module m("m" ); |
422 | m.define(R"( |
423 | def forward(self, x, h): |
424 | return x + h |
425 | )" ); |
426 | |
427 | std::vector<IValue> inputs; |
428 | inputs.emplace_back(torch::rand({2, 4})); |
429 | inputs.emplace_back(torch::rand({13, 9})); |
430 | |
431 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
432 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
433 | fake_dict.insert("" , "" ); |
434 | compile_spec.insert("forward" , fake_dict); |
435 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
436 | // lowered module |
437 | auto lm = torch::jit::detail::codegen_backend_module( |
438 | "backend_with_compiler_demo" , m, compile_spec, any_dict_ty); |
439 | |
440 | std::stringstream ss; |
441 | lm._save_for_mobile(ss, ExtraFilesMap(), true); |
442 | auto mlm = _load_for_mobile(ss); |
443 | std::string error_pattern = R"( |
444 | Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add |
445 | Traceback of TorchScript (most recent call last): |
446 | File "<string>", line 3, in <unknown> |
447 | |
448 | def forward(self, x: Tensor, h: Tensor): |
449 | return self.__loweredModule__.forward(x, h) |
450 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
451 | |
452 | File "<string>", line 5, in forward |
453 | typed_inputs: List[Any] = [x, h, ] |
454 | if self.__backend.is_available() : |
455 | _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) |
456 | ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
457 | assert isinstance(_0, Tensor) |
458 | return _0 |
459 | File "<string>", line 3, in <unknown> |
460 | |
461 | def forward(self, x, h): |
462 | return x + h |
463 | ~~~~~ <--- HERE |
464 | )" ; |
465 | setShouldUseFormatWithStringTable(false); |
466 | ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); |
467 | } |
468 | |
469 | TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) { |
470 | Module a("A" ); |
471 | a.define(R"( |
472 | def forward(self, x, y): |
473 | return x + y |
474 | )" ); |
475 | Module b("B" ); |
476 | b.define(R"( |
477 | def forward(self, x): |
478 | return x + 2 |
479 | )" ); |
480 | Module c("C" ); |
481 | c.register_module("A0" , a); |
482 | c.register_module("B0" , b); |
483 | c.define(R"( |
484 | def forward(self, x, y): |
485 | return self.A0.forward(x, y) + self.B0.forward(x) |
486 | )" ); |
487 | |
488 | std::vector<IValue> inputs; |
489 | inputs.emplace_back(torch::rand({2, 4})); |
490 | inputs.emplace_back(torch::rand({13, 9})); |
491 | |
492 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
493 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
494 | fake_dict.insert("" , "" ); |
495 | compile_spec.insert("forward" , fake_dict); |
496 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
497 | // lowered module |
498 | auto lm = torch::jit::detail::codegen_backend_module( |
499 | "backend_with_compiler_demo" , c, compile_spec, any_dict_ty); |
500 | |
501 | std::stringstream ss; |
502 | lm._save_for_mobile(ss, ExtraFilesMap(), true); |
503 | auto mlm = _load_for_mobile(ss); |
504 | std::string error_pattern = R"( |
505 | Module hierarchy:top(C)::<unknown>.__loweredModule__(C)::forward.A0(A)::forward.aten::add |
506 | Traceback of TorchScript (most recent call last): |
507 | File "<string>", line 3, in <unknown> |
508 | |
509 | def forward(self, x: Tensor, y: Tensor): |
510 | return self.__loweredModule__.forward(x, y) |
511 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
512 | |
513 | File "<string>", line 5, in forward |
514 | typed_inputs: List[Any] = [x, y, ] |
515 | if self.__backend.is_available() : |
516 | _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) |
517 | ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
518 | assert isinstance(_0, Tensor) |
519 | return _0 |
520 | File "<string>", line 3, in <unknown> |
521 | |
522 | def forward(self, x, y): |
523 | return self.A0.forward(x, y) + self.B0.forward(x) |
524 | ~~~~~~~~~~~~~~~ <--- HERE |
525 | |
526 | File "<string>", line 3, in forward |
527 | |
528 | def forward(self, x, y): |
529 | return x + y |
530 | ~~~~~ <--- HERE |
531 | )" ; |
532 | ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); |
533 | } |
534 | |
535 | TEST( |
536 | BackendTestDebugInfo, |
537 | TestExceptionStackForCompilerWithTwoLevelModuleHierarchy) { |
538 | Module a("A" ); |
539 | a.define(R"( |
540 | def forward(self, x, y): |
541 | return x + y |
542 | )" ); |
543 | Module b("B" ); |
544 | b.register_module("A0" , a); |
545 | b.define(R"( |
546 | def forward(self, x, y): |
547 | return self.A0.forward(x, y) + 2 |
548 | )" ); |
549 | Module c("C" ); |
550 | c.register_module("B0" , b); |
551 | c.define(R"( |
552 | def forward(self, x, y): |
553 | return self.B0.forward(x, y) + 3 |
554 | )" ); |
555 | |
556 | std::vector<IValue> inputs; |
557 | inputs.emplace_back(torch::rand({2, 4})); |
558 | inputs.emplace_back(torch::rand({13, 9})); |
559 | |
560 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
561 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
562 | fake_dict.insert("" , "" ); |
563 | compile_spec.insert("forward" , fake_dict); |
564 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
565 | // lowered module |
566 | auto lm = torch::jit::detail::codegen_backend_module( |
567 | "backend_with_compiler_demo" , c, compile_spec, any_dict_ty); |
568 | |
569 | std::stringstream ss; |
570 | lm._save_for_mobile(ss, ExtraFilesMap(), true); |
571 | auto mlm = _load_for_mobile(ss); |
572 | /* |
573 | * Error stack throw will look like this: |
574 | * Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A) |
575 | * Traceback of TorchScript (most recent call last): |
576 | * File "<string>", line 5, in FunctionName_UNKNOWN |
577 | * typed_inputs: List[Any] = [x, y, ] |
578 | * if self.__backend.is_available() : |
579 | * _0, = self.__backend.execute(self.__handles["forward"], |
580 | * typed_inputs) |
581 | * ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
582 | * assert isinstance(_0, Tensor) |
583 | * return _0 |
584 | * File "<string>", line 3, in FunctionName_UNKNOWN |
585 | * |
586 | * def forward(self, x, y): |
587 | * return self.B0.forward(x, y) + 3 |
588 | * ~~~~~~~~~~~~~~~ <--- HERE |
589 | * |
590 | * File "<string>", line 3, in FunctionName_UNKNOWN |
591 | * |
592 | * def forward(self, x, y): |
593 | * return self.A0.forward(x, y) + 2 |
594 | * ~~~~~~~~~~~~~~~ <--- HERE |
595 | * |
596 | * File "<string>", line 3, in FunctionName_UNKNOWN |
597 | * |
598 | * def forward(self, x, y): |
599 | * return x + y |
600 | * ~~~~~ <--- HERE |
601 | * |
602 | */ |
603 | std::string error_pattern = R"( |
604 | Module hierarchy:top(C)::<unknown>.__loweredModule__(C)::forward.B0(B)::forward.A0(A)::forward.aten::add |
605 | Traceback of TorchScript (most recent call last): |
606 | File "<string>", line 3, in <unknown> |
607 | |
608 | def forward(self, x: Tensor, y: Tensor): |
609 | return self.__loweredModule__.forward(x, y) |
610 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
611 | |
612 | File "<string>", line 5, in forward |
613 | typed_inputs: List[Any] = [x, y, ] |
614 | if self.__backend.is_available() : |
615 | _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) |
616 | ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
617 | assert isinstance(_0, Tensor) |
618 | return _0 |
619 | File "<string>", line 3, in <unknown> |
620 | |
621 | def forward(self, x, y): |
622 | return self.B0.forward(x, y) + 3 |
623 | ~~~~~~~~~~~~~~~ <--- HERE |
624 | |
625 | File "<string>", line 3, in forward |
626 | |
627 | def forward(self, x, y): |
628 | return self.A0.forward(x, y) + 2 |
629 | ~~~~~~~~~~~~~~~ <--- HERE |
630 | |
631 | File "<string>", line 3, in forward |
632 | |
633 | def forward(self, x, y): |
634 | return x + y |
635 | ~~~~~ <--- HERE |
636 | )" ; |
637 | ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); |
638 | } |
639 | |
640 | TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) { |
641 | std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>(); |
642 | Module a("A" ); |
643 | a.define(R"( |
644 | def forward(self, x, y): |
645 | return x + y |
646 | )" ); |
647 | Module b("B" ); |
648 | b.define(R"( |
649 | def forward(self, x): |
650 | return x + 2 |
651 | )" ); |
652 | Module c("C" ); |
653 | c.register_module("A0" , a); |
654 | c.register_module("B0" , b); |
655 | c.define(R"( |
656 | def forward(self, x, y): |
657 | return self.A0.forward(x, y) + self.B0.forward(x) |
658 | )" ); |
659 | |
660 | std::vector<IValue> inputs; |
661 | inputs.emplace_back(torch::rand({2, 4})); |
662 | inputs.emplace_back(torch::rand({13, 9})); |
663 | |
664 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
665 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
666 | fake_dict.insert("" , "" ); |
667 | compile_spec.insert("forward" , fake_dict); |
668 | IValue submodule = c.attr("A0" ); |
669 | Module current_sm = submodule.toModule(); |
670 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
671 | // lowered module |
672 | auto lowered_submodule = torch::jit::detail::codegen_backend_module( |
673 | "backend_with_compiler_demo" , current_sm, compile_spec, any_dict_ty); |
674 | |
675 | c.type()->unsafeChangeAttributeType("A0" , lowered_submodule.type()); |
676 | c.setattr("A0" , lowered_submodule._ivalue()); |
677 | std::unordered_map<TypePtr, TypePtr> type_remap; |
678 | type_remap[a.type()] = lowered_submodule.type(); |
679 | auto type_remap_fn = [&type_remap](TypePtr in) { |
680 | auto it = type_remap.find(in); |
681 | if (it == type_remap.end()) |
682 | return in; |
683 | return it->second; |
684 | }; |
685 | for (auto& fn : c.type()->methods()) { |
686 | auto method = c.get_method(fn->name()); |
687 | auto graph = method.graph(); |
688 | graph->remapTypes(type_remap_fn); |
689 | auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn); |
690 | fn->setSchema(new_schema); |
691 | } |
692 | |
693 | std::stringstream ss; |
694 | c._save_for_mobile(ss, ExtraFilesMap(), true); |
695 | auto c_loaded = _load_for_mobile(ss); |
696 | std::string error_pattern = R"( |
697 | Module hierarchy:top(C)::<unknown>.A0(A)::forward.__loweredModule__(A)::forward.aten::add |
698 | Traceback of TorchScript (most recent call last): |
699 | File "<string>", line 3, in <unknown> |
700 | |
701 | def forward(self, x, y): |
702 | return self.A0.forward(x, y) + self.B0.forward(x) |
703 | ~~~~~~~~~~~~~~~ <--- HERE |
704 | |
705 | File "<string>", line 3, in forward |
706 | |
707 | def forward(self, x: Tensor, y: Tensor): |
708 | return self.__loweredModule__.forward(x, y) |
709 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
710 | |
711 | File "<string>", line 5, in forward |
712 | typed_inputs: List[Any] = [x, y, ] |
713 | if self.__backend.is_available() : |
714 | _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) |
715 | ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
716 | assert isinstance(_0, Tensor) |
717 | return _0 |
718 | File "<string>", line 3, in <unknown> |
719 | |
720 | def forward(self, x, y): |
721 | return x + y |
722 | ~~~~~ <--- HERE |
723 | )" ; |
724 | ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern); |
725 | } |
726 | |
727 | TEST( |
728 | BackendTestDebugInfo, |
729 | TestExceptionStackForCompilerWithSelectiveLoweredSubModule) { |
730 | std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>(); |
731 | Module aa("AA" ); |
732 | aa.define(R"( |
733 | def forward(self, x, y): |
734 | return x + y |
735 | )" ); |
736 | Module a("A" ); |
737 | a.register_module("AA0" , aa); |
738 | a.define(R"( |
739 | def forward(self, x, y): |
740 | return self.AA0.forward(x, y) + 3 |
741 | )" ); |
742 | Module b("B" ); |
743 | b.define(R"( |
744 | def forward(self, x): |
745 | return x + 2 |
746 | )" ); |
747 | Module c("C" ); |
748 | c.register_module("A0" , a); |
749 | c.register_module("B0" , b); |
750 | c.define(R"( |
751 | def forward(self, x, y): |
752 | return self.A0.forward(x, y) + self.B0.forward(x) |
753 | )" ); |
754 | |
755 | std::vector<IValue> inputs; |
756 | inputs.emplace_back(torch::rand({2, 4})); |
757 | inputs.emplace_back(torch::rand({13, 9})); |
758 | |
759 | c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get()); |
760 | c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get()); |
761 | fake_dict.insert("" , "" ); |
762 | compile_spec.insert("forward" , fake_dict); |
763 | IValue submodule = c.attr("A0" ); |
764 | Module current_sm = submodule.toModule(); |
765 | auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); |
766 | // lowered module |
767 | auto lowered_submodule = torch::jit::detail::codegen_backend_module( |
768 | "backend_with_compiler_demo" , current_sm, compile_spec, any_dict_ty); |
769 | |
770 | c.type()->unsafeChangeAttributeType("A0" , lowered_submodule.type()); |
771 | c.setattr("A0" , lowered_submodule._ivalue()); |
772 | std::unordered_map<TypePtr, TypePtr> type_remap; |
773 | type_remap[a.type()] = lowered_submodule.type(); |
774 | auto type_remap_fn = [&type_remap](TypePtr in) { |
775 | auto it = type_remap.find(in); |
776 | if (it == type_remap.end()) |
777 | return in; |
778 | return it->second; |
779 | }; |
780 | for (auto& fn : c.type()->methods()) { |
781 | auto method = c.get_method(fn->name()); |
782 | auto graph = method.graph(); |
783 | graph->remapTypes(type_remap_fn); |
784 | auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn); |
785 | fn->setSchema(new_schema); |
786 | } |
787 | |
788 | std::stringstream ss; |
789 | c._save_for_mobile(ss, ExtraFilesMap(), true); |
790 | auto c_loaded = _load_for_mobile(ss); |
791 | /* |
792 | * Erro stack trace will look like this: |
793 | * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) |
794 | * Traceback of TorchScript (most recent call last): |
795 | * File "<string>", line 3, in FunctionName_UNKNOWN |
796 | * |
797 | * def forward(self, x, y): |
798 | * return self.A0.forward(x, y) + self.B0.forward(x) |
799 | * ~~~~~~~~~~~~~~~ <--- HERE |
800 | * |
801 | * File "<string>", line 5, in FunctionName_UNKNOWN |
802 | * typed_inputs: List[Any] = [x, y, ] |
803 | * if self.__backend.is_available() : |
804 | * _0, = self.__backend.execute(self.__handles["forward"], |
805 | * typed_inputs) |
806 | * ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
807 | * assert isinstance(_0, Tensor) |
808 | * return _0 |
809 | * File "<string>", line 3, in FunctionName_UNKNOWN |
810 | * |
811 | * def forward(self, x, y): |
812 | * return self.AA0.forward(x, y) + 3 |
813 | * ~~~~~~~~~~~~~~~~ <--- HERE |
814 | * |
815 | * File "<string>", line 3, in FunctionName_UNKNOWN |
816 | * |
817 | * def forward(self, x, y): |
818 | * return x + y |
819 | * ~~~~~ <--- HERE |
820 | * |
821 | * |
822 | * */ |
823 | std::string error_pattern = R"( |
824 | Module hierarchy:top(C)::<unknown>.A0(A)::forward.__loweredModule__(A)::forward.AA0(AA)::forward.aten::add |
825 | Traceback of TorchScript (most recent call last): |
826 | File "<string>", line 3, in <unknown> |
827 | |
828 | def forward(self, x, y): |
829 | return self.A0.forward(x, y) + self.B0.forward(x) |
830 | ~~~~~~~~~~~~~~~ <--- HERE |
831 | |
832 | File "<string>", line 3, in forward |
833 | |
834 | def forward(self, x: Tensor, y: Tensor): |
835 | return self.__loweredModule__.forward(x, y) |
836 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
837 | |
838 | File "<string>", line 5, in forward |
839 | typed_inputs: List[Any] = [x, y, ] |
840 | if self.__backend.is_available() : |
841 | _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) |
842 | ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
843 | assert isinstance(_0, Tensor) |
844 | return _0 |
845 | File "<string>", line 3, in <unknown> |
846 | |
847 | def forward(self, x, y): |
848 | return self.AA0.forward(x, y) + 3 |
849 | ~~~~~~~~~~~~~~~~ <--- HERE |
850 | |
851 | File "<string>", line 3, in forward |
852 | |
853 | def forward(self, x, y): |
854 | return x + y |
855 | ~~~~~ <--- HERE |
856 | )" ; |
857 | ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern); |
858 | } |
859 | |
860 | } // namespace jit |
861 | } // namespace torch |
862 | |