1#include <torch/csrc/profiler/standalone/itt_observer.h>
2
3#include <torch/csrc/profiler/stubs/base.h>
4#include <torch/csrc/profiler/util.h>
5
6namespace torch {
7namespace profiler {
8namespace impl {
9
10struct ITTThreadLocalState : ProfilerStateBase {
11 explicit ITTThreadLocalState(const ProfilerConfig& config)
12 : ProfilerStateBase(config) {
13 // Only `report_input_shapes` makes sense in this context.
14 TORCH_CHECK(!config.profile_memory);
15 TORCH_CHECK(!config.with_stack);
16 TORCH_CHECK(!config.with_flops);
17 TORCH_CHECK(!config.with_modules);
18 }
19 ~ITTThreadLocalState() override = default;
20
21 ActiveProfilerType profilerType() override {
22 return ActiveProfilerType::ITT;
23 }
24
25 void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
26 }
27
28 static ITTThreadLocalState* getTLS() {
29 auto tls = ProfilerStateBase::get(/*global=*/false);
30 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
31 tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT);
32 return static_cast<ITTThreadLocalState*>(tls);
33 }
34};
35
36template <bool report_input_shapes>
37std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) {
38 if (ITTThreadLocalState::getTLS() != nullptr) {
39 torch::profiler::impl::ittStubs()->rangePush(fn.name());
40 }
41 return nullptr;
42}
43
44void pushITTCallbacks(
45 const ProfilerConfig& config,
46 const std::unordered_set<at::RecordScope>& scopes) {
47 TORCH_CHECK(
48 torch::profiler::impl::ittStubs()->enabled(),
49 "Can't use ITT profiler - PyTorch was compiled without ITT");
50
51 c10::ThreadLocalDebugInfo::_push(
52 c10::DebugInfoKind::PROFILER_STATE,
53 std::make_shared<ITTThreadLocalState>(config));
54
55 auto state_ptr = ITTThreadLocalState::getTLS();
56 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
57
58 auto handle = at::addThreadLocalCallback(
59 at::RecordFunctionCallback(
60 state_ptr->config().report_input_shapes
61 ? &enterITT</*report_input_shapes=*/true>
62 : &enterITT</*report_input_shapes=*/false>,
63 [](const at::RecordFunction&, at::ObserverContext*) {
64 torch::profiler::impl::ittStubs()->rangePop();
65 })
66 .needsInputs(config.report_input_shapes)
67 .scopes(scopes));
68 state_ptr->setCallbackHandle(handle);
69}
70
71} // namespace impl
72} // namespace profiler
73} // namespace torch
74