1#include <torch/csrc/autograd/profiler_legacy.h>
2
3#include <torch/csrc/autograd/function.h>
4#include <torch/csrc/jit/frontend/tracer.h>
5#include <torch/csrc/jit/runtime/interpreter.h>
6#include <torch/csrc/jit/runtime/operator.h>
7
8#include <ATen/code_template.h>
9#include <ATen/core/op_registration/op_registration.h>
10#include <torch/library.h>
11
12#include <fstream>
13#include <list>
14#include <mutex>
15#include <sstream>
16#include <string>
17#include <vector>
18
19#include <ATen/record_function.h>
20#include <c10/core/Allocator.h>
21#include <c10/util/ThreadLocalDebugInfo.h>
22#include <c10/util/irange.h>
23
24#include <iostream>
25
26namespace torch {
27namespace autograd {
28namespace profiler {
29
30// We decompose the profiler logic into the following components:
31//
32// ThreadLocalDebugInfo:
33//
34// ThreadLocalDebugInfo is a thread local mapping from slots into
35// the debug information structs.
36// ThreadLocalDebugInfo is automatically propagated across thread
37// boundaries, including the cases of:
38// - launching async jobs with at::launch
39// - executing JIT continuations
40// - moving from the forward threads into autograd (backward) threads
41//
42// Entries in ThreadLocalDebugInfo are managed by DebugInfoGuard
43// which can be used to add or overwrite an entry in the thread local
44// mapping. A corresponding entry is removed when the guard is destroyed,
45// potentially revealing the previously set value for the same slot.
46//
47// For the async tasks, slots previuosly set in the main thread before
48// launching of an async task are shared and visible in the async task.
49//
50// On the other hand, any adding or overwriting of the mapping by the
51// async task is not visible to the main thread and any modification
52// (including removal of the entries) in the main thread is not visible
53// to the async task if it happends after launching the task.
54//
55// We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config,
56// as well as a list of events that happen during profiling.
57// An instance of ThreadLocalDebugInfo is created each time we enter
58// profiler (i.e. enter profiling context manager/call enableConfig) and
59// uniquely identifies a profiling run.
60//
61// We automatically propagate ThreadLocalDebugInfo into async tasks,
62// as well as across JIT continuations and autograd thread, so all
63// the operations that happen between profiling start and end
64// (not necessarily within the same thread) are recorded.
65// Unless the profiling slot is overwritten as in the case of nested
66// profiling ranges (in this case events for the subrange are handled
67// by the nested profiler)
68//
69// When we exit a profiling range (either by exiting profiling context
70// manager or by calling disableProfiler), we remove the previously set
71// profiling entry for the given thread local mapping, and consolidate
72// events in the profiling result
73//
74//
75// ThreadLocalState:
76//
77// ThreadLocalState takes a 'snapshot' of thread local variables
78// using provided getters. It is used together with ThreadLocalStateGuard
79// to transfer the snapshot across thread boundary and set the thread local
80// values as in the parent task.
81//
82// Profiler uses ThreadLocalState to propagate profiler's thread local state.
83// ThreadLocalState also automatically propagates profiler callbacks.
84//
85//
86// at::RecordFunction and observers
87//
88// Profiler uses observers mechanism to add a pair of thread local callbacks
89// that are executed on a number of predetermined ranges, including:
90// - c10/ATen ops
91// - TorchScript functions/methods
92// - user defined named ranges (see `record_function` python context manager)
93//
94// Profiler setups a pair of callbacks that record profiling events and save
95// them into the thread local profiler struct (ThreadLocalDebugInfo,
96// PROFILER_STATE slot)
97//
98//
99// Thus, the overall logic is:
100//
101// enableProfiler:
102// - checks that profiler is not enabled (otherwise throws)
103// - pushes new ThreadLocalDebugInfo (slot PROFILER_STATE) as the profiler
104// config for the current thread
105// - pushes profiling callbacks for the current thread
106//
107// disableProfiler:
108// - pops PROFILER_STATE slot from the current ThreadLocalDebugInfo and
109// consolidates events
110// - removes profiling callbacks
111//
112// ThreadLocalState:
113// - propagates ThreadLocalDebugInfo across threads
114// - propagates profiler callbacks across threads
115//
116// Profiler callbacks:
117// - get the current profiling state (PROFILER slot in ThreadLocalDebugInfo)
118// - save profiling events into the profiling state
119//
120
121namespace {
122using torch::profiler::impl::ActiveProfilerType;
123using torch::profiler::impl::ProfilerStateBase;
124
125struct ProfilerLegacyThreadLocalState : public ProfilerStateBase {
126 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
127 explicit ProfilerLegacyThreadLocalState(
128 const torch::profiler::impl::ProfilerConfig& config)
129 : ProfilerStateBase(config), remoteProfiledEvents_{c10::nullopt} {}
130 ~ProfilerLegacyThreadLocalState() override = default;
131
132 static ProfilerLegacyThreadLocalState* getTLS() {
133 auto tls = ProfilerStateBase::get(/*global=*/false);
134 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
135 tls == nullptr || tls->profilerType() == ActiveProfilerType::LEGACY);
136 return static_cast<ProfilerLegacyThreadLocalState*>(tls);
137 }
138
139 thread_event_lists consolidate();
140
141 void mark(std::string name, bool include_cuda = true);
142
143 void setOrAddRemoteProfiledEvents(
144 std::vector<LegacyEvent>&& remoteProfiledEvents);
145
146 void pushRange(
147 const at::RecordFunction& fn,
148 const bool record_cuda,
149 std::vector<std::vector<int64_t>>&& shapes = {});
150
151 void popRange(const at::RecordFunction& fn, const bool record_cuda);
152
153 void reportMemoryUsage(
154 void* /* unused */,
155 int64_t alloc_size,
156 size_t /* total_allocated, unused for legacy */,
157 size_t /* total_reserved, unused for legacy */,
158 c10::Device device) override;
159
160 ActiveProfilerType profilerType() override {
161 return ActiveProfilerType::LEGACY;
162 }
163
164 void leakHandle() {
165 handle_ = 0;
166 }
167
168 protected:
169 RangeEventList& getEventList(int64_t thread_id = -1);
170
171 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
172 std::mutex state_mutex_;
173 std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
174 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
175 event_lists_map_;
176
177 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
178 c10::optional<std::vector<std::vector<LegacyEvent>>> remoteProfiledEvents_;
179};
180
181thread_event_lists ProfilerLegacyThreadLocalState::consolidate() {
182 std::lock_guard<std::mutex> g(state_mutex_);
183 thread_event_lists result;
184 for (auto& kv : event_lists_map_) {
185 auto& list = kv.second;
186 result.emplace_back(list->consolidate());
187 }
188 // Consolidate remote events if applicable as well.
189 if (remoteProfiledEvents_) {
190 result.insert(
191 result.end(),
192 std::make_move_iterator(remoteProfiledEvents_->begin()),
193 std::make_move_iterator(remoteProfiledEvents_->end()));
194 }
195 return result;
196}
197
198void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) {
199 if (config_.disabled()) {
200 return;
201 }
202 if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
203 torch::profiler::impl::cudaStubs()->mark(name.c_str());
204 } else {
205 LegacyEvent evt(
206 EventKind::Mark,
207 at::StringView(std::move(name)),
208 at::RecordFunction::currentThreadId(),
209 include_cuda &&
210 config_.state == torch::profiler::impl::ProfilerState::CUDA);
211 evt.setNodeId(at::RecordFunction::getDefaultNodeId());
212 getEventList().record(std::move(evt));
213 }
214}
215
216void ProfilerLegacyThreadLocalState::setOrAddRemoteProfiledEvents(
217 std::vector<LegacyEvent>&& remoteProfiledEvents) {
218 // Lock to serialize access from multiple callback threads.
219 std::lock_guard<std::mutex> guard(state_mutex_);
220 if (remoteProfiledEvents_) {
221 (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
222 } else {
223 remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
224 }
225}
226
227void ProfilerLegacyThreadLocalState::pushRange(
228 const at::RecordFunction& fn,
229 const bool record_cuda,
230 std::vector<std::vector<int64_t>>&& shapes) {
231 if (config_.disabled()) {
232 return;
233 }
234 if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
235 torch::profiler::impl::cudaStubs()->rangePush(
236 torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes)
237 .c_str());
238 } else {
239 LegacyEvent evt(
240 EventKind::PushRange,
241 at::StringView(std::string(fn.name())),
242 at::RecordFunction::currentThreadId(),
243 record_cuda,
244 fn.handle(),
245 std::move(shapes),
246 at::RecordFunction::getDefaultNodeId(),
247 fn.isAsync());
248 evt.setSequenceNr(fn.seqNr());
249 evt.setFwdThreadId(fn.forwardThreadId());
250 evt.setScope((uint8_t)fn.scope());
251 if (config_.with_flops) {
252 evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn));
253 evt.setFlops(torch::profiler::impl::computeFlops(
254 std::string(fn.name()), evt.extraArgs()));
255 }
256
257// TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon.
258#if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
259 // backward nodes source range corresponds to the forward node
260 // TODO: consider using C++ stack trace
261 if (config_.with_stack &&
262 fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
263 auto cs =
264 torch::profiler::impl::prepareCallstack(jit::currentCallstack());
265 if (cs.empty()) {
266 cs = torch::profiler::impl::prepareCallstack(
267 jit::tracer::pythonCallstack());
268 }
269 evt.setStack(callstackStr(cs));
270 }
271#endif
272 getEventList().record(std::move(evt));
273 }
274}
275
276void ProfilerLegacyThreadLocalState::popRange(
277 const at::RecordFunction& fn,
278 const bool record_cuda) {
279 if (config_.disabled()) {
280 return;
281 }
282 if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
283 torch::profiler::impl::cudaStubs()->rangePop();
284 } else {
285 // In some cases RecordFunction (and popRange) may be
286 // called on a different thread than pushRange
287 // As a convention, we put the async pop on the original
288 // thread and save current thread id in pop event
289 LegacyEvent evt(
290 EventKind::PopRange,
291 at::StringView(""),
292 at::RecordFunction::currentThreadId(),
293 record_cuda,
294 fn.handle());
295 evt.setNodeId(at::RecordFunction::getDefaultNodeId());
296 getEventList(fn.threadId()).record(std::move(evt));
297 }
298}
299
300void ProfilerLegacyThreadLocalState::reportMemoryUsage(
301 void* /* unused */,
302 int64_t alloc_size,
303 size_t /* total_allocated, unused for legacy */,
304 size_t /* total_reserved, unused for legacy */,
305 c10::Device device) {
306 if (config_.profile_memory && !config_.disabled()) {
307 uint64_t thread_id = at::RecordFunction::currentThreadId();
308 LegacyEvent evt(
309 EventKind::MemoryAlloc,
310 at::StringView(""),
311 thread_id,
312 config_.state == torch::profiler::impl::ProfilerState::CUDA);
313 evt.updateMemoryStats(alloc_size, device);
314 getEventList(thread_id).record(std::move(evt));
315 }
316}
317
318RangeEventList& ProfilerLegacyThreadLocalState::getEventList(
319 int64_t thread_id) {
320 if (thread_id < 0) {
321 thread_id = at::RecordFunction::currentThreadId();
322 }
323 RangeEventList* list_ptr = nullptr;
324 std::lock_guard<std::mutex> guard(state_mutex_);
325 auto it = event_lists_map_.find(thread_id);
326 if (it != event_lists_map_.end()) {
327 list_ptr = it->second.get();
328 } else {
329 auto event_list = std::make_shared<RangeEventList>();
330 event_lists_map_[thread_id] = event_list;
331 list_ptr = event_list.get();
332 }
333 return *list_ptr;
334}
335
336enum EventIValueIdx {
337 KIND = 0,
338 NAME,
339 THREAD_ID,
340 HANDLE,
341 NODE_ID,
342 CPU_MEM_USAGE,
343 CPU_NS,
344 CUDA_RECORDED,
345 CUDA_MEM_USAGE,
346 CUDA_DEVICE,
347 CUDA_US,
348 SHAPES,
349 NUM_EVENT_IVALUE_IDX // must be last in list
350};
351
352const std::unordered_set<std::string> disable_cuda_profiling = {
353 "aten::view",
354 "aten::t",
355 "aten::transpose",
356 "aten::stride",
357 "aten::empty",
358 "aten::empty_like",
359 "aten::empty_strided",
360 "aten::as_strided",
361 "aten::expand",
362 "aten::resize_",
363 "aten::squeeze",
364 "aten::unsqueeze",
365 "aten::slice",
366 "aten::_unsafe_view",
367 "aten::size"};
368
369void pushProfilingCallbacksLegacy() {
370 auto registration_state_ptr = ProfilerLegacyThreadLocalState::getTLS();
371 TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set");
372 auto handle = at::addThreadLocalCallback(
373 at::RecordFunctionCallback(
374 [](const at::RecordFunction& fn)
375 -> std::unique_ptr<at::ObserverContext> {
376 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
377 if (!state_ptr || state_ptr->config().disabled()) {
378 return nullptr;
379 }
380 bool record_cuda = state_ptr->config().state ==
381 torch::profiler::impl::ProfilerState::CUDA;
382 if (record_cuda &&
383 disable_cuda_profiling.find(fn.name()) !=
384 disable_cuda_profiling.end()) {
385 record_cuda = false;
386 }
387
388 if (state_ptr->config().report_input_shapes) {
389 auto sizes = torch::profiler::impl::inputSizes(fn);
390 state_ptr->pushRange(fn, record_cuda, std::move(sizes));
391 } else {
392 state_ptr->pushRange(fn, record_cuda);
393 }
394
395 return nullptr;
396 },
397 [](const at::RecordFunction& fn, at::ObserverContext*) {
398 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
399 if (!state_ptr || state_ptr->config().disabled()) {
400 return;
401 }
402 bool record_cuda = state_ptr->config().state ==
403 torch::profiler::impl::ProfilerState::CUDA;
404 if (record_cuda &&
405 disable_cuda_profiling.find(fn.name()) !=
406 disable_cuda_profiling.end()) {
407 record_cuda = false;
408 }
409 state_ptr->popRange(fn, record_cuda);
410 })
411 .needsInputs(registration_state_ptr->config().report_input_shapes)
412 .needsIds(true));
413 registration_state_ptr->setCallbackHandle(handle);
414}
415
416} // namespace
417
418void enableProfilerLegacy(
419 const torch::profiler::impl::ProfilerConfig& new_config) {
420 TORCH_CHECK(
421 new_config.state != torch::profiler::impl::ProfilerState::NVTX ||
422 torch::profiler::impl::cudaStubs()->enabled(),
423 "Can't use NVTX profiler - PyTorch was compiled without CUDA");
424
425 TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO);
426
427 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
428 TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
429 auto state = std::make_shared<ProfilerLegacyThreadLocalState>(new_config);
430 c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
431
432 pushProfilingCallbacksLegacy();
433
434 state->mark("__start_profile", false);
435}
436
437thread_event_lists disableProfilerLegacy(
438 c10::optional<ProfilerDisableOptions> profilerDisableOptions) {
439 auto cleanupTLSState =
440 profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true;
441 auto consolidate =
442 profilerDisableOptions ? profilerDisableOptions->consolidate : true;
443 // all the DebugInfoBase objects are scope based and supposed to use
444 // DebugInfoGuard
445 std::shared_ptr<c10::DebugInfoBase> state;
446 if (cleanupTLSState) {
447 state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE);
448 } else {
449 state =
450 c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE);
451 }
452
453 auto state_ptr = static_cast<ProfilerLegacyThreadLocalState*>(state.get());
454 TORCH_CHECK(
455 state_ptr && !state_ptr->config().disabled(),
456 "Can't disable profiler when it's not running");
457
458 cleanupTLSState ? state_ptr->removeCallback() : state_ptr->leakHandle();
459 if (!consolidate ||
460 state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) {
461 return thread_event_lists();
462 }
463
464 state_ptr->mark("__stop_profile", false);
465 // Note that this will erase the underlying events.
466 return state_ptr->consolidate();
467}
468
469void addEventList(std::vector<LegacyEvent>&& profiledEvents) {
470 auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
471 TORCH_CHECK(state_ptr, "Profiler must be enabled.");
472 state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents));
473}
474
475void LegacyEvent::record(bool record_cuda) {
476 if (record_cuda) {
477 torch::profiler::impl::cudaStubs()->record(&device_, &cuda_event, &cpu_ns_);
478 return;
479 }
480 cpu_ns_ = torch::profiler::impl::getTime();
481}
482
483/* static */ LegacyEvent LegacyEvent::fromIValue(
484 const at::IValue& eventIValue) {
485 TORCH_INTERNAL_ASSERT(
486 eventIValue.isList(),
487 "Expected IValue to contain type c10::impl::GenericList");
488 auto ivalues = eventIValue.toList();
489 TORCH_INTERNAL_ASSERT(
490 ivalues.size() >= NUM_EVENT_IVALUE_IDX,
491 "Expected at least ",
492 NUM_EVENT_IVALUE_IDX,
493 " elements to reconstruct LegacyEvent.");
494
495 // Reconstruct input shapes from ivalues.
496 auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES);
497 TORCH_INTERNAL_ASSERT(
498 shapeListIValue.isList(),
499 "Expected profiler shapes IValue to contain type c10::impl::GenericList.");
500
501 auto shapeList = shapeListIValue.toList();
502 std::vector<std::vector<int64_t>> shapes;
503 shapes.reserve(shapeList.size());
504 for (const auto i : c10::irange(shapeList.size())) {
505 std::vector<int64_t> s;
506 auto shapeIValue = shapeList.get(i);
507 TORCH_INTERNAL_ASSERT(
508 shapeIValue.isList(),
509 "Expected each profiler shape element to contain shapes of type c10::impl::GenericList.")
510 auto curShapesList = shapeIValue.toList();
511 s.reserve(curShapesList.size());
512 for (const auto j : c10::irange(curShapesList.size())) {
513 s.emplace_back(curShapesList.get(j).toInt());
514 }
515 shapes.emplace_back(s);
516 }
517
518 LegacyEvent evt(
519 static_cast<EventKind>(
520 ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind
521 at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name
522 ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id
523 static_cast<at::RecordFunctionHandle>(
524 ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle
525 std::move(shapes), // input shapes
526 ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id
527 true, // is remote
528 ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage
529 ivalues.get(EventIValueIdx::CPU_NS).toInt(), // cpu_ns
530 ivalues.get(EventIValueIdx::CUDA_RECORDED).toBool(), // was cuda recorded
531 ivalues.get(EventIValueIdx::CUDA_MEM_USAGE).toInt(), // cuda memory usage
532 ivalues.get(EventIValueIdx::CUDA_DEVICE).toInt(), // device
533 ivalues.get(EventIValueIdx::CUDA_US).toInt() // cuda_us
534 );
535 return evt;
536}
537
538at::IValue LegacyEvent::toIValue() const {
539 c10::impl::GenericList eventIValueList(at::AnyType::get());
540 eventIValueList.reserve(NUM_EVENT_IVALUE_IDX);
541 eventIValueList.emplace_back(static_cast<int64_t>(kind_));
542 eventIValueList.emplace_back(std::string(name_.str()));
543 eventIValueList.emplace_back(static_cast<int64_t>(thread_id_));
544 eventIValueList.emplace_back(static_cast<double>(handle_));
545 eventIValueList.emplace_back(node_id_);
546 eventIValueList.emplace_back(cpu_memory_usage_);
547 eventIValueList.emplace_back(cpu_ns_);
548 // CUDA event information
549 bool cuda_profiling_enabled = hasCuda();
550 eventIValueList.emplace_back(cuda_profiling_enabled);
551 eventIValueList.emplace_back(static_cast<int64_t>(cuda_memory_usage_));
552 eventIValueList.emplace_back(device_);
553 eventIValueList.emplace_back(cuda_us_);
554 // Shapes
555 c10::impl::GenericList shapesList =
556 c10::impl::GenericList(at::ListType::create(at::IntType::get()));
557 shapesList.reserve(shapes_.size());
558 for (const auto& shape : shapes_) {
559 c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get());
560 s.reserve(shape.size());
561 for (const auto& k : shape) {
562 s.emplace_back(k);
563 }
564 shapesList.emplace_back(s);
565 }
566 eventIValueList.emplace_back(shapesList);
567 return at::IValue(eventIValueList);
568}
569
570double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const {
571 TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA");
572 TORCH_CHECK(
573 e.device() == device(),
574 c10::str(
575 "Events are not on the same device: ", e.device(), " vs ", device()));
576 if (isRemote() && e.isRemote()) {
577 // validate that cuda_us_ has been set properly.
578 TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0);
579 return static_cast<double>(e.cuda_us_ - cuda_us_);
580 }
581 return torch::profiler::impl::cudaStubs()->elapsed(
582 &cuda_event, &e.cuda_event);
583}
584
585static const at::jit::CodeTemplate event_template(R"(
586{
587 "name": "${name}",
588 "ph": "X",
589 "ts": ${ts},
590 "dur": ${dur},
591 "tid": ${tid},
592 "pid": "CPU Functions",
593 "args": {}
594})");
595
596void writeProfilerEventsToStream(
597 std::ostream& out,
598 const std::vector<LegacyEvent*>& events) {
599 TORCH_CHECK(out, "Could not open file");
600 LegacyEvent* profiler_start = nullptr;
601 for (LegacyEvent* e : events) {
602 if (0 == strcmp(e->name(), "__start_profile")) {
603 profiler_start = e;
604 break;
605 }
606 }
607 TORCH_CHECK(profiler_start, "Could not find __start_profile mark");
608
609 struct PairHash {
610 size_t operator()(
611 std::pair<at::RecordFunctionHandle, int> p) const noexcept {
612 return std::hash<at::RecordFunctionHandle>()(p.first) ^
613 std::hash<int64_t>()(p.second);
614 }
615 };
616 std::unordered_map<
617 std::pair<at::RecordFunctionHandle, int64_t>,
618 LegacyEvent*,
619 PairHash>
620 events_map;
621 out << "[\n";
622 bool first = true;
623 for (LegacyEvent* evt : events) {
624 if (evt->kindStr() == "push") {
625 events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt;
626 } else if (evt->kindStr() == "pop") {
627 if (!first) {
628 out << ",\n";
629 }
630 first = false;
631 auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId()));
632 TORCH_CHECK(it != events_map.end(), "Unmatched pop event");
633 LegacyEvent* evt_start = it->second;
634 events_map.erase(it);
635
636 at::jit::TemplateEnv env;
637 env.s("name", evt_start->name());
638 env.d("ts", profiler_start->cpuElapsedUs(*evt_start));
639 env.d("dur", evt_start->cpuElapsedUs(*evt));
640 env.d("tid", evt_start->threadId());
641 out << event_template.format(env);
642 }
643 }
644 out << "]\n";
645}
646
647RecordProfile::RecordProfile(std::ostream& out) : out_(out) {
648 init();
649}
650
651RecordProfile::RecordProfile(const std::string& filename)
652 : file_(new std::ofstream(filename)), out_(*file_) {
653 init();
654}
655
656void RecordProfile::init() {
657 enableProfilerLegacy(torch::profiler::impl::ProfilerConfig(
658 torch::profiler::impl::ProfilerState::CPU));
659}
660
661RecordProfile::~RecordProfile() {
662 try {
663 thread_event_lists event_lists = disableProfilerLegacy();
664 std::vector<LegacyEvent*> events;
665 for (auto& l : event_lists) {
666 for (auto& e : l) {
667 events.push_back(&e);
668 }
669 }
670 processEvents(events);
671 } catch (const std::exception& e) {
672 LOG(ERROR) << e.what() << std::endl;
673 } catch (...) {
674 LOG(ERROR) << "Unknown error" << std::endl;
675 }
676}
677
678void RecordProfile::processEvents(const std::vector<LegacyEvent*>& events) {
679 writeProfilerEventsToStream(out_, events);
680}
681
682} // namespace profiler
683} // namespace autograd
684} // namespace torch
685