1 | #include <torch/csrc/profiler/standalone/nvtx_observer.h> |
2 | |
3 | #include <torch/csrc/profiler/stubs/base.h> |
4 | #include <torch/csrc/profiler/util.h> |
5 | |
6 | namespace torch { |
7 | namespace profiler { |
8 | namespace impl { |
9 | |
10 | struct NVTXThreadLocalState : ProfilerStateBase { |
11 | explicit NVTXThreadLocalState(const ProfilerConfig& config) |
12 | : ProfilerStateBase(config) { |
13 | // Only `report_input_shapes` makes sense in this context. |
14 | TORCH_CHECK(!config.profile_memory); |
15 | TORCH_CHECK(!config.with_stack); |
16 | TORCH_CHECK(!config.with_flops); |
17 | TORCH_CHECK(!config.with_modules); |
18 | } |
19 | ~NVTXThreadLocalState() override = default; |
20 | |
21 | ActiveProfilerType profilerType() override { |
22 | return ActiveProfilerType::NVTX; |
23 | } |
24 | |
25 | void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override { |
26 | } |
27 | |
28 | static NVTXThreadLocalState* getTLS() { |
29 | auto tls = ProfilerStateBase::get(/*global=*/false); |
30 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
31 | tls == nullptr || tls->profilerType() == ActiveProfilerType::NVTX); |
32 | return static_cast<NVTXThreadLocalState*>(tls); |
33 | } |
34 | std::pair<at::RecordFunctionHandle, int> getOpIdFromInput( |
35 | const at::Tensor& tensor); |
36 | |
37 | void setProducerTensorMap( |
38 | at::TensorImpl* tensor, |
39 | at::RecordFunctionHandle op_id, |
40 | int output_nr) { |
41 | producer_tensor_map_[(void*)tensor] = |
42 | std::pair<at::RecordFunctionHandle, int>{op_id, output_nr}; |
43 | } |
44 | |
45 | protected: |
46 | // Maps the address of an output Tensor to a unique op id and output |
47 | // index of the tensor. |
48 | // at::TensorImpl* is the actual type of the key, but using void* |
49 | // to indicate the pointer is just being used as a key |
50 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
51 | std::unordered_map<void*, std::pair<at::RecordFunctionHandle, int>> |
52 | producer_tensor_map_; |
53 | }; |
54 | |
55 | std::pair<at::RecordFunctionHandle, int> NVTXThreadLocalState::getOpIdFromInput( |
56 | const at::Tensor& tensor) { |
57 | std::pair<at::RecordFunctionHandle, int> producer_op_pair(0, -1); |
58 | if (tensor.defined()) { |
59 | at::TensorImpl* ten_addr = tensor.unsafeGetTensorImpl(); |
60 | // See if Address is in the map already |
61 | if (producer_tensor_map_.count((void*)ten_addr) > 0) { |
62 | producer_op_pair = producer_tensor_map_[(void*)ten_addr]; |
63 | } |
64 | } |
65 | return producer_op_pair; |
66 | } |
67 | |
68 | std::list<std::pair<at::RecordFunctionHandle, int>> flattenOpIdList( |
69 | c10::List<c10::IValue> list, |
70 | std::string fn_name) { |
71 | std::list<std::pair<at::RecordFunctionHandle, int>> input_op_id_list; |
72 | auto state_ptr = NVTXThreadLocalState::getTLS(); |
73 | TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set" ); |
74 | for (const c10::IValue& input : list) { |
75 | if (input.isTensor()) { |
76 | const at::Tensor& tensor = input.toTensor(); |
77 | auto producer_op_pair = state_ptr->getOpIdFromInput(tensor); |
78 | input_op_id_list.push_back(producer_op_pair); |
79 | } |
80 | } |
81 | return input_op_id_list; |
82 | } |
83 | |
84 | std::list<std::pair<at::RecordFunctionHandle, int>> getInputTensorOpIds( |
85 | const at::RecordFunction& fn) { |
86 | std::pair<at::RecordFunctionHandle, int> undefined_op_pair(0, -1); |
87 | std::list<std::pair<at::RecordFunctionHandle, int>> input_producer_ops_; |
88 | auto state_ptr = NVTXThreadLocalState::getTLS(); |
89 | TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set" ); |
90 | for (const c10::IValue& input_item : fn.inputs()) { |
91 | if (input_item.isTensor()) { |
92 | const at::Tensor& tensor = input_item.toTensor(); |
93 | auto producer_pair = state_ptr->getOpIdFromInput(tensor); |
94 | input_producer_ops_.push_back(producer_pair); |
95 | } else { |
96 | if (input_item.isList()) { |
97 | std::list<std::pair<at::RecordFunctionHandle, int>> tmp_op_ids = |
98 | flattenOpIdList(input_item.toList(), std::string(fn.name())); |
99 | // Extend the current sizes array by the array returned from input sizes |
100 | if (!tmp_op_ids.empty()) { |
101 | input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids); |
102 | } else { |
103 | input_producer_ops_.emplace_back(undefined_op_pair); |
104 | } |
105 | } else { |
106 | input_producer_ops_.emplace_back(undefined_op_pair); |
107 | } |
108 | } |
109 | } |
110 | return input_producer_ops_; |
111 | } |
112 | |
113 | void updateOutputTensorTracker(const at::RecordFunction& fn) { |
114 | int output_nr = 0; |
115 | auto state_ptr = NVTXThreadLocalState::getTLS(); |
116 | TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set" ); |
117 | for (const c10::IValue& s_tensor : fn.outputs()) { |
118 | if (s_tensor.isTensor()) { |
119 | const at::Tensor& tensor = s_tensor.toTensor(); |
120 | if (tensor.defined()) { |
121 | auto ten_addr = tensor.unsafeGetTensorImpl(); |
122 | state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr); |
123 | } |
124 | } |
125 | output_nr++; |
126 | } |
127 | } |
128 | |
129 | template <bool report_input_shapes> |
130 | std::unique_ptr<at::ObserverContext> enterNVTX(const at::RecordFunction& fn) { |
131 | if (NVTXThreadLocalState::getTLS() != nullptr) { |
132 | auto input_op_ids = getInputTensorOpIds(fn); |
133 | torch::profiler::impl::cudaStubs()->rangePush( |
134 | torch::profiler::impl::getNvtxStr( |
135 | fn.name(), |
136 | fn.seqNr(), |
137 | report_input_shapes ? torch::profiler::impl::inputSizes(fn, true) |
138 | : std::vector<std::vector<int64_t>>(), |
139 | fn.handle(), |
140 | report_input_shapes |
141 | ? input_op_ids |
142 | : std::list<std::pair<at::RecordFunctionHandle, int>>()) |
143 | .c_str()); |
144 | } |
145 | return nullptr; |
146 | } |
147 | |
148 | void pushNVTXCallbacks( |
149 | const ProfilerConfig& config, |
150 | const std::unordered_set<at::RecordScope>& scopes) { |
151 | TORCH_CHECK( |
152 | torch::profiler::impl::cudaStubs()->enabled(), |
153 | "Can't use NVTX profiler - PyTorch was compiled without CUDA" ); |
154 | |
155 | c10::ThreadLocalDebugInfo::_push( |
156 | c10::DebugInfoKind::PROFILER_STATE, |
157 | std::make_shared<NVTXThreadLocalState>(config)); |
158 | |
159 | auto state_ptr = NVTXThreadLocalState::getTLS(); |
160 | TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set" ); |
161 | |
162 | auto handle = at::addThreadLocalCallback( |
163 | at::RecordFunctionCallback( |
164 | state_ptr->config().report_input_shapes |
165 | ? &enterNVTX</*report_input_shapes=*/true> |
166 | : &enterNVTX</*report_input_shapes=*/false>, |
167 | [](const at::RecordFunction& fn, at::ObserverContext* ctx) { |
168 | torch::profiler::impl::cudaStubs()->rangePop(); |
169 | updateOutputTensorTracker(fn); |
170 | }) |
171 | .needsInputs(config.report_input_shapes) |
172 | .needsOutputs(config.report_input_shapes) |
173 | .needsIds(true) |
174 | .scopes(scopes)); |
175 | state_ptr->setCallbackHandle(handle); |
176 | } |
177 | |
178 | } // namespace impl |
179 | } // namespace profiler |
180 | } // namespace torch |
181 | |