1 | #include <torch/csrc/profiler/standalone/itt_observer.h> |
2 | |
3 | #include <torch/csrc/profiler/stubs/base.h> |
4 | #include <torch/csrc/profiler/util.h> |
5 | |
6 | namespace torch { |
7 | namespace profiler { |
8 | namespace impl { |
9 | |
10 | struct ITTThreadLocalState : ProfilerStateBase { |
11 | explicit ITTThreadLocalState(const ProfilerConfig& config) |
12 | : ProfilerStateBase(config) { |
13 | // Only `report_input_shapes` makes sense in this context. |
14 | TORCH_CHECK(!config.profile_memory); |
15 | TORCH_CHECK(!config.with_stack); |
16 | TORCH_CHECK(!config.with_flops); |
17 | TORCH_CHECK(!config.with_modules); |
18 | } |
19 | ~ITTThreadLocalState() override = default; |
20 | |
21 | ActiveProfilerType profilerType() override { |
22 | return ActiveProfilerType::ITT; |
23 | } |
24 | |
25 | void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override { |
26 | } |
27 | |
28 | static ITTThreadLocalState* getTLS() { |
29 | auto tls = ProfilerStateBase::get(/*global=*/false); |
30 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
31 | tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT); |
32 | return static_cast<ITTThreadLocalState*>(tls); |
33 | } |
34 | }; |
35 | |
36 | template <bool report_input_shapes> |
37 | std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) { |
38 | if (ITTThreadLocalState::getTLS() != nullptr) { |
39 | torch::profiler::impl::ittStubs()->rangePush(fn.name()); |
40 | } |
41 | return nullptr; |
42 | } |
43 | |
44 | void pushITTCallbacks( |
45 | const ProfilerConfig& config, |
46 | const std::unordered_set<at::RecordScope>& scopes) { |
47 | TORCH_CHECK( |
48 | torch::profiler::impl::ittStubs()->enabled(), |
49 | "Can't use ITT profiler - PyTorch was compiled without ITT" ); |
50 | |
51 | c10::ThreadLocalDebugInfo::_push( |
52 | c10::DebugInfoKind::PROFILER_STATE, |
53 | std::make_shared<ITTThreadLocalState>(config)); |
54 | |
55 | auto state_ptr = ITTThreadLocalState::getTLS(); |
56 | TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set" ); |
57 | |
58 | auto handle = at::addThreadLocalCallback( |
59 | at::RecordFunctionCallback( |
60 | state_ptr->config().report_input_shapes |
61 | ? &enterITT</*report_input_shapes=*/true> |
62 | : &enterITT</*report_input_shapes=*/false>, |
63 | [](const at::RecordFunction&, at::ObserverContext*) { |
64 | torch::profiler::impl::ittStubs()->rangePop(); |
65 | }) |
66 | .needsInputs(config.report_input_shapes) |
67 | .scopes(scopes)); |
68 | state_ptr->setCallbackHandle(handle); |
69 | } |
70 | |
71 | } // namespace impl |
72 | } // namespace profiler |
73 | } // namespace torch |
74 | |