1#include <torch/csrc/profiler/standalone/nvtx_observer.h>
2
3#include <torch/csrc/profiler/stubs/base.h>
4#include <torch/csrc/profiler/util.h>
5
6namespace torch {
7namespace profiler {
8namespace impl {
9
10struct NVTXThreadLocalState : ProfilerStateBase {
11 explicit NVTXThreadLocalState(const ProfilerConfig& config)
12 : ProfilerStateBase(config) {
13 // Only `report_input_shapes` makes sense in this context.
14 TORCH_CHECK(!config.profile_memory);
15 TORCH_CHECK(!config.with_stack);
16 TORCH_CHECK(!config.with_flops);
17 TORCH_CHECK(!config.with_modules);
18 }
19 ~NVTXThreadLocalState() override = default;
20
21 ActiveProfilerType profilerType() override {
22 return ActiveProfilerType::NVTX;
23 }
24
25 void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
26 }
27
28 static NVTXThreadLocalState* getTLS() {
29 auto tls = ProfilerStateBase::get(/*global=*/false);
30 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
31 tls == nullptr || tls->profilerType() == ActiveProfilerType::NVTX);
32 return static_cast<NVTXThreadLocalState*>(tls);
33 }
34 std::pair<at::RecordFunctionHandle, int> getOpIdFromInput(
35 const at::Tensor& tensor);
36
37 void setProducerTensorMap(
38 at::TensorImpl* tensor,
39 at::RecordFunctionHandle op_id,
40 int output_nr) {
41 producer_tensor_map_[(void*)tensor] =
42 std::pair<at::RecordFunctionHandle, int>{op_id, output_nr};
43 }
44
45 protected:
46 // Maps the address of an output Tensor to a unique op id and output
47 // index of the tensor.
48 // at::TensorImpl* is the actual type of the key, but using void*
49 // to indicate the pointer is just being used as a key
50 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
51 std::unordered_map<void*, std::pair<at::RecordFunctionHandle, int>>
52 producer_tensor_map_;
53};
54
55std::pair<at::RecordFunctionHandle, int> NVTXThreadLocalState::getOpIdFromInput(
56 const at::Tensor& tensor) {
57 std::pair<at::RecordFunctionHandle, int> producer_op_pair(0, -1);
58 if (tensor.defined()) {
59 at::TensorImpl* ten_addr = tensor.unsafeGetTensorImpl();
60 // See if Address is in the map already
61 if (producer_tensor_map_.count((void*)ten_addr) > 0) {
62 producer_op_pair = producer_tensor_map_[(void*)ten_addr];
63 }
64 }
65 return producer_op_pair;
66}
67
68std::list<std::pair<at::RecordFunctionHandle, int>> flattenOpIdList(
69 c10::List<c10::IValue> list,
70 std::string fn_name) {
71 std::list<std::pair<at::RecordFunctionHandle, int>> input_op_id_list;
72 auto state_ptr = NVTXThreadLocalState::getTLS();
73 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
74 for (const c10::IValue& input : list) {
75 if (input.isTensor()) {
76 const at::Tensor& tensor = input.toTensor();
77 auto producer_op_pair = state_ptr->getOpIdFromInput(tensor);
78 input_op_id_list.push_back(producer_op_pair);
79 }
80 }
81 return input_op_id_list;
82}
83
84std::list<std::pair<at::RecordFunctionHandle, int>> getInputTensorOpIds(
85 const at::RecordFunction& fn) {
86 std::pair<at::RecordFunctionHandle, int> undefined_op_pair(0, -1);
87 std::list<std::pair<at::RecordFunctionHandle, int>> input_producer_ops_;
88 auto state_ptr = NVTXThreadLocalState::getTLS();
89 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
90 for (const c10::IValue& input_item : fn.inputs()) {
91 if (input_item.isTensor()) {
92 const at::Tensor& tensor = input_item.toTensor();
93 auto producer_pair = state_ptr->getOpIdFromInput(tensor);
94 input_producer_ops_.push_back(producer_pair);
95 } else {
96 if (input_item.isList()) {
97 std::list<std::pair<at::RecordFunctionHandle, int>> tmp_op_ids =
98 flattenOpIdList(input_item.toList(), std::string(fn.name()));
99 // Extend the current sizes array by the array returned from input sizes
100 if (!tmp_op_ids.empty()) {
101 input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids);
102 } else {
103 input_producer_ops_.emplace_back(undefined_op_pair);
104 }
105 } else {
106 input_producer_ops_.emplace_back(undefined_op_pair);
107 }
108 }
109 }
110 return input_producer_ops_;
111}
112
113void updateOutputTensorTracker(const at::RecordFunction& fn) {
114 int output_nr = 0;
115 auto state_ptr = NVTXThreadLocalState::getTLS();
116 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
117 for (const c10::IValue& s_tensor : fn.outputs()) {
118 if (s_tensor.isTensor()) {
119 const at::Tensor& tensor = s_tensor.toTensor();
120 if (tensor.defined()) {
121 auto ten_addr = tensor.unsafeGetTensorImpl();
122 state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr);
123 }
124 }
125 output_nr++;
126 }
127}
128
129template <bool report_input_shapes>
130std::unique_ptr<at::ObserverContext> enterNVTX(const at::RecordFunction& fn) {
131 if (NVTXThreadLocalState::getTLS() != nullptr) {
132 auto input_op_ids = getInputTensorOpIds(fn);
133 torch::profiler::impl::cudaStubs()->rangePush(
134 torch::profiler::impl::getNvtxStr(
135 fn.name(),
136 fn.seqNr(),
137 report_input_shapes ? torch::profiler::impl::inputSizes(fn, true)
138 : std::vector<std::vector<int64_t>>(),
139 fn.handle(),
140 report_input_shapes
141 ? input_op_ids
142 : std::list<std::pair<at::RecordFunctionHandle, int>>())
143 .c_str());
144 }
145 return nullptr;
146}
147
148void pushNVTXCallbacks(
149 const ProfilerConfig& config,
150 const std::unordered_set<at::RecordScope>& scopes) {
151 TORCH_CHECK(
152 torch::profiler::impl::cudaStubs()->enabled(),
153 "Can't use NVTX profiler - PyTorch was compiled without CUDA");
154
155 c10::ThreadLocalDebugInfo::_push(
156 c10::DebugInfoKind::PROFILER_STATE,
157 std::make_shared<NVTXThreadLocalState>(config));
158
159 auto state_ptr = NVTXThreadLocalState::getTLS();
160 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
161
162 auto handle = at::addThreadLocalCallback(
163 at::RecordFunctionCallback(
164 state_ptr->config().report_input_shapes
165 ? &enterNVTX</*report_input_shapes=*/true>
166 : &enterNVTX</*report_input_shapes=*/false>,
167 [](const at::RecordFunction& fn, at::ObserverContext* ctx) {
168 torch::profiler::impl::cudaStubs()->rangePop();
169 updateOutputTensorTracker(fn);
170 })
171 .needsInputs(config.report_input_shapes)
172 .needsOutputs(config.report_input_shapes)
173 .needsIds(true)
174 .scopes(scopes));
175 state_ptr->setCallbackHandle(handle);
176}
177
178} // namespace impl
179} // namespace profiler
180} // namespace torch
181