1 | #include <torch/csrc/profiler/orchestration/observer.h> |
2 | |
3 | #include <torch/csrc/profiler/util.h> |
4 | |
5 | #include <utility> |
6 | |
7 | namespace torch { |
8 | namespace profiler { |
9 | namespace impl { |
10 | |
11 | using GlobalManager = GlobalStateManager<ProfilerStateBase>; |
12 | |
13 | // ---------------------------------------------------------------------------- |
14 | // -- Profiler Config --------------------------------------------------------- |
15 | // ---------------------------------------------------------------------------- |
16 | ExperimentalConfig::ExperimentalConfig( |
17 | std::vector<std::string> profiler_metrics, |
18 | bool profiler_measure_per_kernel, |
19 | bool verbose, |
20 | std::vector<std::string> performance_events, |
21 | bool adjust_timestamps) |
22 | : profiler_metrics{std::move(profiler_metrics)}, |
23 | profiler_measure_per_kernel{profiler_measure_per_kernel}, |
24 | verbose{verbose}, |
25 | performance_events(std::move(performance_events)), |
26 | adjust_timestamps{adjust_timestamps} {} |
27 | |
28 | /*explicit*/ ExperimentalConfig::operator bool() const { |
29 | return !profiler_metrics.empty(); |
30 | } |
31 | |
32 | ProfilerConfig::ProfilerConfig( |
33 | ProfilerState state, |
34 | bool report_input_shapes, |
35 | bool profile_memory, |
36 | bool with_stack, |
37 | bool with_flops, |
38 | bool with_modules, |
39 | ExperimentalConfig experimental_config) |
40 | : state{state}, |
41 | experimental_config{experimental_config}, |
42 | report_input_shapes{report_input_shapes}, |
43 | profile_memory{profile_memory}, |
44 | with_stack{with_stack}, |
45 | with_flops{with_flops}, |
46 | with_modules{with_modules} {} |
47 | |
48 | bool ProfilerConfig::disabled() const { |
49 | return state == torch::profiler::impl::ProfilerState::Disabled; |
50 | } |
51 | |
52 | bool ProfilerConfig::global() const { |
53 | return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND; |
54 | } |
55 | |
56 | namespace { |
57 | enum ProfilerIValueIdx { |
58 | STATE = 0, |
59 | REPORT_INPUT_SHAPES, |
60 | PROFILE_MEMORY, |
61 | NUM_PROFILER_CFG_IVALUE_IDX // must be last in list |
62 | }; |
63 | } // namespace |
64 | |
65 | at::IValue ProfilerConfig::toIValue() const { |
66 | c10::impl::GenericList eventIValueList(at::AnyType::get()); |
67 | eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX); |
68 | eventIValueList.emplace_back(static_cast<int64_t>(state)); |
69 | eventIValueList.emplace_back(report_input_shapes); |
70 | eventIValueList.emplace_back(profile_memory); |
71 | return eventIValueList; |
72 | } |
73 | |
74 | ProfilerConfig ProfilerConfig::fromIValue( |
75 | const at::IValue& profilerConfigIValue) { |
76 | TORCH_INTERNAL_ASSERT( |
77 | profilerConfigIValue.isList(), |
78 | "Expected IValue to contain type c10::impl::GenericList" ); |
79 | auto ivalues = profilerConfigIValue.toList(); |
80 | TORCH_INTERNAL_ASSERT( |
81 | ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX, |
82 | c10::str( |
83 | "Expected exactly " , |
84 | NUM_PROFILER_CFG_IVALUE_IDX, |
85 | " ivalues to resconstruct ProfilerConfig." )); |
86 | return ProfilerConfig( |
87 | static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()), |
88 | ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(), |
89 | ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool()); |
90 | } |
91 | |
92 | // ---------------------------------------------------------------------------- |
93 | // -- Profiler base class ----------------------------------------------------- |
94 | // ---------------------------------------------------------------------------- |
95 | /*explicit*/ ProfilerStateBase::ProfilerStateBase(const ProfilerConfig& config) |
96 | : c10::MemoryReportingInfoBase(), config_(config) {} |
97 | |
98 | ProfilerStateBase::~ProfilerStateBase() { |
99 | if (handle_) { |
100 | auto handle = handle_; |
101 | removeCallback(); |
102 | SOFT_ASSERT(false, "Leaked callback handle: " , handle); |
103 | } |
104 | } |
105 | |
106 | /*static*/ ProfilerStateBase* ProfilerStateBase::get(bool global) { |
107 | auto* out = global |
108 | ? GlobalManager::get() |
109 | : static_cast<ProfilerStateBase*>( |
110 | c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)); |
111 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global); |
112 | return out; |
113 | } |
114 | |
115 | /*static*/ void ProfilerStateBase::push( |
116 | std::shared_ptr<ProfilerStateBase>&& state) { |
117 | TORCH_INTERNAL_ASSERT(state != nullptr); |
118 | if (state->config().global()) { |
119 | GlobalManager::push(std::move(state)); |
120 | } else { |
121 | c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); |
122 | } |
123 | } |
124 | |
125 | namespace { |
126 | std::shared_ptr<ProfilerStateBase> popTLS() { |
127 | // If there is no active thread local profiler then we simply return null. |
128 | // However if there is an active profiler but it is not the top |
129 | // `DebugInfoBase`then `c10::ThreadLocalDebugInfo::_pop` will throw. |
130 | // TODO(robieta): make `noexcept` version. |
131 | return c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE) |
132 | ? std::static_pointer_cast<ProfilerStateBase>( |
133 | c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE)) |
134 | : nullptr; |
135 | } |
136 | } // namespace |
137 | |
138 | /*static*/ std::shared_ptr<ProfilerStateBase> ProfilerStateBase::pop( |
139 | bool global) { |
140 | auto out = global ? GlobalManager::pop() : popTLS(); |
141 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global); |
142 | return out; |
143 | } |
144 | |
145 | void ProfilerStateBase::setCallbackHandle(at::CallbackHandle handle) { |
146 | if (handle_) { |
147 | at::removeCallback(handle_); |
148 | SOFT_ASSERT( |
149 | false, |
150 | "ProfilerStateBase already has a registered callback. " |
151 | "Removing to avoid leaked callback." ); |
152 | } |
153 | |
154 | handle_ = handle; |
155 | } |
156 | |
157 | void ProfilerStateBase::removeCallback() { |
158 | if (handle_) { |
159 | at::removeCallback(handle_); |
160 | handle_ = 0; |
161 | } |
162 | } |
163 | |
164 | bool profilerEnabled() { |
165 | auto* state_ptr = ProfilerStateBase::get(/*global=*/false); |
166 | return state_ptr && !state_ptr->config().disabled(); |
167 | } |
168 | |
169 | TORCH_API ActiveProfilerType profilerType() { |
170 | auto* state_ptr = ProfilerStateBase::get(/*global=*/false); |
171 | return state_ptr == nullptr ? ActiveProfilerType::NONE |
172 | : state_ptr->profilerType(); |
173 | } |
174 | |
175 | torch::profiler::impl::ProfilerConfig getProfilerConfig() { |
176 | auto* state_ptr = ProfilerStateBase::get(/*global=*/false); |
177 | TORCH_CHECK( |
178 | state_ptr, |
179 | "Tried to access profiler config, but profiler is not enabled!" ); |
180 | return state_ptr->config(); |
181 | } |
182 | |
183 | } // namespace impl |
184 | } // namespace profiler |
185 | } // namespace torch |
186 | |