1#include <test/cpp/jit/test_utils.h>
2
3#include <gtest/gtest.h>
4
5#include <c10/core/TensorOptions.h>
6#include <torch/csrc/autograd/generated/variable_factories.h>
7#include <torch/csrc/jit/api/module.h>
8#include <torch/csrc/jit/backends/backend_debug_handler.h>
9#include <torch/csrc/jit/frontend/resolver.h>
10#include <torch/csrc/jit/mobile/import.h>
11#include <torch/csrc/jit/mobile/module.h>
12#include <torch/csrc/jit/passes/inliner.h>
13#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
14#include <torch/csrc/jit/serialization/export.h>
15#include <torch/csrc/jit/serialization/import.h>
16#include <torch/custom_class.h>
17#include <torch/torch.h>
18
19#include <stack>
20#include <unordered_set>
21
22// Tests go in torch::jit
23namespace torch {
24namespace jit {
25
26namespace {
27bool validate_debug_info(
28 const DebugInfoTuple& pre_serialize,
29 const DebugInfoTuple& post_serialize) {
30 auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize);
31 auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize);
32 if (sr1 != sr2) {
33 return false;
34 }
35 auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize);
36 auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize);
37 if (!csptr1.defined()) {
38 return !csptr2.defined();
39 }
40 if (!csptr2.defined()) {
41 return false;
42 }
43 auto vec1 = csptr1->vec();
44 auto vec2 = csptr2->vec();
45 if (vec1.size() != vec2.size()) {
46 return false;
47 }
48 while (csptr1) {
49 auto rhs_sr = csptr1->source_range();
50 auto lhs_sr = csptr2->source_range();
51 auto rhs_module = csptr1->module_instance();
52 auto lhs_module = csptr2->module_instance();
53 std::string rhs_fn_name, lhs_fn_name;
54 if (csptr1->function()) {
55 rhs_fn_name = csptr1->function()->name();
56 } else {
57 rhs_fn_name = csptr1->function_name();
58 }
59 if (csptr2->function()) {
60 lhs_fn_name = csptr2->function()->name();
61 } else {
62 lhs_fn_name = csptr2->function_name();
63 }
64 if (!((rhs_module.has_value() == lhs_module.has_value()) &&
65 (rhs_module.has_value() &&
66 (rhs_module.value().class_type()->name().value() ==
67 lhs_module.value().class_type()->name().value()) &&
68 (rhs_module.value().instance_name() ==
69 lhs_module.value().instance_name())) &&
70 (rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) {
71 return false;
72 }
73 if (csptr1->callee()) {
74 csptr1 = csptr1->callee().value();
75 csptr2 = csptr2->callee().value();
76 } else {
77 csptr1 = c10::intrusive_ptr<InlinedCallStack>();
78 }
79 }
80 return true;
81}
82
83TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) {
84 std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
85 Module a("A", cu);
86 a.define(R"JIT(
87 def forward(self, x):
88 return x + 1
89 )JIT");
90 Module b("B", cu);
91 b.define(R"JIT(
92 def forward(self, x):
93 return x + 2
94 )JIT");
95 Module c("C", cu);
96 c.register_module("A0", a);
97 c.register_module("B0", b);
98 c.define(R"JIT(
99 def forward(self, x):
100 return self.A0.forward(x) + self.B0.forward(x)
101 )JIT");
102
103 BackendDebugInfoRecorder debug_info_recorder;
104 auto graph = c.get_method("forward").graph();
105 Inline(*graph);
106 std::stack<Block*> blocks_to_visit;
107
108 // maps from source range to debug handle
109 SourceRangeTagMap source_range_tags;
110 // Maps from debug handle to source range
111 ska::flat_hash_map<int64_t, SourceRange> source_range_map;
112 int64_t source_range_tag{0};
113
114 blocks_to_visit.push(graph->block());
115 while (!blocks_to_visit.empty()) {
116 Block* b = blocks_to_visit.top();
117 blocks_to_visit.pop();
118 for (Node* n : b->nodes()) {
119 source_range_tags[n->sourceRange()] = source_range_tag;
120 source_range_map[source_range_tag] = n->sourceRange();
121 source_range_tag++;
122 debug_info_recorder.getNextDebugHandle(n);
123 if (n->callstack().has_value()) {
124 for (const auto& e : n->callstack().value()->vec()) {
125 auto sr = std::get<1>(e);
126 source_range_tags[sr] = source_range_tag;
127 source_range_map[source_range_tag] = sr;
128 source_range_tag++;
129 }
130 }
131 }
132 }
133 auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
134 CallStackDebugInfoPickler cs_debug_info_pickler;
135 auto cs_data =
136 cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags);
137 at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU);
138 CallStackDebugInfoUnpickler unpickler;
139 auto deserialized_cs_map = unpickler.unpickle(
140 std::move(data_ptr), cs_data.size(), source_range_map, cu);
141 for (const auto& it : debug_handle_cs_ptr_map) {
142 auto handle = it.first;
143 auto debug_info_one = it.second;
144 TORCH_CHECK(
145 deserialized_cs_map.count(handle),
146 "Serialized debug handle must be in deserialized map.");
147 auto debug_info_two = deserialized_cs_map[handle];
148 ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two));
149 }
150}
151
152} // namespace
153
154} // namespace jit
155} // namespace torch
156