1 | #include <test/cpp/jit/test_utils.h> |
2 | |
3 | #include <gtest/gtest.h> |
4 | |
5 | #include <c10/core/TensorOptions.h> |
6 | #include <torch/csrc/autograd/generated/variable_factories.h> |
7 | #include <torch/csrc/jit/api/module.h> |
8 | #include <torch/csrc/jit/backends/backend_debug_handler.h> |
9 | #include <torch/csrc/jit/frontend/resolver.h> |
10 | #include <torch/csrc/jit/mobile/import.h> |
11 | #include <torch/csrc/jit/mobile/module.h> |
12 | #include <torch/csrc/jit/passes/inliner.h> |
13 | #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h> |
14 | #include <torch/csrc/jit/serialization/export.h> |
15 | #include <torch/csrc/jit/serialization/import.h> |
16 | #include <torch/custom_class.h> |
17 | #include <torch/torch.h> |
18 | |
19 | #include <stack> |
20 | #include <unordered_set> |
21 | |
22 | // Tests go in torch::jit |
23 | namespace torch { |
24 | namespace jit { |
25 | |
26 | namespace { |
27 | bool validate_debug_info( |
28 | const DebugInfoTuple& pre_serialize, |
29 | const DebugInfoTuple& post_serialize) { |
30 | auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize); |
31 | auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize); |
32 | if (sr1 != sr2) { |
33 | return false; |
34 | } |
35 | auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize); |
36 | auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize); |
37 | if (!csptr1.defined()) { |
38 | return !csptr2.defined(); |
39 | } |
40 | if (!csptr2.defined()) { |
41 | return false; |
42 | } |
43 | auto vec1 = csptr1->vec(); |
44 | auto vec2 = csptr2->vec(); |
45 | if (vec1.size() != vec2.size()) { |
46 | return false; |
47 | } |
48 | while (csptr1) { |
49 | auto rhs_sr = csptr1->source_range(); |
50 | auto lhs_sr = csptr2->source_range(); |
51 | auto rhs_module = csptr1->module_instance(); |
52 | auto lhs_module = csptr2->module_instance(); |
53 | std::string rhs_fn_name, lhs_fn_name; |
54 | if (csptr1->function()) { |
55 | rhs_fn_name = csptr1->function()->name(); |
56 | } else { |
57 | rhs_fn_name = csptr1->function_name(); |
58 | } |
59 | if (csptr2->function()) { |
60 | lhs_fn_name = csptr2->function()->name(); |
61 | } else { |
62 | lhs_fn_name = csptr2->function_name(); |
63 | } |
64 | if (!((rhs_module.has_value() == lhs_module.has_value()) && |
65 | (rhs_module.has_value() && |
66 | (rhs_module.value().class_type()->name().value() == |
67 | lhs_module.value().class_type()->name().value()) && |
68 | (rhs_module.value().instance_name() == |
69 | lhs_module.value().instance_name())) && |
70 | (rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) { |
71 | return false; |
72 | } |
73 | if (csptr1->callee()) { |
74 | csptr1 = csptr1->callee().value(); |
75 | csptr2 = csptr2->callee().value(); |
76 | } else { |
77 | csptr1 = c10::intrusive_ptr<InlinedCallStack>(); |
78 | } |
79 | } |
80 | return true; |
81 | } |
82 | |
83 | TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) { |
84 | std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>(); |
85 | Module a("A" , cu); |
86 | a.define(R"JIT( |
87 | def forward(self, x): |
88 | return x + 1 |
89 | )JIT" ); |
90 | Module b("B" , cu); |
91 | b.define(R"JIT( |
92 | def forward(self, x): |
93 | return x + 2 |
94 | )JIT" ); |
95 | Module c("C" , cu); |
96 | c.register_module("A0" , a); |
97 | c.register_module("B0" , b); |
98 | c.define(R"JIT( |
99 | def forward(self, x): |
100 | return self.A0.forward(x) + self.B0.forward(x) |
101 | )JIT" ); |
102 | |
103 | BackendDebugInfoRecorder debug_info_recorder; |
104 | auto graph = c.get_method("forward" ).graph(); |
105 | Inline(*graph); |
106 | std::stack<Block*> blocks_to_visit; |
107 | |
108 | // maps from source range to debug handle |
109 | SourceRangeTagMap source_range_tags; |
110 | // Maps from debug handle to source range |
111 | ska::flat_hash_map<int64_t, SourceRange> source_range_map; |
112 | int64_t source_range_tag{0}; |
113 | |
114 | blocks_to_visit.push(graph->block()); |
115 | while (!blocks_to_visit.empty()) { |
116 | Block* b = blocks_to_visit.top(); |
117 | blocks_to_visit.pop(); |
118 | for (Node* n : b->nodes()) { |
119 | source_range_tags[n->sourceRange()] = source_range_tag; |
120 | source_range_map[source_range_tag] = n->sourceRange(); |
121 | source_range_tag++; |
122 | debug_info_recorder.getNextDebugHandle(n); |
123 | if (n->callstack().has_value()) { |
124 | for (const auto& e : n->callstack().value()->vec()) { |
125 | auto sr = std::get<1>(e); |
126 | source_range_tags[sr] = source_range_tag; |
127 | source_range_map[source_range_tag] = sr; |
128 | source_range_tag++; |
129 | } |
130 | } |
131 | } |
132 | } |
133 | auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording(); |
134 | CallStackDebugInfoPickler cs_debug_info_pickler; |
135 | auto cs_data = |
136 | cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags); |
137 | at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU); |
138 | CallStackDebugInfoUnpickler unpickler; |
139 | auto deserialized_cs_map = unpickler.unpickle( |
140 | std::move(data_ptr), cs_data.size(), source_range_map, cu); |
141 | for (const auto& it : debug_handle_cs_ptr_map) { |
142 | auto handle = it.first; |
143 | auto debug_info_one = it.second; |
144 | TORCH_CHECK( |
145 | deserialized_cs_map.count(handle), |
146 | "Serialized debug handle must be in deserialized map." ); |
147 | auto debug_info_two = deserialized_cs_map[handle]; |
148 | ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two)); |
149 | } |
150 | } |
151 | |
152 | } // namespace |
153 | |
154 | } // namespace jit |
155 | } // namespace torch |
156 | |