1#include <c10/util/Exception.h>
2#include <c10/util/ThreadLocal.h>
3#include <c10/util/ThreadLocalDebugInfo.h>
4
5#include <utility>
6
7namespace c10 {
8
9C10_DEFINE_TLS_static(std::shared_ptr<ThreadLocalDebugInfo>, tls_debug_info);
10#define debug_info (tls_debug_info.get())
11
12/* static */
13DebugInfoBase* ThreadLocalDebugInfo::get(DebugInfoKind kind) {
14 ThreadLocalDebugInfo* cur = debug_info.get();
15 while (cur) {
16 if (cur->kind_ == kind) {
17 return cur->info_.get();
18 }
19 cur = cur->parent_info_.get();
20 }
21 return nullptr;
22}
23
24/* static */
25std::shared_ptr<ThreadLocalDebugInfo> ThreadLocalDebugInfo::current() {
26 return debug_info;
27}
28
29/* static */
30void ThreadLocalDebugInfo::_forceCurrentDebugInfo(
31 std::shared_ptr<ThreadLocalDebugInfo> info) {
32 debug_info = std::move(info);
33}
34
35/* static */
36void ThreadLocalDebugInfo::_push(
37 DebugInfoKind kind,
38 std::shared_ptr<DebugInfoBase> info) {
39 auto prev_info = debug_info;
40 debug_info = std::make_shared<ThreadLocalDebugInfo>();
41 debug_info->parent_info_ = prev_info;
42 debug_info->kind_ = kind;
43 debug_info->info_ = std::move(info);
44}
45
46/* static */
47std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_pop(DebugInfoKind kind) {
48 TORCH_CHECK(
49 debug_info && debug_info->kind_ == kind,
50 "Expected debug info of type ",
51 (size_t)kind);
52 auto res = debug_info;
53 debug_info = debug_info->parent_info_;
54 return res->info_;
55}
56
57/* static */
58std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_peek(DebugInfoKind kind) {
59 TORCH_CHECK(
60 debug_info && debug_info->kind_ == kind,
61 "Expected debug info of type ",
62 (size_t)kind);
63 return debug_info->info_;
64}
65
66DebugInfoGuard::DebugInfoGuard(
67 DebugInfoKind kind,
68 std::shared_ptr<DebugInfoBase> info) {
69 if (!info) {
70 return;
71 }
72 prev_info_ = debug_info;
73 ThreadLocalDebugInfo::_push(kind, std::move(info));
74 active_ = true;
75}
76
77DebugInfoGuard::~DebugInfoGuard() {
78 if (active_) {
79 debug_info = prev_info_;
80 }
81}
82
83// Used only for setting a debug info after crossing the thread boundary;
84// in this case we assume that thread pool's thread does not have an
85// active debug info
86DebugInfoGuard::DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info) {
87 if (!info) {
88 return;
89 }
90 prev_info_ = std::move(debug_info);
91 debug_info = std::move(info);
92 active_ = true;
93}
94
95} // namespace c10
96