1#include <torch/csrc/autograd/profiler_python.h>
2
3#include <atomic>
4#include <cstdint>
5#include <deque>
6#include <iostream>
7#include <limits>
8#include <memory>
9#include <queue>
10#include <string>
11#include <utility>
12#include <vector>
13
14#include <Python.h>
15#include <frameobject.h>
16
17#include <ATen/core/TensorBase.h>
18#include <c10/macros/Macros.h>
19#include <c10/util/C++17.h>
20#include <c10/util/Exception.h>
21#include <c10/util/Logging.h>
22#include <c10/util/Optional.h>
23#include <c10/util/StringUtil.h>
24#include <c10/util/flat_hash_map.h>
25#include <c10/util/irange.h>
26#include <torch/csrc/autograd/python_variable.h>
27#include <torch/csrc/profiler/collection.h>
28#include <torch/csrc/profiler/containers.h>
29#include <torch/csrc/profiler/orchestration/python_tracer.h>
30#include <torch/csrc/profiler/util.h>
31#include <torch/csrc/utils/pybind.h>
32#include <torch/csrc/utils/python_compat.h>
33#include <torch/csrc/utils/python_strings.h>
34
35namespace py = pybind11;
36
37namespace torch {
38namespace profiler {
39namespace impl {
40namespace {
41enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
42static constexpr size_t CallTypeSize = 4;
43using no_ephemeral_t = std::tuple<>;
44static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();
45
46// ============================================================================
47// == Miscellaneous structs and utils =========================================
48// ============================================================================
49struct CodeLocation {
50 CodeLocation() = default;
51 explicit CodeLocation(PyFrameObject* frame)
52 : line_number_{PyFrame_GetLineNumber(frame)} {
53 auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
54 filename_ = THPUtils_unpackStringView(code->co_filename).data();
55 name_ = THPUtils_unpackStringView(code->co_name).data();
56 }
57
58 bool operator==(const CodeLocation& other) const {
59 return filename_ == other.filename_ && name_ == other.name_ &&
60 line_number_ == other.line_number_;
61 }
62
63 const char* filename_{nullptr};
64 const char* name_{nullptr};
65 int line_number_{0};
66};
67
68template <CallType C>
69PyCodeObject* getCode();
70
71template <>
72PyCodeObject* getCode<CallType::PyModuleCall>() {
73 static auto module_call_code = []() {
74 pybind11::gil_scoped_acquire gil;
75 auto res = py::module::import("torch.nn")
76 .attr("Module")
77 .attr("__call__")
78 .attr("__code__")
79 .ptr();
80 TORCH_INTERNAL_ASSERT(PyCode_Check(res));
81 return (PyCodeObject*)res;
82 }();
83 return module_call_code;
84};
85
86template <>
87PyCodeObject* getCode<CallType::PyOptimizerCall>() {
88 static auto optimizer_step_code = []() {
89 pybind11::gil_scoped_acquire gil;
90 auto res = py::module::import("torch.optim")
91 .attr("Optimizer")
92 .attr("_optimizer_step_code")
93 .attr("__code__")
94 .ptr();
95 TORCH_INTERNAL_ASSERT(PyCode_Check(res));
96 return (PyCodeObject*)res;
97 }();
98 return optimizer_step_code;
99};
100
101} // namespace
102} // namespace impl
103} // namespace profiler
104} // namespace torch
105
106template <>
107struct std::hash<torch::profiler::impl::CodeLocation> {
108 size_t operator()(const torch::profiler::impl::CodeLocation& x) {
109 return c10::get_hash(x.filename_, x.name_, x.line_number_);
110 }
111};
112
113namespace torch {
114namespace profiler {
115namespace impl {
116namespace {
117// ============================================================================
118// == CallTypeHelper: Tools for generic programming on specializations. =======
119// ============================================================================
120template <template <CallType> class ClassT>
121class CallTypeHelper final {
122 private:
123 static_assert(
124 CallType::PyCall == 0,
125 "CallTypeHelper uses integer math which depends on a zero start.");
126 static constexpr size_t End = CallTypeSize;
127
128 template <size_t... I>
129 static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
130 std::index_sequence<I...>);
131
132 template <size_t C, typename T, typename FunctorT, typename... Args>
133 static void map(T& t, FunctorT& f, Args&&... args) {
134 f(std::get<C>(t), args...);
135 c10::guts::if_constexpr<C + 1 < End>(
136 [&](auto _) { map<C + 1>(_(t), f, std::forward<Args>(args)...); });
137 }
138
139 public:
140 using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));
141
142 template <typename FunctorT, typename... Args>
143 static void map(tuple_type& t, FunctorT& f, Args&&... args) {
144 map<0>(t, f, std::forward<Args>(args)...);
145 }
146};
147
148// ============================================================================
149// == Event type definitions. =================================================
150// ============================================================================
151// When we are tracing a Python program, the general procedure is to record
152// every time we enter or exit a function and later replay these events during
153// post processing. Thus, during the profiling phase we want to do the MINIMAL
154// amount of work to capture all of the information that we need; otherwise we
155// will distort the profile. (While we don't wish to be terribly inefficient
156// during post processing, we are willing to do extra fixup work in post if it
157// reduces overhead in the profiling phase.)
158//
159// When the tracer first enters a frame, it constructs a CallKey for that
160// location. The contents of the key vary by context. For a python function
161// the key is the (PyCodeObject*, int) pair that defines the bytecode of the
162// function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
163// For a bound C function it is a (non-owning) pointer to the bound function.
164// A CallKey should be small, inexpensive, and POD.
165//
166// We then collect a CallKey<CallType::PyCall> for the calling frame for better
167// source tracking. This pair is a `Callsite`, and serves as a first level key
168// during tracing. We lookup the Callsite in a thread local cache which maps
169// Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
170// TraceKey and return. On a cache miss, we use a global value cache to store
171// whatever fields we need from the two CallKeys, generate a new TraceKey, and
172// update the local cache.
173//
174// During post processing we:
175// 1) Determine the type represented by a TraceKey by checking which
176// sub-cache it appears in in the thread local cache.
177// 2) Look up the pair of CallKeys from the thread local cache.
178// 3) Look up the expanded values of each CallKey from the global value cache.
179//
180// To add a new event type to the cache:
181// 1) Add an entry to the `CallType` enum.
182// 2) Add a specialization of Config which defined key_t, ephemeral_t and
183// cache_t.
184// 3) Add a specialization of ValueCache::store and ValueCache::load.
185//
186// -------------------------
187// -- Ephemeral arguments --
188// -------------------------
189// The value cache mechanism assumes that `key_t` is enough to specify the
190// correct value. However it may not be possible to materialize a value using
191// only an instance of `key_t`. As a result, the cache also accepts "ephemeral"
192// inputs which can be used to populate the value cache. Ephemeral inputs come
193// with two caveats:
194// 1) They are NOT safe to save, and cannot be used after `ValueCache::store`.
195// 2) They should be used to access data that is not expect to change from
196// call to call, such as the name of a function.
197
198template <CallType>
199struct Config;
200
201template <>
202struct Config<CallType::PyCall> {
203 using key_t = CodeLocation;
204 using ephemeral_t = no_ephemeral_t;
205 using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
206 static constexpr EventType event_type = EventType::PyCall;
207};
208
209template <typename Key, typename Cls, typename ParameterInfo>
210struct ExtendedPyCallConfig {
211 using key_t = Key;
212 using cls_t = Cls;
213 using ephemeral_t = PyFrameObject*;
214
215 struct ClsAndParameters {
216 cls_t cls_;
217 std::vector<ParameterInfo> parameters_;
218 };
219
220 struct Cache {
221 // `nn.Module.forward` or `optim.Optimizer._optimizer_step_code`
222 c10::optional<CodeLocation> location_;
223 ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_;
224 ska::flat_hash_map<cls_t, at::StringView> cls_names_;
225 };
226 using cache_t = Cache;
227
228 static constexpr EventType event_type = EventType::PyCall;
229};
230
231template <>
232struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig<
233 PyModuleSelf,
234 PyModuleCls,
235 NNModuleInfo::ParameterInfo> {};
236
237template <>
238struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig<
239 PyOptimizerSelf,
240 PyOptimizerCls,
241 OptimizerInfo::ParameterInfo> {};
242
243template <>
244struct Config<CallType::PyCCall> {
245 using key_t = PyMethod;
246 using ephemeral_t = PyObject*;
247 using cache_t = ska::flat_hash_map<key_t, at::StringView>;
248 static constexpr EventType event_type = EventType::PyCCall;
249};
250
251// ============================================================================
252// == Callsite & ValueCache: Storage during profiling =========================
253// ============================================================================
254template <CallType C>
255class Callsite {
256 public:
257 static constexpr CallType call_type = C;
258 using key_t = typename Config<C>::key_t;
259
260 static_assert(
261 std::is_trivially_copyable<key_t>::value,
262 "Key should be trivial, as it is passed by value.");
263
264 template <typename U>
265 Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {}
266
267 bool operator==(const Callsite<C>& other) const {
268 return value_ == other.value_ && caller_ == other.caller_;
269 }
270
271 key_t value_;
272 Config<CallType::PyCall>::key_t caller_;
273};
274
275// ============================================================================
276// == Type specific store and load implementations. ===========================
277// ============================================================================
278using PyCallKey = Config<CallType::PyCall>::key_t;
279using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
280using PyCCallKey = Config<CallType::PyCCall>::key_t;
281using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;
282
283class ValueCache {
284 public:
285 ValueCache() = default;
286 ValueCache(const ValueCache&) = delete;
287
288 template <CallType C>
289 void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);
290
291 template <CallType C>
292 auto load(const Callsite<C>& callsite, size_t python_tid) const {
293 auto caller = load<CallType::PyCall>(callsite.caller_);
294 TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
295 return ExtraFields<Config<C>::event_type>{
296 /*end_time_ns=*/std::numeric_limits<time_t>::min(),
297 python_tid,
298 caller.frame_state_,
299 load<C>(callsite.value_)};
300 }
301
302 c10::optional<TensorMetadata> recordIfTensor(py::handle p);
303 std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap(
304 py::dict tensor_map);
305 void trimPrefixes();
306
307 private:
308 template <CallType C>
309 typename ExtraFields<Config<C>::event_type>::args_t load(
310 const typename Config<C>::key_t&) const;
311
312 template <CallType C>
313 using State = typename Config<C>::cache_t;
314
315 CallTypeHelper<State>::tuple_type state_;
316};
317
318template <CallType C>
319typename Config<C>::cls_t set_class(
320 ValueCache* value_cache,
321 typename Config<C>::cache_t& cache,
322 const typename Config<C>::key_t& key,
323 const typename Config<C>::ephemeral_t& frame) {
324 if (C10_UNLIKELY(!cache.location_.has_value())) {
325 auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
326 TORCH_INTERNAL_ASSERT(code.get() == getCode<C>());
327 cache.location_ = PyCallKey(frame);
328 value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t());
329 }
330
331 auto cls_handle = py::handle((PyObject*)key).attr("__class__");
332 auto cls = typename Config<C>::cls_t(cls_handle.ptr());
333 if (cache.cls_names_.find(cls) == cache.cls_names_.end()) {
334 cache.cls_names_[cls] =
335 at::StringView(py::str(cls_handle.attr("__name__")));
336 }
337 return cls;
338}
339
340TensorMetadata toTensorMetadata(PyObject* self) {
341 TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self));
342 const auto& t = THPVariable_Unpack(self);
343 RawTensorMetadata m{t};
344 return TensorMetadata{
345 m,
346 t.sizes().vec(),
347 m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()};
348}
349
350c10::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) {
351 return THPVariable_CheckExact(p.ptr())
352 ? c10::optional<TensorMetadata>{toTensorMetadata(p.ptr())}
353 : c10::nullopt;
354}
355
356std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap(
357 py::dict tensor_map) {
358 std::vector<std::pair<std::string, TensorMetadata>> out;
359 for (auto& it : tensor_map) {
360 auto* value = it.second.ptr();
361 if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) {
362 out.emplace_back(
363 py::cast<std::string>(it.first), toTensorMetadata(value));
364 }
365 }
366 return out;
367}
368
369template <>
370void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) {
371 auto& locations = std::get<CallType::PyCall>(state_);
372 if (C10_UNLIKELY(locations.find(key) == locations.end())) {
373 locations[key] = {
374 key.line_number_,
375 at::StringView(key.filename_),
376 at::StringView(key.name_)};
377 }
378}
379
380template <>
381ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
382 const PyCallKey& key) const {
383 return {std::get<CallType::PyCall>(state_).at(key), c10::nullopt};
384}
385
386template <>
387void ValueCache::store<CallType::PyModuleCall>(
388 const PyModuleCallKey& key,
389 Config<CallType::PyModuleCall>::ephemeral_t frame) {
390 auto& cache = std::get<CallType::PyModuleCall>(state_);
391 if (C10_UNLIKELY(
392 cache.cls_and_parameters_.find(key) ==
393 cache.cls_and_parameters_.end())) {
394 auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame);
395
396 py::dict params = py::handle((PyObject*)key).attr("_parameters");
397 std::vector<NNModuleInfo::ParameterInfo> params_;
398 for (auto& it : params) {
399 auto* p = it.second.ptr();
400 if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) {
401 params_.push_back(
402 {it.first.cast<std::string>(),
403 toTensorMetadata(p),
404 recordIfTensor(py::getattr(it.second, "grad", py::none()))});
405 }
406 }
407 cache.cls_and_parameters_[key] = {cls, std::move(params_)};
408 }
409}
410
411template <>
412ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
413 const PyModuleCallKey& key) const {
414 auto& cache = std::get<CallType::PyModuleCall>(state_);
415 TORCH_INTERNAL_ASSERT(cache.location_.has_value());
416 const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
417 const auto& cls = cls_and_parameters.cls_;
418 NNModuleInfo info{
419 key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
420 return {
421 /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
422 /*module_info_=*/std::move(info),
423 /*optimizer_info_=*/c10::nullopt};
424}
425
426template <>
427void ValueCache::store<CallType::PyOptimizerCall>(
428 const PyOptimizerCallKey& key,
429 Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
430 auto& cache = std::get<CallType::PyOptimizerCall>(state_);
431 if (C10_UNLIKELY(
432 cache.cls_and_parameters_.find(key) ==
433 cache.cls_and_parameters_.end())) {
434 auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
435 const py::handle self{(PyObject*)key};
436 std::vector<OptimizerInfo::ParameterInfo> params;
437
438 for (const auto& i : (py::list)self.attr("param_groups")) {
439 for (auto& param : py::cast<py::dict>(i).attr("get")("params")) {
440 if (THPVariable_CheckExact(param.ptr())) {
441 // While `self.state` is permitted to store data in an arbitrary way,
442 // all generic optimizers (SGD, Adam, etc) use param as the key since
443 // the state in question is tied to particular parameters. We can
444 // relax this assumption if the need arises.
445 params.push_back(
446 {toTensorMetadata(param.ptr()),
447 recordIfTensor(py::getattr(param, "grad", py::none())),
448 unpackTensorMap(py::cast<py::dict>(self.attr("state"))
449 .attr("get")(param, py::dict()))});
450 }
451 }
452 }
453
454 cache.cls_and_parameters_[key] = {cls, std::move(params)};
455 }
456}
457
458template <>
459ExtraFields<EventType::PyCall>::args_t ValueCache::load<
460 CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
461 auto& cache = std::get<CallType::PyOptimizerCall>(state_);
462 const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
463 auto cls = cls_and_parameters.cls_;
464 OptimizerInfo info{
465 key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
466 return {
467 /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
468 /*module_info_=*/c10::nullopt,
469 /*optimizer_info_=*/std::move(info)};
470}
471
472template <>
473void ValueCache::store<CallType::PyCCall>(
474 const PyCCallKey& key,
475 Config<CallType::PyCCall>::ephemeral_t arg) {
476 auto& names = std::get<CallType::PyCCall>(state_);
477 if (C10_UNLIKELY(names.find(key) == names.end())) {
478 names[key] = at::StringView(py::repr(arg));
479 }
480}
481
482template <>
483ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
484 const PyCCallKey& key) const {
485 return std::get<CallType::PyCCall>(state_).at(key);
486}
487
488// TODO: Use re2.
489void ValueCache::trimPrefixes() {
490 static const auto prefixes = []() {
491 pybind11::gil_scoped_acquire gil;
492 return py::module::import("torch.profiler.python_tracer")
493 .attr("_prefix_regex")()
494 .cast<std::vector<std::string>>();
495 }();
496
497 for (auto& it : std::get<CallType::PyCall>(state_)) {
498 std::string filename = it.second.filename_.str();
499 for (const auto& p : prefixes) {
500 if (filename.compare(0, p.size(), p) == 0) {
501 filename.erase(0, p.size());
502 it.second.filename_ = at::StringView(filename);
503 break;
504 }
505 }
506 }
507}
508
509// ============================================================================
510// == TraceKey cache ==========================================================
511// ============================================================================
512using python_tracer::TraceKey;
513
514TraceKey nextKey() {
515 static std::atomic<uint64_t> key{0};
516 return TraceKey{++key};
517}
518
519template <CallType C>
520struct TraceKeyCacheState {
521 struct Hash {
522 size_t operator()(const Callsite<C>& key) {
523 return c10::get_hash(key.value_, key.caller_);
524 }
525 };
526
527 TraceKey intern(
528 Callsite<C> callsite,
529 typename Config<C>::ephemeral_t ephemeral,
530 ValueCache& value_cache) {
531 auto it = state_.find(callsite);
532 if (C10_UNLIKELY(it == state_.end())) {
533 value_cache.store<C>(callsite.value_, ephemeral);
534 value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t());
535 it = state_.insert({callsite, nextKey()}).first;
536 }
537 return it->second;
538 }
539
540 auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
541 return std::make_pair(
542 value_cache.load<C>(callsite.value_),
543 value_cache.load<CallType::PyCall>(callsite.caller_));
544 }
545
546 ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
547};
548
549// ============================================================================
550// == Core CPython data types =================================================
551// ============================================================================
552// PyObject that allows different threads to record events without colliding.
553// It is passed as the second argument when enabling tracing via
554// `PyEval_SetProfile`.
555struct ThreadLocalResults;
556struct TraceContext {
557 PyObject_HEAD;
558 ThreadLocalResults* thread_local_results_;
559};
560
561// CPython boilerplate to define `TraceContext` as a proper python object.
562static PyTypeObject TraceContextType = {
563 PyVarObject_HEAD_INIT(nullptr, 0) "TraceContext", /* tp_name */
564 sizeof(TraceContext), /* tp_basicsize */
565 0, /* tp_itemsize */
566 nullptr, /* tp_dealloc */
567 0,
568 /* tp_vectorcall_offset */ // NOLINT: modernize-use-nullptr
569 nullptr, /* tp_getattr */
570 nullptr, /* tp_setattr */
571 nullptr, /* tp_reserved */
572 nullptr, /* tp_repr */
573 nullptr, /* tp_as_number */
574 nullptr, /* tp_as_sequence */
575 nullptr, /* tp_as_mapping */
576 nullptr, /* tp_hash */
577 nullptr, /* tp_call */
578 nullptr, /* tp_str */
579 nullptr, /* tp_getattro */
580 nullptr, /* tp_setattro */
581 nullptr, /* tp_as_buffer */
582 Py_TPFLAGS_DEFAULT, /* tp_flags */
583 "Python tracer TLS", /* tp_doc */
584 nullptr, /* tp_traverse */
585 nullptr, /* tp_clear */
586 nullptr, /* tp_richcompare */
587 0, /* tp_weaklistoffset */
588 nullptr, /* tp_iter */
589 nullptr, /* tp_iternext */
590 nullptr, /* tp_methods */
591 nullptr, /* tp_members */
592 nullptr, /* tp_getset */
593 nullptr, /* tp_base */
594 nullptr, /* tp_dict */
595 nullptr, /* tp_descr_get */
596 nullptr, /* tp_descr_set */
597 0, /* tp_dictoffset */
598 nullptr, /* tp_init */
599 nullptr, /* tp_alloc */
600 PyType_GenericNew, /* tp_new */
601 nullptr /* tp_free */
602};
603
604class gil_and_restore_thread {
605 public:
606 gil_and_restore_thread()
607 : gil_(), initial_thread_state_{PyThreadState_Get()} {}
608 ~gil_and_restore_thread() {
609 PyThreadState_Swap(initial_thread_state_);
610
611 // `gil_scoped_acquire` is a bit fragile in on-demand mode:
612 // https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
613 if (!Py_IsInitialized()) {
614 gil_.disarm();
615 }
616 }
617
618 PyThreadState* initial_thread_state() const {
619 return initial_thread_state_;
620 }
621
622 private:
623 pybind11::gil_scoped_acquire gil_;
624 PyThreadState* initial_thread_state_;
625};
626
627// ============================================================================
628// == Thread local cache ======================================================
629// ============================================================================
630class PythonTracer;
631struct ThreadLocalResults {
632 ThreadLocalResults(
633 PyThreadState* thread_state,
634 ValueCache* value_cache,
635 PythonTracer* active_tracer)
636 : thread_state_{thread_state},
637 ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
638 value_cache_{value_cache},
639 active_tracer_{active_tracer} {
640 ctx_->thread_local_results_ = this;
641 }
642
643 ThreadLocalResults() = delete;
644 ThreadLocalResults(const ThreadLocalResults&) = delete;
645 ThreadLocalResults(ThreadLocalResults&&) = delete;
646 ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
647 ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;
648
649 ~ThreadLocalResults() {
650 Py_DECREF((PyObject*)ctx_);
651 }
652
653 template <CallType C, EventType E, typename Ephemeral, typename... Args>
654 TraceKey intern(Ephemeral ephemeral, Args... args) {
655 static_assert(
656 Config<C>::event_type == E,
657 "ThreadLocalResults.intern called from the wrong typed context.");
658 auto callsite = Callsite<C>(std::forward<Args>(args)...);
659 return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_);
660 }
661
662 static constexpr size_t BLOCK_SIZE = 1024;
663
664 PyThreadState* thread_state_;
665 TraceContext* ctx_;
666 ValueCache* value_cache_;
667 PythonTracer* active_tracer_;
668 CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
669 AppendOnlyList<approx_time_t, BLOCK_SIZE> exit_times_;
670 AppendOnlyList<approx_time_t, BLOCK_SIZE> c_exit_times_;
671};
672
673// ============================================================================
674// == Tracing implementation ==================================================
675// ============================================================================
676class PythonTracer final : public python_tracer::PythonTracerBase {
677 public:
678 PythonTracer(torch::profiler::impl::RecordQueue* queue);
679 ~PythonTracer() override;
680
681 static int pyProfileFn(
682 PyObject* obj,
683 PyFrameObject* frame,
684 int what,
685 PyObject* arg);
686
687 void stop() override;
688 std::vector<std::shared_ptr<Result>> getEvents(
689 std::function<time_t(approx_time_t)> time_converter,
690 std::vector<python_tracer::CompressedEvent>& enters,
691 time_t end_time_ns) override;
692
693 struct StartFrame {
694 TraceKey trace_key_;
695 approx_time_t start_time;
696 };
697
698 private:
699 void recordPyCall(
700 ThreadLocalResults& tls,
701 PyFrameObject* frame,
702 bool is_startup_frame);
703
704 void recordCCall(
705 ThreadLocalResults& tls,
706 PyFrameObject* frame,
707 PyObject* arg);
708
709 const std::vector<PyThreadState*> interpreterThreads() const;
710
711 std::atomic<bool> active_lock_{false};
712 bool active_{false};
713
714 torch::profiler::impl::RecordQueue* queue_;
715 PyInterpreterState* interpreter_;
716 PyCodeObject* module_call_code_;
717 PyCodeObject* optimizer_hook_;
718
719 std::vector<StartFrame> start_frames_;
720 std::deque<ThreadLocalResults> thread_local_results_;
721 ValueCache value_cache_;
722};
723
724const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
725 pybind11::gil_scoped_acquire gil;
726 std::vector<PyThreadState*> out;
727 if (SOFT_ASSERT(interpreter_)) {
728 auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
729 while (thread_state != nullptr) {
730 out.push_back(thread_state);
731 thread_state = PyThreadState_Next(thread_state);
732 }
733 }
734 return out;
735}
736
737PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
738 : queue_(queue),
739 interpreter_(nullptr),
740 module_call_code_(getCode<CallType::PyModuleCall>()),
741 optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
742 TORCH_CHECK(queue_ != nullptr);
743
744 bool expected{false};
745 active_ = active_lock_.compare_exchange_strong(expected, true);
746 if (!active_) {
747 TORCH_WARN(
748 "There is already an active Python tracer. "
749 "Refusing to register profile functions.");
750 return;
751 }
752
753 gil_and_restore_thread gil;
754 interpreter_ = PyInterpreterState_Get();
755
756 if (!gil.initial_thread_state()) {
757 TORCH_WARN("PyThreadState_Get returned NULL");
758 return;
759 }
760
761 // Register the tracer in each thread.
762 for (const auto thread_state : interpreterThreads()) {
763 PyThreadState_Swap(thread_state);
764
765 thread_local_results_.emplace_back(thread_state, &value_cache_, this);
766 auto* ctx = thread_local_results_.back().ctx_;
767
768 // When we begin profiling there are already frames on the Python
769 // interpreter stack. To ensure a complete trace, we must push calls
770 // to all the prior frames onto our event stack. (We stop at depth=128)
771
772 std::vector<THPFrameObjectPtr> current_stack;
773 auto frame = PyEval_GetFrame();
774 Py_XINCREF(frame);
775
776 size_t depth = 0; // Make sure we can't infinite loop.
777 while (frame != nullptr) {
778 current_stack.emplace_back(frame);
779 if (++depth == 128) {
780 break;
781 }
782
783 // NB: `PyFrame_GetBack` returns a strong reference.
784 frame = PyFrame_GetBack(frame);
785 }
786
787 for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
788 recordPyCall(thread_local_results_.back(), it->get(), true);
789 auto frame_refcount = Py_REFCNT(it->get());
790
791 // We hold one reference in `current_stack`, and the interpreter holds
792 // another.
793 TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
794 }
795
796 // Note:
797 // This profile will not compose with other CPython profilers, and
798 // cannot be round tripped via `sys.settrace(sys.gettrace())`
799 PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
800 }
801};
802
803void PythonTracer::stop() {
804 gil_and_restore_thread gil;
805 if (active_) {
806 for (const auto thread_state : interpreterThreads()) {
807 if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
808 PyThreadState_Swap(thread_state);
809 PyEval_SetProfile(nullptr, nullptr);
810 }
811 }
812
813 auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
814 active_ = false;
815 SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
816 }
817}
818
819PythonTracer::~PythonTracer() {
820 if (active_) {
821 TORCH_WARN("`PythonTracer::stop()` was not called.");
822 stop();
823 }
824}
825
826void PythonTracer::recordPyCall(
827 ThreadLocalResults& tls,
828 PyFrameObject* frame,
829 bool is_startup_frame) {
830 static constexpr auto E = EventType::PyCall;
831 const auto key = [&]() -> TraceKey {
832 auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
833 if (code.get() == module_call_code_) {
834 // By default, CPython stores locals in a "fast" format, with an array
835 // of names and an array of values. Consequently, frame->f_locals is
836 // NULL since the interpreter has no need to populate it.
837 //
838 // If these arrays were part of the public API then we could very
839 // quickly access `self`. Unfortunately they are not, and moreover are
840 // not stable across versions. As a result, we are forced to call
841 // `PyFrame_FastToLocals` which forces the interpreter to materialize
842 // the full dict of locals.
843 auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
844 auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
845 Py_INCREF(self.get());
846 auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
847 TORCH_INTERNAL_ASSERT(back != nullptr);
848 return tls.intern<CallType::PyModuleCall, E>(
849 frame, self.get(), back.get());
850 } else if (code.get() == optimizer_hook_) {
851 auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
852 auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
853 Py_INCREF(self.get());
854 auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
855 TORCH_INTERNAL_ASSERT(back != nullptr);
856 return tls.intern<CallType::PyOptimizerCall, E>(
857 frame, self.get(), back.get());
858 } else {
859 auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
860 auto f_back = (back.get() != nullptr) ? back.get() : frame;
861 return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
862 }
863 }();
864 const auto time = getApproximateTime();
865 is_startup_frame ? start_frames_.push_back({key, time})
866 : queue_->getSubqueue()->emplace_py_call(key, time);
867}
868
869void PythonTracer::recordCCall(
870 ThreadLocalResults& tls,
871 PyFrameObject* frame,
872 PyObject* arg) {
873 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(Py_TYPE(arg) == &PyCFunction_Type);
874 auto fn = reinterpret_cast<PyCFunctionObject*>(arg);
875
876 // NB: For C calls a new frame is not created, so we use `frame` rather than
877 // `frame->f_back`.
878 auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
879 arg, (void*)(fn->m_ml), frame);
880 queue_->getSubqueue()->emplace_py_call(key, getApproximateTime());
881}
882
883// ============================================================================
884// == Post processing =========================================================
885// ============================================================================
886struct Exit {
887 bool operator>(const Exit& other) const {
888 return t_ > other.t_;
889 }
890
891 time_t t_;
892 size_t python_tid_;
893};
894
895class PostProcess {
896 public:
897 PostProcess(
898 std::function<time_t(approx_time_t)> time_converter,
899 std::deque<ThreadLocalResults>& tls,
900 const ValueCache& value_cache,
901 time_t end_time_ns)
902 : end_time_{end_time_ns}, time_converter_{std::move(time_converter)} {
903 for (size_t python_tid : c10::irange(tls.size())) {
904 CallTypeHelper<TraceKeyCacheState>::map(
905 tls[python_tid].trace_keys_, *this, value_cache, python_tid);
906
907 addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
908 addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
909 }
910 }
911
912 void set_start_frames(
913 const std::vector<PythonTracer::StartFrame>& start_frames,
914 std::vector<python_tracer::CompressedEvent>& enters) {
915 for (const auto& frame : start_frames) {
916 enters.push_back(
917 {frame.trace_key_,
918 NoTID, // Allows us to detect unhandled start frames
919 {},
920 time_converter_(frame.start_time)});
921 }
922 }
923
924 template <CallType C>
925 void operator()(
926 const TraceKeyCacheState<C>& trace_cache,
927 const ValueCache& value_cache,
928 size_t python_tid) {
929 for (const auto& it : trace_cache.state_) {
930 const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
931 {it.second, value_cache.load(it.first, python_tid)});
932 TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
933 }
934 }
935
936 template <EventType E, size_t N>
937 void addExits(AppendOnlyList<approx_time_t, N>& exits, size_t python_tid) {
938 for (const auto i : exits) {
939 get_state<E>().exits_.push({time_converter_(i), python_tid});
940 }
941 }
942
943 std::vector<std::shared_ptr<Result>> run(
944 std::vector<python_tracer::CompressedEvent>& enters) {
945 std::stable_sort(
946 enters.begin(), enters.end(), [](const auto a, const auto b) {
947 return a.enter_t_ < b.enter_t_;
948 });
949 std::vector<std::shared_ptr<Result>> out;
950 populate<EventType::PyCall>(enters, out);
951 populate<EventType::PyCCall>(enters, out);
952 return out;
953 }
954
955 private:
956 template <EventType E>
957 void populate(
958 std::vector<python_tracer::CompressedEvent>& enters,
959 std::vector<std::shared_ptr<Result>>& out) {
960 using stack_t = std::vector<std::shared_ptr<Result>>;
961 const auto initial_size = out.size();
962 auto pop = [](stack_t& stack, time_t t) {
963 TORCH_INTERNAL_ASSERT(stack.size(), "Python replay stack is empty.");
964 c10::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
965 stack.pop_back();
966 };
967
968 ska::flat_hash_map<size_t, stack_t> stacks;
969 auto& state = get_state<E>();
970 for (const auto& enter : enters) {
971 auto fields_it = state.fields_.find(enter.key_);
972 if (fields_it != state.fields_.end()) {
973 while (!state.exits_.empty() &&
974 state.exits_.top().t_ < enter.enter_t_) {
975 auto& exit = state.exits_.top();
976 pop(stacks[exit.python_tid_], exit.t_);
977 state.exits_.pop();
978 }
979 out.push_back(Result::create(
980 enter.enter_t_,
981 enter.system_tid_,
982 enter.kineto_info_,
983 fields_it->second));
984
985 stacks[fields_it->second.python_tid_].push_back(out.back());
986 }
987 }
988
989 // Handle events which were still running when profiling ended.
990 for (auto& i : stacks) {
991 while (!i.second.empty()) {
992 pop(i.second, end_time_);
993 }
994 }
995
996 // Assign system TIDs to start events based on the system TID of the next
997 // observed event with the same Python TID.
998 ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
999 tid_map;
1000 auto it = out.rbegin();
1001 for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) {
1002 const auto python_tid =
1003 c10::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
1004 if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
1005 const auto& tid_info =
1006 tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
1007 .first->second;
1008 (*it)->start_tid_ = tid_info.first;
1009 (*it)->kineto_info_ = tid_info.second;
1010 }
1011 tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
1012 ++it;
1013 }
1014 }
1015
1016 template <EventType E>
1017 struct State {
1018 ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
1019 std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_;
1020 };
1021
1022 template <EventType E>
1023 auto& get_state() {
1024 return std::get < E == EventType::PyCall ? 0 : 1 > (state_);
1025 }
1026
1027 time_t end_time_;
1028 std::function<time_t(approx_time_t)> time_converter_;
1029 std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
1030};
1031
1032struct PythonIDVisitor {
1033 void operator()(ExtraFields<EventType::PyCall>& py_call) {
1034 py_call.id_ = ++current_python_id_;
1035 if (py_call.module_.has_value()) {
1036 auto& m = py_call.module_;
1037 auto& module_ids = module_ids_[m->cls_];
1038 m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
1039 }
1040 }
1041
1042 void operator()(ExtraFields<EventType::PyCCall>& py_call) {
1043 py_call.id_ = ++current_python_id_;
1044 }
1045
1046 template <typename T>
1047 void operator()(T&) {}
1048
1049 size_t current_python_id_{0};
1050 ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
1051 module_ids_;
1052};
1053
1054std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
1055 std::function<time_t(approx_time_t)> time_converter,
1056 std::vector<python_tracer::CompressedEvent>& enters,
1057 time_t end_time_ns) {
1058 value_cache_.trimPrefixes();
1059 PostProcess post_process(
1060 std::move(time_converter),
1061 thread_local_results_,
1062 value_cache_,
1063 end_time_ns);
1064 post_process.set_start_frames(start_frames_, enters);
1065 auto out = post_process.run(enters);
1066
1067 std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
1068 return a->start_time_ns_ < b->start_time_ns_;
1069 });
1070
1071 PythonIDVisitor id_visitor;
1072 for (auto& i : out) {
1073 c10::visit(id_visitor, i->extra_fields_);
1074 }
1075
1076 return out;
1077}
1078
1079// ============================================================================
1080// == API =====================================================================
1081// ============================================================================
1082int PythonTracer::pyProfileFn(
1083 PyObject* obj,
1084 PyFrameObject* frame,
1085 int what,
1086 PyObject* arg) {
1087 auto& local_results =
1088 *reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
1089 switch (what) {
1090 case PyTrace_CALL:
1091 local_results.active_tracer_->recordPyCall(local_results, frame, false);
1092 break;
1093
1094 case PyTrace_C_CALL:
1095 local_results.active_tracer_->recordCCall(local_results, frame, arg);
1096 break;
1097
1098 case PyTrace_EXCEPTION:
1099 case PyTrace_RETURN:
1100 local_results.exit_times_.emplace_back(getApproximateTime());
1101 break;
1102
1103 case PyTrace_C_EXCEPTION:
1104 case PyTrace_C_RETURN:
1105 local_results.c_exit_times_.emplace_back(getApproximateTime());
1106 break;
1107 }
1108 return 0;
1109}
1110
1111std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
1112 torch::profiler::impl::RecordQueue* queue) {
1113 return std::make_unique<PythonTracer>(queue);
1114}
1115} // namespace
1116} // namespace impl
1117} // namespace profiler
1118} // namespace torch
1119
1120namespace torch {
1121namespace autograd {
1122namespace profiler {
1123namespace python_tracer {
1124
1125void init() {
1126 pybind11::gil_scoped_acquire gil;
1127 TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
1128 torch::profiler::impl::python_tracer::registerTracer(
1129 &torch::profiler::impl::getTracer);
1130}
1131} // namespace python_tracer
1132} // namespace profiler
1133} // namespace autograd
1134} // namespace torch
1135