1 | #pragma once |
2 | |
3 | #include <ATen/record_function.h> |
4 | #include <torch/csrc/Export.h> |
5 | |
6 | #include <utility> |
7 | |
8 | namespace torch { |
9 | namespace profiler { |
10 | namespace impl { |
11 | |
12 | // ---------------------------------------------------------------------------- |
13 | // -- Profiler Config --------------------------------------------------------- |
14 | // ---------------------------------------------------------------------------- |
15 | enum class C10_API_ENUM ActivityType { |
16 | CPU = 0, |
17 | CUDA, // CUDA kernels, runtime |
18 | NUM_KINETO_ACTIVITIES, // must be the last one |
19 | }; |
20 | |
21 | enum class C10_API_ENUM ProfilerState { |
22 | Disabled = 0, |
23 | CPU, // CPU-only profiling |
24 | CUDA, // CPU + CUDA events |
25 | NVTX, // only emit NVTX markers |
26 | ITT, // only emit ITT markers |
27 | KINETO, // use libkineto |
28 | KINETO_GPU_FALLBACK, // use CUDA events when CUPTI is not available |
29 | KINETO_ONDEMAND, // run the profiler in on-demand mode |
30 | NUM_PROFILER_STATES, // must be the last one |
31 | }; |
32 | |
33 | enum class C10_API_ENUM ActiveProfilerType { |
34 | NONE = 0, |
35 | LEGACY, |
36 | KINETO, |
37 | NVTX, |
38 | ITT |
39 | }; |
40 | |
41 | struct TORCH_API ExperimentalConfig { |
42 | ExperimentalConfig( |
43 | std::vector<std::string> profiler_metrics = {}, |
44 | bool profiler_measure_per_kernel = false, |
45 | bool verbose = false, |
46 | std::vector<std::string> performance_events = {}, |
47 | bool adjust_timestamps = false); |
48 | ~ExperimentalConfig() = default; |
49 | explicit operator bool() const; |
50 | |
51 | std::vector<std::string> profiler_metrics; |
52 | bool profiler_measure_per_kernel; |
53 | bool verbose; |
54 | /* |
55 | * List of performance events to be profiled. |
56 | * An empty list will disable performance event based profiling altogether. |
57 | */ |
58 | std::vector<std::string> performance_events; |
59 | /* |
60 | * Controls whether or not timestamp adjustment occurs after profiling. |
61 | * The purpose of this is to adjust Vulkan event timelines to align with those |
62 | * of their parent CPU events. |
63 | * This sometimes requires increasing CPU event durations (to fully contain |
64 | * their child events) and delaying CPU event start times (to |
65 | * prevent overlaps), so this should not be used unless Vulkan events are |
66 | * being profiled and it is ok to use this modified timestamp/duration |
67 | * information instead of the the original information. |
68 | */ |
69 | bool adjust_timestamps; |
70 | }; |
71 | |
72 | struct TORCH_API ProfilerConfig { |
73 | ProfilerConfig( |
74 | ProfilerState state, |
75 | bool report_input_shapes = false, |
76 | bool profile_memory = false, |
77 | bool with_stack = false, |
78 | bool with_flops = false, |
79 | bool with_modules = false, |
80 | ExperimentalConfig experimental_config = ExperimentalConfig()); |
81 | ~ProfilerConfig() = default; |
82 | |
83 | bool disabled() const; |
84 | bool global() const; |
85 | |
86 | ProfilerState state; |
87 | ExperimentalConfig experimental_config; |
88 | bool report_input_shapes; |
89 | bool profile_memory; |
90 | bool with_stack; |
91 | bool with_flops; |
92 | bool with_modules; |
93 | |
94 | // For serialization |
95 | at::IValue toIValue() const; |
96 | static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue); |
97 | }; |
98 | |
99 | // ---------------------------------------------------------------------------- |
100 | // -- Profiler base class ----------------------------------------------------- |
101 | // ---------------------------------------------------------------------------- |
102 | struct TORCH_API ProfilerStateBase : public c10::MemoryReportingInfoBase { |
103 | explicit ProfilerStateBase(const ProfilerConfig& config); |
104 | ~ProfilerStateBase() override; |
105 | |
106 | static ProfilerStateBase* get(bool global); |
107 | static ProfilerStateBase* get() { |
108 | auto* out = get(/*global=*/true); |
109 | return out ? out : get(/*global=*/false); |
110 | } |
111 | |
112 | static void push(std::shared_ptr<ProfilerStateBase>&& state); |
113 | |
114 | static std::shared_ptr<ProfilerStateBase> pop(bool global); |
115 | static std::shared_ptr<ProfilerStateBase> pop() { |
116 | auto out = pop(/*global=*/true); |
117 | return out ? std::move(out) : pop(/*global=*/false); |
118 | } |
119 | |
120 | const ProfilerConfig& config() const { |
121 | return config_; |
122 | } |
123 | |
124 | void setCallbackHandle(at::CallbackHandle handle); |
125 | void removeCallback(); |
126 | |
127 | bool memoryProfilingEnabled() const override { |
128 | return config_.profile_memory; |
129 | } |
130 | |
131 | virtual ActiveProfilerType profilerType() = 0; |
132 | |
133 | protected: |
134 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
135 | std::mutex state_mutex_; |
136 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
137 | ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); |
138 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
139 | at::CallbackHandle handle_ = 0; |
140 | }; |
141 | |
142 | // Note: The following are only for the active *thread local* profiler. |
143 | TORCH_API bool profilerEnabled(); |
144 | TORCH_API ActiveProfilerType profilerType(); |
145 | TORCH_API ProfilerConfig getProfilerConfig(); |
146 | |
147 | } // namespace impl |
148 | } // namespace profiler |
149 | } // namespace torch |
150 | |