1 | #include <torch/csrc/autograd/profiler_legacy.h> |
2 | |
3 | #include <torch/csrc/autograd/function.h> |
4 | #include <torch/csrc/jit/frontend/tracer.h> |
5 | #include <torch/csrc/jit/runtime/interpreter.h> |
6 | #include <torch/csrc/jit/runtime/operator.h> |
7 | |
8 | #include <ATen/code_template.h> |
9 | #include <ATen/core/op_registration/op_registration.h> |
10 | #include <torch/library.h> |
11 | |
12 | #include <fstream> |
13 | #include <list> |
14 | #include <mutex> |
15 | #include <sstream> |
16 | #include <string> |
17 | #include <vector> |
18 | |
19 | #include <ATen/record_function.h> |
20 | #include <c10/core/Allocator.h> |
21 | #include <c10/util/ThreadLocalDebugInfo.h> |
22 | #include <c10/util/irange.h> |
23 | |
24 | #include <iostream> |
25 | |
26 | namespace torch { |
27 | namespace autograd { |
28 | namespace profiler { |
29 | |
30 | // We decompose the profiler logic into the following components: |
31 | // |
32 | // ThreadLocalDebugInfo: |
33 | // |
34 | // ThreadLocalDebugInfo is a thread local mapping from slots into |
35 | // the debug information structs. |
36 | // ThreadLocalDebugInfo is automatically propagated across thread |
37 | // boundaries, including the cases of: |
38 | // - launching async jobs with at::launch |
39 | // - executing JIT continuations |
40 | // - moving from the forward threads into autograd (backward) threads |
41 | // |
42 | // Entries in ThreadLocalDebugInfo are managed by DebugInfoGuard |
43 | // which can be used to add or overwrite an entry in the thread local |
44 | // mapping. A corresponding entry is removed when the guard is destroyed, |
45 | // potentially revealing the previously set value for the same slot. |
46 | // |
47 | // For the async tasks, slots previuosly set in the main thread before |
48 | // launching of an async task are shared and visible in the async task. |
49 | // |
50 | // On the other hand, any adding or overwriting of the mapping by the |
51 | // async task is not visible to the main thread and any modification |
52 | // (including removal of the entries) in the main thread is not visible |
53 | // to the async task if it happends after launching the task. |
54 | // |
55 | // We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config, |
56 | // as well as a list of events that happen during profiling. |
57 | // An instance of ThreadLocalDebugInfo is created each time we enter |
58 | // profiler (i.e. enter profiling context manager/call enableConfig) and |
59 | // uniquely identifies a profiling run. |
60 | // |
61 | // We automatically propagate ThreadLocalDebugInfo into async tasks, |
62 | // as well as across JIT continuations and autograd thread, so all |
63 | // the operations that happen between profiling start and end |
64 | // (not necessarily within the same thread) are recorded. |
65 | // Unless the profiling slot is overwritten as in the case of nested |
66 | // profiling ranges (in this case events for the subrange are handled |
67 | // by the nested profiler) |
68 | // |
69 | // When we exit a profiling range (either by exiting profiling context |
70 | // manager or by calling disableProfiler), we remove the previously set |
71 | // profiling entry for the given thread local mapping, and consolidate |
72 | // events in the profiling result |
73 | // |
74 | // |
75 | // ThreadLocalState: |
76 | // |
77 | // ThreadLocalState takes a 'snapshot' of thread local variables |
78 | // using provided getters. It is used together with ThreadLocalStateGuard |
79 | // to transfer the snapshot across thread boundary and set the thread local |
80 | // values as in the parent task. |
81 | // |
82 | // Profiler uses ThreadLocalState to propagate profiler's thread local state. |
83 | // ThreadLocalState also automatically propagates profiler callbacks. |
84 | // |
85 | // |
86 | // at::RecordFunction and observers |
87 | // |
88 | // Profiler uses observers mechanism to add a pair of thread local callbacks |
89 | // that are executed on a number of predetermined ranges, including: |
90 | // - c10/ATen ops |
91 | // - TorchScript functions/methods |
92 | // - user defined named ranges (see `record_function` python context manager) |
93 | // |
94 | // Profiler setups a pair of callbacks that record profiling events and save |
95 | // them into the thread local profiler struct (ThreadLocalDebugInfo, |
96 | // PROFILER_STATE slot) |
97 | // |
98 | // |
99 | // Thus, the overall logic is: |
100 | // |
101 | // enableProfiler: |
102 | // - checks that profiler is not enabled (otherwise throws) |
103 | // - pushes new ThreadLocalDebugInfo (slot PROFILER_STATE) as the profiler |
104 | // config for the current thread |
105 | // - pushes profiling callbacks for the current thread |
106 | // |
107 | // disableProfiler: |
108 | // - pops PROFILER_STATE slot from the current ThreadLocalDebugInfo and |
109 | // consolidates events |
110 | // - removes profiling callbacks |
111 | // |
112 | // ThreadLocalState: |
113 | // - propagates ThreadLocalDebugInfo across threads |
114 | // - propagates profiler callbacks across threads |
115 | // |
116 | // Profiler callbacks: |
117 | // - get the current profiling state (PROFILER slot in ThreadLocalDebugInfo) |
118 | // - save profiling events into the profiling state |
119 | // |
120 | |
121 | namespace { |
122 | using torch::profiler::impl::ActiveProfilerType; |
123 | using torch::profiler::impl::ProfilerStateBase; |
124 | |
125 | struct ProfilerLegacyThreadLocalState : public ProfilerStateBase { |
126 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
127 | explicit ProfilerLegacyThreadLocalState( |
128 | const torch::profiler::impl::ProfilerConfig& config) |
129 | : ProfilerStateBase(config), remoteProfiledEvents_{c10::nullopt} {} |
130 | ~ProfilerLegacyThreadLocalState() override = default; |
131 | |
132 | static ProfilerLegacyThreadLocalState* getTLS() { |
133 | auto tls = ProfilerStateBase::get(/*global=*/false); |
134 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
135 | tls == nullptr || tls->profilerType() == ActiveProfilerType::LEGACY); |
136 | return static_cast<ProfilerLegacyThreadLocalState*>(tls); |
137 | } |
138 | |
139 | thread_event_lists consolidate(); |
140 | |
141 | void mark(std::string name, bool include_cuda = true); |
142 | |
143 | void setOrAddRemoteProfiledEvents( |
144 | std::vector<LegacyEvent>&& remoteProfiledEvents); |
145 | |
146 | void pushRange( |
147 | const at::RecordFunction& fn, |
148 | const bool record_cuda, |
149 | std::vector<std::vector<int64_t>>&& shapes = {}); |
150 | |
151 | void popRange(const at::RecordFunction& fn, const bool record_cuda); |
152 | |
153 | void reportMemoryUsage( |
154 | void* /* unused */, |
155 | int64_t alloc_size, |
156 | size_t /* total_allocated, unused for legacy */, |
157 | size_t /* total_reserved, unused for legacy */, |
158 | c10::Device device) override; |
159 | |
160 | ActiveProfilerType profilerType() override { |
161 | return ActiveProfilerType::LEGACY; |
162 | } |
163 | |
164 | void leakHandle() { |
165 | handle_ = 0; |
166 | } |
167 | |
168 | protected: |
169 | RangeEventList& getEventList(int64_t thread_id = -1); |
170 | |
171 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
172 | std::mutex state_mutex_; |
173 | std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>> |
174 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
175 | event_lists_map_; |
176 | |
177 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
178 | c10::optional<std::vector<std::vector<LegacyEvent>>> remoteProfiledEvents_; |
179 | }; |
180 | |
181 | thread_event_lists ProfilerLegacyThreadLocalState::consolidate() { |
182 | std::lock_guard<std::mutex> g(state_mutex_); |
183 | thread_event_lists result; |
184 | for (auto& kv : event_lists_map_) { |
185 | auto& list = kv.second; |
186 | result.emplace_back(list->consolidate()); |
187 | } |
188 | // Consolidate remote events if applicable as well. |
189 | if (remoteProfiledEvents_) { |
190 | result.insert( |
191 | result.end(), |
192 | std::make_move_iterator(remoteProfiledEvents_->begin()), |
193 | std::make_move_iterator(remoteProfiledEvents_->end())); |
194 | } |
195 | return result; |
196 | } |
197 | |
198 | void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) { |
199 | if (config_.disabled()) { |
200 | return; |
201 | } |
202 | if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { |
203 | torch::profiler::impl::cudaStubs()->mark(name.c_str()); |
204 | } else { |
205 | LegacyEvent evt( |
206 | EventKind::Mark, |
207 | at::StringView(std::move(name)), |
208 | at::RecordFunction::currentThreadId(), |
209 | include_cuda && |
210 | config_.state == torch::profiler::impl::ProfilerState::CUDA); |
211 | evt.setNodeId(at::RecordFunction::getDefaultNodeId()); |
212 | getEventList().record(std::move(evt)); |
213 | } |
214 | } |
215 | |
216 | void ProfilerLegacyThreadLocalState::setOrAddRemoteProfiledEvents( |
217 | std::vector<LegacyEvent>&& remoteProfiledEvents) { |
218 | // Lock to serialize access from multiple callback threads. |
219 | std::lock_guard<std::mutex> guard(state_mutex_); |
220 | if (remoteProfiledEvents_) { |
221 | (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents); |
222 | } else { |
223 | remoteProfiledEvents_ = {std::move(remoteProfiledEvents)}; |
224 | } |
225 | } |
226 | |
227 | void ProfilerLegacyThreadLocalState::pushRange( |
228 | const at::RecordFunction& fn, |
229 | const bool record_cuda, |
230 | std::vector<std::vector<int64_t>>&& shapes) { |
231 | if (config_.disabled()) { |
232 | return; |
233 | } |
234 | if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { |
235 | torch::profiler::impl::cudaStubs()->rangePush( |
236 | torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes) |
237 | .c_str()); |
238 | } else { |
239 | LegacyEvent evt( |
240 | EventKind::PushRange, |
241 | at::StringView(std::string(fn.name())), |
242 | at::RecordFunction::currentThreadId(), |
243 | record_cuda, |
244 | fn.handle(), |
245 | std::move(shapes), |
246 | at::RecordFunction::getDefaultNodeId(), |
247 | fn.isAsync()); |
248 | evt.setSequenceNr(fn.seqNr()); |
249 | evt.setFwdThreadId(fn.forwardThreadId()); |
250 | evt.setScope((uint8_t)fn.scope()); |
251 | if (config_.with_flops) { |
252 | evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn)); |
253 | evt.setFlops(torch::profiler::impl::computeFlops( |
254 | std::string(fn.name()), evt.extraArgs())); |
255 | } |
256 | |
257 | // TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon. |
258 | #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE |
259 | // backward nodes source range corresponds to the forward node |
260 | // TODO: consider using C++ stack trace |
261 | if (config_.with_stack && |
262 | fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { |
263 | auto cs = |
264 | torch::profiler::impl::prepareCallstack(jit::currentCallstack()); |
265 | if (cs.empty()) { |
266 | cs = torch::profiler::impl::prepareCallstack( |
267 | jit::tracer::pythonCallstack()); |
268 | } |
269 | evt.setStack(callstackStr(cs)); |
270 | } |
271 | #endif |
272 | getEventList().record(std::move(evt)); |
273 | } |
274 | } |
275 | |
276 | void ProfilerLegacyThreadLocalState::popRange( |
277 | const at::RecordFunction& fn, |
278 | const bool record_cuda) { |
279 | if (config_.disabled()) { |
280 | return; |
281 | } |
282 | if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { |
283 | torch::profiler::impl::cudaStubs()->rangePop(); |
284 | } else { |
285 | // In some cases RecordFunction (and popRange) may be |
286 | // called on a different thread than pushRange |
287 | // As a convention, we put the async pop on the original |
288 | // thread and save current thread id in pop event |
289 | LegacyEvent evt( |
290 | EventKind::PopRange, |
291 | at::StringView("" ), |
292 | at::RecordFunction::currentThreadId(), |
293 | record_cuda, |
294 | fn.handle()); |
295 | evt.setNodeId(at::RecordFunction::getDefaultNodeId()); |
296 | getEventList(fn.threadId()).record(std::move(evt)); |
297 | } |
298 | } |
299 | |
300 | void ProfilerLegacyThreadLocalState::reportMemoryUsage( |
301 | void* /* unused */, |
302 | int64_t alloc_size, |
303 | size_t /* total_allocated, unused for legacy */, |
304 | size_t /* total_reserved, unused for legacy */, |
305 | c10::Device device) { |
306 | if (config_.profile_memory && !config_.disabled()) { |
307 | uint64_t thread_id = at::RecordFunction::currentThreadId(); |
308 | LegacyEvent evt( |
309 | EventKind::MemoryAlloc, |
310 | at::StringView("" ), |
311 | thread_id, |
312 | config_.state == torch::profiler::impl::ProfilerState::CUDA); |
313 | evt.updateMemoryStats(alloc_size, device); |
314 | getEventList(thread_id).record(std::move(evt)); |
315 | } |
316 | } |
317 | |
318 | RangeEventList& ProfilerLegacyThreadLocalState::getEventList( |
319 | int64_t thread_id) { |
320 | if (thread_id < 0) { |
321 | thread_id = at::RecordFunction::currentThreadId(); |
322 | } |
323 | RangeEventList* list_ptr = nullptr; |
324 | std::lock_guard<std::mutex> guard(state_mutex_); |
325 | auto it = event_lists_map_.find(thread_id); |
326 | if (it != event_lists_map_.end()) { |
327 | list_ptr = it->second.get(); |
328 | } else { |
329 | auto event_list = std::make_shared<RangeEventList>(); |
330 | event_lists_map_[thread_id] = event_list; |
331 | list_ptr = event_list.get(); |
332 | } |
333 | return *list_ptr; |
334 | } |
335 | |
336 | enum EventIValueIdx { |
337 | KIND = 0, |
338 | NAME, |
339 | THREAD_ID, |
340 | HANDLE, |
341 | NODE_ID, |
342 | CPU_MEM_USAGE, |
343 | CPU_NS, |
344 | CUDA_RECORDED, |
345 | CUDA_MEM_USAGE, |
346 | CUDA_DEVICE, |
347 | CUDA_US, |
348 | SHAPES, |
349 | NUM_EVENT_IVALUE_IDX // must be last in list |
350 | }; |
351 | |
352 | const std::unordered_set<std::string> disable_cuda_profiling = { |
353 | "aten::view" , |
354 | "aten::t" , |
355 | "aten::transpose" , |
356 | "aten::stride" , |
357 | "aten::empty" , |
358 | "aten::empty_like" , |
359 | "aten::empty_strided" , |
360 | "aten::as_strided" , |
361 | "aten::expand" , |
362 | "aten::resize_" , |
363 | "aten::squeeze" , |
364 | "aten::unsqueeze" , |
365 | "aten::slice" , |
366 | "aten::_unsafe_view" , |
367 | "aten::size" }; |
368 | |
369 | void pushProfilingCallbacksLegacy() { |
370 | auto registration_state_ptr = ProfilerLegacyThreadLocalState::getTLS(); |
371 | TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set" ); |
372 | auto handle = at::addThreadLocalCallback( |
373 | at::RecordFunctionCallback( |
374 | [](const at::RecordFunction& fn) |
375 | -> std::unique_ptr<at::ObserverContext> { |
376 | auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); |
377 | if (!state_ptr || state_ptr->config().disabled()) { |
378 | return nullptr; |
379 | } |
380 | bool record_cuda = state_ptr->config().state == |
381 | torch::profiler::impl::ProfilerState::CUDA; |
382 | if (record_cuda && |
383 | disable_cuda_profiling.find(fn.name()) != |
384 | disable_cuda_profiling.end()) { |
385 | record_cuda = false; |
386 | } |
387 | |
388 | if (state_ptr->config().report_input_shapes) { |
389 | auto sizes = torch::profiler::impl::inputSizes(fn); |
390 | state_ptr->pushRange(fn, record_cuda, std::move(sizes)); |
391 | } else { |
392 | state_ptr->pushRange(fn, record_cuda); |
393 | } |
394 | |
395 | return nullptr; |
396 | }, |
397 | [](const at::RecordFunction& fn, at::ObserverContext*) { |
398 | auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); |
399 | if (!state_ptr || state_ptr->config().disabled()) { |
400 | return; |
401 | } |
402 | bool record_cuda = state_ptr->config().state == |
403 | torch::profiler::impl::ProfilerState::CUDA; |
404 | if (record_cuda && |
405 | disable_cuda_profiling.find(fn.name()) != |
406 | disable_cuda_profiling.end()) { |
407 | record_cuda = false; |
408 | } |
409 | state_ptr->popRange(fn, record_cuda); |
410 | }) |
411 | .needsInputs(registration_state_ptr->config().report_input_shapes) |
412 | .needsIds(true)); |
413 | registration_state_ptr->setCallbackHandle(handle); |
414 | } |
415 | |
416 | } // namespace |
417 | |
418 | void enableProfilerLegacy( |
419 | const torch::profiler::impl::ProfilerConfig& new_config) { |
420 | TORCH_CHECK( |
421 | new_config.state != torch::profiler::impl::ProfilerState::NVTX || |
422 | torch::profiler::impl::cudaStubs()->enabled(), |
423 | "Can't use NVTX profiler - PyTorch was compiled without CUDA" ); |
424 | |
425 | TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO); |
426 | |
427 | auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); |
428 | TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread" ); |
429 | auto state = std::make_shared<ProfilerLegacyThreadLocalState>(new_config); |
430 | c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); |
431 | |
432 | pushProfilingCallbacksLegacy(); |
433 | |
434 | state->mark("__start_profile" , false); |
435 | } |
436 | |
437 | thread_event_lists disableProfilerLegacy( |
438 | c10::optional<ProfilerDisableOptions> profilerDisableOptions) { |
439 | auto cleanupTLSState = |
440 | profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true; |
441 | auto consolidate = |
442 | profilerDisableOptions ? profilerDisableOptions->consolidate : true; |
443 | // all the DebugInfoBase objects are scope based and supposed to use |
444 | // DebugInfoGuard |
445 | std::shared_ptr<c10::DebugInfoBase> state; |
446 | if (cleanupTLSState) { |
447 | state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); |
448 | } else { |
449 | state = |
450 | c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE); |
451 | } |
452 | |
453 | auto state_ptr = static_cast<ProfilerLegacyThreadLocalState*>(state.get()); |
454 | TORCH_CHECK( |
455 | state_ptr && !state_ptr->config().disabled(), |
456 | "Can't disable profiler when it's not running" ); |
457 | |
458 | cleanupTLSState ? state_ptr->removeCallback() : state_ptr->leakHandle(); |
459 | if (!consolidate || |
460 | state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) { |
461 | return thread_event_lists(); |
462 | } |
463 | |
464 | state_ptr->mark("__stop_profile" , false); |
465 | // Note that this will erase the underlying events. |
466 | return state_ptr->consolidate(); |
467 | } |
468 | |
469 | void addEventList(std::vector<LegacyEvent>&& profiledEvents) { |
470 | auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); |
471 | TORCH_CHECK(state_ptr, "Profiler must be enabled." ); |
472 | state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents)); |
473 | } |
474 | |
475 | void LegacyEvent::record(bool record_cuda) { |
476 | if (record_cuda) { |
477 | torch::profiler::impl::cudaStubs()->record(&device_, &cuda_event, &cpu_ns_); |
478 | return; |
479 | } |
480 | cpu_ns_ = torch::profiler::impl::getTime(); |
481 | } |
482 | |
483 | /* static */ LegacyEvent LegacyEvent::fromIValue( |
484 | const at::IValue& eventIValue) { |
485 | TORCH_INTERNAL_ASSERT( |
486 | eventIValue.isList(), |
487 | "Expected IValue to contain type c10::impl::GenericList" ); |
488 | auto ivalues = eventIValue.toList(); |
489 | TORCH_INTERNAL_ASSERT( |
490 | ivalues.size() >= NUM_EVENT_IVALUE_IDX, |
491 | "Expected at least " , |
492 | NUM_EVENT_IVALUE_IDX, |
493 | " elements to reconstruct LegacyEvent." ); |
494 | |
495 | // Reconstruct input shapes from ivalues. |
496 | auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES); |
497 | TORCH_INTERNAL_ASSERT( |
498 | shapeListIValue.isList(), |
499 | "Expected profiler shapes IValue to contain type c10::impl::GenericList." ); |
500 | |
501 | auto shapeList = shapeListIValue.toList(); |
502 | std::vector<std::vector<int64_t>> shapes; |
503 | shapes.reserve(shapeList.size()); |
504 | for (const auto i : c10::irange(shapeList.size())) { |
505 | std::vector<int64_t> s; |
506 | auto shapeIValue = shapeList.get(i); |
507 | TORCH_INTERNAL_ASSERT( |
508 | shapeIValue.isList(), |
509 | "Expected each profiler shape element to contain shapes of type c10::impl::GenericList." ) |
510 | auto curShapesList = shapeIValue.toList(); |
511 | s.reserve(curShapesList.size()); |
512 | for (const auto j : c10::irange(curShapesList.size())) { |
513 | s.emplace_back(curShapesList.get(j).toInt()); |
514 | } |
515 | shapes.emplace_back(s); |
516 | } |
517 | |
518 | LegacyEvent evt( |
519 | static_cast<EventKind>( |
520 | ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind |
521 | at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name |
522 | ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id |
523 | static_cast<at::RecordFunctionHandle>( |
524 | ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle |
525 | std::move(shapes), // input shapes |
526 | ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id |
527 | true, // is remote |
528 | ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage |
529 | ivalues.get(EventIValueIdx::CPU_NS).toInt(), // cpu_ns |
530 | ivalues.get(EventIValueIdx::CUDA_RECORDED).toBool(), // was cuda recorded |
531 | ivalues.get(EventIValueIdx::CUDA_MEM_USAGE).toInt(), // cuda memory usage |
532 | ivalues.get(EventIValueIdx::CUDA_DEVICE).toInt(), // device |
533 | ivalues.get(EventIValueIdx::CUDA_US).toInt() // cuda_us |
534 | ); |
535 | return evt; |
536 | } |
537 | |
538 | at::IValue LegacyEvent::toIValue() const { |
539 | c10::impl::GenericList eventIValueList(at::AnyType::get()); |
540 | eventIValueList.reserve(NUM_EVENT_IVALUE_IDX); |
541 | eventIValueList.emplace_back(static_cast<int64_t>(kind_)); |
542 | eventIValueList.emplace_back(std::string(name_.str())); |
543 | eventIValueList.emplace_back(static_cast<int64_t>(thread_id_)); |
544 | eventIValueList.emplace_back(static_cast<double>(handle_)); |
545 | eventIValueList.emplace_back(node_id_); |
546 | eventIValueList.emplace_back(cpu_memory_usage_); |
547 | eventIValueList.emplace_back(cpu_ns_); |
548 | // CUDA event information |
549 | bool cuda_profiling_enabled = hasCuda(); |
550 | eventIValueList.emplace_back(cuda_profiling_enabled); |
551 | eventIValueList.emplace_back(static_cast<int64_t>(cuda_memory_usage_)); |
552 | eventIValueList.emplace_back(device_); |
553 | eventIValueList.emplace_back(cuda_us_); |
554 | // Shapes |
555 | c10::impl::GenericList shapesList = |
556 | c10::impl::GenericList(at::ListType::create(at::IntType::get())); |
557 | shapesList.reserve(shapes_.size()); |
558 | for (const auto& shape : shapes_) { |
559 | c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get()); |
560 | s.reserve(shape.size()); |
561 | for (const auto& k : shape) { |
562 | s.emplace_back(k); |
563 | } |
564 | shapesList.emplace_back(s); |
565 | } |
566 | eventIValueList.emplace_back(shapesList); |
567 | return at::IValue(eventIValueList); |
568 | } |
569 | |
570 | double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const { |
571 | TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA" ); |
572 | TORCH_CHECK( |
573 | e.device() == device(), |
574 | c10::str( |
575 | "Events are not on the same device: " , e.device(), " vs " , device())); |
576 | if (isRemote() && e.isRemote()) { |
577 | // validate that cuda_us_ has been set properly. |
578 | TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0); |
579 | return static_cast<double>(e.cuda_us_ - cuda_us_); |
580 | } |
581 | return torch::profiler::impl::cudaStubs()->elapsed( |
582 | &cuda_event, &e.cuda_event); |
583 | } |
584 | |
585 | static const at::jit::CodeTemplate event_template(R"( |
586 | { |
587 | "name": "${name}", |
588 | "ph": "X", |
589 | "ts": ${ts}, |
590 | "dur": ${dur}, |
591 | "tid": ${tid}, |
592 | "pid": "CPU Functions", |
593 | "args": {} |
594 | })" ); |
595 | |
596 | void writeProfilerEventsToStream( |
597 | std::ostream& out, |
598 | const std::vector<LegacyEvent*>& events) { |
599 | TORCH_CHECK(out, "Could not open file" ); |
600 | LegacyEvent* profiler_start = nullptr; |
601 | for (LegacyEvent* e : events) { |
602 | if (0 == strcmp(e->name(), "__start_profile" )) { |
603 | profiler_start = e; |
604 | break; |
605 | } |
606 | } |
607 | TORCH_CHECK(profiler_start, "Could not find __start_profile mark" ); |
608 | |
609 | struct PairHash { |
610 | size_t operator()( |
611 | std::pair<at::RecordFunctionHandle, int> p) const noexcept { |
612 | return std::hash<at::RecordFunctionHandle>()(p.first) ^ |
613 | std::hash<int64_t>()(p.second); |
614 | } |
615 | }; |
616 | std::unordered_map< |
617 | std::pair<at::RecordFunctionHandle, int64_t>, |
618 | LegacyEvent*, |
619 | PairHash> |
620 | events_map; |
621 | out << "[\n" ; |
622 | bool first = true; |
623 | for (LegacyEvent* evt : events) { |
624 | if (evt->kindStr() == "push" ) { |
625 | events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt; |
626 | } else if (evt->kindStr() == "pop" ) { |
627 | if (!first) { |
628 | out << ",\n" ; |
629 | } |
630 | first = false; |
631 | auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId())); |
632 | TORCH_CHECK(it != events_map.end(), "Unmatched pop event" ); |
633 | LegacyEvent* evt_start = it->second; |
634 | events_map.erase(it); |
635 | |
636 | at::jit::TemplateEnv env; |
637 | env.s("name" , evt_start->name()); |
638 | env.d("ts" , profiler_start->cpuElapsedUs(*evt_start)); |
639 | env.d("dur" , evt_start->cpuElapsedUs(*evt)); |
640 | env.d("tid" , evt_start->threadId()); |
641 | out << event_template.format(env); |
642 | } |
643 | } |
644 | out << "]\n" ; |
645 | } |
646 | |
647 | RecordProfile::RecordProfile(std::ostream& out) : out_(out) { |
648 | init(); |
649 | } |
650 | |
651 | RecordProfile::RecordProfile(const std::string& filename) |
652 | : file_(new std::ofstream(filename)), out_(*file_) { |
653 | init(); |
654 | } |
655 | |
656 | void RecordProfile::init() { |
657 | enableProfilerLegacy(torch::profiler::impl::ProfilerConfig( |
658 | torch::profiler::impl::ProfilerState::CPU)); |
659 | } |
660 | |
661 | RecordProfile::~RecordProfile() { |
662 | try { |
663 | thread_event_lists event_lists = disableProfilerLegacy(); |
664 | std::vector<LegacyEvent*> events; |
665 | for (auto& l : event_lists) { |
666 | for (auto& e : l) { |
667 | events.push_back(&e); |
668 | } |
669 | } |
670 | processEvents(events); |
671 | } catch (const std::exception& e) { |
672 | LOG(ERROR) << e.what() << std::endl; |
673 | } catch (...) { |
674 | LOG(ERROR) << "Unknown error" << std::endl; |
675 | } |
676 | } |
677 | |
678 | void RecordProfile::processEvents(const std::vector<LegacyEvent*>& events) { |
679 | writeProfilerEventsToStream(out_, events); |
680 | } |
681 | |
682 | } // namespace profiler |
683 | } // namespace autograd |
684 | } // namespace torch |
685 | |