1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <memory>
6#include <string>
7
8namespace c10 {
9
10enum class C10_API_ENUM DebugInfoKind : uint8_t {
11 PRODUCER_INFO = 0,
12 MOBILE_RUNTIME_INFO,
13 PROFILER_STATE,
14 INFERENCE_CONTEXT, // for inference usage
15 PARAM_COMMS_INFO,
16
17 TEST_INFO, // used only in tests
18 TEST_INFO_2, // used only in tests
19};
20
21class C10_API DebugInfoBase {
22 public:
23 DebugInfoBase() = default;
24 virtual ~DebugInfoBase() = default;
25};
26
27// Thread local debug information is propagated across the forward
28// (including async fork tasks) and backward passes and is supposed
29// to be utilized by the user's code to pass extra information from
30// the higher layers (e.g. model id) down to the lower levels
31// (e.g. to the operator observers used for debugging, logging,
32// profiling, etc)
33class C10_API ThreadLocalDebugInfo {
34 public:
35 static DebugInfoBase* get(DebugInfoKind kind);
36
37 // Get current ThreadLocalDebugInfo
38 static std::shared_ptr<ThreadLocalDebugInfo> current();
39
40 // Internal, use DebugInfoGuard/ThreadLocalStateGuard
41 static void _forceCurrentDebugInfo(
42 std::shared_ptr<ThreadLocalDebugInfo> info);
43
44 // Push debug info struct of a given kind
45 static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
46 // Pop debug info, throws in case the last pushed
47 // debug info is not of a given kind
48 static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
49 // Peek debug info, throws in case the last pushed debug info is not of the
50 // given kind
51 static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
52
53 private:
54 std::shared_ptr<DebugInfoBase> info_;
55 DebugInfoKind kind_;
56 std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
57
58 friend class DebugInfoGuard;
59};
60
61// DebugInfoGuard is used to set debug information,
62// ThreadLocalDebugInfo is semantically immutable, the values are set
63// through the scope-based guard object.
64// Nested DebugInfoGuard adds/overrides existing values in the scope,
65// restoring the original values after exiting the scope.
66// Users can access the values through the ThreadLocalDebugInfo::get() call;
67class C10_API DebugInfoGuard {
68 public:
69 DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
70
71 explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
72
73 ~DebugInfoGuard();
74
75 DebugInfoGuard(const DebugInfoGuard&) = delete;
76 DebugInfoGuard(DebugInfoGuard&&) = delete;
77
78 private:
79 bool active_ = false;
80 std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
81};
82
83} // namespace c10
84