1 | #include <c10/util/Exception.h> |
---|---|
2 | #include <c10/util/ThreadLocal.h> |
3 | #include <c10/util/ThreadLocalDebugInfo.h> |
4 | |
5 | #include <utility> |
6 | |
7 | namespace c10 { |
8 | |
9 | C10_DEFINE_TLS_static(std::shared_ptr<ThreadLocalDebugInfo>, tls_debug_info); |
10 | #define debug_info (tls_debug_info.get()) |
11 | |
12 | /* static */ |
13 | DebugInfoBase* ThreadLocalDebugInfo::get(DebugInfoKind kind) { |
14 | ThreadLocalDebugInfo* cur = debug_info.get(); |
15 | while (cur) { |
16 | if (cur->kind_ == kind) { |
17 | return cur->info_.get(); |
18 | } |
19 | cur = cur->parent_info_.get(); |
20 | } |
21 | return nullptr; |
22 | } |
23 | |
24 | /* static */ |
25 | std::shared_ptr<ThreadLocalDebugInfo> ThreadLocalDebugInfo::current() { |
26 | return debug_info; |
27 | } |
28 | |
29 | /* static */ |
30 | void ThreadLocalDebugInfo::_forceCurrentDebugInfo( |
31 | std::shared_ptr<ThreadLocalDebugInfo> info) { |
32 | debug_info = std::move(info); |
33 | } |
34 | |
35 | /* static */ |
36 | void ThreadLocalDebugInfo::_push( |
37 | DebugInfoKind kind, |
38 | std::shared_ptr<DebugInfoBase> info) { |
39 | auto prev_info = debug_info; |
40 | debug_info = std::make_shared<ThreadLocalDebugInfo>(); |
41 | debug_info->parent_info_ = prev_info; |
42 | debug_info->kind_ = kind; |
43 | debug_info->info_ = std::move(info); |
44 | } |
45 | |
46 | /* static */ |
47 | std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_pop(DebugInfoKind kind) { |
48 | TORCH_CHECK( |
49 | debug_info && debug_info->kind_ == kind, |
50 | "Expected debug info of type ", |
51 | (size_t)kind); |
52 | auto res = debug_info; |
53 | debug_info = debug_info->parent_info_; |
54 | return res->info_; |
55 | } |
56 | |
57 | /* static */ |
58 | std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_peek(DebugInfoKind kind) { |
59 | TORCH_CHECK( |
60 | debug_info && debug_info->kind_ == kind, |
61 | "Expected debug info of type ", |
62 | (size_t)kind); |
63 | return debug_info->info_; |
64 | } |
65 | |
66 | DebugInfoGuard::DebugInfoGuard( |
67 | DebugInfoKind kind, |
68 | std::shared_ptr<DebugInfoBase> info) { |
69 | if (!info) { |
70 | return; |
71 | } |
72 | prev_info_ = debug_info; |
73 | ThreadLocalDebugInfo::_push(kind, std::move(info)); |
74 | active_ = true; |
75 | } |
76 | |
77 | DebugInfoGuard::~DebugInfoGuard() { |
78 | if (active_) { |
79 | debug_info = prev_info_; |
80 | } |
81 | } |
82 | |
83 | // Used only for setting a debug info after crossing the thread boundary; |
84 | // in this case we assume that thread pool's thread does not have an |
85 | // active debug info |
86 | DebugInfoGuard::DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info) { |
87 | if (!info) { |
88 | return; |
89 | } |
90 | prev_info_ = std::move(debug_info); |
91 | debug_info = std::move(info); |
92 | active_ = true; |
93 | } |
94 | |
95 | } // namespace c10 |
96 |