1 | #include <cstring> |
2 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
3 | #include <torch/csrc/autograd/profiler_kineto.h> |
4 | |
5 | #include <c10/macros/Export.h> |
6 | #include <c10/util/C++17.h> |
7 | #include <c10/util/Exception.h> |
8 | #include <c10/util/flat_hash_map.h> |
9 | #include <c10/util/irange.h> |
10 | #include <c10/util/overloaded.h> |
11 | #include <c10/util/variant.h> |
12 | |
13 | #include <torch/csrc/profiler/api.h> |
14 | #include <torch/csrc/profiler/collection.h> |
15 | #include <torch/csrc/profiler/containers.h> |
16 | #include <torch/csrc/profiler/events.h> |
17 | #include <torch/csrc/profiler/kineto_shim.h> |
18 | #include <torch/csrc/profiler/orchestration/observer.h> |
19 | #include <torch/csrc/profiler/perf.h> |
20 | #include <torch/csrc/profiler/standalone/itt_observer.h> |
21 | #include <torch/csrc/profiler/standalone/nvtx_observer.h> |
22 | #include <torch/csrc/profiler/util.h> |
23 | |
24 | #include <ATen/Context.h> |
25 | |
26 | #include <deque> |
27 | #include <limits> |
28 | #include <sstream> |
29 | #include <stdexcept> |
30 | #include <utility> |
31 | |
32 | #ifdef USE_KINETO |
33 | #include <libkineto.h> |
34 | #include <time_since_epoch.h> |
35 | |
36 | #ifndef _MSC_VER |
37 | // TODO: TO be removed, once this properly works from libkineto |
38 | // Literal copy-n-paste from third_party/kineto/libkineto/src/WeakSymbols.cpp |
39 | extern "C" { |
40 | // This function is needed to avoid superfluous dependency on GNU OpenMP library |
41 | // when cuPTI is linked statically For more details see |
42 | // https://github.com/pytorch/pytorch/issues/51026 |
43 | __attribute__((weak)) int acc_get_device_type() { |
44 | throw std::runtime_error( |
45 | "Dummy implementation of acc_get_device_type is not supposed to be called!" ); |
46 | } |
47 | } // extern "C" |
48 | #endif // _MSC_VER |
49 | #endif // USE_KINETO |
50 | |
51 | namespace torch { |
52 | namespace autograd { |
53 | namespace profiler { |
54 | |
55 | namespace { |
56 | inline int64_t getTimeUs() { |
57 | #ifdef USE_KINETO |
58 | return libkineto::timeSinceEpoch(std::chrono::system_clock::now()); |
59 | #else |
60 | return torch::profiler::impl::getTime() / 1000; |
61 | #endif // USE_KINETO |
62 | } |
63 | |
64 | using torch::profiler::impl::ActiveProfilerType; |
65 | using torch::profiler::impl::dtypesToStr; |
66 | using torch::profiler::impl::EventType; |
67 | using torch::profiler::impl::ExtraFields; |
68 | using torch::profiler::impl::op_input_t; |
69 | using torch::profiler::impl::ProfilerStateBase; |
70 | using torch::profiler::impl::PyExtraFieldsBase; |
71 | using torch::profiler::impl::Result; |
72 | using torch::profiler::impl::shapesToStr; |
73 | using torch::profiler::impl::stacksToStr; |
74 | using torch::profiler::impl::TensorMetadata; |
75 | |
76 | auto shapesAndDtypes(const std::vector<op_input_t>& inputs) { |
77 | std::vector<std::vector<int64_t>> shapes; |
78 | std::vector<std::string> dtypes; |
79 | for (const auto& i : inputs) { |
80 | c10::visit( |
81 | c10::overloaded( |
82 | [&](const TensorMetadata& t) { |
83 | shapes.emplace_back(t.sizes_); |
84 | dtypes.emplace_back(scalarTypeToTypeMeta(t.dtype_).name()); |
85 | }, |
86 | [&](const std::vector<TensorMetadata>&) { |
87 | shapes.emplace_back(); |
88 | dtypes.emplace_back("TensorList" ); |
89 | }, |
90 | [&](const c10::IValue&) { |
91 | shapes.emplace_back(); |
92 | dtypes.emplace_back("Scalar" ); |
93 | }, |
94 | [&](const auto&) { |
95 | shapes.emplace_back(); |
96 | dtypes.emplace_back(); |
97 | }), |
98 | i); |
99 | } |
100 | return std::make_pair(shapes, dtypes); |
101 | } |
102 | |
103 | struct MetadataBase { |
104 | MetadataBase(const std::shared_ptr<Result>& result) |
105 | : kineto_activity_{result->kineto_activity_} { |
106 | if (c10::holds_alternative<ExtraFields<EventType::Kineto>>( |
107 | result->extra_fields_)) { |
108 | // In order to add metadata we have to downcast from |
109 | // `libkineto::ITraceActivity` to `libkineto::GenericTraceActivity`. We |
110 | // know that all activities provided by PyTorch are of the correct type, |
111 | // however Kineto profilers can (and do) add events that inherit directly |
112 | // from ITraceActivity. As a result, any Result which was constructed from |
113 | // an event that Kineto provided is unsafe to cast. |
114 | if (!(SOFT_ASSERT(!hasKinetoActivity()))) { |
115 | result->kineto_activity_ = nullptr; |
116 | } |
117 | kineto_activity_ = result->kineto_activity_; |
118 | } |
119 | } |
120 | |
121 | void addMetadata(const std::string& key, const std::string& value) { |
122 | if (kineto_activity_ && !value.empty() && value != "\"\"" ) { |
123 | torch::profiler::impl::kineto::addMetadata(kineto_activity_, key, value); |
124 | } |
125 | } |
126 | |
127 | bool hasKinetoActivity() const { |
128 | return kineto_activity_ != nullptr; |
129 | } |
130 | |
131 | private: |
132 | const torch::profiler::impl::kineto::activity_t* kineto_activity_{nullptr}; |
133 | }; |
134 | |
135 | struct AddTensorboardFields : public MetadataBase { |
136 | AddTensorboardFields( |
137 | const std::shared_ptr<Result>& result, |
138 | KinetoEvent& kineto_event) |
139 | : MetadataBase(result) { |
140 | result->visit(*this); |
141 | const auto module_hierarchy = kineto_event.moduleHierarchy(); |
142 | addMetadata("Module Hierarchy" , stacksToStr(module_hierarchy.vec(), "." )); |
143 | addMetadata("Call stack" , stacksToStr(kineto_event.stack().vec(), ";" )); |
144 | |
145 | result->visit_if_base<PyExtraFieldsBase>([&, this](const auto& i) -> void { |
146 | this->addMetadata("Python id" , std::to_string(i.id_)); |
147 | |
148 | c10::optional<std::string> parent_id; |
149 | std::shared_ptr<Result> parent = result->parent_.lock(); |
150 | while (parent && !parent_id.has_value()) { |
151 | parent->visit_if_base<PyExtraFieldsBase>( |
152 | [&](const auto& j) { parent_id = std::to_string(j.id_); }); |
153 | parent = parent->parent_.lock(); |
154 | } |
155 | this->addMetadata("Python parent id" , parent_id.value_or("null" )); |
156 | }); |
157 | } |
158 | |
159 | void (const ExtraFields<EventType::PyCall>& py_call) { |
160 | if (py_call.module_.has_value()) { |
161 | addMetadata("Python module id" , std::to_string(py_call.module_->id_)); |
162 | } |
163 | } |
164 | |
165 | template <typename T> |
166 | void operator()(const T&) {} |
167 | }; |
168 | |
169 | struct AddGenericMetadata : public MetadataBase { |
170 | AddGenericMetadata( |
171 | std::shared_ptr<Result>& result, |
172 | const torch::profiler::impl::ProfilerConfig* config) |
173 | : MetadataBase(result), config_(config) { |
174 | result->visit(*this); |
175 | if (config->experimental_config.verbose) { |
176 | result->visit_if_base<PyExtraFieldsBase>( |
177 | [&, this](const auto& i) -> void { |
178 | this->addMetadata("Python thread" , std::to_string(i.python_tid_)); |
179 | }); |
180 | } |
181 | } |
182 | |
183 | void (ExtraFields<EventType::TorchOp>& op_event) { |
184 | const auto shapes_and_dtypes = shapesAndDtypes(op_event.inputs_); |
185 | if (!shapes_and_dtypes.first.empty()) { |
186 | addMetadata("Input Dims" , shapesToStr(shapes_and_dtypes.first)); |
187 | } |
188 | |
189 | if (!shapes_and_dtypes.second.empty()) { |
190 | addMetadata("Input type" , dtypesToStr(shapes_and_dtypes.second)); |
191 | } |
192 | |
193 | if (config_ && !config_->experimental_config.performance_events.empty()) { |
194 | auto& event_names = config_->experimental_config.performance_events; |
195 | for (auto i = 0; i < op_event.perf_event_counters_->size(); ++i) { |
196 | addMetadata( |
197 | event_names[i], |
198 | std::to_string((*op_event.perf_event_counters_)[i])); |
199 | } |
200 | } |
201 | |
202 | // add information about an associated forward op, if a sequence number |
203 | // is available (e.g. during training) |
204 | if (op_event.sequence_number_ >= 0) { |
205 | addMetadata("Fwd thread id" , std::to_string(op_event.forward_tid_)); |
206 | addMetadata("Sequence number" , std::to_string(op_event.sequence_number_)); |
207 | } |
208 | } |
209 | |
210 | void (ExtraFields<EventType::Backend>& backend_event) { |
211 | if (!backend_event.backend_.empty()) { |
212 | addMetadata("Backend" , "\"" + backend_event.backend_ + "\"" ); |
213 | } |
214 | } |
215 | |
216 | void (const ExtraFields<EventType::Allocation>& alloc) { |
217 | addMetadata("Device Type" , std::to_string((int8_t)alloc.device_type_)); |
218 | addMetadata("Device Id" , std::to_string(alloc.device_index_)); |
219 | addMetadata("Addr" , std::to_string(reinterpret_cast<intptr_t>(alloc.ptr_))); |
220 | addMetadata("Bytes" , std::to_string(alloc.alloc_size_)); |
221 | addMetadata("Total Allocated" , std::to_string(alloc.total_allocated_)); |
222 | addMetadata("Total Reserved" , std::to_string(alloc.total_reserved_)); |
223 | } |
224 | |
225 | void (const ExtraFields<EventType::OutOfMemory>& alloc) { |
226 | addMetadata("Device Type" , std::to_string((int8_t)alloc.device_type_)); |
227 | addMetadata("Device Id" , std::to_string(alloc.device_index_)); |
228 | addMetadata("Bytes" , std::to_string(alloc.alloc_size_)); |
229 | addMetadata("Total Allocated" , std::to_string(alloc.total_allocated_)); |
230 | addMetadata("Total Reserved" , std::to_string(alloc.total_reserved_)); |
231 | } |
232 | |
233 | template <typename T> |
234 | void operator()(const T&) {} |
235 | |
236 | private: |
237 | /* To get names of the performance events */ |
238 | const torch::profiler::impl::ProfilerConfig* config_; |
239 | }; |
240 | |
241 | // Assumption: Total threads number will not exceed 2^16-1, and total ops will |
242 | // not exceed 2^48 -1. |
243 | static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) { |
244 | return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1))); |
245 | } |
246 | |
247 | struct KinetoThreadLocalState : public ProfilerStateBase { |
248 | explicit KinetoThreadLocalState( |
249 | const ProfilerConfig& config, |
250 | std::set<torch::profiler::impl::ActivityType> activities) |
251 | : ProfilerStateBase(config), |
252 | start_time_(getTimeUs()), |
253 | record_queue_(config, std::move(activities)) {} |
254 | ~KinetoThreadLocalState() override = default; |
255 | |
256 | static KinetoThreadLocalState* get(bool global) { |
257 | auto* state = ProfilerStateBase::get(/*global=*/global); |
258 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
259 | state == nullptr || |
260 | state->profilerType() == ActiveProfilerType::KINETO); |
261 | return static_cast<KinetoThreadLocalState*>(state); |
262 | } |
263 | |
264 | ActiveProfilerType profilerType() override { |
265 | return ActiveProfilerType::KINETO; |
266 | } |
267 | |
268 | void reportVulkanEventToProfiler(torch::profiler::impl::vulkan_id_t id) { |
269 | if (!config_.disabled()) { |
270 | record_queue_.getSubqueue()->emplace_vulkan_event( |
271 | torch::profiler::impl::getApproximateTime(), id); |
272 | } |
273 | } |
274 | |
275 | void reportMemoryUsage( |
276 | void* ptr, |
277 | int64_t alloc_size, |
278 | size_t total_allocated, |
279 | size_t total_reserved, |
280 | c10::Device device) override { |
281 | if (config_.profile_memory && !config_.disabled()) { |
282 | record_queue_.getSubqueue()->emplace_allocation_event( |
283 | torch::profiler::impl::getApproximateTime(), |
284 | ptr, |
285 | alloc_size, |
286 | total_allocated, |
287 | total_reserved, |
288 | device.type(), |
289 | device.index()); |
290 | } |
291 | } |
292 | |
293 | void reportOutOfMemory( |
294 | int64_t alloc_size, |
295 | size_t total_allocated, |
296 | size_t total_reserved, |
297 | c10::Device device) override { |
298 | if (config_.profile_memory && !config_.disabled()) { |
299 | record_queue_.getSubqueue()->emplace_ooms_event( |
300 | torch::profiler::impl::getApproximateTime(), |
301 | alloc_size, |
302 | total_allocated, |
303 | total_reserved, |
304 | device.type(), |
305 | device.index()); |
306 | } |
307 | } |
308 | |
309 | const post_process_t& getEventPostProcessingCallback() const { |
310 | return event_post_process_cb_; |
311 | } |
312 | |
313 | void setEventPostProcessingCallback(post_process_t&& cb) { |
314 | event_post_process_cb_ = std::move(cb); |
315 | } |
316 | |
317 | std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper> |
318 | finalizeTrace() { |
319 | auto end_time = getTimeUs(); |
320 | record_queue_.stop(); |
321 | |
322 | std::lock_guard<std::mutex> guard(state_mutex_); |
323 | auto converter = clock_converter_.makeConverter(); |
324 | auto records_and_trace = |
325 | record_queue_.getRecords(std::move(converter), start_time_, end_time); |
326 | |
327 | materializeOpEvents(records_and_trace.first); |
328 | |
329 | // finalizeCPUTrace(cpu_trace_.get()); |
330 | |
331 | // `kineto_events_` does not include Python events. Instead it exposes them |
332 | // via the `stacks` property. |
333 | kineto_events_.erase( |
334 | std::remove_if( |
335 | kineto_events_.begin(), |
336 | kineto_events_.end(), |
337 | [](const auto& i) { return i.isPythonFunction(); }), |
338 | kineto_events_.end()); |
339 | |
340 | return std::move(records_and_trace.second); |
341 | } |
342 | |
343 | template <typename T> |
344 | void invokeCallback(T& t) { |
345 | if (event_post_process_cb_) { |
346 | event_post_process_cb_(t.debug_handle_, t.jit_stack_, t.jit_modules_); |
347 | } |
348 | } |
349 | |
350 | void materializeOpEvents(std::vector<std::shared_ptr<Result>>& events) { |
351 | for (auto& e : events) { |
352 | if (e->parent_.expired()) { |
353 | event_tree_.push_back(e); |
354 | } |
355 | |
356 | if (e->finished_) { |
357 | e->visit(c10::overloaded( |
358 | [this](ExtraFields<EventType::TorchOp>& i) { invokeCallback(i); }, |
359 | [this](ExtraFields<EventType::Backend>& i) { invokeCallback(i); }, |
360 | [](auto&) {})); |
361 | |
362 | kineto_events_.emplace_back(e, config_.experimental_config.verbose); |
363 | AddTensorboardFields add_tb(e, kineto_events_.back()); |
364 | AddGenericMetadata add_generic(e, &config_); |
365 | |
366 | // It is not safe to use the activity after post processing. |
367 | e->kineto_activity_ = nullptr; |
368 | } |
369 | } |
370 | } |
371 | |
372 | void finalizeCPUTrace( |
373 | std::unique_ptr<torch::profiler::impl::kineto::trace_t>& cpu_trace) { |
374 | #ifndef USE_KINETO |
375 | } |
376 | #else // USE_KINETO |
377 | TORCH_INTERNAL_ASSERT( |
378 | cpu_trace->activities.size() == kineto_events_.size()); |
379 | // startThreadId_seqNum to pointer of activity. |
380 | // Low-16bits of startThreadId and low-48bits seqNum are concatenated into |
381 | // one uint64_t variable as key. |
382 | |
383 | // From the time being, we need disable the forward/backward correlation |
384 | // feature to workaround the crash bug. |
385 | // TODO: by Mike Guo |
386 | // reenable the forward/backward correlation when kineto fix the following |
387 | // raw pointer |
388 | // GenericTraceActivity.flow.linkedActivity |
389 | /* |
390 | std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> |
391 | tidSeq2activity; |
392 | |
393 | for (const auto idx : c10::irange(cpu_trace->activities.size())) { |
394 | auto& kineto_event = kineto_events_[idx]; |
395 | auto& activity = cpu_trace->activities[idx]; |
396 | |
397 | // add information about an associated forward op, if a sequence number |
398 | // is available (e.g. during training) |
399 | if (kineto_event.sequenceNr() >= 0) { |
400 | generateForwardBackwardLink( |
401 | kineto_event, fwd_bwd_link_id, activity, tidSeq2activity); |
402 | } |
403 | } |
404 | */ |
405 | } |
406 | |
407 | void generateForwardBackwardLink( |
408 | const KinetoEvent& kineto_event, |
409 | uint64_t& fwd_bwd_link_id, |
410 | libkineto::GenericTraceActivity& activity, |
411 | std::unordered_map<uint64_t, libkineto::GenericTraceActivity*>& |
412 | tidSeq2activity) { |
413 | if (kineto_event.fwdThreadId() > 0) { |
414 | // act is backward op. |
415 | uint64_t key = getForwardThreadKey( |
416 | kineto_event.fwdThreadId(), kineto_event.sequenceNr()); |
417 | auto iter = tidSeq2activity.find(key); |
418 | if (iter != tidSeq2activity.end()) { |
419 | libkineto::GenericTraceActivity* fwd = iter->second; |
420 | fwd->flow.start = true; |
421 | activity.flow.id = fwd->flow.id = fwd_bwd_link_id; |
422 | activity.flow.type = fwd->flow.type = libkineto::kLinkFwdBwd; |
423 | ++fwd_bwd_link_id; |
424 | } |
425 | } else if (kineto_event.startThreadId() != 0) { |
426 | // act is forward op. |
427 | uint64_t key = getForwardThreadKey( |
428 | kineto_event.startThreadId(), kineto_event.sequenceNr()); |
429 | // Assumption: Among all ops with same sequence number, |
430 | // the one with biggest start time is most likely launching backward op. |
431 | auto iter = tidSeq2activity.find(key); |
432 | if (iter == tidSeq2activity.end()) { |
433 | tidSeq2activity[key] = &activity; |
434 | } else { |
435 | // Now the sequence number is only incremented on creating a "Node" |
436 | // object for backward pass, by calling |
437 | // "at::sequence_number::get_and_increment()". Among all ops with same |
438 | // sequence number, the one with biggest startTime is the one launching |
439 | // backward op. |
440 | if (activity.startTime >= iter->second->startTime) { |
441 | tidSeq2activity[key] = &activity; |
442 | } |
443 | } |
444 | } |
445 | } |
446 | #endif // USE_KINETO |
447 | |
448 | uint64_t start_time_; |
449 | torch::profiler::impl::ApproximateClockToUnixTimeConverter clock_converter_; |
450 | torch::profiler::impl::RecordQueue record_queue_; |
451 | std::vector<KinetoEvent> kineto_events_; |
452 | std::vector<experimental_event_t> event_tree_; |
453 | // Optional, if event post-processing is enabled. |
454 | post_process_t event_post_process_cb_; |
455 | }; |
456 | |
457 | template <bool use_global_state_ptr = false> |
458 | std::unique_ptr<at::ObserverContext> onFunctionEnter( |
459 | const at::RecordFunction& fn) { |
460 | auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr); |
461 | if (!state_ptr) { |
462 | return nullptr; |
463 | } |
464 | return state_ptr->record_queue_.getSubqueue()->begin_op(fn); |
465 | } |
466 | |
467 | // @lint-ignore CLANGTIDY clang-diagnostic-unused-parameter |
468 | template <bool use_global_state_ptr = false> |
469 | void onFunctionExit( |
470 | const at::RecordFunction& fn, |
471 | at::ObserverContext* ctx_ptr) { |
472 | auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr); |
473 | if (!state_ptr) { |
474 | return; |
475 | } |
476 | const auto& config = state_ptr->config(); |
477 | auto* kineto_ctx_ptr = |
478 | static_cast<torch::profiler::impl::KinetoObserverContext*>(ctx_ptr); |
479 | TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr); |
480 | kineto_ctx_ptr->event_->end_time_ = |
481 | torch::profiler::impl::getApproximateTime(); |
482 | if (!config.experimental_config.performance_events.empty()) { |
483 | state_ptr->record_queue_.getSubqueue()->disable_perf_profiler( |
484 | *kineto_ctx_ptr->event_->counters_); |
485 | } |
486 | kineto_ctx_ptr->event_->basic_fields_.end_tid_ = |
487 | at::RecordFunction::currentThreadId(); |
488 | if (config.state == ProfilerState::KINETO_GPU_FALLBACK) { |
489 | try { |
490 | auto fallback = kineto_ctx_ptr->fallback_; |
491 | TORCH_INTERNAL_ASSERT(fallback != nullptr); |
492 | torch::profiler::impl::cudaStubs()->record( |
493 | nullptr, &fallback->cuda_event_end_, nullptr); |
494 | } catch (const std::exception& e) { |
495 | LOG(WARNING) << "Failed to record CUDA event. " << e.what(); |
496 | } |
497 | } |
498 | |
499 | if (fn.scope() == at::RecordScope::USER_SCOPE) { |
500 | torch::profiler::impl::kineto::popUserCorrelationId(); |
501 | } else { |
502 | torch::profiler::impl::kineto::popCorrelationId(); |
503 | } |
504 | } |
505 | |
506 | template <bool use_global_callback = false> |
507 | void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) { |
508 | auto registration_state_ptr = |
509 | KinetoThreadLocalState::get(use_global_callback); |
510 | TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set" ); |
511 | auto recordFunctionCallback = |
512 | at::RecordFunctionCallback( |
513 | onFunctionEnter<use_global_callback>, |
514 | onFunctionExit<use_global_callback>) |
515 | .needsInputs(registration_state_ptr->config().report_input_shapes) |
516 | .scopes(scopes); |
517 | |
518 | auto handle = c10::guts::if_constexpr<use_global_callback>( |
519 | [&] { return at::addGlobalCallback(recordFunctionCallback); }, |
520 | [&] { return at::addThreadLocalCallback(recordFunctionCallback); }); |
521 | registration_state_ptr->setCallbackHandle(handle); |
522 | } |
523 | |
524 | } // namespace |
525 | |
526 | void reportBackendEventToActiveKinetoProfiler( |
527 | const int64_t start_time_us, |
528 | const int64_t end_time_us, |
529 | const int64_t debug_handle, |
530 | const at::RecordScope scope, |
531 | const std::string& event_name, |
532 | const std::string& backend_name) { |
533 | TORCH_INTERNAL_ASSERT( |
534 | KinetoThreadLocalState::get(/*global=*/true) == nullptr, |
535 | "On-demand profiling does not support post processing callback" ); |
536 | |
537 | auto state_ptr = KinetoThreadLocalState::get(/*global=*/false); |
538 | if (!state_ptr) { |
539 | return; |
540 | } |
541 | |
542 | state_ptr->record_queue_.getSubqueue()->emplace_backend_event( |
543 | start_time_us, |
544 | end_time_us, |
545 | debug_handle, |
546 | scope, |
547 | event_name, |
548 | backend_name); |
549 | |
550 | /* no support for input shapes now? |
551 | if (config.report_input_shapes) { |
552 | ctx_ptr->shapes = inputSizes(fn); |
553 | ctx_ptr->dtypes = inputTypes(fn); |
554 | } |
555 | */ |
556 | } |
557 | |
558 | void prepareProfiler( |
559 | const torch::profiler::impl::ProfilerConfig& config, |
560 | const std::set<torch::profiler::impl::ActivityType>& activities) { |
561 | if (config.state == ProfilerState::NVTX || |
562 | config.state == ProfilerState::ITT) { |
563 | return; |
564 | } |
565 | TORCH_CHECK( |
566 | config.state == ProfilerState::KINETO || |
567 | config.state == ProfilerState::KINETO_GPU_FALLBACK, |
568 | "Supported only in Kineto profiler" ); |
569 | torch::profiler::impl::kineto::prepareTrace( |
570 | /*cpuOnly=*/!at::hasCUDA(), activities, config.experimental_config); |
571 | |
572 | if (!config.experimental_config.performance_events.empty()) { |
573 | /* For now only CPU activity is supported */ |
574 | TORCH_CHECK( |
575 | activities.count(torch::autograd::profiler::ActivityType::CPU), |
576 | "Cannot run cpu hardware profiler without CPU activities, please only use CPU activity type" ); |
577 | /* |
578 | * Sending a warning and passing the non-standard event to the backend |
579 | * Backend can abort if the event is not supported. |
580 | * TODO Should we gracefully drop the invalid event if we have atleast one |
581 | * valid? |
582 | */ |
583 | auto is_standard_event = [](const std::string& event) -> bool { |
584 | for (auto e : torch::profiler::ProfilerPerfEvents) { |
585 | if (!std::strcmp(event.c_str(), e)) { |
586 | return true; |
587 | } |
588 | } |
589 | return false; |
590 | }; |
591 | |
592 | for (const auto& e : config.experimental_config.performance_events) { |
593 | if (!is_standard_event(e)) { |
594 | TORCH_WARN("Forwarding a non-standard CPU performance event : " , e); |
595 | } |
596 | } |
597 | } |
598 | } |
599 | |
600 | void enableProfilerWithEventPostProcess( |
601 | const torch::profiler::impl::ProfilerConfig& config, |
602 | const std::set<torch::profiler::impl::ActivityType>& activities, |
603 | post_process_t&& cb, |
604 | const std::unordered_set<at::RecordScope>& scopes) { |
605 | TORCH_CHECK( |
606 | config.state != ProfilerState::NVTX, |
607 | "NVTX does not support post processing callback." ); |
608 | TORCH_CHECK( |
609 | config.state != ProfilerState::ITT, |
610 | "ITT does not support post processing callback." ); |
611 | TORCH_INTERNAL_ASSERT( |
612 | KinetoThreadLocalState::get(/*global=*/true) == nullptr, |
613 | "On-demand profiling does not support post processing callback" ); |
614 | |
615 | enableProfiler(config, activities, scopes); |
616 | auto state_ptr = KinetoThreadLocalState::get(config.global()); |
617 | state_ptr->setEventPostProcessingCallback(std::move(cb)); |
618 | } |
619 | |
620 | void enableProfiler( |
621 | const torch::profiler::impl::ProfilerConfig& config, |
622 | const std::set<torch::profiler::impl::ActivityType>& activities, |
623 | const std::unordered_set<at::RecordScope>& scopes) { |
624 | const auto has_cpu = activities.count(ActivityType::CPU); |
625 | TORCH_CHECK( |
626 | KinetoThreadLocalState::get(/*global=*/config.global()) == nullptr, |
627 | "Profiler is already enabled" , |
628 | (config.global() ? "." : " on this thread." )); |
629 | |
630 | if (config.state == ProfilerState::NVTX) { |
631 | torch::profiler::impl::pushNVTXCallbacks(config, scopes); |
632 | return; |
633 | } else if (config.state == ProfilerState::ITT) { |
634 | torch::profiler::impl::pushITTCallbacks(config, scopes); |
635 | return; |
636 | } |
637 | |
638 | TORCH_CHECK( |
639 | config.state == ProfilerState::KINETO || |
640 | config.state == ProfilerState::KINETO_GPU_FALLBACK || config.global()); |
641 | TORCH_CHECK(!activities.empty(), "No activities specified." ); |
642 | TORCH_INTERNAL_ASSERT( |
643 | has_cpu || !config.global(), |
644 | "Ondemand profiling must enable CPU tracing" ); |
645 | |
646 | KinetoThreadLocalState::push( |
647 | std::make_shared<KinetoThreadLocalState>(config, activities)); |
648 | |
649 | if (has_cpu) { |
650 | config.global() ? pushProfilingCallbacks</*global=*/true>(scopes) |
651 | : pushProfilingCallbacks</*global=*/false>(scopes); |
652 | } |
653 | |
654 | if (!config.global()) { |
655 | torch::profiler::impl::kineto::startTrace(); |
656 | } |
657 | } |
658 | |
659 | std::unique_ptr<ProfilerResult> disableProfiler() { |
660 | auto state_ptr = ProfilerStateBase::pop(); |
661 | const auto& config = state_ptr->config(); |
662 | TORCH_CHECK( |
663 | state_ptr && |
664 | (config.state == ProfilerState::KINETO || |
665 | config.state == ProfilerState::KINETO_GPU_FALLBACK || |
666 | config.state == ProfilerState::KINETO_ONDEMAND || |
667 | config.state == ProfilerState::NVTX || |
668 | config.state == ProfilerState::ITT), |
669 | "Can't disable Kineto profiler when it's not running" ); |
670 | |
671 | state_ptr->removeCallback(); |
672 | |
673 | // Traces are converged via libkineto automatically for ondemand flow |
674 | if (state_ptr->config().global()) { |
675 | (void)std::static_pointer_cast<KinetoThreadLocalState>(state_ptr) |
676 | ->finalizeTrace(); |
677 | return std::make_unique<ProfilerResult>(); |
678 | } |
679 | |
680 | // Shared among NVTX, KINETO, KINETO_GPU_FALLBACK |
681 | std::unique_ptr<ProfilerResult> result; |
682 | if (state_ptr->config().state == ProfilerState::NVTX) { |
683 | result = std::make_unique<ProfilerResult>(); |
684 | } |
685 | |
686 | if (config.state == ProfilerState::KINETO || |
687 | config.state == ProfilerState::KINETO_GPU_FALLBACK) { |
688 | auto kineto_state_ptr = |
689 | std::static_pointer_cast<KinetoThreadLocalState>(state_ptr); |
690 | auto trace = kineto_state_ptr->finalizeTrace(); |
691 | result = std::make_unique<ProfilerResult>( |
692 | kineto_state_ptr->start_time_, |
693 | std::move(kineto_state_ptr->kineto_events_), |
694 | std::move(trace), |
695 | std::move(kineto_state_ptr->event_tree_)); |
696 | } |
697 | |
698 | return result; |
699 | } |
700 | |
701 | KinetoEvent::KinetoEvent( |
702 | std::shared_ptr<const torch::profiler::impl::Result> result, |
703 | const bool verbose) |
704 | : result_{result} { |
705 | TORCH_INTERNAL_ASSERT(result != nullptr); |
706 | |
707 | if (verbose) { |
708 | // Populate Python stack |
709 | auto parent = result_->parent_.lock(); |
710 | while (parent != nullptr) { |
711 | parent->visit_if_base<PyExtraFieldsBase>( |
712 | [&](const auto& i) { python_stack_.push_back(parent->name()); }); |
713 | parent = parent->parent_.lock(); |
714 | } |
715 | } |
716 | |
717 | result->visit_if_base<ExtraFields<EventType::TorchOp>>([&](const auto& op) { |
718 | std::tie(shapes_, dtypes_) = shapesAndDtypes(op.inputs_); |
719 | }); |
720 | } |
721 | |
722 | bool KinetoEvent::isPythonFunction() const { |
723 | bool out{false}; |
724 | result_->visit_if_base<PyExtraFieldsBase>([&](const auto&) { out = true; }); |
725 | return out; |
726 | } |
727 | |
728 | bool KinetoEvent::hasShapes() const { |
729 | return !shapes_.empty(); |
730 | } |
731 | |
732 | const c10::ArrayRef<std::vector<int64_t>> KinetoEvent::shapes() const { |
733 | return shapes_; |
734 | } |
735 | |
736 | bool KinetoEvent::hasTypes() const { |
737 | return !dtypes_.empty(); |
738 | } |
739 | |
740 | const c10::ArrayRef<std::string> KinetoEvent::dtypes() const { |
741 | return dtypes_; |
742 | } |
743 | |
744 | const c10::ArrayRef<std::string> KinetoEvent::stack() const { |
745 | auto get = [&](const auto& i) -> auto& { |
746 | return !i.jit_stack_.empty() ? i.jit_stack_ : python_stack_; |
747 | }; |
748 | |
749 | using out_t = const c10::ArrayRef<std::string>; |
750 | return result_->visit(c10::overloaded( |
751 | [&](const ExtraFields<EventType::TorchOp>& i) -> out_t { return get(i); }, |
752 | [&](const ExtraFields<EventType::Backend>& i) -> out_t { return get(i); }, |
753 | [&](const auto&) -> out_t { return python_stack_; })); |
754 | } |
755 | |
756 | const c10::ArrayRef<std::string> KinetoEvent::moduleHierarchy() const { |
757 | return result_->visit(c10::overloaded( |
758 | [](const ExtraFields<EventType::TorchOp>& e) |
759 | -> const c10::ArrayRef<std::string> { return e.jit_modules_; }, |
760 | [](const ExtraFields<EventType::Backend>& e) |
761 | -> const c10::ArrayRef<std::string> { return e.jit_modules_; }, |
762 | [](const auto&) -> const c10::ArrayRef<std::string> { return {}; })); |
763 | } |
764 | |
765 | uint64_t KinetoEvent::durationUs() const { |
766 | return (result_->endTimeNS() - result_->start_time_ns_) / 1000; |
767 | } |
768 | |
769 | int64_t KinetoEvent::debugHandle() const { |
770 | return result_->visit(c10::overloaded( |
771 | [](const ExtraFields<EventType::TorchOp>& i) { return i.debug_handle_; }, |
772 | [](const ExtraFields<EventType::Backend>& i) { return i.debug_handle_; }, |
773 | [](const auto&) -> int64_t { return -1; })); |
774 | } |
775 | |
776 | uint8_t KinetoEvent::deviceIndex() const { |
777 | return result_->visit(c10::overloaded( |
778 | [](const ExtraFields<EventType::Allocation>& i) { |
779 | return static_cast<uint8_t>(i.device_index_); |
780 | }, |
781 | [](const ExtraFields<EventType::OutOfMemory>& i) { |
782 | return static_cast<uint8_t>(i.device_index_); |
783 | }, |
784 | [&](const auto&) { |
785 | return static_cast<uint8_t>(result_->kineto_info_.device); |
786 | })); |
787 | } |
788 | |
789 | bool KinetoEvent::hasStack() const { |
790 | return !stack().empty(); |
791 | } |
792 | |
793 | int64_t KinetoEvent::cudaElapsedUs() const { |
794 | auto cuda_event_start = fallbackStart(); |
795 | auto cuda_event_end = fallbackEnd(); |
796 | if (!cuda_event_start || !cuda_event_end) { |
797 | return -1; |
798 | } |
799 | try { |
800 | return (int64_t)torch::profiler::impl::cudaStubs()->elapsed( |
801 | &cuda_event_start, &cuda_event_end); |
802 | } catch (std::exception& e) { |
803 | LOG(WARNING) << "Failed to measure time between two CUDA events. " |
804 | << e.what(); |
805 | } |
806 | return -1; |
807 | } |
808 | |
809 | void KinetoEvent::getPerfEventCounters(std::vector<uint64_t>& in) const { |
810 | return result_->visit(c10::overloaded( |
811 | [&in](const ExtraFields<EventType::TorchOp>& e) -> void { |
812 | const size_t n = e.perf_event_counters_->size(); |
813 | // should be rare |
814 | if (in.size() < n) { |
815 | in.resize(n, 0); |
816 | } |
817 | for (size_t i = 0; i < n; ++i) { |
818 | in[i] = (*e.perf_event_counters_)[i]; |
819 | } |
820 | }, |
821 | [](const auto&) -> void { return; })); |
822 | } |
823 | |
824 | #define FORWARD_FROM_RESULT(method_name, result_expr) \ |
825 | decltype(std::declval<KinetoEvent>().method_name()) \ |
826 | KinetoEvent::method_name() const { \ |
827 | return static_cast<decltype(std::declval<KinetoEvent>().method_name())>( \ |
828 | result_->result_expr); \ |
829 | } |
830 | |
831 | FORWARD_FROM_RESULT(startThreadId, start_tid_) |
832 | FORWARD_FROM_RESULT(endThreadId, endTID()) |
833 | FORWARD_FROM_RESULT(activityType, kinetoType()) |
834 | FORWARD_FROM_RESULT(name, name()) |
835 | FORWARD_FROM_RESULT(deviceType, deviceType()) |
836 | FORWARD_FROM_RESULT(startUs, start_time_ns_ / 1000) |
837 | FORWARD_FROM_RESULT(correlationId, correlationID()) |
838 | FORWARD_FROM_RESULT(deviceResourceId, kineto_info_.resource) |
839 | #undef FORWARD_FROM_RESULT |
840 | |
841 | // Most of the fields in `KinetoEvent` only make sense for a single event type. |
842 | // (Generally TorchOp.) For all other types they simply return the default |
843 | // value. This macro provides a succinct way of expressing this behavior. |
844 | #define TYPED_ATTR_WITH_DEFAULT( \ |
845 | event_type, method_name, expression, default_value) \ |
846 | decltype(std::declval<KinetoEvent>().method_name()) \ |
847 | KinetoEvent::method_name() const { \ |
848 | using out_t = decltype(std::declval<KinetoEvent>().method_name()); \ |
849 | return result_->visit(c10::overloaded( \ |
850 | [](const ExtraFields<EventType::event_type>& e) -> out_t { \ |
851 | return expression; \ |
852 | }, \ |
853 | [](const auto&) -> out_t { return default_value; })); \ |
854 | } |
855 | |
856 | #define TYPED_ATTR(event_type, method_name, expression) \ |
857 | TYPED_ATTR_WITH_DEFAULT(event_type, method_name, expression, {}) |
858 | |
859 | TYPED_ATTR_WITH_DEFAULT(TorchOp, sequenceNr, e.sequence_number_, -1) |
860 | TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0) |
861 | TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_)) |
862 | TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty()) |
863 | TYPED_ATTR(TorchOp, isAsync, e.is_async_) |
864 | TYPED_ATTR(TorchOp, fallbackStart, e.gpu_fallback_.cuda_event_start_) |
865 | TYPED_ATTR(TorchOp, fallbackEnd, e.gpu_fallback_.cuda_event_end_) |
866 | TYPED_ATTR( |
867 | TorchOp, |
868 | flops, |
869 | !e.extra_args_.empty() ? computeFlops(e.name_, e.extra_args_) : 0) |
870 | TYPED_ATTR(Backend, backend, e.backend_) |
871 | TYPED_ATTR(Allocation, nBytes, e.alloc_size_) |
872 | TYPED_ATTR(Kineto, linkedCorrelationId, [&]() { |
873 | const auto linked = e.linked_activity_.lock(); |
874 | return linked ? linked->correlationID() : 0; |
875 | }()) |
876 | #undef TYPED_ATTR |
877 | #undef TYPED_ATTR_WITH_DEFAULT |
878 | |
879 | ProfilerResult::ProfilerResult( |
880 | uint64_t start_time, |
881 | std::vector<KinetoEvent> events, |
882 | std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>&& |
883 | trace, |
884 | std::vector<experimental_event_t>&& event_tree) |
885 | : trace_start_us_(start_time), |
886 | events_(std::move(events)), |
887 | trace_(std::move(trace)), |
888 | event_tree_(std::move(event_tree)) {} |
889 | ProfilerResult::ProfilerResult() = default; |
890 | ProfilerResult::~ProfilerResult() = default; |
891 | |
892 | void ProfilerResult::save(const std::string& path) { |
893 | trace_->save(path); |
894 | } |
895 | |
896 | } // namespace profiler |
897 | } // namespace autograd |
898 | |
899 | namespace profiler { |
900 | namespace impl { |
901 | void _reportVulkanEventToProfiler(vulkan_id_t id) { |
902 | auto state_ptr = ::torch::autograd::profiler::KinetoThreadLocalState::get( |
903 | /*global=*/false); |
904 | if (state_ptr) { |
905 | state_ptr->reportVulkanEventToProfiler(id); |
906 | } |
907 | } |
908 | } // namespace impl |
909 | } // namespace profiler |
910 | |
911 | } // namespace torch |
912 | |