1 | #include <torch/csrc/profiler/collection.h> |
2 | #include <torch/csrc/profiler/orchestration/vulkan.h> |
3 | |
4 | #include <algorithm> |
5 | #include <functional> |
6 | #include <limits> |
7 | #include <memory> |
8 | #include <queue> |
9 | #include <type_traits> |
10 | #include <utility> |
11 | |
12 | #include <fmt/format.h> |
13 | |
14 | #ifdef USE_KINETO |
15 | #include <libkineto.h> |
16 | #endif |
17 | |
18 | #include <ATen/Context.h> |
19 | #include <ATen/record_function.h> |
20 | #include <c10/core/ScalarTypeToTypeMeta.h> |
21 | #include <c10/util/Exception.h> |
22 | #include <c10/util/flat_hash_map.h> |
23 | #include <c10/util/hash.h> |
24 | #include <c10/util/overloaded.h> |
25 | #include <torch/csrc/jit/runtime/interpreter.h> |
26 | #include <torch/csrc/profiler/data_flow.h> |
27 | #include <torch/csrc/profiler/kineto_shim.h> |
28 | |
29 | namespace torch { |
30 | namespace profiler { |
31 | namespace impl { |
32 | using result_ptr_t = std::shared_ptr<Result>; |
33 | using trace_ptr_t = |
34 | std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>; |
35 | |
36 | RawTensorMetadataBase::RawTensorMetadataBase(const at::Tensor& t) |
37 | : data_{t.has_storage() ? t.storage().data() : nullptr}, |
38 | dtype_{t.scalar_type()}, |
39 | layout_{t.layout()}, |
40 | dim_{static_cast<uint32_t>(t.sizes().size())} { |
41 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
42 | t.sizes().size() <= std::numeric_limits<uint32_t>::max(), |
43 | "Cannot profile Tensors of size > uint32 max. Got dim: " , |
44 | t.sizes().size()); |
45 | } |
46 | |
47 | RawTensorMetadata::RawTensorMetadata(const at::Tensor& t) |
48 | : RawTensorMetadataBase(t), |
49 | weak_self_{WeakTensor(t)}, |
50 | device_type_{t.device().type()}, |
51 | device_index_{t.device().index()} {} |
52 | |
53 | TensorMetadata::TensorMetadata( |
54 | const RawTensorMetadata& r, |
55 | std::vector<int64_t> sizes, |
56 | std::vector<int64_t> strides) |
57 | : RawTensorMetadataBase(r), |
58 | weak_self_{r.weak_self_.value_or(WeakTensor(at::Tensor()))}, |
59 | device_{r.device_type_, r.device_index_}, |
60 | sizes_{std::move(sizes)}, |
61 | strides_{std::move(strides)} { |
62 | SOFT_ASSERT(r.weak_self_.has_value()); |
63 | } |
64 | |
65 | // ============================================================================ |
66 | // == PyTorch Ops ============================================================= |
67 | // ============================================================================ |
68 | |
69 | // ---------------------------- |
70 | // | Input / Output encoder | |
71 | // ---------------------------- |
72 | void InputOutputEncoder::push(c10::ArrayRef<const c10::IValue> values) { |
73 | for (const auto& value : values) { |
74 | if (value.isTensor()) { |
75 | push(value.toTensor()); |
76 | } else if (value.isScalar()) { |
77 | tags_.emplace_back(Tag::Scalar); |
78 | // Scalars are small enough that they are stored in ivalues without an |
79 | // extra memory alloc |
80 | // TODO: further optimize this by maybe giving Profiler access to the |
81 | // guts of IValue. |
82 | ivalues_.emplace_back(value); |
83 | } else if (value.isTensorList()) { |
84 | tags_.emplace_back(Tag::TensorListBegin); |
85 | for (const auto& t : value.toTensorList()) { |
86 | push(t); |
87 | } |
88 | tags_.emplace_back(Tag::TERMINATOR); |
89 | } else { |
90 | tags_.emplace_back(Tag::Other); |
91 | } |
92 | } |
93 | tags_.emplace_back(Tag::TERMINATOR); |
94 | } |
95 | |
96 | void InputOutputEncoder::push(const at::Tensor& t) { |
97 | if (t.defined() && !t.is_nested()) { // TODO fix nested sizes |
98 | tags_.emplace_back(Tag::Tensor); |
99 | tensor_metadata_.emplace_back(t); |
100 | tensor_sizes_strides_.copy(t.sizes()); |
101 | if (t.layout() == at::kStrided) { |
102 | // Only Strided layout tensors have strides |
103 | tensor_sizes_strides_.copy(t.strides()); |
104 | } |
105 | } else { |
106 | tags_.emplace_back(Tag::UndefinedTensor); |
107 | } |
108 | } |
109 | |
110 | // This is a custom-iterator-like getter to obtain input shapes and dtypes. |
111 | auto InputOutputEncoder::getNextShapesAndDtypes() { |
112 | return [this, |
113 | tag_it = tags_.begin(), |
114 | tensor_metadata_it = tensor_metadata_.begin(), |
115 | tensor_size_strides_it = tensor_sizes_strides_.begin(), |
116 | ivals_it = ivalues_.begin()]() mutable { |
117 | auto decode_tensor = [&]() -> TensorMetadata { |
118 | const auto& raw_metadata = *tensor_metadata_it++; |
119 | std::vector<int64_t> sizes; |
120 | std::vector<int64_t> strides; |
121 | for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) { |
122 | sizes.push_back(*tensor_size_strides_it++); |
123 | } |
124 | if (raw_metadata.layout_ == at::kStrided) { |
125 | for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) { |
126 | strides.push_back(*tensor_size_strides_it++); |
127 | } |
128 | } |
129 | return {raw_metadata, sizes, strides}; |
130 | }; |
131 | |
132 | std::vector<op_input_t> out; |
133 | bool terminate = false; |
134 | while (!terminate && tag_it != tags_.end()) { |
135 | switch (*tag_it) { |
136 | case Tag::Tensor: |
137 | out.emplace_back(decode_tensor()); |
138 | break; |
139 | |
140 | case Tag::TensorListBegin: { |
141 | std::vector<TensorMetadata> arg; |
142 | while (*(++tag_it) != Tag::TERMINATOR) { |
143 | TORCH_INTERNAL_ASSERT(*tag_it == Tag::Tensor, (int)(*tag_it)); |
144 | arg.emplace_back(decode_tensor()); |
145 | } |
146 | out.emplace_back(std::move(arg)); |
147 | } break; |
148 | |
149 | case Tag::Scalar: |
150 | out.emplace_back(*ivals_it++); |
151 | break; |
152 | |
153 | case Tag::UndefinedTensor: |
154 | case Tag::Other: |
155 | out.emplace_back(c10::nullopt); |
156 | break; |
157 | |
158 | case Tag::TERMINATOR: |
159 | // This marks the end of this op. |
160 | terminate = true; |
161 | break; |
162 | |
163 | default: |
164 | break; |
165 | } |
166 | ++tag_it; |
167 | } |
168 | return out; |
169 | }; |
170 | } |
171 | |
172 | void InputOutputEncoder::clear() { |
173 | tags_.clear(); |
174 | tensor_metadata_.clear(); |
175 | tensor_sizes_strides_.clear(); |
176 | ivalues_.clear(); |
177 | } |
178 | |
179 | // --------------------------------------------------- |
180 | // | Correlation ID tracking (OpList & EventBlock) | |
181 | // --------------------------------------------------- |
182 | template <typename T, size_t ChunkSize> |
183 | ThreadLocalSubqueue::TorchOpStorage::EventBlock<T, ChunkSize>::EventBlock() { |
184 | static std::atomic<uint64_t> counter_{0}; |
185 | id_start_ = 1 + ChunkSize * counter_++; |
186 | } |
187 | |
188 | template <class... Args> |
189 | std::pair<KinetoObserverContext::Event*, uint64_t> ThreadLocalSubqueue:: |
190 | TorchOpStorage::OpList::emplace_back(Args&&... args) { |
191 | maybe_grow(); |
192 | *next_ = {std::forward<Args>(args)...}; |
193 | auto corr_id = buffer_last_->correlation_id(next_); |
194 | return {next_++, corr_id}; |
195 | } |
196 | |
197 | uint64_t ThreadLocalSubqueue::TorchOpStorage::OpList::correlationID( |
198 | const OpList::Iterator& e) { |
199 | return e.address().first->correlation_id(&*e); |
200 | } |
201 | |
202 | template <typename T, size_t ChunkSize> |
203 | uint64_t ThreadLocalSubqueue::TorchOpStorage::EventBlock<T, ChunkSize>:: |
204 | correlation_id(const T* ptr) const { |
205 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
206 | ptr >= this->data() && ptr < this->data() + ChunkSize); |
207 | return id_start_ + (ptr - this->data()); |
208 | } |
209 | |
210 | // --------------------------------- |
211 | // | Collection (Observer logic) | |
212 | // --------------------------------- |
213 | std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op( |
214 | const at::RecordFunction& fn) { |
215 | KinetoObserverContext::Event* event; |
216 | uint64_t corr_id; |
217 | std::tie(event, corr_id) = torch_ops_.op_events_.emplace_back( |
218 | fn.seqNr(), |
219 | fn.forwardThreadId(), |
220 | fn.scope(), |
221 | fn.isAsync(), |
222 | fn.debugHandle(), |
223 | fn.name()); |
224 | if (config_.report_input_shapes) { |
225 | torch_ops_.inputs_outputs_.push(fn.inputs()); |
226 | } |
227 | if (fn.scope() == at::RecordScope::USER_SCOPE) { |
228 | torch::profiler::impl::kineto::pushUserCorrelationId(corr_id); |
229 | } else { |
230 | torch::profiler::impl::kineto::pushCorrelationId(corr_id); |
231 | } |
232 | |
233 | #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE |
234 | // backward nodes source range corresponds to the forward node |
235 | // TODO: consider using C++ stack trace |
236 | if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { |
237 | auto cs = torch::profiler::impl::prepareCallstack(jit::currentCallstack()); |
238 | torch_ops_.jit_stack_.emplace_back(callstackStr(cs)); |
239 | } |
240 | if (config_.with_modules && |
241 | fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { |
242 | torch_ops_.jit_modules_.emplace_back(jit::currentModuleHierarchy()); |
243 | } |
244 | #endif |
245 | if (config_.with_flops) { |
246 | torch_ops_.extra_args_.emplace_back( |
247 | torch::profiler::impl::saveExtraArgs(fn)); |
248 | } |
249 | |
250 | auto out = std::make_unique<KinetoObserverContext>(event); |
251 | |
252 | if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) { |
253 | try { |
254 | out->fallback_ = torch_ops_.gpu_fallback_.emplace_back(); |
255 | torch::profiler::impl::cudaStubs()->record( |
256 | nullptr, &out->fallback_->cuda_event_start_, nullptr); |
257 | } catch (const std::exception& e) { |
258 | LOG(WARNING) << "Failed to record CUDA event. " << e.what(); |
259 | } |
260 | } |
261 | |
262 | event->start_time_ = torch::profiler::impl::getApproximateTime(); |
263 | event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS(); |
264 | if (!config_.experimental_config.performance_events.empty()) { |
265 | const size_t n = config_.experimental_config.performance_events.size(); |
266 | event->counters_ = std::make_unique<perf_counters_t>(n, 0); |
267 | perf_profiler_->Enable(); |
268 | } |
269 | return out; |
270 | } |
271 | |
272 | // --------------- |
273 | // | Collation | |
274 | // --------------- |
275 | namespace { |
276 | template <typename T> |
277 | struct StealOrDefault { |
278 | StealOrDefault(T& container) |
279 | : container_{container}, it_{container.begin()} {} |
280 | |
281 | ~StealOrDefault() { |
282 | container_.get().clear(); |
283 | } |
284 | |
285 | typename T::Iterator::value_type operator()() { |
286 | if (it_.exhausted()) { |
287 | return typename T::Iterator::value_type(); |
288 | } else { |
289 | auto result = std::move(*it_); |
290 | ++it_; |
291 | return result; |
292 | } |
293 | } |
294 | |
295 | std::reference_wrapper<T> container_; |
296 | typename T::Iterator it_; |
297 | }; |
298 | } // namespace |
299 | |
300 | void ThreadLocalSubqueue::TorchOpStorage::materialize( |
301 | std::vector<std::shared_ptr<Result>>& out, |
302 | const std::function<time_t(approx_time_t)> time_converter, |
303 | const uint64_t tid, |
304 | const kineto::DeviceAndResource& kineto_info) { |
305 | // Plumb Autograd info to the top level annotation. |
306 | auto it = op_events_.begin(); |
307 | for (C10_UNUSED const auto _ : |
308 | c10::irange(static_cast<int64_t>(op_events_.size()) - 1)) { |
309 | auto& first = it->basic_fields_; |
310 | auto& second = (++it)->basic_fields_; |
311 | if (first.scope_ == at::RecordScope::FUNCTION && |
312 | second.scope_ == at::RecordScope::BACKWARD_FUNCTION && |
313 | first.name_.rfind("autograd::engine::evaluate_function: " , 0) == 0) { |
314 | first.sequence_number_ = second.sequence_number_; |
315 | first.forward_tid_ = second.forward_tid_; |
316 | } |
317 | } |
318 | |
319 | // `AccumulateGrad` is an important marker for profile analysis; however the |
320 | // annotation relies on `c10::demangle` which is platform dependent. In |
321 | // particular, Windows will add a "struct " prefix. |
322 | const std::string accumulate_grad = "torch::autograd::AccumulateGrad" ; |
323 | const std::string windows_pattern = std::string("struct " ) + accumulate_grad; |
324 | for (auto& event : op_events_) { |
325 | auto& name = event.basic_fields_.name_; |
326 | auto position = name.find(windows_pattern); |
327 | if (position != std::string::npos) { |
328 | name.replace(position, windows_pattern.size(), accumulate_grad); |
329 | } |
330 | } |
331 | |
332 | auto input_getter = inputs_outputs_.getNextShapesAndDtypes(); |
333 | |
334 | // TODO: CTAD will take care of template args when we move to C++17 |
335 | auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_); |
336 | auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_); |
337 | auto = StealOrDefault<decltype(extra_args_)>(extra_args_); |
338 | auto gpu_fallback = StealOrDefault<decltype(gpu_fallback_)>(gpu_fallback_); |
339 | |
340 | for (auto event = op_events_.begin(); event != op_events_.end(); ++event) { |
341 | ExtraFields<EventType::TorchOp> e{ |
342 | std::move(event->basic_fields_), |
343 | ThreadLocalSubqueue::TorchOpStorage::OpList::correlationID(event), |
344 | time_converter(event->end_time_), |
345 | input_getter(), |
346 | jit_stack(), |
347 | jit_module(), |
348 | extra_args(), |
349 | gpu_fallback(), |
350 | event->allow_tf32_cublas_, |
351 | std::move(event->counters_)}; |
352 | |
353 | out.emplace_back(Result::create( |
354 | time_converter(event->start_time_), tid, kineto_info, std::move(e))); |
355 | } |
356 | |
357 | op_events_.clear(); |
358 | inputs_outputs_.clear(); |
359 | } |
360 | |
361 | template <size_t BlockSize> |
362 | void materialize_vulkan( |
363 | std::vector<std::shared_ptr<Result>>& out, |
364 | AppendOnlyList<ExtraFields<EventType::Vulkan>::raw_event_t, BlockSize>& |
365 | raw_events, |
366 | const std::function<time_t(approx_time_t)> time_converter, |
367 | const uint64_t tid, |
368 | const kineto::DeviceAndResource& kineto_info) { |
369 | for (const auto& i : raw_events) { |
370 | const auto name_and_duration_ns = |
371 | torch::profiler::impl::vulkan::getShaderNameAndDurationNs(i.second); |
372 | |
373 | out.emplace_back(Result::create( |
374 | /*start_time_ns_=*/time_converter(i.first), |
375 | /*start_tid_=*/tid, |
376 | /*kineto_info_=*/kineto_info, |
377 | /*extra_fields_=*/ |
378 | ExtraFields<EventType::Vulkan>{ |
379 | /*name_=*/std::get<0>(name_and_duration_ns), |
380 | /*duration_ns_=*/ |
381 | static_cast<int64_t>(std::get<1>(name_and_duration_ns)), |
382 | /*in_tree_building_=*/false})); |
383 | } |
384 | } |
385 | |
386 | namespace { |
387 | // See `RecordQueue::getSubqueue()` for an overview of this cache. |
388 | struct SubQueueThreadCache { |
389 | uint32_t key_; |
390 | ThreadLocalSubqueue* ref_; |
391 | }; |
392 | |
393 | // The astute observer will note that this leaves a dangling reference; nothing |
394 | // in the teardown of `RecordQueue` or `ThreadLocalSubqueue` clears this value. |
395 | // (And the raw pointer in `SubQueueThreadCache` will not extend the lifetime |
396 | // of `*ref_`.) This is safe, however, because `getSubqueue` will check |
397 | // `sub_queue_cache_.key_` before attempting to access `ref_`, and if `key_` |
398 | // does not match the RecordQueue's *unique* `id_` it will evict |
399 | // `sub_queue_cache_` and fall back to a different mechanism. |
400 | std::atomic<uint32_t> queue_id_{0}; |
401 | thread_local SubQueueThreadCache sub_queue_cache_{0, nullptr}; |
402 | |
403 | std::string (const ExtraFields<EventType::PyCall>& e) { |
404 | if (e.module_.has_value()) { |
405 | return fmt::format( |
406 | "nn.Module: {}_{}" , e.module_->cls_name_.str(), e.module_->id_); |
407 | } |
408 | return fmt::format( |
409 | "{}({}): {}" , |
410 | e.callsite_.filename_.str(), |
411 | e.callsite_.line_no_, |
412 | e.callsite_.funcname_.str()); |
413 | } |
414 | |
415 | auto scopeToType(at::RecordScope scope) { |
416 | return scope == at::RecordScope::USER_SCOPE |
417 | ? libkineto::ActivityType::USER_ANNOTATION |
418 | : libkineto::ActivityType::CPU_OP; |
419 | } |
420 | |
421 | int64_t ( |
422 | const ExtraFields<EventType::TorchOp>& e, |
423 | const bool finished, |
424 | const std::weak_ptr<Result>& parent) { |
425 | if (finished && e.end_time_ns_ == std::numeric_limits<time_t>::min()) { |
426 | auto p = parent.lock(); |
427 | if (p) { |
428 | return p->endTimeNS(); |
429 | } |
430 | } |
431 | return e.end_time_ns_; |
432 | } |
433 | |
434 | auto ( |
435 | const ExtraFields<EventType::Kineto>& e, |
436 | const std::weak_ptr<Result>& parent) { |
437 | if (e.correlation_id_) { |
438 | return e.correlation_id_; |
439 | } |
440 | auto p = parent.lock(); |
441 | return p ? p->correlationID() : 0; |
442 | } |
443 | } // namespace |
444 | |
445 | #define ATTRIBUTE(event_type, expr) \ |
446 | [&](const ExtraFields<EventType::event_type>& e) { \ |
447 | (void)e; \ |
448 | return expr; \ |
449 | } |
450 | |
451 | std::string Result::name() const { |
452 | return visit(c10::overloaded( |
453 | ATTRIBUTE(Vulkan, std::string(e.name_)), |
454 | ATTRIBUTE(Allocation, std::string("[memory]" )), |
455 | ATTRIBUTE(OutOfMemory, std::string("[OutOfMemory]" )), |
456 | ATTRIBUTE(PyCall, toString(e)), |
457 | ATTRIBUTE(PyCCall, std::string(e.function_name_.str())), |
458 | [](const auto& e) -> std::string { return e.name_; })); |
459 | } |
460 | |
461 | libkineto::ActivityType Result::kinetoType() const { |
462 | return visit(c10::overloaded( |
463 | ATTRIBUTE(TorchOp, scopeToType(e.scope_)), |
464 | ATTRIBUTE(Backend, scopeToType(e.scope_)), |
465 | ATTRIBUTE(Vulkan, libkineto::ActivityType::CPU_OP), |
466 | ATTRIBUTE(Allocation, libkineto::ActivityType::CPU_INSTANT_EVENT), |
467 | ATTRIBUTE(OutOfMemory, libkineto::ActivityType::CPU_INSTANT_EVENT), |
468 | ATTRIBUTE(PyCall, libkineto::ActivityType::PYTHON_FUNCTION), |
469 | ATTRIBUTE(PyCCall, libkineto::ActivityType::PYTHON_FUNCTION), |
470 | ATTRIBUTE(Kineto, e.activity_type_))); |
471 | } |
472 | |
473 | uint64_t Result::correlationID() const { |
474 | return visit(c10::overloaded( |
475 | ATTRIBUTE(TorchOp, e.correlation_id_), |
476 | ATTRIBUTE(Kineto, kinetoEventCorrelationID(e, parent_)), |
477 | [&](const auto&) -> uint64_t { return 0; })); |
478 | } |
479 | |
480 | int64_t Result::endTimeNS() const { |
481 | auto end_time_ns = visit(c10::overloaded( |
482 | ATTRIBUTE(TorchOp, torchOpEndNS(e, finished_, parent_)), |
483 | ATTRIBUTE(Backend, e.end_time_us_ * 1000), |
484 | ATTRIBUTE( |
485 | Vulkan, start_time_ns_ + (e.in_tree_building_ ? 0 : e.duration_ns_)), |
486 | ATTRIBUTE(Allocation, start_time_ns_), |
487 | ATTRIBUTE(OutOfMemory, start_time_ns_), |
488 | ATTRIBUTE(Kineto, start_time_ns_ + e.duration_us_ * 1000), |
489 | [&](const auto& e) -> int64_t { return e.end_time_ns_; })); |
490 | |
491 | // In rare cases we're willing to tolerate ops which are missing an end time |
492 | // so long as they can borrow their parent's end time. A consequence of this, |
493 | // however, is that `endTimeNS` may not make sense until tree construction is |
494 | // complete. |
495 | auto end_time_is_valid = |
496 | !finished_ || SOFT_ASSERT(end_time_ns >= start_time_ns_, name()); |
497 | return end_time_is_valid ? end_time_ns : start_time_ns_; |
498 | } |
499 | |
500 | uint64_t Result::endTID() const { |
501 | return visit(c10::overloaded( |
502 | ATTRIBUTE(TorchOp, e.end_tid_), |
503 | [&](const auto&) -> uint64_t { return start_tid_; })); |
504 | } |
505 | |
506 | c10::DeviceType Result::deviceType() const { |
507 | using torch::autograd::profiler::deviceTypeFromActivity; |
508 | return visit(c10::overloaded( |
509 | ATTRIBUTE(Vulkan, c10::DeviceType::Vulkan), |
510 | ATTRIBUTE(Allocation, e.device_type_), |
511 | ATTRIBUTE(OutOfMemory, e.device_type_), |
512 | ATTRIBUTE(Kineto, deviceTypeFromActivity(e.activity_type_)), |
513 | [&](const auto&) { return c10::DeviceType::CPU; })); |
514 | } |
515 | #undef ATTRIBUTE |
516 | |
517 | ThreadLocalSubqueue::ThreadLocalSubqueue( |
518 | const uint64_t tid, |
519 | const ProfilerConfig& config) |
520 | : tid_{tid}, config_{config}, kineto_info_{kineto::kineto_ids()} { |
521 | torch::profiler::impl::kineto::recordThreadInfo(); |
522 | if (!config_.experimental_config.performance_events.empty()) { |
523 | perf_profiler_ = |
524 | std::make_unique<torch::profiler::impl::linux_perf::PerfProfiler>(); |
525 | perf_profiler_->Configure(config_.experimental_config.performance_events); |
526 | } |
527 | } |
528 | |
529 | RecordQueue::RecordQueue( |
530 | const ProfilerConfig& config, |
531 | std::set<ActivityType> activities) |
532 | : id_(++queue_id_), config_{config}, activities_{std::move(activities)} { |
533 | if (tracePython()) { |
534 | python_tracer_ = python_tracer::PythonTracerBase::make(this); |
535 | } |
536 | } |
537 | |
538 | bool RecordQueue::tracePython() const { |
539 | return config_.with_stack && activities_.count(ActivityType::CPU); |
540 | } |
541 | |
542 | ThreadLocalSubqueue* RecordQueue::getSubqueue() { |
543 | // In the most common case, a thread will want to write to the same sub-queue |
544 | // that it wrote to last call. The only time that isn't true is if: |
545 | // A) The profiler context has ended and we are in a new one. |
546 | // B) Two profilers are active in different TLS contexts, and this thread |
547 | // is a worker helping with intra-op parallelism. |
548 | // Since we expect this to be the OVERWHELMINGLY common case (>99%), we add a |
549 | // special thread_local cache so that we can skip the overall `flat_hash_map` |
550 | // (and corresponding lock). |
551 | if (id_ == sub_queue_cache_.key_) { |
552 | return sub_queue_cache_.ref_; |
553 | } |
554 | |
555 | const auto tid = at::RecordFunction::currentThreadId(); |
556 | std::lock_guard<std::mutex> guard(sub_queue_mutex_); |
557 | auto it = sub_queues_.find(tid); |
558 | if (it == sub_queues_.end()) { |
559 | it = sub_queues_ |
560 | .emplace(tid, std::make_unique<ThreadLocalSubqueue>(tid, config_)) |
561 | .first; |
562 | } |
563 | |
564 | sub_queue_cache_ = SubQueueThreadCache{id_, it->second.get()}; |
565 | return it->second.get(); |
566 | } |
567 | |
568 | void RecordQueue::stop() { |
569 | if (python_tracer_) { |
570 | python_tracer_->stop(); |
571 | } |
572 | } |
573 | |
574 | namespace { |
575 | void mark_finished(std::shared_ptr<Result>& r) { |
576 | TORCH_INTERNAL_ASSERT(!r->finished_, r->name()); |
577 | r->finished_ = true; |
578 | TORCH_INTERNAL_ASSERT(r->endTimeNS() >= r->start_time_ns_, r->name()); |
579 | } |
580 | |
581 | static constexpr const char* indexKey = "Ev Idx" ; |
582 | |
583 | void passEventsToKineto( |
584 | const std::vector<std::shared_ptr<Result>>& results, |
585 | uint64_t start_time_us, |
586 | uint64_t end_time_us) { |
587 | using namespace torch::profiler::impl::kineto; |
588 | TraceWrapper cpu_trace(start_time_us, "PyTorch Profiler" ); |
589 | |
590 | // Generate Kineto events for each event recorded by the PyTorch profiler. |
591 | for (const auto i : c10::irange(results.size())) { |
592 | const auto& e = results[i]; |
593 | const auto* activity = cpu_trace.addCPUActivity( |
594 | e->name(), |
595 | e->kinetoType(), |
596 | e->kineto_info_, |
597 | e->correlationID(), |
598 | e->start_time_ns_ / 1000, |
599 | e->endTimeNS() / 1000); |
600 | |
601 | TORCH_INTERNAL_ASSERT(activity || !kKinetoAvailable); |
602 | if (activity) { |
603 | addMetadata(activity, indexKey, std::to_string(i)); |
604 | } |
605 | } |
606 | |
607 | // Kineto adds the events that it collected. |
608 | cpu_trace.transferCpuTrace(end_time_us); |
609 | } |
610 | |
611 | #ifdef USE_KINETO |
612 | // There are two mechanisms that we use to connect Profiler and Kineto events. |
613 | // The first is the correlation ID. The profiler pushes a unique integer at the |
614 | // start of an op and pops it at the end. Kineto then associates the events |
615 | // that it collects with that correlation ID and sets the linked activity of |
616 | // the events that it collected to point to the profiler op. |
617 | // |
618 | // However, this is not a sufficient description because it does not retain |
619 | // dependency information between kineto ops. Consider a call to `torch.add`. |
620 | // Three events will be collected: |
621 | // `aten::add` (TorchOp, collected by profiler) |
622 | // `cudaLaunchKernel` (CUDA runtime event, collected by Kineto) |
623 | // `at::vectorized_...` (GPU kernel, collected by Kineto) |
624 | // If we only relied on correlation IDs we would set both Kineto events as |
625 | // children of the `at::add`, rather than the correct |
626 | // `at::add -> cudaLaunchKernel -> at::vectorized_...` |
627 | // |
628 | // Kineto surfaces this information through a second concept called a "flow". |
629 | // In this example, the `cudaLaunchKernel` event is the start of a flow and the |
630 | // GPU kernel has the same flow id but is not a start event. Thus, when merging |
631 | // the Kineto events into the call tree we first add all events which are flow |
632 | // start nodes. We then merge the rest, trying to pair them with flow starts |
633 | // and falling back to correlation ID if necessary. For any nodes without |
634 | // linked events the caller is determined using the normal tree construction |
635 | // algorithm. |
636 | class TransferEvents { |
637 | using itrace_t = libkineto::ITraceActivity; |
638 | using activity_t = torch::profiler::impl::kineto::activity_t; |
639 | |
640 | public: |
641 | TransferEvents( |
642 | std::vector<std::shared_ptr<Result>>& results, |
643 | trace_ptr_t& trace) |
644 | : results_{results} { |
645 | auto* trace_activities_ptr = trace->get()->activities(); |
646 | TORCH_INTERNAL_ASSERT(trace_activities_ptr != nullptr); |
647 | trace_activities_ = *trace_activities_ptr; |
648 | reassociate(); |
649 | extractEventsFromTrace(); |
650 | setParents(); |
651 | } |
652 | |
653 | private: |
654 | static long long (const std::string& metadata_json) { |
655 | static const auto prefix = fmt::format("\"{}\": " , indexKey); |
656 | auto pos = metadata_json.find(prefix); |
657 | return (pos == std::string::npos) ? unmatchedIndex : [&]() { |
658 | auto end = metadata_json.find(',', pos); |
659 | end = (end == std::string::npos) ? metadata_json.size() : end; |
660 | return std::stoll(metadata_json.substr(pos + prefix.size(), end)); |
661 | }(); |
662 | } |
663 | |
664 | std::shared_ptr<Result> lookup(const itrace_t* key) { |
665 | if (key == nullptr) { |
666 | return nullptr; |
667 | } |
668 | |
669 | // First check the map. |
670 | auto it = kineto_events_.find(key); |
671 | if (it != kineto_events_.end()) { |
672 | return it->second; |
673 | } |
674 | |
675 | // Then fallback to the encoded metadata. |
676 | const auto index = extractIndex(key ? key->metadataJson() : "" ); |
677 | if (index != unmatchedIndex) { |
678 | auto out = results_.get().at(index); |
679 | kineto_events_[key] = out; |
680 | return out; |
681 | } |
682 | |
683 | // And finally give up. |
684 | return nullptr; |
685 | } |
686 | |
687 | void reassociate() { |
688 | // Match profiler events with the corresponding kineto events. Kineto may |
689 | // have moved or copied the activities, so we have to recover the |
690 | // relationship between `libkineto::ITraceActivity` and `Result`. |
691 | for (const auto* activity : trace_activities_) { |
692 | TORCH_INTERNAL_ASSERT(activity != nullptr); |
693 | auto e = lookup(activity); |
694 | if (e != nullptr) { |
695 | TORCH_INTERNAL_ASSERT(e->kineto_activity_ == nullptr); |
696 | e->kineto_activity_ = static_cast<const activity_t*>(activity); |
697 | } |
698 | } |
699 | if (results_.get().size() != kineto_events_.size()) { |
700 | TORCH_WARN(fmt::format( |
701 | "Failed to recover relationship between all profiler and kineto events: " |
702 | "{} vs. {} reassociated." , |
703 | results_.get().size(), |
704 | kineto_events_.size())); |
705 | } |
706 | } |
707 | |
708 | std::shared_ptr<Result> resultFromActivity(const itrace_t* activity) { |
709 | TORCH_INTERNAL_ASSERT(activity != nullptr); |
710 | |
711 | // Kineto is inconsistent with types, so we have to cast to int32. |
712 | torch::profiler::impl::kineto::DeviceAndResource device_and_resource{ |
713 | static_cast<int32_t>(activity->deviceId()), |
714 | static_cast<int32_t>(activity->resourceId())}; |
715 | |
716 | auto event = Result::create( |
717 | activity->timestamp() * 1000, |
718 | noTID, // Placeholder |
719 | device_and_resource, |
720 | ExtraFields<EventType::Kineto>{ |
721 | activity->name(), |
722 | activity->duration(), |
723 | static_cast<uint64_t>(activity->correlationId()), |
724 | activity->type(), |
725 | {/*id=*/static_cast<uint32_t>(activity->flowId()), |
726 | /*type=*/static_cast<uint32_t>(activity->flowType()), |
727 | /*start=*/activity->flowStart()}}); |
728 | |
729 | // NB: It's tempting to set `event->kineto_activity_`; however we can only |
730 | // guarantee that the events we passed to Kineto are of type |
731 | // `GenericTraceActivity`. Others may derive from ITraceActivity and thus |
732 | // are not safe to cast. |
733 | return event; |
734 | } |
735 | |
736 | std::shared_ptr<Result> toResult(const itrace_t* activity) { |
737 | auto e = lookup(activity); |
738 | |
739 | // Until we are very sure that we can reassociate kineto and profiler |
740 | // events we need to be very defensive. |
741 | const auto type = activity->type(); |
742 | if (e == nullptr && |
743 | (type == libkineto::ActivityType::CPU_OP || |
744 | type == libkineto::ActivityType::CPU_INSTANT_EVENT || |
745 | type == libkineto::ActivityType::USER_ANNOTATION || |
746 | type == libkineto::ActivityType::PYTHON_FUNCTION)) { |
747 | TORCH_WARN_ONCE( |
748 | "Detected an event which was likely passed to kineto by the PyTorch " |
749 | "profiler, but is not present in the set of known events: " , |
750 | activity->name(), |
751 | " This most likely means that Kineto has not " |
752 | "maintained address stability for this event. Please report this to " |
753 | "the PyTorch team." ); |
754 | return nullptr; |
755 | } |
756 | |
757 | if (e == nullptr) { |
758 | e = resultFromActivity(activity); |
759 | results_.get().push_back(e); |
760 | kineto_events_[activity] = e; |
761 | } |
762 | return e; |
763 | } |
764 | |
765 | void () { |
766 | for (const auto* activity : trace_activities_) { |
767 | auto e = toResult(activity); |
768 | const auto* linked_activity = activity->linkedActivity(); |
769 | if (e && linked_activity) { |
770 | e->visit(c10::overloaded( |
771 | [&](ExtraFields<EventType::Kineto>& i) { |
772 | i.linked_activity_ = toResult(linked_activity); |
773 | }, |
774 | [](auto&) { TORCH_INTERNAL_ASSERT(false); })); |
775 | } |
776 | } |
777 | } |
778 | |
779 | void setKinetoTID( |
780 | std::shared_ptr<Result>& r, |
781 | std::shared_ptr<Result> parent) { |
782 | r->visit(c10::overloaded( |
783 | [&](ExtraFields<EventType::Kineto>& i) { |
784 | TORCH_INTERNAL_ASSERT(r->start_tid_ == noTID); |
785 | r->start_tid_ = parent ? parent->start_tid_ |
786 | : at::RecordFunction::currentThreadId(); |
787 | }, |
788 | [](auto&) {})); |
789 | |
790 | for (auto& child : r->children_) { |
791 | setKinetoTID(child, r); |
792 | } |
793 | } |
794 | |
795 | void setParents() { |
796 | // First pass: Collect start events and set parent to linked event. |
797 | ska::flat_hash_map<int, std::shared_ptr<Result>> flow_map; |
798 | for (auto& e : results_.get()) { |
799 | TORCH_INTERNAL_ASSERT(e != nullptr); |
800 | e->visit(c10::overloaded( |
801 | [&](const ExtraFields<EventType::Kineto>& i) { |
802 | if (i.flow.type == libkineto::kLinkAsyncCpuGpu && i.flow.start) { |
803 | auto inserted = flow_map.insert({i.flow.id, e}); |
804 | #ifdef USE_ROCM |
805 | if (inserted.second) { |
806 | TORCH_WARN_ONCE( |
807 | "ROCTracer produced duplicate flow start: " , i.flow.id); |
808 | } |
809 | #else // USE_ROCM |
810 | TORCH_INTERNAL_ASSERT(inserted.second); |
811 | #endif // USE_ROCM |
812 | } |
813 | TORCH_INTERNAL_ASSERT(e->parent_.expired()); |
814 | e->parent_ = i.linked_activity_; |
815 | }, |
816 | [](const auto&) {})); |
817 | } |
818 | |
819 | // Second pass |
820 | for (auto& e : results_.get()) { |
821 | e->visit(c10::overloaded( |
822 | [&](const ExtraFields<EventType::Kineto>& i) { |
823 | // Flow takes priority over linked event. |
824 | const auto it = flow_map.find(i.flow.id); |
825 | if (it != flow_map.end() && |
826 | i.flow.type == libkineto::kLinkAsyncCpuGpu && !i.flow.start) { |
827 | e->parent_ = it->second; |
828 | } |
829 | |
830 | // If a parent was set we have to do some bookkeeping. |
831 | auto parent = e->parent_.lock(); |
832 | if (parent) { |
833 | parent->children_.push_back(e); |
834 | mark_finished(e); |
835 | } |
836 | }, |
837 | [](const auto&) {})); |
838 | } |
839 | |
840 | // Set TIDs now that we have established lineage. |
841 | for (auto& e : results_.get()) { |
842 | if (e->parent_.expired()) { |
843 | setKinetoTID(e, nullptr); |
844 | } |
845 | } |
846 | } |
847 | |
848 | static constexpr long long unmatchedIndex = -1; |
849 | static constexpr auto noTID = std::numeric_limits<uint64_t>::max(); |
850 | std::reference_wrapper<std::vector<std::shared_ptr<Result>>> results_; |
851 | std::vector<const itrace_t*> trace_activities_; |
852 | ska::flat_hash_map<const itrace_t*, std::shared_ptr<Result>> kineto_events_; |
853 | }; |
854 | #else |
855 | class TransferEvents { |
856 | public: |
857 | template <class... Args> |
858 | TransferEvents(Args&&...) {} |
859 | }; |
860 | #endif |
861 | |
862 | trace_ptr_t addKinetoEvents( |
863 | std::vector<std::shared_ptr<Result>>& results, |
864 | uint64_t start_time_us, |
865 | uint64_t end_time_us, |
866 | const ProfilerConfig& config) { |
867 | using namespace torch::profiler::impl::kineto; |
868 | passEventsToKineto(results, start_time_us, end_time_us); |
869 | |
870 | // In on demand mode kineto is directly controlled by other machinery. |
871 | if (config.global()) { |
872 | return nullptr; |
873 | } |
874 | |
875 | auto trace = std::make_unique<ActivityTraceWrapper>(stopTrace()); |
876 | TORCH_INTERNAL_ASSERT(trace || !kKinetoAvailable); |
877 | TransferEvents transfer{results, trace}; |
878 | return trace; |
879 | } |
880 | |
881 | struct ResultGreater { |
882 | bool operator()(const result_ptr_t& a, const result_ptr_t& b) const { |
883 | return a->endTimeNS() > b->endTimeNS(); |
884 | } |
885 | }; |
886 | |
887 | void set_in_tree_building( |
888 | std::vector<result_ptr_t>& results, |
889 | const bool value) { |
890 | for (result_ptr_t& r : results) { |
891 | r->visit(c10::overloaded( |
892 | [value](ExtraFields<EventType::Vulkan>& i) { |
893 | i.in_tree_building_ = value; |
894 | }, |
895 | [&](auto&) { |
896 | // pass |
897 | })); |
898 | } |
899 | } |
900 | |
901 | void build_tree(std::vector<std::shared_ptr<Result>>& sorted_events) { |
902 | set_in_tree_building(sorted_events, true); |
903 | |
904 | using op_fields = ExtraFields<EventType::TorchOp>; |
905 | ska::flat_hash_map<uint64_t, std::shared_ptr<Result>> stacks; |
906 | std::priority_queue<result_ptr_t, std::vector<result_ptr_t>, ResultGreater> |
907 | end_events_; |
908 | |
909 | auto push_event = [&stacks, &end_events_](std::shared_ptr<Result>& event) { |
910 | // Kineto builds subtrees using correlation ids and flows, so some Kineto |
911 | // events are already marked finished before the main tree building |
912 | // algorithm. It's fine to ignore them; the root event of these subtrees |
913 | // not a Kineto op and will be handled normally. |
914 | if (c10::holds_alternative<ExtraFields<EventType::Kineto>>( |
915 | event->extra_fields_) && |
916 | event->finished_) { |
917 | return; |
918 | } |
919 | |
920 | TORCH_INTERNAL_ASSERT(event->parent_.expired()); |
921 | for (const auto& child : event->children_) { |
922 | TORCH_INTERNAL_ASSERT(child->finished_); |
923 | } |
924 | TORCH_INTERNAL_ASSERT(!event->finished_); |
925 | |
926 | auto parent_it = stacks.find(event->start_tid_); |
927 | if (parent_it == stacks.end()) { |
928 | auto fwd_tid = event->visit(c10::overloaded( |
929 | [](const op_fields& i) { return i.forward_tid_; }, |
930 | [](const auto&) -> uint64_t { return 0; })); |
931 | if (fwd_tid) { |
932 | parent_it = stacks.find(fwd_tid); |
933 | } |
934 | } |
935 | |
936 | if (parent_it != stacks.end()) { |
937 | event->parent_ = parent_it->second; |
938 | parent_it->second->children_.push_back(event); |
939 | } |
940 | |
941 | if (event->endTimeNS() > event->start_time_ns_) { |
942 | stacks[event->start_tid_] = event; |
943 | end_events_.push(event); |
944 | } else if (event->endTimeNS() == std::numeric_limits<time_t>::min()) { |
945 | // We use min time to indicate the lack of a termination event, so if we |
946 | // encounter such a case we don't push to `end_events_`. |
947 | stacks[event->start_tid_] = event; |
948 | } else { |
949 | mark_finished(event); |
950 | } |
951 | }; |
952 | |
953 | auto pop_event = [&stacks](std::shared_ptr<Result> event) { |
954 | if (event->finished_) { |
955 | // This event was marked finished by a previous `pop_event` call. |
956 | return; |
957 | } |
958 | |
959 | auto start_tid = event->start_tid_; |
960 | auto frame = stacks.at(start_tid); |
961 | |
962 | while (frame.get() != event.get()) { |
963 | TORCH_INTERNAL_ASSERT(frame != nullptr); |
964 | mark_finished(frame); |
965 | TORCH_INTERNAL_ASSERT(!frame->parent_.expired()); |
966 | frame = frame->parent_.lock(); |
967 | } |
968 | |
969 | mark_finished(event); |
970 | stacks.erase(start_tid); |
971 | auto new_frame = event->parent_.lock(); |
972 | if (new_frame != nullptr) { |
973 | stacks[start_tid] = new_frame; |
974 | } |
975 | }; |
976 | |
977 | // Stack replay loop. |
978 | for (auto& event : sorted_events) { |
979 | while (!end_events_.empty() && |
980 | end_events_.top()->endTimeNS() < event->start_time_ns_) { |
981 | pop_event(end_events_.top()); |
982 | end_events_.pop(); |
983 | } |
984 | push_event(event); |
985 | } |
986 | |
987 | // Cleanup remaining exit events. |
988 | while (!end_events_.empty()) { |
989 | pop_event(end_events_.top()); |
990 | end_events_.pop(); |
991 | } |
992 | |
993 | set_in_tree_building(sorted_events, false); |
994 | } |
995 | |
996 | /** |
997 | * Adjust r's duration to be the max of its current duration and the sum of all |
998 | * of its children's adjusted durations (keeping its start time the same) |
999 | * (adjust all child durations recursively) |
1000 | */ |
1001 | int64_t adjust_durations_dfs(std::shared_ptr<Result>& r) { |
1002 | if (SOFT_ASSERT(r != nullptr)) { |
1003 | int64_t original_duration = r->endTimeNS() - r->start_time_ns_; |
1004 | int64_t children_total_duration = std::accumulate( |
1005 | r->children_.begin(), |
1006 | r->children_.end(), |
1007 | 0, |
1008 | [](int64_t acc, std::shared_ptr<Result>& child) { |
1009 | return acc + adjust_durations_dfs(child); |
1010 | }); |
1011 | |
1012 | if (children_total_duration > original_duration) { |
1013 | r->visit(c10::overloaded( |
1014 | [&r, &children_total_duration](ExtraFields<EventType::TorchOp>& i) { |
1015 | i.end_time_ns_ = r->start_time_ns_ + children_total_duration; |
1016 | }, |
1017 | [&children_total_duration](ExtraFields<EventType::Vulkan>& i) { |
1018 | i.duration_ns_ = children_total_duration; |
1019 | }, |
1020 | [](ExtraFields<EventType::Allocation>& _) { |
1021 | // Pass- Allocation events can't have children |
1022 | }, |
1023 | [&](auto&) { |
1024 | SOFT_ASSERT( |
1025 | false, |
1026 | "unexpected event type in mobile profiler adjust_durations_dfs: " , |
1027 | r->name()); |
1028 | })); |
1029 | return children_total_duration; |
1030 | } else { |
1031 | return original_duration; |
1032 | } |
1033 | } else { |
1034 | return 0; |
1035 | } |
1036 | } |
1037 | |
1038 | /** |
1039 | * 1) Adjust r's start time to be [new_start_time] (also adjusting end time and |
1040 | keeping duration the same) |
1041 | * 2) Recursively adjust r's children's start times, making them line up such |
1042 | that the last one ends at the same time as r |
1043 | * 3) Return r's final end time |
1044 | */ |
1045 | int64_t adjust_timestamps_dfs( |
1046 | std::shared_ptr<Result>& r, |
1047 | int64_t new_start_time) { |
1048 | if (SOFT_ASSERT(r != nullptr)) { |
1049 | if (r->start_time_ns_ != new_start_time) { |
1050 | // Adjust start time (keeping duration constant) |
1051 | r->visit(c10::overloaded( |
1052 | [&r, &new_start_time](ExtraFields<EventType::TorchOp>& i) { |
1053 | i.end_time_ns_ = |
1054 | new_start_time + (i.end_time_ns_ - r->start_time_ns_); |
1055 | }, |
1056 | [](ExtraFields<EventType::Vulkan>& i) { |
1057 | // Pass- We don't need to manually adjust end time for Vulkan events |
1058 | }, |
1059 | [](ExtraFields<EventType::Allocation>& _) { |
1060 | // Pass- No duration or end time to adjust |
1061 | }, |
1062 | [&](auto&) { |
1063 | SOFT_ASSERT( |
1064 | false, |
1065 | "unexpected event type in mobile profiler adjust_timestamps_dfs: " , |
1066 | r->name()); |
1067 | })); |
1068 | r->start_time_ns_ = new_start_time; |
1069 | } |
1070 | int64_t children_total_duration = std::accumulate( |
1071 | r->children_.begin(), |
1072 | r->children_.end(), |
1073 | 0, |
1074 | [](int64_t acc, std::shared_ptr<Result>& child) { |
1075 | return acc + (child->endTimeNS() - child->start_time_ns_); |
1076 | }); |
1077 | |
1078 | int64_t child_start_time = r->endTimeNS() - children_total_duration; |
1079 | for (std::shared_ptr<Result>& child : r->children_) { |
1080 | child_start_time = adjust_timestamps_dfs(child, child_start_time); |
1081 | } |
1082 | } |
1083 | return r->endTimeNS(); |
1084 | } |
1085 | |
1086 | /** |
1087 | * Adjust timestamps and durations of nodes in [out] such that |
1088 | * - Vulkan event timelines are synchronized with CPU event times |
1089 | * - Parent event timelines fully contain their child timelines |
1090 | * - No overlaps in timelines for nodes at the same depth |
1091 | */ |
1092 | void adjust_timestamps(std::vector<std::shared_ptr<Result>>& out) { |
1093 | if (out.empty()) { |
1094 | return; |
1095 | } |
1096 | |
1097 | int64_t min_start_time = out[0]->start_time_ns_; |
1098 | for (std::shared_ptr<Result>& r : out) { |
1099 | // Only begin traversal for root nodes. |
1100 | if (r->parent_.expired()) { |
1101 | adjust_durations_dfs(r); |
1102 | min_start_time = adjust_timestamps_dfs( |
1103 | r, |
1104 | std::max( |
1105 | r->tag() != EventType::Vulkan |
1106 | ? r->start_time_ns_ |
1107 | : std::numeric_limits<int64_t>::min(), |
1108 | min_start_time)); |
1109 | } |
1110 | } |
1111 | } |
1112 | } // namespace |
1113 | |
1114 | std::pair< |
1115 | std::vector<std::shared_ptr<Result>>, |
1116 | std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>> |
1117 | RecordQueue::getRecords( |
1118 | std::function<time_t(approx_time_t)> time_converter, |
1119 | uint64_t start_time_us, |
1120 | uint64_t end_time_us) { |
1121 | auto converter = [&](approx_time_t t) { |
1122 | return t == std::numeric_limits<approx_time_t>::min() |
1123 | ? std::numeric_limits<time_t>::min() |
1124 | : time_converter(t); |
1125 | }; |
1126 | std::vector<std::shared_ptr<Result>> out; |
1127 | std::vector<python_tracer::CompressedEvent> python_enters; |
1128 | for (auto& subqueue_it : sub_queues_) { |
1129 | auto& queue = *subqueue_it.second; |
1130 | auto materialize = [&](auto& events) { |
1131 | for (auto& i : events) { |
1132 | time_t start_time_ns; |
1133 | if constexpr (std::is_same< |
1134 | std::remove_reference_t<decltype(i)>, |
1135 | ExtraFields<EventType::Backend>>::value) { |
1136 | start_time_ns = i.start_time_us_ * 1000; |
1137 | } else { |
1138 | start_time_ns = converter(i.start_time_); |
1139 | } |
1140 | out.emplace_back(Result::create( |
1141 | /*start_time_ns_=*/start_time_ns, |
1142 | /*start_tid_=*/queue.tid(), |
1143 | /*kineto_info_=*/queue.kineto_info(), |
1144 | /*extra_fields_=*/std::move(i))); |
1145 | } |
1146 | events.clear(); |
1147 | }; |
1148 | |
1149 | queue.torch_ops_.materialize( |
1150 | out, converter, queue.tid(), queue.kineto_info()); |
1151 | materialize(queue.backend_events_); |
1152 | materialize_vulkan( |
1153 | out, queue.vulkan_events_, converter, queue.tid(), queue.kineto_info()); |
1154 | for (auto& i : queue.allocations_) { |
1155 | out.emplace_back(Result::create( |
1156 | /*start_time_ns_=*/converter(i.start_time_), |
1157 | /*start_tid_=*/queue.tid(), |
1158 | /*kineto_info_=*/queue.kineto_info(), |
1159 | /*extra_fields_=*/ExtraFields<EventType::Allocation>(i))); |
1160 | } |
1161 | materialize(queue.ooms_); |
1162 | |
1163 | for (auto& i : queue.py_calls_) { |
1164 | python_enters.push_back( |
1165 | {i.first, queue.tid(), queue.kineto_info(), converter(i.second)}); |
1166 | } |
1167 | } |
1168 | |
1169 | if (python_tracer_) { |
1170 | for (const auto& i : python_tracer_->getEvents( |
1171 | converter, python_enters, end_time_us * 1000)) { |
1172 | out.push_back(i); |
1173 | } |
1174 | python_tracer_.reset(); |
1175 | } |
1176 | |
1177 | if (config_.experimental_config.adjust_timestamps) { |
1178 | std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) { |
1179 | return a->start_time_ns_ < b->start_time_ns_; |
1180 | }); |
1181 | build_tree(out); |
1182 | adjust_timestamps(out); |
1183 | for (auto& r : out) { |
1184 | r->parent_.reset(); |
1185 | // Reset these so that second build_tree can happen |
1186 | r->finished_ = false; |
1187 | r->children_.clear(); |
1188 | } |
1189 | } |
1190 | |
1191 | auto trace = addKinetoEvents(out, start_time_us, end_time_us, config_); |
1192 | |
1193 | std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) { |
1194 | return a->start_time_ns_ < b->start_time_ns_; |
1195 | }); |
1196 | |
1197 | if (config_.report_input_shapes && config_.profile_memory) { |
1198 | calculateUniqueTensorIDs(out); |
1199 | } |
1200 | |
1201 | build_tree(out); |
1202 | return {out, std::move(trace)}; |
1203 | } |
1204 | |
1205 | } // namespace impl |
1206 | } // namespace profiler |
1207 | } // namespace torch |
1208 | |