1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <memory> |
6 | #include <string> |
7 | |
8 | namespace c10 { |
9 | |
10 | enum class C10_API_ENUM DebugInfoKind : uint8_t { |
11 | PRODUCER_INFO = 0, |
12 | MOBILE_RUNTIME_INFO, |
13 | PROFILER_STATE, |
14 | INFERENCE_CONTEXT, // for inference usage |
15 | PARAM_COMMS_INFO, |
16 | |
17 | TEST_INFO, // used only in tests |
18 | TEST_INFO_2, // used only in tests |
19 | }; |
20 | |
21 | class C10_API DebugInfoBase { |
22 | public: |
23 | DebugInfoBase() = default; |
24 | virtual ~DebugInfoBase() = default; |
25 | }; |
26 | |
27 | // Thread local debug information is propagated across the forward |
28 | // (including async fork tasks) and backward passes and is supposed |
29 | // to be utilized by the user's code to pass extra information from |
30 | // the higher layers (e.g. model id) down to the lower levels |
31 | // (e.g. to the operator observers used for debugging, logging, |
32 | // profiling, etc) |
33 | class C10_API ThreadLocalDebugInfo { |
34 | public: |
35 | static DebugInfoBase* get(DebugInfoKind kind); |
36 | |
37 | // Get current ThreadLocalDebugInfo |
38 | static std::shared_ptr<ThreadLocalDebugInfo> current(); |
39 | |
40 | // Internal, use DebugInfoGuard/ThreadLocalStateGuard |
41 | static void _forceCurrentDebugInfo( |
42 | std::shared_ptr<ThreadLocalDebugInfo> info); |
43 | |
44 | // Push debug info struct of a given kind |
45 | static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info); |
46 | // Pop debug info, throws in case the last pushed |
47 | // debug info is not of a given kind |
48 | static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind); |
49 | // Peek debug info, throws in case the last pushed debug info is not of the |
50 | // given kind |
51 | static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind); |
52 | |
53 | private: |
54 | std::shared_ptr<DebugInfoBase> info_; |
55 | DebugInfoKind kind_; |
56 | std::shared_ptr<ThreadLocalDebugInfo> parent_info_; |
57 | |
58 | friend class DebugInfoGuard; |
59 | }; |
60 | |
61 | // DebugInfoGuard is used to set debug information, |
62 | // ThreadLocalDebugInfo is semantically immutable, the values are set |
63 | // through the scope-based guard object. |
64 | // Nested DebugInfoGuard adds/overrides existing values in the scope, |
65 | // restoring the original values after exiting the scope. |
66 | // Users can access the values through the ThreadLocalDebugInfo::get() call; |
67 | class C10_API DebugInfoGuard { |
68 | public: |
69 | DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info); |
70 | |
71 | explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info); |
72 | |
73 | ~DebugInfoGuard(); |
74 | |
75 | DebugInfoGuard(const DebugInfoGuard&) = delete; |
76 | DebugInfoGuard(DebugInfoGuard&&) = delete; |
77 | |
78 | private: |
79 | bool active_ = false; |
80 | std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr; |
81 | }; |
82 | |
83 | } // namespace c10 |
84 | |