1#include <gtest/gtest.h>
2#include <test/cpp/jit/test_utils.h>
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/backends/backend_detail.h>
5#include <torch/csrc/jit/mobile/import.h>
6#include <torch/csrc/jit/serialization/import.h>
7#include <torch/torch.h>
8
9// Tests go in torch::jit
10namespace torch {
11namespace jit {
12TEST(BackendTest, ToBackend) {
13 Module m("m");
14 m.define(R"(
15 def forward(self, x, h):
16 return self.accum(x, h), self.sub_accum(x, h)
17
18 def accum(self, x, h):
19 return x + h
20
21 def sub_accum(self, x, h):
22 return x - h
23 )");
24
25 std::vector<IValue> inputs;
26 inputs.emplace_back(2.0 * torch::ones({}));
27 inputs.emplace_back(1.0 * torch::ones({}));
28 auto ref = m.forward(inputs).toTupleRef().elements().vec();
29
30 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
31 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
32 fake_dict.insert("", "");
33 compile_spec.insert("forward", fake_dict);
34 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
35 // lowered module
36 auto lm = torch::jit::detail::codegen_backend_module(
37 "test_backend", m, compile_spec, any_dict_ty);
38 // lowered module code:
39 /*
40 class test_backendLoweredModule(Module):
41 __parameters__ = []
42 __buffers__ = []
43 __processed_module : Any
44 __method_compile_spec : Dict[str, Any]
45 __backend : __torch__.torch.classes.__backends__.test_backend
46 __handles : Dict[str, Any]
47 def __create_backend(self: torch.jit.test_backendLoweredModule) -> None:
48 _0 =
49 __torch__.torch.classes.__backends__.test_backend.__new__(__torch__.torch.classes.__backends__.test_backend)
50 _1 = (_0).__init__()
51 self.__backend = _0
52 return None
53 def __getstate__(self: torch.jit.test_backendLoweredModule) ->
54 Tuple[Dict[str, Any], Any]: _2 = (self.__method_compile_spec,
55 self.__processed_module) return _2 def __setstate__(self:
56 torch.jit.test_backendLoweredModule, state: Tuple[Dict[str, Any], Any]) ->
57 None: self.__method_compile_spec = (state)[0] self.__processed_module =
58 (state)[1] _3 = (self).__create_backend() _4 =
59 (self.__backend).compile(self.__processed_module,
60 self.__method_compile_spec, ) self.__handles = _4 return None def
61 forward(self: torch.jit.test_backendLoweredModule, x: Tensor, h: Tensor) ->
62 Tuple[Tensor, Tensor]: _5 = uninitialized(Tensor) typed_inputs =
63 annotate(List[Any], [x, h]) _6 =
64 (self.__backend).execute((self.__handles)["forward"], typed_inputs, ) _7,
65 _8, = _6 _9 = isinstance(_7, Tensor) if _9: _10 = unchecked_cast(Tensor, _7)
66 else:
67 ops.prim.RaiseException("AssertionError: ")
68 _10 = _5
69 _11 = isinstance(_8, Tensor)
70 if _11:
71 _12 = unchecked_cast(Tensor, _8)
72 else:
73 ops.prim.RaiseException("AssertionError: ")
74 _12 = _5
75 return (_10, _12)
76
77 */
78 auto res = lm.forward(inputs).toTupleRef().elements().vec();
79 AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor()));
80 AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
81}
82
83TEST(BackendTest, ToBackendNotAvailable) {
84 Module m("m");
85 m.define(R"(
86 def forward(self, x, h):
87 return self.accum(x, h), self.sub_accum(x, h)
88
89 def accum(self, x, h):
90 return x + h
91
92 def sub_accum(self, x, h):
93 return x - h
94 )");
95
96 std::vector<IValue> inputs;
97 inputs.emplace_back(2.0 * torch::ones({}));
98 inputs.emplace_back(1.0 * torch::ones({}));
99 auto ref = m.forward(inputs).toTupleRef().elements().vec();
100
101 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
102 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
103 fake_dict.insert("", "");
104 compile_spec.insert("forward", fake_dict);
105 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
106 // Produce lowered module (backend not available).
107 // Exception is not thrown at this point.
108 auto lm = torch::jit::detail::codegen_backend_module(
109 "test_backend_unavailable", m, compile_spec, any_dict_ty);
110 // Validate exception is thrown when trying to execute and
111 // the backend is not available.
112 ASSERT_THROWS_WITH_MESSAGE(
113 lm.forward(inputs).toTupleRef().elements(), "Backend is not available.");
114}
115
116TEST(BackendTest, TestCompiler) {
117 Module m("m");
118 m.define(R"(
119 def forward(self, x, h):
120 return x + h
121 )");
122
123 std::vector<IValue> inputs;
124 inputs.emplace_back(2.0 * torch::ones({}));
125 inputs.emplace_back(1.0 * torch::ones({}));
126 auto ref = m.forward(inputs);
127
128 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
129 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
130 fake_dict.insert("", "");
131 compile_spec.insert("forward", fake_dict);
132 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
133 // lowered module
134 auto lm = torch::jit::detail::codegen_backend_module(
135 "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
136 auto res = lm.forward(inputs);
137 AT_ASSERT(res.toTensor().equal(ref.toTensor()));
138
139 std::stringstream ss;
140 lm._save_for_mobile(ss);
141 auto mlm = _load_for_mobile(ss);
142 auto mres = mlm.forward(inputs);
143 AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
144}
145
146TEST(BackendTest, TestCompilerWithStringTable) {
147 setShouldUseFormatWithStringTable(true);
148 Module m("m");
149 m.define(R"(
150 def forward(self, x, h):
151 return x + h
152 )");
153
154 std::vector<IValue> inputs;
155 inputs.emplace_back(2.0 * torch::ones({}));
156 inputs.emplace_back(1.0 * torch::ones({}));
157 auto ref = m.forward(inputs);
158
159 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
160 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
161 fake_dict.insert("", "");
162 compile_spec.insert("forward", fake_dict);
163 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
164 // lowered module
165 auto lm = torch::jit::detail::codegen_backend_module(
166 "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
167 auto res = lm.forward(inputs);
168 AT_ASSERT(res.toTensor().equal(ref.toTensor()));
169
170 std::stringstream ss;
171 lm._save_for_mobile(ss);
172 auto mlm = _load_for_mobile(ss);
173 auto mres = mlm.forward(inputs);
174 setShouldUseFormatWithStringTable(false);
175 AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
176}
177
178TEST(BackendTest, TestComposite) {
179 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
180 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
181 fake_dict.insert("", "");
182 compile_spec.insert("forward", fake_dict);
183 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
184
185 Module m_add("m_add");
186 m_add.define(R"(
187 def forward(self, x, y):
188 return x + y
189 )");
190 auto lm_add = torch::jit::detail::codegen_backend_module(
191 "backend_with_compiler_demo", m_add, compile_spec, any_dict_ty);
192
193 Module m_sub("m_sub");
194 m_sub.define(R"(
195 def forward(self, x, y):
196 return x - y
197 )");
198 auto lm_sub = torch::jit::detail::codegen_backend_module(
199 "backend_with_compiler_demo", m_sub, compile_spec, any_dict_ty);
200
201 Module c("C");
202 c.register_module("Add", lm_add);
203 c.register_module("Sub", lm_sub);
204 c.define(R"(
205 def forward(self, x, y):
206 return self.Add.forward(x, y) * self.Sub.forward(x, y)
207 )");
208
209 std::vector<IValue> inputs;
210 inputs.emplace_back(3.0 * torch::ones({}));
211 inputs.emplace_back(1.0 * torch::ones({}));
212 auto res_jit = c.forward(inputs);
213
214 std::stringstream ss;
215 c._save_for_mobile(ss);
216 auto mc = _load_for_mobile(ss);
217 auto res_mobile = mc.forward(inputs);
218
219 AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
220}
221
222TEST(BackendTest, TestPrimDtype) {
223 Module c("name");
224 c.define(R"(
225 def forward(self, x, y):
226 c = y.dtype
227 return c
228 )");
229
230 std::vector<IValue> inputs;
231 inputs.emplace_back(3.0 * torch::ones({}));
232 inputs.emplace_back(1.0 * torch::ones({}));
233 auto res_jit = c.forward(inputs);
234
235 std::stringstream ss;
236 c._save_for_mobile(ss);
237 auto mc = _load_for_mobile(ss);
238 auto res_mobile = mc.forward(inputs);
239
240 ASSERT_EQ(res_jit.toInt(), res_mobile.toInt());
241}
242
243Module getCompositeModuleWithSameNameSubModules() {
244 // Two submodules with same module name but different forward and other
245 // functions should be serialized and loaded correctly.
246
247 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
248 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
249 fake_dict.insert("", "");
250 compile_spec.insert("forward", fake_dict);
251 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
252
253 Module sub1("m_add");
254 sub1.define(R"(
255 def forward(self, x, y):
256 return x + y
257 )");
258 auto lowered_sub1 = torch::jit::detail::codegen_backend_module(
259 "backend_with_compiler_demo", sub1, compile_spec, any_dict_ty);
260
261 Module sub2("m_add");
262 sub2.define(R"(
263 def forward(self, x, y):
264 return x - y
265 )");
266 auto lowered_sub2 = torch::jit::detail::codegen_backend_module(
267 "backend_with_compiler_demo", sub2, compile_spec, any_dict_ty);
268
269 Module c("C");
270 c.register_module("Add", lowered_sub1);
271 c.register_module("Sub", lowered_sub2);
272 c.define(R"(
273 def forward(self, a, b, s:int):
274 c = self.Add.forward(a, b)
275 d = self.Sub.forward(a, b)
276 y = s * (c * d)
277 return y
278 )");
279
280 return c;
281}
282
283TEST(BackendTest, TestCompositeWithSetStates) {
284 Module c = getCompositeModuleWithSameNameSubModules();
285
286 std::vector<IValue> inputs;
287 inputs.emplace_back(torch::ones({}));
288 inputs.emplace_back(3.0 * torch::ones({}));
289 inputs.emplace_back(3);
290 auto res_jit = c.forward(inputs);
291
292 std::stringstream ss;
293 c._save_for_mobile(ss);
294 auto mc = _load_for_mobile(ss);
295 auto res_mobile = mc.forward(inputs);
296 AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
297}
298
299TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
300 Module c = getCompositeModuleWithSameNameSubModules();
301
302 std::vector<IValue> inputs;
303 inputs.emplace_back(torch::ones({}));
304 inputs.emplace_back(3.0 * torch::ones({}));
305 inputs.emplace_back(3);
306
307 std::stringstream ss, ss_resave;
308 c._save_for_mobile(ss);
309 auto mc = _load_for_mobile(ss);
310 auto res_mobile = mc.forward(inputs);
311 ss.seekg(0, ss.beg);
312
313 // check if the methods names are always the same
314 // by reloading the script module and saving it back as mobile
315 // The below checks ensure that the names of Methods
316 // and numerical outputs of mobile and reloaded mobile
317 // modules are same.
318 auto script_module_load = torch::jit::load(ss);
319 script_module_load._save_for_mobile(ss_resave);
320 auto mc_reload = _load_for_mobile(ss_resave);
321 auto res_mobile_reload = mc_reload.forward(inputs);
322
323 AT_ASSERT(res_mobile_reload.toTensor().equal(res_mobile.toTensor()));
324
325 auto mc_methods = mc.get_methods();
326 auto mc_reload_methods = mc_reload.get_methods();
327
328 std::vector<std::string> mc_method_qns, mc_reload_method_qns;
329
330 auto get_qual_name = [](mobile::Method method) -> std::string {
331 return method.function().qualname().qualifiedName();
332 };
333
334 std::transform(
335 mc_methods.begin(),
336 mc_methods.end(),
337 std::back_inserter(mc_method_qns),
338 get_qual_name);
339
340 std::transform(
341 mc_reload_methods.begin(),
342 mc_reload_methods.end(),
343 std::back_inserter(mc_reload_method_qns),
344 get_qual_name);
345
346 AT_ASSERT(std::equal(
347 mc_method_qns.begin(),
348 mc_method_qns.end(),
349 mc_reload_method_qns.begin()));
350}
351
352TEST(BackendTest, TestCompilerNotSupport) {
353 Module m("m");
354 m.define(R"(
355 def forward(self, x, h):
356 return x * h
357 )");
358
359 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
360 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
361 fake_dict.insert("", "");
362 compile_spec.insert("forward", fake_dict);
363 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
364 // lowered module
365 ASSERT_THROWS_WITH_MESSAGE(
366 torch::jit::detail::codegen_backend_module(
367 "backend_with_compiler_demo", m, compile_spec, any_dict_ty),
368 "The node of aten::mul is not supported in this compiler. Source code:");
369}
370
371TEST(BackendTestDebugInfo, TestCompiler) {
372 Module m("m");
373 m.define(R"(
374 def forward(self, x, h):
375 return x + h
376 )");
377
378 std::vector<IValue> inputs;
379 inputs.emplace_back(torch::rand({2, 4}));
380 inputs.emplace_back(torch::rand({13, 9}));
381
382 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
383 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
384 fake_dict.insert("", "");
385 compile_spec.insert("forward", fake_dict);
386 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
387 // lowered module
388 auto lm = torch::jit::detail::codegen_backend_module(
389 "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
390
391 std::stringstream ss;
392 lm._save_for_mobile(ss, ExtraFilesMap(), true);
393 auto mlm = _load_for_mobile(ss);
394 std::string error_pattern = R"(
395 Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add
396Traceback of TorchScript (most recent call last):
397 File "<string>", line 3, in <unknown>
398
399 def forward(self, x: Tensor, h: Tensor):
400 return self.__loweredModule__.forward(x, h)
401 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
402
403 File "<string>", line 5, in forward
404 typed_inputs: List[Any] = [x, h, ]
405 if self.__backend.is_available() :
406 _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
407 ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
408 assert isinstance(_0, Tensor)
409 return _0
410 File "<string>", line 3, in <unknown>
411
412 def forward(self, x, h):
413 return x + h
414 ~~~~~ <--- HERE
415 )";
416 ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
417}
418
419TEST(BackendTestDebugInfo, TestCompilerWithStringTable) {
420 setShouldUseFormatWithStringTable(true);
421 Module m("m");
422 m.define(R"(
423 def forward(self, x, h):
424 return x + h
425 )");
426
427 std::vector<IValue> inputs;
428 inputs.emplace_back(torch::rand({2, 4}));
429 inputs.emplace_back(torch::rand({13, 9}));
430
431 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
432 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
433 fake_dict.insert("", "");
434 compile_spec.insert("forward", fake_dict);
435 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
436 // lowered module
437 auto lm = torch::jit::detail::codegen_backend_module(
438 "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
439
440 std::stringstream ss;
441 lm._save_for_mobile(ss, ExtraFilesMap(), true);
442 auto mlm = _load_for_mobile(ss);
443 std::string error_pattern = R"(
444 Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add
445Traceback of TorchScript (most recent call last):
446 File "<string>", line 3, in <unknown>
447
448 def forward(self, x: Tensor, h: Tensor):
449 return self.__loweredModule__.forward(x, h)
450 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
451
452 File "<string>", line 5, in forward
453 typed_inputs: List[Any] = [x, h, ]
454 if self.__backend.is_available() :
455 _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
456 ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
457 assert isinstance(_0, Tensor)
458 return _0
459 File "<string>", line 3, in <unknown>
460
461 def forward(self, x, h):
462 return x + h
463 ~~~~~ <--- HERE
464 )";
465 setShouldUseFormatWithStringTable(false);
466 ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
467}
468
469TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
470 Module a("A");
471 a.define(R"(
472 def forward(self, x, y):
473 return x + y
474 )");
475 Module b("B");
476 b.define(R"(
477 def forward(self, x):
478 return x + 2
479 )");
480 Module c("C");
481 c.register_module("A0", a);
482 c.register_module("B0", b);
483 c.define(R"(
484 def forward(self, x, y):
485 return self.A0.forward(x, y) + self.B0.forward(x)
486 )");
487
488 std::vector<IValue> inputs;
489 inputs.emplace_back(torch::rand({2, 4}));
490 inputs.emplace_back(torch::rand({13, 9}));
491
492 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
493 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
494 fake_dict.insert("", "");
495 compile_spec.insert("forward", fake_dict);
496 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
497 // lowered module
498 auto lm = torch::jit::detail::codegen_backend_module(
499 "backend_with_compiler_demo", c, compile_spec, any_dict_ty);
500
501 std::stringstream ss;
502 lm._save_for_mobile(ss, ExtraFilesMap(), true);
503 auto mlm = _load_for_mobile(ss);
504 std::string error_pattern = R"(
505 Module hierarchy:top(C)::<unknown>.__loweredModule__(C)::forward.A0(A)::forward.aten::add
506Traceback of TorchScript (most recent call last):
507 File "<string>", line 3, in <unknown>
508
509 def forward(self, x: Tensor, y: Tensor):
510 return self.__loweredModule__.forward(x, y)
511 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
512
513 File "<string>", line 5, in forward
514 typed_inputs: List[Any] = [x, y, ]
515 if self.__backend.is_available() :
516 _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
517 ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
518 assert isinstance(_0, Tensor)
519 return _0
520 File "<string>", line 3, in <unknown>
521
522 def forward(self, x, y):
523 return self.A0.forward(x, y) + self.B0.forward(x)
524 ~~~~~~~~~~~~~~~ <--- HERE
525
526 File "<string>", line 3, in forward
527
528 def forward(self, x, y):
529 return x + y
530 ~~~~~ <--- HERE
531 )";
532 ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
533}
534
535TEST(
536 BackendTestDebugInfo,
537 TestExceptionStackForCompilerWithTwoLevelModuleHierarchy) {
538 Module a("A");
539 a.define(R"(
540 def forward(self, x, y):
541 return x + y
542 )");
543 Module b("B");
544 b.register_module("A0", a);
545 b.define(R"(
546 def forward(self, x, y):
547 return self.A0.forward(x, y) + 2
548 )");
549 Module c("C");
550 c.register_module("B0", b);
551 c.define(R"(
552 def forward(self, x, y):
553 return self.B0.forward(x, y) + 3
554 )");
555
556 std::vector<IValue> inputs;
557 inputs.emplace_back(torch::rand({2, 4}));
558 inputs.emplace_back(torch::rand({13, 9}));
559
560 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
561 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
562 fake_dict.insert("", "");
563 compile_spec.insert("forward", fake_dict);
564 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
565 // lowered module
566 auto lm = torch::jit::detail::codegen_backend_module(
567 "backend_with_compiler_demo", c, compile_spec, any_dict_ty);
568
569 std::stringstream ss;
570 lm._save_for_mobile(ss, ExtraFilesMap(), true);
571 auto mlm = _load_for_mobile(ss);
572 /*
573 * Error stack throw will look like this:
574 * Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A)
575 * Traceback of TorchScript (most recent call last):
576 * File "<string>", line 5, in FunctionName_UNKNOWN
577 * typed_inputs: List[Any] = [x, y, ]
578 * if self.__backend.is_available() :
579 * _0, = self.__backend.execute(self.__handles["forward"],
580 * typed_inputs)
581 * ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
582 * assert isinstance(_0, Tensor)
583 * return _0
584 * File "<string>", line 3, in FunctionName_UNKNOWN
585 *
586 * def forward(self, x, y):
587 * return self.B0.forward(x, y) + 3
588 * ~~~~~~~~~~~~~~~ <--- HERE
589 *
590 * File "<string>", line 3, in FunctionName_UNKNOWN
591 *
592 * def forward(self, x, y):
593 * return self.A0.forward(x, y) + 2
594 * ~~~~~~~~~~~~~~~ <--- HERE
595 *
596 * File "<string>", line 3, in FunctionName_UNKNOWN
597 *
598 * def forward(self, x, y):
599 * return x + y
600 * ~~~~~ <--- HERE
601 *
602 */
603 std::string error_pattern = R"(
604 Module hierarchy:top(C)::<unknown>.__loweredModule__(C)::forward.B0(B)::forward.A0(A)::forward.aten::add
605Traceback of TorchScript (most recent call last):
606 File "<string>", line 3, in <unknown>
607
608 def forward(self, x: Tensor, y: Tensor):
609 return self.__loweredModule__.forward(x, y)
610 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
611
612 File "<string>", line 5, in forward
613 typed_inputs: List[Any] = [x, y, ]
614 if self.__backend.is_available() :
615 _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
616 ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
617 assert isinstance(_0, Tensor)
618 return _0
619 File "<string>", line 3, in <unknown>
620
621 def forward(self, x, y):
622 return self.B0.forward(x, y) + 3
623 ~~~~~~~~~~~~~~~ <--- HERE
624
625 File "<string>", line 3, in forward
626
627 def forward(self, x, y):
628 return self.A0.forward(x, y) + 2
629 ~~~~~~~~~~~~~~~ <--- HERE
630
631 File "<string>", line 3, in forward
632
633 def forward(self, x, y):
634 return x + y
635 ~~~~~ <--- HERE
636 )";
637 ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
638}
639
640TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
641 std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
642 Module a("A");
643 a.define(R"(
644 def forward(self, x, y):
645 return x + y
646 )");
647 Module b("B");
648 b.define(R"(
649 def forward(self, x):
650 return x + 2
651 )");
652 Module c("C");
653 c.register_module("A0", a);
654 c.register_module("B0", b);
655 c.define(R"(
656 def forward(self, x, y):
657 return self.A0.forward(x, y) + self.B0.forward(x)
658 )");
659
660 std::vector<IValue> inputs;
661 inputs.emplace_back(torch::rand({2, 4}));
662 inputs.emplace_back(torch::rand({13, 9}));
663
664 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
665 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
666 fake_dict.insert("", "");
667 compile_spec.insert("forward", fake_dict);
668 IValue submodule = c.attr("A0");
669 Module current_sm = submodule.toModule();
670 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
671 // lowered module
672 auto lowered_submodule = torch::jit::detail::codegen_backend_module(
673 "backend_with_compiler_demo", current_sm, compile_spec, any_dict_ty);
674
675 c.type()->unsafeChangeAttributeType("A0", lowered_submodule.type());
676 c.setattr("A0", lowered_submodule._ivalue());
677 std::unordered_map<TypePtr, TypePtr> type_remap;
678 type_remap[a.type()] = lowered_submodule.type();
679 auto type_remap_fn = [&type_remap](TypePtr in) {
680 auto it = type_remap.find(in);
681 if (it == type_remap.end())
682 return in;
683 return it->second;
684 };
685 for (auto& fn : c.type()->methods()) {
686 auto method = c.get_method(fn->name());
687 auto graph = method.graph();
688 graph->remapTypes(type_remap_fn);
689 auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
690 fn->setSchema(new_schema);
691 }
692
693 std::stringstream ss;
694 c._save_for_mobile(ss, ExtraFilesMap(), true);
695 auto c_loaded = _load_for_mobile(ss);
696 std::string error_pattern = R"(
697 Module hierarchy:top(C)::<unknown>.A0(A)::forward.__loweredModule__(A)::forward.aten::add
698Traceback of TorchScript (most recent call last):
699 File "<string>", line 3, in <unknown>
700
701 def forward(self, x, y):
702 return self.A0.forward(x, y) + self.B0.forward(x)
703 ~~~~~~~~~~~~~~~ <--- HERE
704
705 File "<string>", line 3, in forward
706
707 def forward(self, x: Tensor, y: Tensor):
708 return self.__loweredModule__.forward(x, y)
709 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
710
711 File "<string>", line 5, in forward
712 typed_inputs: List[Any] = [x, y, ]
713 if self.__backend.is_available() :
714 _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
715 ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
716 assert isinstance(_0, Tensor)
717 return _0
718 File "<string>", line 3, in <unknown>
719
720 def forward(self, x, y):
721 return x + y
722 ~~~~~ <--- HERE
723 )";
724 ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
725}
726
727TEST(
728 BackendTestDebugInfo,
729 TestExceptionStackForCompilerWithSelectiveLoweredSubModule) {
730 std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
731 Module aa("AA");
732 aa.define(R"(
733 def forward(self, x, y):
734 return x + y
735 )");
736 Module a("A");
737 a.register_module("AA0", aa);
738 a.define(R"(
739 def forward(self, x, y):
740 return self.AA0.forward(x, y) + 3
741 )");
742 Module b("B");
743 b.define(R"(
744 def forward(self, x):
745 return x + 2
746 )");
747 Module c("C");
748 c.register_module("A0", a);
749 c.register_module("B0", b);
750 c.define(R"(
751 def forward(self, x, y):
752 return self.A0.forward(x, y) + self.B0.forward(x)
753 )");
754
755 std::vector<IValue> inputs;
756 inputs.emplace_back(torch::rand({2, 4}));
757 inputs.emplace_back(torch::rand({13, 9}));
758
759 c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
760 c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
761 fake_dict.insert("", "");
762 compile_spec.insert("forward", fake_dict);
763 IValue submodule = c.attr("A0");
764 Module current_sm = submodule.toModule();
765 auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
766 // lowered module
767 auto lowered_submodule = torch::jit::detail::codegen_backend_module(
768 "backend_with_compiler_demo", current_sm, compile_spec, any_dict_ty);
769
770 c.type()->unsafeChangeAttributeType("A0", lowered_submodule.type());
771 c.setattr("A0", lowered_submodule._ivalue());
772 std::unordered_map<TypePtr, TypePtr> type_remap;
773 type_remap[a.type()] = lowered_submodule.type();
774 auto type_remap_fn = [&type_remap](TypePtr in) {
775 auto it = type_remap.find(in);
776 if (it == type_remap.end())
777 return in;
778 return it->second;
779 };
780 for (auto& fn : c.type()->methods()) {
781 auto method = c.get_method(fn->name());
782 auto graph = method.graph();
783 graph->remapTypes(type_remap_fn);
784 auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
785 fn->setSchema(new_schema);
786 }
787
788 std::stringstream ss;
789 c._save_for_mobile(ss, ExtraFilesMap(), true);
790 auto c_loaded = _load_for_mobile(ss);
791 /*
792 * Erro stack trace will look like this:
793 * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA)
794 * Traceback of TorchScript (most recent call last):
795 * File "<string>", line 3, in FunctionName_UNKNOWN
796 *
797 * def forward(self, x, y):
798 * return self.A0.forward(x, y) + self.B0.forward(x)
799 * ~~~~~~~~~~~~~~~ <--- HERE
800 *
801 * File "<string>", line 5, in FunctionName_UNKNOWN
802 * typed_inputs: List[Any] = [x, y, ]
803 * if self.__backend.is_available() :
804 * _0, = self.__backend.execute(self.__handles["forward"],
805 * typed_inputs)
806 * ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
807 * assert isinstance(_0, Tensor)
808 * return _0
809 * File "<string>", line 3, in FunctionName_UNKNOWN
810 *
811 * def forward(self, x, y):
812 * return self.AA0.forward(x, y) + 3
813 * ~~~~~~~~~~~~~~~~ <--- HERE
814 *
815 * File "<string>", line 3, in FunctionName_UNKNOWN
816 *
817 * def forward(self, x, y):
818 * return x + y
819 * ~~~~~ <--- HERE
820 *
821 *
822 * */
823 std::string error_pattern = R"(
824 Module hierarchy:top(C)::<unknown>.A0(A)::forward.__loweredModule__(A)::forward.AA0(AA)::forward.aten::add
825Traceback of TorchScript (most recent call last):
826 File "<string>", line 3, in <unknown>
827
828 def forward(self, x, y):
829 return self.A0.forward(x, y) + self.B0.forward(x)
830 ~~~~~~~~~~~~~~~ <--- HERE
831
832 File "<string>", line 3, in forward
833
834 def forward(self, x: Tensor, y: Tensor):
835 return self.__loweredModule__.forward(x, y)
836 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
837
838 File "<string>", line 5, in forward
839 typed_inputs: List[Any] = [x, y, ]
840 if self.__backend.is_available() :
841 _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
842 ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
843 assert isinstance(_0, Tensor)
844 return _0
845 File "<string>", line 3, in <unknown>
846
847 def forward(self, x, y):
848 return self.AA0.forward(x, y) + 3
849 ~~~~~~~~~~~~~~~~ <--- HERE
850
851 File "<string>", line 3, in forward
852
853 def forward(self, x, y):
854 return x + y
855 ~~~~~ <--- HERE
856 )";
857 ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
858}
859
860} // namespace jit
861} // namespace torch
862