1#pragma once
2
3#include <cstdint>
4#include <forward_list>
5#include <iostream>
6#include <memory>
7#include <mutex>
8#include <sstream>
9#include <string>
10#include <tuple>
11#include <vector>
12
13#include <torch/csrc/Export.h>
14#include <torch/csrc/profiler/api.h>
15#include <torch/csrc/profiler/stubs/base.h>
16#include <torch/csrc/profiler/util.h>
17
18namespace torch {
19namespace autograd {
20
21struct Node;
22
23namespace profiler {
24
25enum class C10_API_ENUM EventKind : uint16_t {
26 Mark,
27 PushRange,
28 PopRange,
29 MemoryAlloc,
30};
31
32// To be deprecated, once we switch to Kineto profiling
33struct TORCH_API LegacyEvent {
34 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
35 LegacyEvent(
36 EventKind kind,
37 at::StringView name,
38 uint16_t thread_id,
39 bool record_cuda,
40 at::RecordFunctionHandle handle = 0,
41 std::vector<std::vector<int64_t>>&& shapes = {},
42 int node_id = -1,
43 bool is_async = false)
44 : name_(std::move(name)),
45 kind_(kind),
46 thread_id_(thread_id),
47 handle_(handle),
48 shapes_(shapes),
49 node_id_(node_id),
50 is_async_(is_async) {
51 record(record_cuda);
52 }
53
54 // Constructor to be used in conjunction with LegacyEvent::fromIValue.
55 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
56 LegacyEvent(
57 EventKind kind,
58 at::StringView name,
59 uint16_t thread_id,
60 at::RecordFunctionHandle handle,
61 std::vector<std::vector<int64_t>>&& shapes,
62 int node_id,
63 bool is_remote,
64 int64_t cpu_memory_usage,
65 int64_t cpu_ns,
66 bool cuda_recorded,
67 int64_t cuda_memory_usage = 0,
68 int device = -1,
69 double cuda_us = -1)
70 : cpu_ns_(cpu_ns),
71 name_(std::move(name)),
72 kind_(kind),
73 thread_id_(thread_id),
74 handle_(handle),
75 shapes_(shapes),
76 cpu_memory_usage_(cpu_memory_usage),
77 cuda_memory_usage_(cuda_memory_usage),
78 device_(device),
79 node_id_(node_id),
80 is_remote_(is_remote),
81 cuda_us_(cuda_us) {
82 // Sanity check values that were deserialized
83 TORCH_INTERNAL_ASSERT(cpu_ns_ > 0);
84 if (cuda_recorded) {
85 TORCH_INTERNAL_ASSERT(device_ >= 0);
86 TORCH_INTERNAL_ASSERT(cuda_us_ >= 0);
87 }
88 }
89
90 // Returns IValues corresponding to event structure, to be used for
91 // serialization.
92 at::IValue toIValue() const;
93
94 // Reconstructs an event from IValues given by toIValue.
95 static LegacyEvent fromIValue(const at::IValue& eventIValue);
96
97 void record(bool record_cuda);
98
99 std::string kindStr() const {
100 switch (kind_) {
101 case EventKind::Mark:
102 return "mark";
103 case EventKind::PushRange:
104 return "push";
105 case EventKind::PopRange:
106 return "pop";
107 case EventKind::MemoryAlloc:
108 return "memory_alloc";
109 }
110 throw std::runtime_error("unknown event kind");
111 }
112
113 EventKind kind() const {
114 return kind_;
115 }
116
117 const char* name() const {
118 return name_.str();
119 }
120
121 uint64_t threadId() const {
122 return thread_id_;
123 }
124
125 std::vector<std::vector<int64_t>> shapes() const {
126 return shapes_;
127 }
128
129 double cpuElapsedUs(const LegacyEvent& e) const {
130 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
131 return static_cast<double>(e.cpu_ns_ - cpu_ns_) / (1000.0);
132 }
133
134 void setCpuUs(int64_t cpu_us) {
135 cpu_ns_ = static_cast<double>(cpu_us) * 1000.0;
136 }
137
138 double cpuUs() const {
139 return static_cast<double>(cpu_ns_) / (1000.0);
140 }
141
142 double cudaElapsedUs(const LegacyEvent& e) const;
143
144 bool hasCuda() const {
145 return cuda_event != nullptr || (isRemote() && device_ != -1);
146 }
147
148 int device() const {
149 return device_;
150 }
151
152 void updateMemoryStats(int64_t alloc_size, c10::Device device) {
153 if (device.is_cuda() || device.type() == c10::DeviceType::HIP) {
154 cuda_memory_usage_ = alloc_size;
155 } else if (
156 device.is_cpu() || device.type() == c10::DeviceType::MKLDNN ||
157 device.type() == c10::DeviceType::IDEEP) {
158 cpu_memory_usage_ = alloc_size;
159 } else {
160 LOG(WARNING) << "Unsupported memory profiling device: " << device;
161 }
162 }
163
164 int64_t cpuMemoryUsage() const {
165 return cpu_memory_usage_;
166 }
167
168 int64_t cudaMemoryUsage() const {
169 return cuda_memory_usage_;
170 }
171
172 at::RecordFunctionHandle handle() const {
173 return handle_;
174 }
175
176 // Node ID corresponding to this event.
177 int nodeId() const {
178 return node_id_;
179 }
180
181 // Set Node ID on this event.
182 void setNodeId(int node_id) {
183 node_id_ = node_id;
184 }
185
186 void setName(at::StringView newName_) {
187 name_ = std::move(newName_);
188 }
189
190 bool isRemote() const {
191 return is_remote_;
192 }
193
194 void setCudaUs(int64_t cuda_us) {
195 cuda_us_ = cuda_us;
196 }
197
198 void setSequenceNr(int64_t sequence_nr) {
199 sequence_nr_ = sequence_nr;
200 }
201
202 int64_t sequenceNr() const {
203 return sequence_nr_;
204 }
205
206 void setCorrelationId(uint64_t correlation_id) {
207 correlation_id_ = correlation_id;
208 }
209
210 uint64_t correlationId() const {
211 return correlation_id_;
212 }
213
214 const std::vector<std::string>& stack() const {
215 return stack_;
216 }
217
218 void setStack(const std::vector<std::string>& stack) {
219 stack_ = stack;
220 }
221
222 uint64_t fwdThreadId() const {
223 return fwd_thread_id_;
224 }
225
226 void setFwdThreadId(uint64_t fwd_thread_id) {
227 fwd_thread_id_ = fwd_thread_id;
228 }
229
230 uint8_t scope() const {
231 return scope_;
232 }
233
234 void setScope(uint8_t scope) {
235 scope_ = scope;
236 }
237
238 const std::unordered_map<std::string, c10::IValue>& extraArgs() const {
239 return extra_args_;
240 }
241
242 void setExtraArgs(std::unordered_map<std::string, c10::IValue>&& save_args) {
243 extra_args_ = std::move(save_args);
244 }
245
246 uint64_t flops() {
247 return flops_;
248 }
249
250 bool isAsync() {
251 return is_async_;
252 }
253
254 void setFlops(uint64_t flops) {
255 flops_ = flops;
256 }
257
258 private:
259 // signed to allow for negative intervals, initialized for safety.
260 int64_t cpu_ns_ = 0;
261 at::StringView name_;
262 EventKind kind_;
263 uint64_t thread_id_;
264 uint64_t fwd_thread_id_;
265 at::RecordFunctionHandle handle_{0};
266 std::vector<std::vector<int64_t>> shapes_;
267 int64_t cpu_memory_usage_ = 0;
268 int64_t cuda_memory_usage_ = 0;
269 int device_ = -1;
270 torch::profiler::impl::ProfilerEventStub cuda_event = nullptr;
271 int node_id_ = 0;
272 bool is_remote_ = false;
273 int64_t cuda_us_ = -1;
274 int64_t sequence_nr_ = -1;
275 bool is_async_ = false;
276
277 std::vector<std::string> stack_;
278 uint8_t scope_;
279 uint64_t correlation_id_;
280 // Extra arguments for computing op flops
281 std::unordered_map<std::string, c10::IValue> extra_args_;
282 uint64_t flops_ = 0;
283};
284
285// a linked-list of fixed sized vectors, to avoid
286// a std::vector resize from taking a large amount of time inside
287// a profiling event
288struct RangeEventList {
289 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,modernize-use-equals-default)
290 RangeEventList() {
291 events_.reserve(kReservedCapacity);
292 }
293
294 template <typename... Args>
295 void record(Args&&... args) {
296 std::lock_guard<std::mutex> guard(mutex_);
297 events_.emplace_back(std::forward<Args>(args)...);
298 }
299
300 std::vector<LegacyEvent> consolidate() {
301 std::lock_guard<std::mutex> lock(mutex_);
302 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
303 std::vector<LegacyEvent> result;
304 result.insert(
305 result.begin(),
306 std::make_move_iterator(events_.begin()),
307 std::make_move_iterator(events_.end()));
308 events_.erase(events_.begin(), events_.end());
309 return result;
310 }
311
312 size_t size() {
313 std::lock_guard<std::mutex> lock(mutex_);
314 return events_.size();
315 }
316
317 private:
318 // This mutex is used to serialize access when different threads are writing
319 // to the same instance of RangeEventList.
320 std::mutex mutex_;
321 std::vector<LegacyEvent> events_;
322
323 static const size_t kReservedCapacity = 1024;
324};
325
326// A struct to control settings of disableProfiler options.
327struct TORCH_API ProfilerDisableOptions {
328 ProfilerDisableOptions() = default;
329 ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate)
330 : cleanupTLSState(shouldCleanupTLSState),
331 consolidate(shouldConsolidate) {}
332 // Whether we should clean up profiler states that are thread local, such as
333 // ThreadLocalDebugInfo and thread local RecordFunction callbacks.
334 bool cleanupTLSState = true;
335 // Whether we should consolidate all currently recorded profiled events. If
336 // false, will not consolidate and other threads can continue to write to the
337 // event lists.
338 bool consolidate = true;
339};
340
341// NOTE: profiler mode is thread local, with automatic propagation
342// across thread boundary (e.g. at::launch tasks)
343TORCH_API void enableProfilerLegacy(
344 const torch::profiler::impl::ProfilerConfig&);
345using thread_event_lists = std::vector<std::vector<LegacyEvent>>;
346TORCH_API thread_event_lists disableProfilerLegacy(
347 c10::optional<ProfilerDisableOptions> profilerDisableOptions =
348 c10::nullopt);
349
350// adds profiledEvents to the current thread local recorded events. Each event
351// will be marked with node ID given by fromNodeId.
352TORCH_API void addEventList(std::vector<LegacyEvent>&& profiledEvents);
353// Writes profiled events to a stream.
354TORCH_API void writeProfilerEventsToStream(
355 std::ostream& out,
356 const std::vector<LegacyEvent*>& events);
357
358// Usage:
359// {
360// RecordProfile guard("filename.trace");
361// // code you want to profile
362// }
363// Then open filename.trace in chrome://tracing
364struct TORCH_API RecordProfile {
365 RecordProfile(std::ostream& out);
366 RecordProfile(const std::string& filename);
367
368 ~RecordProfile();
369
370 private:
371 void init();
372 std::unique_ptr<std::ofstream> file_;
373 std::ostream& out_;
374 void processEvents(const std::vector<LegacyEvent*>& events);
375};
376
377// A guard that enables the legacy profiler, taking in an optional callback to
378// process the results Usage:
379// {
380// TLSLegacyProfilerGuard g([](thread_event_lists profilerResults) {
381// // process profilerResults
382// });
383// Code to profile
384// }
385struct TORCH_API TLSLegacyProfilerGuard {
386 explicit TLSLegacyProfilerGuard(
387 const torch::profiler::impl::ProfilerConfig& cfg,
388 c10::optional<std::function<void(const thread_event_lists&)>>
389 resultCallback = c10::nullopt,
390 c10::optional<ProfilerDisableOptions> profilerDisableOptions =
391 c10::nullopt)
392 : cb_(std::move(resultCallback)),
393 // NOLINTNEXTLINE(performance-move-const-arg)
394 profilerDisableOptions_(std::move(profilerDisableOptions)) {
395 enableProfilerLegacy(cfg);
396 }
397 ~TLSLegacyProfilerGuard() {
398 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
399 thread_event_lists event_lists =
400 disableProfilerLegacy(profilerDisableOptions_);
401 if (cb_) {
402 try {
403 (*cb_)(event_lists);
404 } catch (const std::exception& e) {
405 LOG(ERROR) << "Got error processing profiler events: " << e.what();
406 }
407 }
408 }
409
410 private:
411 c10::optional<std::function<void(const thread_event_lists&)>> cb_;
412 const c10::optional<ProfilerDisableOptions> profilerDisableOptions_;
413};
414
415} // namespace profiler
416} // namespace autograd
417} // namespace torch
418