1#include <torch/csrc/profiler/orchestration/observer.h>
2
3#include <torch/csrc/profiler/util.h>
4
5#include <utility>
6
7namespace torch {
8namespace profiler {
9namespace impl {
10
11using GlobalManager = GlobalStateManager<ProfilerStateBase>;
12
13// ----------------------------------------------------------------------------
14// -- Profiler Config ---------------------------------------------------------
15// ----------------------------------------------------------------------------
16ExperimentalConfig::ExperimentalConfig(
17 std::vector<std::string> profiler_metrics,
18 bool profiler_measure_per_kernel,
19 bool verbose,
20 std::vector<std::string> performance_events,
21 bool adjust_timestamps)
22 : profiler_metrics{std::move(profiler_metrics)},
23 profiler_measure_per_kernel{profiler_measure_per_kernel},
24 verbose{verbose},
25 performance_events(std::move(performance_events)),
26 adjust_timestamps{adjust_timestamps} {}
27
28/*explicit*/ ExperimentalConfig::operator bool() const {
29 return !profiler_metrics.empty();
30}
31
32ProfilerConfig::ProfilerConfig(
33 ProfilerState state,
34 bool report_input_shapes,
35 bool profile_memory,
36 bool with_stack,
37 bool with_flops,
38 bool with_modules,
39 ExperimentalConfig experimental_config)
40 : state{state},
41 experimental_config{experimental_config},
42 report_input_shapes{report_input_shapes},
43 profile_memory{profile_memory},
44 with_stack{with_stack},
45 with_flops{with_flops},
46 with_modules{with_modules} {}
47
48bool ProfilerConfig::disabled() const {
49 return state == torch::profiler::impl::ProfilerState::Disabled;
50}
51
52bool ProfilerConfig::global() const {
53 return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
54}
55
56namespace {
57enum ProfilerIValueIdx {
58 STATE = 0,
59 REPORT_INPUT_SHAPES,
60 PROFILE_MEMORY,
61 NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
62};
63} // namespace
64
65at::IValue ProfilerConfig::toIValue() const {
66 c10::impl::GenericList eventIValueList(at::AnyType::get());
67 eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
68 eventIValueList.emplace_back(static_cast<int64_t>(state));
69 eventIValueList.emplace_back(report_input_shapes);
70 eventIValueList.emplace_back(profile_memory);
71 return eventIValueList;
72}
73
74ProfilerConfig ProfilerConfig::fromIValue(
75 const at::IValue& profilerConfigIValue) {
76 TORCH_INTERNAL_ASSERT(
77 profilerConfigIValue.isList(),
78 "Expected IValue to contain type c10::impl::GenericList");
79 auto ivalues = profilerConfigIValue.toList();
80 TORCH_INTERNAL_ASSERT(
81 ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX,
82 c10::str(
83 "Expected exactly ",
84 NUM_PROFILER_CFG_IVALUE_IDX,
85 " ivalues to resconstruct ProfilerConfig."));
86 return ProfilerConfig(
87 static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()),
88 ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(),
89 ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool());
90}
91
92// ----------------------------------------------------------------------------
93// -- Profiler base class -----------------------------------------------------
94// ----------------------------------------------------------------------------
95/*explicit*/ ProfilerStateBase::ProfilerStateBase(const ProfilerConfig& config)
96 : c10::MemoryReportingInfoBase(), config_(config) {}
97
98ProfilerStateBase::~ProfilerStateBase() {
99 if (handle_) {
100 auto handle = handle_;
101 removeCallback();
102 SOFT_ASSERT(false, "Leaked callback handle: ", handle);
103 }
104}
105
106/*static*/ ProfilerStateBase* ProfilerStateBase::get(bool global) {
107 auto* out = global
108 ? GlobalManager::get()
109 : static_cast<ProfilerStateBase*>(
110 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
111 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
112 return out;
113}
114
115/*static*/ void ProfilerStateBase::push(
116 std::shared_ptr<ProfilerStateBase>&& state) {
117 TORCH_INTERNAL_ASSERT(state != nullptr);
118 if (state->config().global()) {
119 GlobalManager::push(std::move(state));
120 } else {
121 c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
122 }
123}
124
125namespace {
126std::shared_ptr<ProfilerStateBase> popTLS() {
127 // If there is no active thread local profiler then we simply return null.
128 // However if there is an active profiler but it is not the top
129 // `DebugInfoBase`then `c10::ThreadLocalDebugInfo::_pop` will throw.
130 // TODO(robieta): make `noexcept` version.
131 return c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)
132 ? std::static_pointer_cast<ProfilerStateBase>(
133 c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE))
134 : nullptr;
135}
136} // namespace
137
138/*static*/ std::shared_ptr<ProfilerStateBase> ProfilerStateBase::pop(
139 bool global) {
140 auto out = global ? GlobalManager::pop() : popTLS();
141 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
142 return out;
143}
144
145void ProfilerStateBase::setCallbackHandle(at::CallbackHandle handle) {
146 if (handle_) {
147 at::removeCallback(handle_);
148 SOFT_ASSERT(
149 false,
150 "ProfilerStateBase already has a registered callback. "
151 "Removing to avoid leaked callback.");
152 }
153
154 handle_ = handle;
155}
156
157void ProfilerStateBase::removeCallback() {
158 if (handle_) {
159 at::removeCallback(handle_);
160 handle_ = 0;
161 }
162}
163
164bool profilerEnabled() {
165 auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
166 return state_ptr && !state_ptr->config().disabled();
167}
168
169TORCH_API ActiveProfilerType profilerType() {
170 auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
171 return state_ptr == nullptr ? ActiveProfilerType::NONE
172 : state_ptr->profilerType();
173}
174
175torch::profiler::impl::ProfilerConfig getProfilerConfig() {
176 auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
177 TORCH_CHECK(
178 state_ptr,
179 "Tried to access profiler config, but profiler is not enabled!");
180 return state_ptr->config();
181}
182
183} // namespace impl
184} // namespace profiler
185} // namespace torch
186