1#include <torch/csrc/profiler/collection.h>
2#include <torch/csrc/profiler/orchestration/vulkan.h>
3
4#include <algorithm>
5#include <functional>
6#include <limits>
7#include <memory>
8#include <queue>
9#include <type_traits>
10#include <utility>
11
12#include <fmt/format.h>
13
14#ifdef USE_KINETO
15#include <libkineto.h>
16#endif
17
18#include <ATen/Context.h>
19#include <ATen/record_function.h>
20#include <c10/core/ScalarTypeToTypeMeta.h>
21#include <c10/util/Exception.h>
22#include <c10/util/flat_hash_map.h>
23#include <c10/util/hash.h>
24#include <c10/util/overloaded.h>
25#include <torch/csrc/jit/runtime/interpreter.h>
26#include <torch/csrc/profiler/data_flow.h>
27#include <torch/csrc/profiler/kineto_shim.h>
28
29namespace torch {
30namespace profiler {
31namespace impl {
32using result_ptr_t = std::shared_ptr<Result>;
33using trace_ptr_t =
34 std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>;
35
36RawTensorMetadataBase::RawTensorMetadataBase(const at::Tensor& t)
37 : data_{t.has_storage() ? t.storage().data() : nullptr},
38 dtype_{t.scalar_type()},
39 layout_{t.layout()},
40 dim_{static_cast<uint32_t>(t.sizes().size())} {
41 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
42 t.sizes().size() <= std::numeric_limits<uint32_t>::max(),
43 "Cannot profile Tensors of size > uint32 max. Got dim: ",
44 t.sizes().size());
45}
46
47RawTensorMetadata::RawTensorMetadata(const at::Tensor& t)
48 : RawTensorMetadataBase(t),
49 weak_self_{WeakTensor(t)},
50 device_type_{t.device().type()},
51 device_index_{t.device().index()} {}
52
53TensorMetadata::TensorMetadata(
54 const RawTensorMetadata& r,
55 std::vector<int64_t> sizes,
56 std::vector<int64_t> strides)
57 : RawTensorMetadataBase(r),
58 weak_self_{r.weak_self_.value_or(WeakTensor(at::Tensor()))},
59 device_{r.device_type_, r.device_index_},
60 sizes_{std::move(sizes)},
61 strides_{std::move(strides)} {
62 SOFT_ASSERT(r.weak_self_.has_value());
63}
64
65// ============================================================================
66// == PyTorch Ops =============================================================
67// ============================================================================
68
69// ----------------------------
70// | Input / Output encoder |
71// ----------------------------
72void InputOutputEncoder::push(c10::ArrayRef<const c10::IValue> values) {
73 for (const auto& value : values) {
74 if (value.isTensor()) {
75 push(value.toTensor());
76 } else if (value.isScalar()) {
77 tags_.emplace_back(Tag::Scalar);
78 // Scalars are small enough that they are stored in ivalues without an
79 // extra memory alloc
80 // TODO: further optimize this by maybe giving Profiler access to the
81 // guts of IValue.
82 ivalues_.emplace_back(value);
83 } else if (value.isTensorList()) {
84 tags_.emplace_back(Tag::TensorListBegin);
85 for (const auto& t : value.toTensorList()) {
86 push(t);
87 }
88 tags_.emplace_back(Tag::TERMINATOR);
89 } else {
90 tags_.emplace_back(Tag::Other);
91 }
92 }
93 tags_.emplace_back(Tag::TERMINATOR);
94}
95
96void InputOutputEncoder::push(const at::Tensor& t) {
97 if (t.defined() && !t.is_nested()) { // TODO fix nested sizes
98 tags_.emplace_back(Tag::Tensor);
99 tensor_metadata_.emplace_back(t);
100 tensor_sizes_strides_.copy(t.sizes());
101 if (t.layout() == at::kStrided) {
102 // Only Strided layout tensors have strides
103 tensor_sizes_strides_.copy(t.strides());
104 }
105 } else {
106 tags_.emplace_back(Tag::UndefinedTensor);
107 }
108}
109
110// This is a custom-iterator-like getter to obtain input shapes and dtypes.
111auto InputOutputEncoder::getNextShapesAndDtypes() {
112 return [this,
113 tag_it = tags_.begin(),
114 tensor_metadata_it = tensor_metadata_.begin(),
115 tensor_size_strides_it = tensor_sizes_strides_.begin(),
116 ivals_it = ivalues_.begin()]() mutable {
117 auto decode_tensor = [&]() -> TensorMetadata {
118 const auto& raw_metadata = *tensor_metadata_it++;
119 std::vector<int64_t> sizes;
120 std::vector<int64_t> strides;
121 for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) {
122 sizes.push_back(*tensor_size_strides_it++);
123 }
124 if (raw_metadata.layout_ == at::kStrided) {
125 for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) {
126 strides.push_back(*tensor_size_strides_it++);
127 }
128 }
129 return {raw_metadata, sizes, strides};
130 };
131
132 std::vector<op_input_t> out;
133 bool terminate = false;
134 while (!terminate && tag_it != tags_.end()) {
135 switch (*tag_it) {
136 case Tag::Tensor:
137 out.emplace_back(decode_tensor());
138 break;
139
140 case Tag::TensorListBegin: {
141 std::vector<TensorMetadata> arg;
142 while (*(++tag_it) != Tag::TERMINATOR) {
143 TORCH_INTERNAL_ASSERT(*tag_it == Tag::Tensor, (int)(*tag_it));
144 arg.emplace_back(decode_tensor());
145 }
146 out.emplace_back(std::move(arg));
147 } break;
148
149 case Tag::Scalar:
150 out.emplace_back(*ivals_it++);
151 break;
152
153 case Tag::UndefinedTensor:
154 case Tag::Other:
155 out.emplace_back(c10::nullopt);
156 break;
157
158 case Tag::TERMINATOR:
159 // This marks the end of this op.
160 terminate = true;
161 break;
162
163 default:
164 break;
165 }
166 ++tag_it;
167 }
168 return out;
169 };
170}
171
172void InputOutputEncoder::clear() {
173 tags_.clear();
174 tensor_metadata_.clear();
175 tensor_sizes_strides_.clear();
176 ivalues_.clear();
177}
178
179// ---------------------------------------------------
180// | Correlation ID tracking (OpList & EventBlock) |
181// ---------------------------------------------------
182template <typename T, size_t ChunkSize>
183ThreadLocalSubqueue::TorchOpStorage::EventBlock<T, ChunkSize>::EventBlock() {
184 static std::atomic<uint64_t> counter_{0};
185 id_start_ = 1 + ChunkSize * counter_++;
186}
187
188template <class... Args>
189std::pair<KinetoObserverContext::Event*, uint64_t> ThreadLocalSubqueue::
190 TorchOpStorage::OpList::emplace_back(Args&&... args) {
191 maybe_grow();
192 *next_ = {std::forward<Args>(args)...};
193 auto corr_id = buffer_last_->correlation_id(next_);
194 return {next_++, corr_id};
195}
196
197uint64_t ThreadLocalSubqueue::TorchOpStorage::OpList::correlationID(
198 const OpList::Iterator& e) {
199 return e.address().first->correlation_id(&*e);
200}
201
202template <typename T, size_t ChunkSize>
203uint64_t ThreadLocalSubqueue::TorchOpStorage::EventBlock<T, ChunkSize>::
204 correlation_id(const T* ptr) const {
205 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
206 ptr >= this->data() && ptr < this->data() + ChunkSize);
207 return id_start_ + (ptr - this->data());
208}
209
210// ---------------------------------
211// | Collection (Observer logic) |
212// ---------------------------------
213std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
214 const at::RecordFunction& fn) {
215 KinetoObserverContext::Event* event;
216 uint64_t corr_id;
217 std::tie(event, corr_id) = torch_ops_.op_events_.emplace_back(
218 fn.seqNr(),
219 fn.forwardThreadId(),
220 fn.scope(),
221 fn.isAsync(),
222 fn.debugHandle(),
223 fn.name());
224 if (config_.report_input_shapes) {
225 torch_ops_.inputs_outputs_.push(fn.inputs());
226 }
227 if (fn.scope() == at::RecordScope::USER_SCOPE) {
228 torch::profiler::impl::kineto::pushUserCorrelationId(corr_id);
229 } else {
230 torch::profiler::impl::kineto::pushCorrelationId(corr_id);
231 }
232
233#if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
234 // backward nodes source range corresponds to the forward node
235 // TODO: consider using C++ stack trace
236 if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
237 auto cs = torch::profiler::impl::prepareCallstack(jit::currentCallstack());
238 torch_ops_.jit_stack_.emplace_back(callstackStr(cs));
239 }
240 if (config_.with_modules &&
241 fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
242 torch_ops_.jit_modules_.emplace_back(jit::currentModuleHierarchy());
243 }
244#endif
245 if (config_.with_flops) {
246 torch_ops_.extra_args_.emplace_back(
247 torch::profiler::impl::saveExtraArgs(fn));
248 }
249
250 auto out = std::make_unique<KinetoObserverContext>(event);
251
252 if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) {
253 try {
254 out->fallback_ = torch_ops_.gpu_fallback_.emplace_back();
255 torch::profiler::impl::cudaStubs()->record(
256 nullptr, &out->fallback_->cuda_event_start_, nullptr);
257 } catch (const std::exception& e) {
258 LOG(WARNING) << "Failed to record CUDA event. " << e.what();
259 }
260 }
261
262 event->start_time_ = torch::profiler::impl::getApproximateTime();
263 event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS();
264 if (!config_.experimental_config.performance_events.empty()) {
265 const size_t n = config_.experimental_config.performance_events.size();
266 event->counters_ = std::make_unique<perf_counters_t>(n, 0);
267 perf_profiler_->Enable();
268 }
269 return out;
270}
271
272// ---------------
273// | Collation |
274// ---------------
275namespace {
276template <typename T>
277struct StealOrDefault {
278 StealOrDefault(T& container)
279 : container_{container}, it_{container.begin()} {}
280
281 ~StealOrDefault() {
282 container_.get().clear();
283 }
284
285 typename T::Iterator::value_type operator()() {
286 if (it_.exhausted()) {
287 return typename T::Iterator::value_type();
288 } else {
289 auto result = std::move(*it_);
290 ++it_;
291 return result;
292 }
293 }
294
295 std::reference_wrapper<T> container_;
296 typename T::Iterator it_;
297};
298} // namespace
299
300void ThreadLocalSubqueue::TorchOpStorage::materialize(
301 std::vector<std::shared_ptr<Result>>& out,
302 const std::function<time_t(approx_time_t)> time_converter,
303 const uint64_t tid,
304 const kineto::DeviceAndResource& kineto_info) {
305 // Plumb Autograd info to the top level annotation.
306 auto it = op_events_.begin();
307 for (C10_UNUSED const auto _ :
308 c10::irange(static_cast<int64_t>(op_events_.size()) - 1)) {
309 auto& first = it->basic_fields_;
310 auto& second = (++it)->basic_fields_;
311 if (first.scope_ == at::RecordScope::FUNCTION &&
312 second.scope_ == at::RecordScope::BACKWARD_FUNCTION &&
313 first.name_.rfind("autograd::engine::evaluate_function: ", 0) == 0) {
314 first.sequence_number_ = second.sequence_number_;
315 first.forward_tid_ = second.forward_tid_;
316 }
317 }
318
319 // `AccumulateGrad` is an important marker for profile analysis; however the
320 // annotation relies on `c10::demangle` which is platform dependent. In
321 // particular, Windows will add a "struct " prefix.
322 const std::string accumulate_grad = "torch::autograd::AccumulateGrad";
323 const std::string windows_pattern = std::string("struct ") + accumulate_grad;
324 for (auto& event : op_events_) {
325 auto& name = event.basic_fields_.name_;
326 auto position = name.find(windows_pattern);
327 if (position != std::string::npos) {
328 name.replace(position, windows_pattern.size(), accumulate_grad);
329 }
330 }
331
332 auto input_getter = inputs_outputs_.getNextShapesAndDtypes();
333
334 // TODO: CTAD will take care of template args when we move to C++17
335 auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_);
336 auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_);
337 auto extra_args = StealOrDefault<decltype(extra_args_)>(extra_args_);
338 auto gpu_fallback = StealOrDefault<decltype(gpu_fallback_)>(gpu_fallback_);
339
340 for (auto event = op_events_.begin(); event != op_events_.end(); ++event) {
341 ExtraFields<EventType::TorchOp> e{
342 std::move(event->basic_fields_),
343 ThreadLocalSubqueue::TorchOpStorage::OpList::correlationID(event),
344 time_converter(event->end_time_),
345 input_getter(),
346 jit_stack(),
347 jit_module(),
348 extra_args(),
349 gpu_fallback(),
350 event->allow_tf32_cublas_,
351 std::move(event->counters_)};
352
353 out.emplace_back(Result::create(
354 time_converter(event->start_time_), tid, kineto_info, std::move(e)));
355 }
356
357 op_events_.clear();
358 inputs_outputs_.clear();
359}
360
361template <size_t BlockSize>
362void materialize_vulkan(
363 std::vector<std::shared_ptr<Result>>& out,
364 AppendOnlyList<ExtraFields<EventType::Vulkan>::raw_event_t, BlockSize>&
365 raw_events,
366 const std::function<time_t(approx_time_t)> time_converter,
367 const uint64_t tid,
368 const kineto::DeviceAndResource& kineto_info) {
369 for (const auto& i : raw_events) {
370 const auto name_and_duration_ns =
371 torch::profiler::impl::vulkan::getShaderNameAndDurationNs(i.second);
372
373 out.emplace_back(Result::create(
374 /*start_time_ns_=*/time_converter(i.first),
375 /*start_tid_=*/tid,
376 /*kineto_info_=*/kineto_info,
377 /*extra_fields_=*/
378 ExtraFields<EventType::Vulkan>{
379 /*name_=*/std::get<0>(name_and_duration_ns),
380 /*duration_ns_=*/
381 static_cast<int64_t>(std::get<1>(name_and_duration_ns)),
382 /*in_tree_building_=*/false}));
383 }
384}
385
386namespace {
387// See `RecordQueue::getSubqueue()` for an overview of this cache.
388struct SubQueueThreadCache {
389 uint32_t key_;
390 ThreadLocalSubqueue* ref_;
391};
392
393// The astute observer will note that this leaves a dangling reference; nothing
394// in the teardown of `RecordQueue` or `ThreadLocalSubqueue` clears this value.
395// (And the raw pointer in `SubQueueThreadCache` will not extend the lifetime
396// of `*ref_`.) This is safe, however, because `getSubqueue` will check
397// `sub_queue_cache_.key_` before attempting to access `ref_`, and if `key_`
398// does not match the RecordQueue's *unique* `id_` it will evict
399// `sub_queue_cache_` and fall back to a different mechanism.
400std::atomic<uint32_t> queue_id_{0};
401thread_local SubQueueThreadCache sub_queue_cache_{0, nullptr};
402
403std::string toString(const ExtraFields<EventType::PyCall>& e) {
404 if (e.module_.has_value()) {
405 return fmt::format(
406 "nn.Module: {}_{}", e.module_->cls_name_.str(), e.module_->id_);
407 }
408 return fmt::format(
409 "{}({}): {}",
410 e.callsite_.filename_.str(),
411 e.callsite_.line_no_,
412 e.callsite_.funcname_.str());
413}
414
415auto scopeToType(at::RecordScope scope) {
416 return scope == at::RecordScope::USER_SCOPE
417 ? libkineto::ActivityType::USER_ANNOTATION
418 : libkineto::ActivityType::CPU_OP;
419}
420
421int64_t torchOpEndNS(
422 const ExtraFields<EventType::TorchOp>& e,
423 const bool finished,
424 const std::weak_ptr<Result>& parent) {
425 if (finished && e.end_time_ns_ == std::numeric_limits<time_t>::min()) {
426 auto p = parent.lock();
427 if (p) {
428 return p->endTimeNS();
429 }
430 }
431 return e.end_time_ns_;
432}
433
434auto kinetoEventCorrelationID(
435 const ExtraFields<EventType::Kineto>& e,
436 const std::weak_ptr<Result>& parent) {
437 if (e.correlation_id_) {
438 return e.correlation_id_;
439 }
440 auto p = parent.lock();
441 return p ? p->correlationID() : 0;
442}
443} // namespace
444
445#define ATTRIBUTE(event_type, expr) \
446 [&](const ExtraFields<EventType::event_type>& e) { \
447 (void)e; \
448 return expr; \
449 }
450
451std::string Result::name() const {
452 return visit(c10::overloaded(
453 ATTRIBUTE(Vulkan, std::string(e.name_)),
454 ATTRIBUTE(Allocation, std::string("[memory]")),
455 ATTRIBUTE(OutOfMemory, std::string("[OutOfMemory]")),
456 ATTRIBUTE(PyCall, toString(e)),
457 ATTRIBUTE(PyCCall, std::string(e.function_name_.str())),
458 [](const auto& e) -> std::string { return e.name_; }));
459}
460
461libkineto::ActivityType Result::kinetoType() const {
462 return visit(c10::overloaded(
463 ATTRIBUTE(TorchOp, scopeToType(e.scope_)),
464 ATTRIBUTE(Backend, scopeToType(e.scope_)),
465 ATTRIBUTE(Vulkan, libkineto::ActivityType::CPU_OP),
466 ATTRIBUTE(Allocation, libkineto::ActivityType::CPU_INSTANT_EVENT),
467 ATTRIBUTE(OutOfMemory, libkineto::ActivityType::CPU_INSTANT_EVENT),
468 ATTRIBUTE(PyCall, libkineto::ActivityType::PYTHON_FUNCTION),
469 ATTRIBUTE(PyCCall, libkineto::ActivityType::PYTHON_FUNCTION),
470 ATTRIBUTE(Kineto, e.activity_type_)));
471}
472
473uint64_t Result::correlationID() const {
474 return visit(c10::overloaded(
475 ATTRIBUTE(TorchOp, e.correlation_id_),
476 ATTRIBUTE(Kineto, kinetoEventCorrelationID(e, parent_)),
477 [&](const auto&) -> uint64_t { return 0; }));
478}
479
480int64_t Result::endTimeNS() const {
481 auto end_time_ns = visit(c10::overloaded(
482 ATTRIBUTE(TorchOp, torchOpEndNS(e, finished_, parent_)),
483 ATTRIBUTE(Backend, e.end_time_us_ * 1000),
484 ATTRIBUTE(
485 Vulkan, start_time_ns_ + (e.in_tree_building_ ? 0 : e.duration_ns_)),
486 ATTRIBUTE(Allocation, start_time_ns_),
487 ATTRIBUTE(OutOfMemory, start_time_ns_),
488 ATTRIBUTE(Kineto, start_time_ns_ + e.duration_us_ * 1000),
489 [&](const auto& e) -> int64_t { return e.end_time_ns_; }));
490
491 // In rare cases we're willing to tolerate ops which are missing an end time
492 // so long as they can borrow their parent's end time. A consequence of this,
493 // however, is that `endTimeNS` may not make sense until tree construction is
494 // complete.
495 auto end_time_is_valid =
496 !finished_ || SOFT_ASSERT(end_time_ns >= start_time_ns_, name());
497 return end_time_is_valid ? end_time_ns : start_time_ns_;
498}
499
500uint64_t Result::endTID() const {
501 return visit(c10::overloaded(
502 ATTRIBUTE(TorchOp, e.end_tid_),
503 [&](const auto&) -> uint64_t { return start_tid_; }));
504}
505
506c10::DeviceType Result::deviceType() const {
507 using torch::autograd::profiler::deviceTypeFromActivity;
508 return visit(c10::overloaded(
509 ATTRIBUTE(Vulkan, c10::DeviceType::Vulkan),
510 ATTRIBUTE(Allocation, e.device_type_),
511 ATTRIBUTE(OutOfMemory, e.device_type_),
512 ATTRIBUTE(Kineto, deviceTypeFromActivity(e.activity_type_)),
513 [&](const auto&) { return c10::DeviceType::CPU; }));
514}
515#undef ATTRIBUTE
516
517ThreadLocalSubqueue::ThreadLocalSubqueue(
518 const uint64_t tid,
519 const ProfilerConfig& config)
520 : tid_{tid}, config_{config}, kineto_info_{kineto::kineto_ids()} {
521 torch::profiler::impl::kineto::recordThreadInfo();
522 if (!config_.experimental_config.performance_events.empty()) {
523 perf_profiler_ =
524 std::make_unique<torch::profiler::impl::linux_perf::PerfProfiler>();
525 perf_profiler_->Configure(config_.experimental_config.performance_events);
526 }
527}
528
529RecordQueue::RecordQueue(
530 const ProfilerConfig& config,
531 std::set<ActivityType> activities)
532 : id_(++queue_id_), config_{config}, activities_{std::move(activities)} {
533 if (tracePython()) {
534 python_tracer_ = python_tracer::PythonTracerBase::make(this);
535 }
536}
537
538bool RecordQueue::tracePython() const {
539 return config_.with_stack && activities_.count(ActivityType::CPU);
540}
541
542ThreadLocalSubqueue* RecordQueue::getSubqueue() {
543 // In the most common case, a thread will want to write to the same sub-queue
544 // that it wrote to last call. The only time that isn't true is if:
545 // A) The profiler context has ended and we are in a new one.
546 // B) Two profilers are active in different TLS contexts, and this thread
547 // is a worker helping with intra-op parallelism.
548 // Since we expect this to be the OVERWHELMINGLY common case (>99%), we add a
549 // special thread_local cache so that we can skip the overall `flat_hash_map`
550 // (and corresponding lock).
551 if (id_ == sub_queue_cache_.key_) {
552 return sub_queue_cache_.ref_;
553 }
554
555 const auto tid = at::RecordFunction::currentThreadId();
556 std::lock_guard<std::mutex> guard(sub_queue_mutex_);
557 auto it = sub_queues_.find(tid);
558 if (it == sub_queues_.end()) {
559 it = sub_queues_
560 .emplace(tid, std::make_unique<ThreadLocalSubqueue>(tid, config_))
561 .first;
562 }
563
564 sub_queue_cache_ = SubQueueThreadCache{id_, it->second.get()};
565 return it->second.get();
566}
567
568void RecordQueue::stop() {
569 if (python_tracer_) {
570 python_tracer_->stop();
571 }
572}
573
574namespace {
575void mark_finished(std::shared_ptr<Result>& r) {
576 TORCH_INTERNAL_ASSERT(!r->finished_, r->name());
577 r->finished_ = true;
578 TORCH_INTERNAL_ASSERT(r->endTimeNS() >= r->start_time_ns_, r->name());
579}
580
581static constexpr const char* indexKey = "Ev Idx";
582
583void passEventsToKineto(
584 const std::vector<std::shared_ptr<Result>>& results,
585 uint64_t start_time_us,
586 uint64_t end_time_us) {
587 using namespace torch::profiler::impl::kineto;
588 TraceWrapper cpu_trace(start_time_us, "PyTorch Profiler");
589
590 // Generate Kineto events for each event recorded by the PyTorch profiler.
591 for (const auto i : c10::irange(results.size())) {
592 const auto& e = results[i];
593 const auto* activity = cpu_trace.addCPUActivity(
594 e->name(),
595 e->kinetoType(),
596 e->kineto_info_,
597 e->correlationID(),
598 e->start_time_ns_ / 1000,
599 e->endTimeNS() / 1000);
600
601 TORCH_INTERNAL_ASSERT(activity || !kKinetoAvailable);
602 if (activity) {
603 addMetadata(activity, indexKey, std::to_string(i));
604 }
605 }
606
607 // Kineto adds the events that it collected.
608 cpu_trace.transferCpuTrace(end_time_us);
609}
610
611#ifdef USE_KINETO
612// There are two mechanisms that we use to connect Profiler and Kineto events.
613// The first is the correlation ID. The profiler pushes a unique integer at the
614// start of an op and pops it at the end. Kineto then associates the events
615// that it collects with that correlation ID and sets the linked activity of
616// the events that it collected to point to the profiler op.
617//
618// However, this is not a sufficient description because it does not retain
619// dependency information between kineto ops. Consider a call to `torch.add`.
620// Three events will be collected:
621// `aten::add` (TorchOp, collected by profiler)
622// `cudaLaunchKernel` (CUDA runtime event, collected by Kineto)
623// `at::vectorized_...` (GPU kernel, collected by Kineto)
624// If we only relied on correlation IDs we would set both Kineto events as
625// children of the `at::add`, rather than the correct
626// `at::add -> cudaLaunchKernel -> at::vectorized_...`
627//
628// Kineto surfaces this information through a second concept called a "flow".
629// In this example, the `cudaLaunchKernel` event is the start of a flow and the
630// GPU kernel has the same flow id but is not a start event. Thus, when merging
631// the Kineto events into the call tree we first add all events which are flow
632// start nodes. We then merge the rest, trying to pair them with flow starts
633// and falling back to correlation ID if necessary. For any nodes without
634// linked events the caller is determined using the normal tree construction
635// algorithm.
636class TransferEvents {
637 using itrace_t = libkineto::ITraceActivity;
638 using activity_t = torch::profiler::impl::kineto::activity_t;
639
640 public:
641 TransferEvents(
642 std::vector<std::shared_ptr<Result>>& results,
643 trace_ptr_t& trace)
644 : results_{results} {
645 auto* trace_activities_ptr = trace->get()->activities();
646 TORCH_INTERNAL_ASSERT(trace_activities_ptr != nullptr);
647 trace_activities_ = *trace_activities_ptr;
648 reassociate();
649 extractEventsFromTrace();
650 setParents();
651 }
652
653 private:
654 static long long extractIndex(const std::string& metadata_json) {
655 static const auto prefix = fmt::format("\"{}\": ", indexKey);
656 auto pos = metadata_json.find(prefix);
657 return (pos == std::string::npos) ? unmatchedIndex : [&]() {
658 auto end = metadata_json.find(',', pos);
659 end = (end == std::string::npos) ? metadata_json.size() : end;
660 return std::stoll(metadata_json.substr(pos + prefix.size(), end));
661 }();
662 }
663
664 std::shared_ptr<Result> lookup(const itrace_t* key) {
665 if (key == nullptr) {
666 return nullptr;
667 }
668
669 // First check the map.
670 auto it = kineto_events_.find(key);
671 if (it != kineto_events_.end()) {
672 return it->second;
673 }
674
675 // Then fallback to the encoded metadata.
676 const auto index = extractIndex(key ? key->metadataJson() : "");
677 if (index != unmatchedIndex) {
678 auto out = results_.get().at(index);
679 kineto_events_[key] = out;
680 return out;
681 }
682
683 // And finally give up.
684 return nullptr;
685 }
686
687 void reassociate() {
688 // Match profiler events with the corresponding kineto events. Kineto may
689 // have moved or copied the activities, so we have to recover the
690 // relationship between `libkineto::ITraceActivity` and `Result`.
691 for (const auto* activity : trace_activities_) {
692 TORCH_INTERNAL_ASSERT(activity != nullptr);
693 auto e = lookup(activity);
694 if (e != nullptr) {
695 TORCH_INTERNAL_ASSERT(e->kineto_activity_ == nullptr);
696 e->kineto_activity_ = static_cast<const activity_t*>(activity);
697 }
698 }
699 if (results_.get().size() != kineto_events_.size()) {
700 TORCH_WARN(fmt::format(
701 "Failed to recover relationship between all profiler and kineto events: "
702 "{} vs. {} reassociated.",
703 results_.get().size(),
704 kineto_events_.size()));
705 }
706 }
707
708 std::shared_ptr<Result> resultFromActivity(const itrace_t* activity) {
709 TORCH_INTERNAL_ASSERT(activity != nullptr);
710
711 // Kineto is inconsistent with types, so we have to cast to int32.
712 torch::profiler::impl::kineto::DeviceAndResource device_and_resource{
713 static_cast<int32_t>(activity->deviceId()),
714 static_cast<int32_t>(activity->resourceId())};
715
716 auto event = Result::create(
717 activity->timestamp() * 1000,
718 noTID, // Placeholder
719 device_and_resource,
720 ExtraFields<EventType::Kineto>{
721 activity->name(),
722 activity->duration(),
723 static_cast<uint64_t>(activity->correlationId()),
724 activity->type(),
725 {/*id=*/static_cast<uint32_t>(activity->flowId()),
726 /*type=*/static_cast<uint32_t>(activity->flowType()),
727 /*start=*/activity->flowStart()}});
728
729 // NB: It's tempting to set `event->kineto_activity_`; however we can only
730 // guarantee that the events we passed to Kineto are of type
731 // `GenericTraceActivity`. Others may derive from ITraceActivity and thus
732 // are not safe to cast.
733 return event;
734 }
735
736 std::shared_ptr<Result> toResult(const itrace_t* activity) {
737 auto e = lookup(activity);
738
739 // Until we are very sure that we can reassociate kineto and profiler
740 // events we need to be very defensive.
741 const auto type = activity->type();
742 if (e == nullptr &&
743 (type == libkineto::ActivityType::CPU_OP ||
744 type == libkineto::ActivityType::CPU_INSTANT_EVENT ||
745 type == libkineto::ActivityType::USER_ANNOTATION ||
746 type == libkineto::ActivityType::PYTHON_FUNCTION)) {
747 TORCH_WARN_ONCE(
748 "Detected an event which was likely passed to kineto by the PyTorch "
749 "profiler, but is not present in the set of known events: ",
750 activity->name(),
751 " This most likely means that Kineto has not "
752 "maintained address stability for this event. Please report this to "
753 "the PyTorch team.");
754 return nullptr;
755 }
756
757 if (e == nullptr) {
758 e = resultFromActivity(activity);
759 results_.get().push_back(e);
760 kineto_events_[activity] = e;
761 }
762 return e;
763 }
764
765 void extractEventsFromTrace() {
766 for (const auto* activity : trace_activities_) {
767 auto e = toResult(activity);
768 const auto* linked_activity = activity->linkedActivity();
769 if (e && linked_activity) {
770 e->visit(c10::overloaded(
771 [&](ExtraFields<EventType::Kineto>& i) {
772 i.linked_activity_ = toResult(linked_activity);
773 },
774 [](auto&) { TORCH_INTERNAL_ASSERT(false); }));
775 }
776 }
777 }
778
779 void setKinetoTID(
780 std::shared_ptr<Result>& r,
781 std::shared_ptr<Result> parent) {
782 r->visit(c10::overloaded(
783 [&](ExtraFields<EventType::Kineto>& i) {
784 TORCH_INTERNAL_ASSERT(r->start_tid_ == noTID);
785 r->start_tid_ = parent ? parent->start_tid_
786 : at::RecordFunction::currentThreadId();
787 },
788 [](auto&) {}));
789
790 for (auto& child : r->children_) {
791 setKinetoTID(child, r);
792 }
793 }
794
795 void setParents() {
796 // First pass: Collect start events and set parent to linked event.
797 ska::flat_hash_map<int, std::shared_ptr<Result>> flow_map;
798 for (auto& e : results_.get()) {
799 TORCH_INTERNAL_ASSERT(e != nullptr);
800 e->visit(c10::overloaded(
801 [&](const ExtraFields<EventType::Kineto>& i) {
802 if (i.flow.type == libkineto::kLinkAsyncCpuGpu && i.flow.start) {
803 auto inserted = flow_map.insert({i.flow.id, e});
804#ifdef USE_ROCM
805 if (inserted.second) {
806 TORCH_WARN_ONCE(
807 "ROCTracer produced duplicate flow start: ", i.flow.id);
808 }
809#else // USE_ROCM
810 TORCH_INTERNAL_ASSERT(inserted.second);
811#endif // USE_ROCM
812 }
813 TORCH_INTERNAL_ASSERT(e->parent_.expired());
814 e->parent_ = i.linked_activity_;
815 },
816 [](const auto&) {}));
817 }
818
819 // Second pass
820 for (auto& e : results_.get()) {
821 e->visit(c10::overloaded(
822 [&](const ExtraFields<EventType::Kineto>& i) {
823 // Flow takes priority over linked event.
824 const auto it = flow_map.find(i.flow.id);
825 if (it != flow_map.end() &&
826 i.flow.type == libkineto::kLinkAsyncCpuGpu && !i.flow.start) {
827 e->parent_ = it->second;
828 }
829
830 // If a parent was set we have to do some bookkeeping.
831 auto parent = e->parent_.lock();
832 if (parent) {
833 parent->children_.push_back(e);
834 mark_finished(e);
835 }
836 },
837 [](const auto&) {}));
838 }
839
840 // Set TIDs now that we have established lineage.
841 for (auto& e : results_.get()) {
842 if (e->parent_.expired()) {
843 setKinetoTID(e, nullptr);
844 }
845 }
846 }
847
848 static constexpr long long unmatchedIndex = -1;
849 static constexpr auto noTID = std::numeric_limits<uint64_t>::max();
850 std::reference_wrapper<std::vector<std::shared_ptr<Result>>> results_;
851 std::vector<const itrace_t*> trace_activities_;
852 ska::flat_hash_map<const itrace_t*, std::shared_ptr<Result>> kineto_events_;
853};
854#else
855class TransferEvents {
856 public:
857 template <class... Args>
858 TransferEvents(Args&&...) {}
859};
860#endif
861
862trace_ptr_t addKinetoEvents(
863 std::vector<std::shared_ptr<Result>>& results,
864 uint64_t start_time_us,
865 uint64_t end_time_us,
866 const ProfilerConfig& config) {
867 using namespace torch::profiler::impl::kineto;
868 passEventsToKineto(results, start_time_us, end_time_us);
869
870 // In on demand mode kineto is directly controlled by other machinery.
871 if (config.global()) {
872 return nullptr;
873 }
874
875 auto trace = std::make_unique<ActivityTraceWrapper>(stopTrace());
876 TORCH_INTERNAL_ASSERT(trace || !kKinetoAvailable);
877 TransferEvents transfer{results, trace};
878 return trace;
879}
880
881struct ResultGreater {
882 bool operator()(const result_ptr_t& a, const result_ptr_t& b) const {
883 return a->endTimeNS() > b->endTimeNS();
884 }
885};
886
887void set_in_tree_building(
888 std::vector<result_ptr_t>& results,
889 const bool value) {
890 for (result_ptr_t& r : results) {
891 r->visit(c10::overloaded(
892 [value](ExtraFields<EventType::Vulkan>& i) {
893 i.in_tree_building_ = value;
894 },
895 [&](auto&) {
896 // pass
897 }));
898 }
899}
900
901void build_tree(std::vector<std::shared_ptr<Result>>& sorted_events) {
902 set_in_tree_building(sorted_events, true);
903
904 using op_fields = ExtraFields<EventType::TorchOp>;
905 ska::flat_hash_map<uint64_t, std::shared_ptr<Result>> stacks;
906 std::priority_queue<result_ptr_t, std::vector<result_ptr_t>, ResultGreater>
907 end_events_;
908
909 auto push_event = [&stacks, &end_events_](std::shared_ptr<Result>& event) {
910 // Kineto builds subtrees using correlation ids and flows, so some Kineto
911 // events are already marked finished before the main tree building
912 // algorithm. It's fine to ignore them; the root event of these subtrees
913 // not a Kineto op and will be handled normally.
914 if (c10::holds_alternative<ExtraFields<EventType::Kineto>>(
915 event->extra_fields_) &&
916 event->finished_) {
917 return;
918 }
919
920 TORCH_INTERNAL_ASSERT(event->parent_.expired());
921 for (const auto& child : event->children_) {
922 TORCH_INTERNAL_ASSERT(child->finished_);
923 }
924 TORCH_INTERNAL_ASSERT(!event->finished_);
925
926 auto parent_it = stacks.find(event->start_tid_);
927 if (parent_it == stacks.end()) {
928 auto fwd_tid = event->visit(c10::overloaded(
929 [](const op_fields& i) { return i.forward_tid_; },
930 [](const auto&) -> uint64_t { return 0; }));
931 if (fwd_tid) {
932 parent_it = stacks.find(fwd_tid);
933 }
934 }
935
936 if (parent_it != stacks.end()) {
937 event->parent_ = parent_it->second;
938 parent_it->second->children_.push_back(event);
939 }
940
941 if (event->endTimeNS() > event->start_time_ns_) {
942 stacks[event->start_tid_] = event;
943 end_events_.push(event);
944 } else if (event->endTimeNS() == std::numeric_limits<time_t>::min()) {
945 // We use min time to indicate the lack of a termination event, so if we
946 // encounter such a case we don't push to `end_events_`.
947 stacks[event->start_tid_] = event;
948 } else {
949 mark_finished(event);
950 }
951 };
952
953 auto pop_event = [&stacks](std::shared_ptr<Result> event) {
954 if (event->finished_) {
955 // This event was marked finished by a previous `pop_event` call.
956 return;
957 }
958
959 auto start_tid = event->start_tid_;
960 auto frame = stacks.at(start_tid);
961
962 while (frame.get() != event.get()) {
963 TORCH_INTERNAL_ASSERT(frame != nullptr);
964 mark_finished(frame);
965 TORCH_INTERNAL_ASSERT(!frame->parent_.expired());
966 frame = frame->parent_.lock();
967 }
968
969 mark_finished(event);
970 stacks.erase(start_tid);
971 auto new_frame = event->parent_.lock();
972 if (new_frame != nullptr) {
973 stacks[start_tid] = new_frame;
974 }
975 };
976
977 // Stack replay loop.
978 for (auto& event : sorted_events) {
979 while (!end_events_.empty() &&
980 end_events_.top()->endTimeNS() < event->start_time_ns_) {
981 pop_event(end_events_.top());
982 end_events_.pop();
983 }
984 push_event(event);
985 }
986
987 // Cleanup remaining exit events.
988 while (!end_events_.empty()) {
989 pop_event(end_events_.top());
990 end_events_.pop();
991 }
992
993 set_in_tree_building(sorted_events, false);
994}
995
996/**
997 * Adjust r's duration to be the max of its current duration and the sum of all
998 * of its children's adjusted durations (keeping its start time the same)
999 * (adjust all child durations recursively)
1000 */
1001int64_t adjust_durations_dfs(std::shared_ptr<Result>& r) {
1002 if (SOFT_ASSERT(r != nullptr)) {
1003 int64_t original_duration = r->endTimeNS() - r->start_time_ns_;
1004 int64_t children_total_duration = std::accumulate(
1005 r->children_.begin(),
1006 r->children_.end(),
1007 0,
1008 [](int64_t acc, std::shared_ptr<Result>& child) {
1009 return acc + adjust_durations_dfs(child);
1010 });
1011
1012 if (children_total_duration > original_duration) {
1013 r->visit(c10::overloaded(
1014 [&r, &children_total_duration](ExtraFields<EventType::TorchOp>& i) {
1015 i.end_time_ns_ = r->start_time_ns_ + children_total_duration;
1016 },
1017 [&children_total_duration](ExtraFields<EventType::Vulkan>& i) {
1018 i.duration_ns_ = children_total_duration;
1019 },
1020 [](ExtraFields<EventType::Allocation>& _) {
1021 // Pass- Allocation events can't have children
1022 },
1023 [&](auto&) {
1024 SOFT_ASSERT(
1025 false,
1026 "unexpected event type in mobile profiler adjust_durations_dfs: ",
1027 r->name());
1028 }));
1029 return children_total_duration;
1030 } else {
1031 return original_duration;
1032 }
1033 } else {
1034 return 0;
1035 }
1036}
1037
1038/**
1039 * 1) Adjust r's start time to be [new_start_time] (also adjusting end time and
1040 keeping duration the same)
1041 * 2) Recursively adjust r's children's start times, making them line up such
1042 that the last one ends at the same time as r
1043 * 3) Return r's final end time
1044 */
1045int64_t adjust_timestamps_dfs(
1046 std::shared_ptr<Result>& r,
1047 int64_t new_start_time) {
1048 if (SOFT_ASSERT(r != nullptr)) {
1049 if (r->start_time_ns_ != new_start_time) {
1050 // Adjust start time (keeping duration constant)
1051 r->visit(c10::overloaded(
1052 [&r, &new_start_time](ExtraFields<EventType::TorchOp>& i) {
1053 i.end_time_ns_ =
1054 new_start_time + (i.end_time_ns_ - r->start_time_ns_);
1055 },
1056 [](ExtraFields<EventType::Vulkan>& i) {
1057 // Pass- We don't need to manually adjust end time for Vulkan events
1058 },
1059 [](ExtraFields<EventType::Allocation>& _) {
1060 // Pass- No duration or end time to adjust
1061 },
1062 [&](auto&) {
1063 SOFT_ASSERT(
1064 false,
1065 "unexpected event type in mobile profiler adjust_timestamps_dfs: ",
1066 r->name());
1067 }));
1068 r->start_time_ns_ = new_start_time;
1069 }
1070 int64_t children_total_duration = std::accumulate(
1071 r->children_.begin(),
1072 r->children_.end(),
1073 0,
1074 [](int64_t acc, std::shared_ptr<Result>& child) {
1075 return acc + (child->endTimeNS() - child->start_time_ns_);
1076 });
1077
1078 int64_t child_start_time = r->endTimeNS() - children_total_duration;
1079 for (std::shared_ptr<Result>& child : r->children_) {
1080 child_start_time = adjust_timestamps_dfs(child, child_start_time);
1081 }
1082 }
1083 return r->endTimeNS();
1084}
1085
1086/**
1087 * Adjust timestamps and durations of nodes in [out] such that
1088 * - Vulkan event timelines are synchronized with CPU event times
1089 * - Parent event timelines fully contain their child timelines
1090 * - No overlaps in timelines for nodes at the same depth
1091 */
1092void adjust_timestamps(std::vector<std::shared_ptr<Result>>& out) {
1093 if (out.empty()) {
1094 return;
1095 }
1096
1097 int64_t min_start_time = out[0]->start_time_ns_;
1098 for (std::shared_ptr<Result>& r : out) {
1099 // Only begin traversal for root nodes.
1100 if (r->parent_.expired()) {
1101 adjust_durations_dfs(r);
1102 min_start_time = adjust_timestamps_dfs(
1103 r,
1104 std::max(
1105 r->tag() != EventType::Vulkan
1106 ? r->start_time_ns_
1107 : std::numeric_limits<int64_t>::min(),
1108 min_start_time));
1109 }
1110 }
1111}
1112} // namespace
1113
1114std::pair<
1115 std::vector<std::shared_ptr<Result>>,
1116 std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>>
1117RecordQueue::getRecords(
1118 std::function<time_t(approx_time_t)> time_converter,
1119 uint64_t start_time_us,
1120 uint64_t end_time_us) {
1121 auto converter = [&](approx_time_t t) {
1122 return t == std::numeric_limits<approx_time_t>::min()
1123 ? std::numeric_limits<time_t>::min()
1124 : time_converter(t);
1125 };
1126 std::vector<std::shared_ptr<Result>> out;
1127 std::vector<python_tracer::CompressedEvent> python_enters;
1128 for (auto& subqueue_it : sub_queues_) {
1129 auto& queue = *subqueue_it.second;
1130 auto materialize = [&](auto& events) {
1131 for (auto& i : events) {
1132 time_t start_time_ns;
1133 if constexpr (std::is_same<
1134 std::remove_reference_t<decltype(i)>,
1135 ExtraFields<EventType::Backend>>::value) {
1136 start_time_ns = i.start_time_us_ * 1000;
1137 } else {
1138 start_time_ns = converter(i.start_time_);
1139 }
1140 out.emplace_back(Result::create(
1141 /*start_time_ns_=*/start_time_ns,
1142 /*start_tid_=*/queue.tid(),
1143 /*kineto_info_=*/queue.kineto_info(),
1144 /*extra_fields_=*/std::move(i)));
1145 }
1146 events.clear();
1147 };
1148
1149 queue.torch_ops_.materialize(
1150 out, converter, queue.tid(), queue.kineto_info());
1151 materialize(queue.backend_events_);
1152 materialize_vulkan(
1153 out, queue.vulkan_events_, converter, queue.tid(), queue.kineto_info());
1154 for (auto& i : queue.allocations_) {
1155 out.emplace_back(Result::create(
1156 /*start_time_ns_=*/converter(i.start_time_),
1157 /*start_tid_=*/queue.tid(),
1158 /*kineto_info_=*/queue.kineto_info(),
1159 /*extra_fields_=*/ExtraFields<EventType::Allocation>(i)));
1160 }
1161 materialize(queue.ooms_);
1162
1163 for (auto& i : queue.py_calls_) {
1164 python_enters.push_back(
1165 {i.first, queue.tid(), queue.kineto_info(), converter(i.second)});
1166 }
1167 }
1168
1169 if (python_tracer_) {
1170 for (const auto& i : python_tracer_->getEvents(
1171 converter, python_enters, end_time_us * 1000)) {
1172 out.push_back(i);
1173 }
1174 python_tracer_.reset();
1175 }
1176
1177 if (config_.experimental_config.adjust_timestamps) {
1178 std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
1179 return a->start_time_ns_ < b->start_time_ns_;
1180 });
1181 build_tree(out);
1182 adjust_timestamps(out);
1183 for (auto& r : out) {
1184 r->parent_.reset();
1185 // Reset these so that second build_tree can happen
1186 r->finished_ = false;
1187 r->children_.clear();
1188 }
1189 }
1190
1191 auto trace = addKinetoEvents(out, start_time_us, end_time_us, config_);
1192
1193 std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
1194 return a->start_time_ns_ < b->start_time_ns_;
1195 });
1196
1197 if (config_.report_input_shapes && config_.profile_memory) {
1198 calculateUniqueTensorIDs(out);
1199 }
1200
1201 build_tree(out);
1202 return {out, std::move(trace)};
1203}
1204
1205} // namespace impl
1206} // namespace profiler
1207} // namespace torch
1208