1 | #include <torch/csrc/profiler/kineto_shim.h> |
2 | |
3 | #include <type_traits> |
4 | |
5 | #ifdef USE_KINETO |
6 | #include <libkineto.h> |
7 | #endif |
8 | |
9 | #include <c10/util/Exception.h> |
10 | |
11 | namespace torch { |
12 | namespace profiler { |
13 | namespace impl { |
14 | namespace kineto { |
15 | |
16 | // Here lies pain and `#ifdef USE_KINETO` |
17 | |
18 | #ifdef USE_KINETO |
19 | namespace { |
20 | const std::set<libkineto::ActivityType> cpuTypes{ |
21 | libkineto::ActivityType::CPU_OP, |
22 | libkineto::ActivityType::CPU_INSTANT_EVENT, |
23 | libkineto::ActivityType::USER_ANNOTATION, |
24 | libkineto::ActivityType::EXTERNAL_CORRELATION, |
25 | libkineto::ActivityType::CUDA_RUNTIME, |
26 | libkineto::ActivityType::PYTHON_FUNCTION, |
27 | }; |
28 | |
29 | const std::set<libkineto::ActivityType> cudaTypes = { |
30 | libkineto::ActivityType::GPU_MEMCPY, |
31 | libkineto::ActivityType::GPU_MEMSET, |
32 | libkineto::ActivityType::CONCURRENT_KERNEL, |
33 | // CUDA_RUNTIME appears in both cpuTypes and cudaTypes. |
34 | libkineto::ActivityType::CUDA_RUNTIME, |
35 | }; |
36 | } // namespace |
37 | #endif // USE_KINETO |
38 | |
39 | static_assert( |
40 | c10::is_pod_v<DeviceAndResource>, |
41 | "Kineto specific details should be in `kineto_ids`." ); |
42 | |
43 | const DeviceAndResource kineto_ids() { |
44 | #ifdef USE_KINETO |
45 | return { |
46 | /*device=*/libkineto::processId(), |
47 | /*resource=*/libkineto::systemThreadId()}; |
48 | #else |
49 | return {}; |
50 | #endif // USE_KINETO |
51 | } |
52 | |
53 | void addMetadata( |
54 | const activity_t* activity, |
55 | const std::string& key, |
56 | const std::string& value) { |
57 | #ifdef USE_KINETO |
58 | // ActivityTraceInterface returns const pointers, so we have to cast away the |
59 | // constness to add metadata. |
60 | const_cast<activity_t*>(activity)->addMetadata(key, value); |
61 | #endif // USE_KINETO |
62 | } |
63 | |
64 | TraceWrapper::TraceWrapper(const int64_t start_time, const std::string& name) |
65 | #ifdef USE_KINETO |
66 | : cpu_trace_(std::make_unique<libkineto::CpuTraceBuffer>()) { |
67 | cpu_trace_->span.startTime = start_time; |
68 | cpu_trace_->gpuOpCount = -1; |
69 | cpu_trace_->span.name = name; |
70 | } |
71 | #else |
72 | { |
73 | } |
74 | #endif // USE_KINETO |
75 | |
76 | TraceWrapper::~TraceWrapper() = default; |
77 | |
78 | activity_t* TraceWrapper::addCPUActivity( |
79 | const std::string& name, |
80 | const libkineto::ActivityType type, |
81 | const DeviceAndResource device_and_resource, |
82 | const uint64_t correlation_id, |
83 | const int64_t start_time, |
84 | const int64_t end_time) { |
85 | #ifdef USE_KINETO |
86 | TORCH_CHECK((bool)(*this), "Cannot add event to non-existent trace." ); |
87 | cpu_trace_->emplace_activity(cpu_trace_->span, type, name); |
88 | auto& act = libkineto::CpuTraceBuffer::toRef(cpu_trace_->activities.back()); |
89 | act.device = device_and_resource.device; |
90 | act.resource = device_and_resource.resource; |
91 | act.id = correlation_id; |
92 | act.startTime = start_time; |
93 | if (type != libkineto::ActivityType::CPU_INSTANT_EVENT) { |
94 | act.endTime = end_time; |
95 | } |
96 | return cpu_trace_->activities.back().get(); |
97 | #else |
98 | return nullptr; |
99 | #endif // USE_KINETO |
100 | } |
101 | |
102 | void TraceWrapper::transferCpuTrace(int64_t end_time) { |
103 | #ifdef USE_KINETO |
104 | cpu_trace_->span.endTime = end_time; |
105 | libkineto::api().activityProfiler().transferCpuTrace(std::move(cpu_trace_)); |
106 | #endif // USE_KINETO |
107 | } |
108 | |
109 | TraceWrapper::operator bool() const { |
110 | #ifdef USE_KINETO |
111 | return cpu_trace_ != nullptr; |
112 | #else |
113 | return false; |
114 | #endif // USE_KINETO |
115 | } |
116 | |
117 | ActivityTraceWrapper::ActivityTraceWrapper( |
118 | std::unique_ptr<interface_trace_t>&& trace) |
119 | : trace_(std::move(trace)) {} |
120 | |
121 | ActivityTraceWrapper::operator bool() const { |
122 | #ifdef USE_KINETO |
123 | return trace_ != nullptr; |
124 | #else |
125 | return false; |
126 | #endif // USE_KINETO |
127 | } |
128 | |
129 | void ActivityTraceWrapper::save(const std::string& path) { |
130 | #ifdef USE_KINETO |
131 | TORCH_CHECK(!saved_, "Trace is already saved." ); |
132 | TORCH_CHECK(trace_ != nullptr, "Missing trace." ) |
133 | trace_->save(path); |
134 | saved_ = true; |
135 | #else |
136 | TORCH_CHECK( |
137 | false, |
138 | "Saving a trace requires using torch.profiler with Kineto support (USE_KINETO=1)" ); |
139 | #endif // USE_KINETO |
140 | } |
141 | |
142 | namespace { |
143 | // Handles processing of Experimental Config options for Kineto |
144 | class ExperimentalConfigWrapper { |
145 | public: |
146 | explicit ExperimentalConfigWrapper( |
147 | const torch::profiler::impl::ExperimentalConfig& config) |
148 | : config_(config) {} |
149 | |
150 | bool assertValid(const ActivitySet& activities) { |
151 | // Kineto supports reading performance events per kernel/iteration |
152 | // using CUPTI Range based profiler API. In this mode however we |
153 | // do not trace CPU or GPU events. |
154 | bool cupti_range_profiler = !config_.profiler_metrics.empty(); |
155 | if (cupti_range_profiler && |
156 | activities.count(torch::autograd::profiler::ActivityType::CPU)) { |
157 | LOG(WARNING) |
158 | << "Cannot run range profiler with CPU activities, please only" |
159 | << " use CUDA activity type" ; |
160 | return false; |
161 | } |
162 | return cupti_range_profiler; |
163 | } |
164 | |
165 | void prepareTraceWithExperimentalOptions() { |
166 | #ifdef USE_KINETO |
167 | std::set<libkineto::ActivityType> k_activities{ |
168 | libkineto::ActivityType::CUDA_PROFILER_RANGE}; |
169 | |
170 | const size_t num_metrics = config_.profiler_metrics.size(); |
171 | std::stringstream configss; |
172 | |
173 | LOG(INFO) << "CUPTI profiler metrics size = " << num_metrics; |
174 | |
175 | configss << "ACTIVITIES_WARMUP_PERIOD_SECS=0\n" |
176 | << "CUPTI_PROFILER_METRICS=" ; |
177 | |
178 | for (int i = 0; i < num_metrics; i++) { |
179 | configss << config_.profiler_metrics[i]; |
180 | if (num_metrics > 1 && i < (num_metrics - 1)) { |
181 | configss << "," ; |
182 | } |
183 | } |
184 | configss << "\nCUPTI_PROFILER_ENABLE_PER_KERNEL=" |
185 | << (config_.profiler_measure_per_kernel ? "true" : "false" ) |
186 | << "\n" ; |
187 | LOG(INFO) << "Generated config = " << configss.str(); |
188 | |
189 | libkineto::api().activityProfiler().prepareTrace( |
190 | k_activities, configss.str()); |
191 | #endif // USE_KINETO |
192 | } |
193 | |
194 | private: |
195 | const torch::profiler::impl::ExperimentalConfig& config_; |
196 | }; |
197 | } // namespace |
198 | |
199 | void prepareTrace( |
200 | const bool cpuOnly, |
201 | const ActivitySet& activities, |
202 | const torch::profiler::impl::ExperimentalConfig& config) { |
203 | #ifdef USE_KINETO |
204 | if (!libkineto::api().isProfilerRegistered()) { |
205 | libkineto_init(/*cpuOnly=*/cpuOnly, /*logOnError=*/true); |
206 | libkineto::api().suppressLogMessages(); |
207 | } |
208 | |
209 | if (!libkineto::api().isProfilerInitialized()) { |
210 | libkineto::api().initProfilerIfRegistered(); |
211 | } |
212 | |
213 | std::set<libkineto::ActivityType> k_activities; |
214 | if (activities.count(torch::autograd::profiler::ActivityType::CPU)) { |
215 | k_activities.insert(cpuTypes.begin(), cpuTypes.end()); |
216 | } |
217 | if (activities.count(torch::autograd::profiler::ActivityType::CUDA)) { |
218 | k_activities.insert(cudaTypes.begin(), cudaTypes.end()); |
219 | } |
220 | |
221 | ExperimentalConfigWrapper configWrap(config); |
222 | |
223 | // Experimental Configuration options are present |
224 | if (config && configWrap.assertValid(activities)) { |
225 | configWrap.prepareTraceWithExperimentalOptions(); |
226 | return; |
227 | } |
228 | |
229 | libkineto::api().activityProfiler().prepareTrace(k_activities); |
230 | #endif // USE_KINETO |
231 | } |
232 | |
233 | void startTrace() { |
234 | #ifdef USE_KINETO |
235 | libkineto::api().activityProfiler().startTrace(); |
236 | #endif // USE_KINETO |
237 | } |
238 | |
239 | ActivityTraceWrapper stopTrace() { |
240 | return ActivityTraceWrapper{ |
241 | #ifdef USE_KINETO |
242 | libkineto::api().activityProfiler().stopTrace() |
243 | #else |
244 | std::make_unique<interface_trace_t>() |
245 | #endif // USE_KINETO |
246 | }; |
247 | } |
248 | |
249 | void pushCorrelationId(uint64_t correlation_id) { |
250 | #ifdef USE_KINETO |
251 | libkineto::api().activityProfiler().pushCorrelationId(correlation_id); |
252 | #endif // USE_KINETO |
253 | } |
254 | |
255 | void pushUserCorrelationId(uint64_t correlation_id) { |
256 | #ifdef USE_KINETO |
257 | libkineto::api().activityProfiler().pushUserCorrelationId(correlation_id); |
258 | #endif // USE_KINETO |
259 | } |
260 | |
261 | void popCorrelationId() { |
262 | #ifdef USE_KINETO |
263 | libkineto::api().activityProfiler().popCorrelationId(); |
264 | #endif // USE_KINETO |
265 | } |
266 | |
267 | void popUserCorrelationId() { |
268 | #ifdef USE_KINETO |
269 | libkineto::api().activityProfiler().popUserCorrelationId(); |
270 | #endif // USE_KINETO |
271 | } |
272 | |
273 | void recordThreadInfo() { |
274 | #ifdef USE_KINETO |
275 | libkineto::api().activityProfiler().recordThreadInfo(); |
276 | #endif // USE_KINETO |
277 | } |
278 | |
279 | void logInvariantViolation( |
280 | const std::string& assertion, |
281 | const std::string& error, |
282 | const std::string& profile_id, |
283 | const std::string& group_profile_id) { |
284 | #ifdef USE_KINETO |
285 | if (libkineto::api().isProfilerInitialized()) { |
286 | libkineto::api().activityProfiler().logInvariantViolation( |
287 | profile_id, assertion, error, group_profile_id); |
288 | } |
289 | #endif // USE_KINETO |
290 | } |
291 | |
292 | } // namespace kineto |
293 | } // namespace impl |
294 | } // namespace profiler |
295 | |
296 | namespace autograd { |
297 | namespace profiler { |
298 | c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { |
299 | // fallthrough |
300 | switch (activity_type) { |
301 | case libkineto::ActivityType::GPU_MEMCPY: |
302 | case libkineto::ActivityType::GPU_MEMSET: |
303 | case libkineto::ActivityType::CONCURRENT_KERNEL: |
304 | case libkineto::ActivityType::GPU_USER_ANNOTATION: |
305 | case libkineto::ActivityType::CUDA_PROFILER_RANGE: |
306 | return c10::DeviceType::CUDA; |
307 | case libkineto::ActivityType::CPU_OP: |
308 | case libkineto::ActivityType::USER_ANNOTATION: |
309 | case libkineto::ActivityType::EXTERNAL_CORRELATION: |
310 | case libkineto::ActivityType::CUDA_RUNTIME: |
311 | case libkineto::ActivityType::CPU_INSTANT_EVENT: |
312 | case libkineto::ActivityType::GLOW_RUNTIME: |
313 | case libkineto::ActivityType::PYTHON_FUNCTION: |
314 | return c10::DeviceType::CPU; |
315 | default: { |
316 | TORCH_WARN( |
317 | "Unknown activity type (" , |
318 | (uint8_t)activity_type, |
319 | "), assuming CPU device" ); |
320 | return c10::DeviceType::CPU; |
321 | } |
322 | } |
323 | } |
324 | |
325 | void addMetadataJson(const std::string& key, const std::string& value) { |
326 | #ifdef USE_KINETO |
327 | if (libkineto::api().isProfilerInitialized()) { |
328 | libkineto::api().activityProfiler().addMetadata(key, value); |
329 | } else { |
330 | LOG(WARNING) << "Profiler is not initialized: skipping profiling metadata" ; |
331 | } |
332 | #else |
333 | LOG(WARNING) << "Adding profiling metadata requires using " |
334 | << "torch.profiler with Kineto support (USE_KINETO=1)" ; |
335 | #endif // USE_KINETO |
336 | } |
337 | |
338 | void profilerStep() { |
339 | #ifdef USE_KINETO |
340 | if (libkineto::api().isProfilerInitialized()) { |
341 | libkineto::api().activityProfiler().step(); |
342 | } else { |
343 | LOG(WARNING) << "Profiler is not initialized: skipping step() invocation" ; |
344 | } |
345 | #endif // USE_KINETO |
346 | } |
347 | |
348 | } // namespace profiler |
349 | } // namespace autograd |
350 | } // namespace torch |
351 | |