1 | #pragma once |
2 | #include <c10/util/flat_hash_map.h> |
3 | #include <caffe2/serialize/inline_container.h> |
4 | #include <torch/csrc/jit/api/compilation_unit.h> |
5 | #include <torch/csrc/jit/ir/scope.h> |
6 | #include <torch/csrc/jit/serialization/source_range_serialization.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | /* |
11 | * MobileDebugTable: |
12 | * Deserializes debug_pkl and callstack_map records from PT model's zip archive |
13 | * and stores them in a map of debug handles to DebugInfoPair. Debug handles are |
14 | * unique per model and runtime, be in lite interpreter or delegate, an |
15 | * exception of BackendRuntimeException should raised using debug handles. |
16 | * getSourceDebugString method is responsible for translating debug |
17 | * handles to correspond debug information. |
18 | * This debug informatin includes stack trace of model level source code and |
19 | * module hierarchy where the exception occurred. |
20 | */ |
21 | class MobileDebugTable { |
22 | public: |
23 | MobileDebugTable() = default; |
24 | MobileDebugTable( |
25 | std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader, |
26 | const std::shared_ptr<CompilationUnit>& cu); |
27 | |
28 | template <typename It> |
29 | MobileDebugTable(It begin, It end) : callstack_ptr_map_(begin, end) {} |
30 | |
31 | std::string getSourceDebugString( |
32 | const int64_t debug_handle, |
33 | const std::string& top_module_type_name = "ModuleTypeUnknown" ) const; |
34 | std::string getSourceDebugString( |
35 | const std::vector<int64_t>& debug_handles, |
36 | const std::string& top_module_type_name = "ModuleTypeUnknown" ) const; |
37 | std::string getModuleHierarchyInfo( |
38 | const int64_t debug_handle, |
39 | const std::string& top_module_type_name = "ModuleTypeUnknown" ) const; |
40 | std::string getModuleHierarchyInfo( |
41 | const std::vector<int64_t>& debug_handles, |
42 | const std::string& top_module_type_name = "ModuleTypeUnknown" ) const; |
43 | |
44 | const ska::flat_hash_map<int64_t, DebugInfoTuple>& getCallStackPtrMap() |
45 | const { |
46 | return callstack_ptr_map_; |
47 | } |
48 | |
49 | private: |
50 | std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo( |
51 | const std::vector<int64_t>& debug_handles, |
52 | const std::string& top_module_type_name = "ModuleTypeUnknown" ) const; |
53 | ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptr_map_; |
54 | }; |
55 | |
56 | } // namespace jit |
57 | } // namespace torch |
58 | |