1 | #include <torch/csrc/autograd/profiler_python.h> |
2 | |
3 | #include <atomic> |
4 | #include <cstdint> |
5 | #include <deque> |
6 | #include <iostream> |
7 | #include <limits> |
8 | #include <memory> |
9 | #include <queue> |
10 | #include <string> |
11 | #include <utility> |
12 | #include <vector> |
13 | |
14 | #include <Python.h> |
15 | #include <frameobject.h> |
16 | |
17 | #include <ATen/core/TensorBase.h> |
18 | #include <c10/macros/Macros.h> |
19 | #include <c10/util/C++17.h> |
20 | #include <c10/util/Exception.h> |
21 | #include <c10/util/Logging.h> |
22 | #include <c10/util/Optional.h> |
23 | #include <c10/util/StringUtil.h> |
24 | #include <c10/util/flat_hash_map.h> |
25 | #include <c10/util/irange.h> |
26 | #include <torch/csrc/autograd/python_variable.h> |
27 | #include <torch/csrc/profiler/collection.h> |
28 | #include <torch/csrc/profiler/containers.h> |
29 | #include <torch/csrc/profiler/orchestration/python_tracer.h> |
30 | #include <torch/csrc/profiler/util.h> |
31 | #include <torch/csrc/utils/pybind.h> |
32 | #include <torch/csrc/utils/python_compat.h> |
33 | #include <torch/csrc/utils/python_strings.h> |
34 | |
35 | namespace py = pybind11; |
36 | |
37 | namespace torch { |
38 | namespace profiler { |
39 | namespace impl { |
40 | namespace { |
41 | enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall }; |
42 | static constexpr size_t CallTypeSize = 4; |
43 | using no_ephemeral_t = std::tuple<>; |
44 | static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max(); |
45 | |
46 | // ============================================================================ |
47 | // == Miscellaneous structs and utils ========================================= |
48 | // ============================================================================ |
49 | struct CodeLocation { |
50 | CodeLocation() = default; |
51 | explicit CodeLocation(PyFrameObject* frame) |
52 | : line_number_{PyFrame_GetLineNumber(frame)} { |
53 | auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); |
54 | filename_ = THPUtils_unpackStringView(code->co_filename).data(); |
55 | name_ = THPUtils_unpackStringView(code->co_name).data(); |
56 | } |
57 | |
58 | bool operator==(const CodeLocation& other) const { |
59 | return filename_ == other.filename_ && name_ == other.name_ && |
60 | line_number_ == other.line_number_; |
61 | } |
62 | |
63 | const char* filename_{nullptr}; |
64 | const char* name_{nullptr}; |
65 | int line_number_{0}; |
66 | }; |
67 | |
68 | template <CallType C> |
69 | PyCodeObject* getCode(); |
70 | |
71 | template <> |
72 | PyCodeObject* getCode<CallType::PyModuleCall>() { |
73 | static auto module_call_code = []() { |
74 | pybind11::gil_scoped_acquire gil; |
75 | auto res = py::module::import("torch.nn" ) |
76 | .attr("Module" ) |
77 | .attr("__call__" ) |
78 | .attr("__code__" ) |
79 | .ptr(); |
80 | TORCH_INTERNAL_ASSERT(PyCode_Check(res)); |
81 | return (PyCodeObject*)res; |
82 | }(); |
83 | return module_call_code; |
84 | }; |
85 | |
86 | template <> |
87 | PyCodeObject* getCode<CallType::PyOptimizerCall>() { |
88 | static auto optimizer_step_code = []() { |
89 | pybind11::gil_scoped_acquire gil; |
90 | auto res = py::module::import("torch.optim" ) |
91 | .attr("Optimizer" ) |
92 | .attr("_optimizer_step_code" ) |
93 | .attr("__code__" ) |
94 | .ptr(); |
95 | TORCH_INTERNAL_ASSERT(PyCode_Check(res)); |
96 | return (PyCodeObject*)res; |
97 | }(); |
98 | return optimizer_step_code; |
99 | }; |
100 | |
101 | } // namespace |
102 | } // namespace impl |
103 | } // namespace profiler |
104 | } // namespace torch |
105 | |
106 | template <> |
107 | struct std::hash<torch::profiler::impl::CodeLocation> { |
108 | size_t operator()(const torch::profiler::impl::CodeLocation& x) { |
109 | return c10::get_hash(x.filename_, x.name_, x.line_number_); |
110 | } |
111 | }; |
112 | |
113 | namespace torch { |
114 | namespace profiler { |
115 | namespace impl { |
116 | namespace { |
117 | // ============================================================================ |
118 | // == CallTypeHelper: Tools for generic programming on specializations. ======= |
119 | // ============================================================================ |
120 | template <template <CallType> class ClassT> |
121 | class CallTypeHelper final { |
122 | private: |
123 | static_assert( |
124 | CallType::PyCall == 0, |
125 | "CallTypeHelper uses integer math which depends on a zero start." ); |
126 | static constexpr size_t End = CallTypeSize; |
127 | |
128 | template <size_t... I> |
129 | static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl( |
130 | std::index_sequence<I...>); |
131 | |
132 | template <size_t C, typename T, typename FunctorT, typename... Args> |
133 | static void map(T& t, FunctorT& f, Args&&... args) { |
134 | f(std::get<C>(t), args...); |
135 | c10::guts::if_constexpr<C + 1 < End>( |
136 | [&](auto _) { map<C + 1>(_(t), f, std::forward<Args>(args)...); }); |
137 | } |
138 | |
139 | public: |
140 | using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{})); |
141 | |
142 | template <typename FunctorT, typename... Args> |
143 | static void map(tuple_type& t, FunctorT& f, Args&&... args) { |
144 | map<0>(t, f, std::forward<Args>(args)...); |
145 | } |
146 | }; |
147 | |
148 | // ============================================================================ |
149 | // == Event type definitions. ================================================= |
150 | // ============================================================================ |
151 | // When we are tracing a Python program, the general procedure is to record |
152 | // every time we enter or exit a function and later replay these events during |
153 | // post processing. Thus, during the profiling phase we want to do the MINIMAL |
154 | // amount of work to capture all of the information that we need; otherwise we |
155 | // will distort the profile. (While we don't wish to be terribly inefficient |
156 | // during post processing, we are willing to do extra fixup work in post if it |
157 | // reduces overhead in the profiling phase.) |
158 | // |
159 | // When the tracer first enters a frame, it constructs a CallKey for that |
160 | // location. The contents of the key vary by context. For a python function |
161 | // the key is the (PyCodeObject*, int) pair that defines the bytecode of the |
162 | // function. For an `nn.Module` the key is a (non-owning) pointer to `self`. |
163 | // For a bound C function it is a (non-owning) pointer to the bound function. |
164 | // A CallKey should be small, inexpensive, and POD. |
165 | // |
166 | // We then collect a CallKey<CallType::PyCall> for the calling frame for better |
167 | // source tracking. This pair is a `Callsite`, and serves as a first level key |
168 | // during tracing. We lookup the Callsite in a thread local cache which maps |
169 | // Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the |
170 | // TraceKey and return. On a cache miss, we use a global value cache to store |
171 | // whatever fields we need from the two CallKeys, generate a new TraceKey, and |
172 | // update the local cache. |
173 | // |
174 | // During post processing we: |
175 | // 1) Determine the type represented by a TraceKey by checking which |
176 | // sub-cache it appears in in the thread local cache. |
177 | // 2) Look up the pair of CallKeys from the thread local cache. |
178 | // 3) Look up the expanded values of each CallKey from the global value cache. |
179 | // |
180 | // To add a new event type to the cache: |
181 | // 1) Add an entry to the `CallType` enum. |
182 | // 2) Add a specialization of Config which defined key_t, ephemeral_t and |
183 | // cache_t. |
184 | // 3) Add a specialization of ValueCache::store and ValueCache::load. |
185 | // |
186 | // ------------------------- |
187 | // -- Ephemeral arguments -- |
188 | // ------------------------- |
189 | // The value cache mechanism assumes that `key_t` is enough to specify the |
190 | // correct value. However it may not be possible to materialize a value using |
191 | // only an instance of `key_t`. As a result, the cache also accepts "ephemeral" |
192 | // inputs which can be used to populate the value cache. Ephemeral inputs come |
193 | // with two caveats: |
194 | // 1) They are NOT safe to save, and cannot be used after `ValueCache::store`. |
195 | // 2) They should be used to access data that is not expect to change from |
196 | // call to call, such as the name of a function. |
197 | |
198 | template <CallType> |
199 | struct Config; |
200 | |
201 | template <> |
202 | struct Config<CallType::PyCall> { |
203 | using key_t = CodeLocation; |
204 | using ephemeral_t = no_ephemeral_t; |
205 | using cache_t = ska::flat_hash_map<key_t, PyFrameState>; |
206 | static constexpr EventType event_type = EventType::PyCall; |
207 | }; |
208 | |
209 | template <typename Key, typename Cls, typename ParameterInfo> |
210 | struct ExtendedPyCallConfig { |
211 | using key_t = Key; |
212 | using cls_t = Cls; |
213 | using ephemeral_t = PyFrameObject*; |
214 | |
215 | struct ClsAndParameters { |
216 | cls_t cls_; |
217 | std::vector<ParameterInfo> parameters_; |
218 | }; |
219 | |
220 | struct Cache { |
221 | // `nn.Module.forward` or `optim.Optimizer._optimizer_step_code` |
222 | c10::optional<CodeLocation> location_; |
223 | ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_; |
224 | ska::flat_hash_map<cls_t, at::StringView> cls_names_; |
225 | }; |
226 | using cache_t = Cache; |
227 | |
228 | static constexpr EventType event_type = EventType::PyCall; |
229 | }; |
230 | |
231 | template <> |
232 | struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig< |
233 | PyModuleSelf, |
234 | PyModuleCls, |
235 | NNModuleInfo::ParameterInfo> {}; |
236 | |
237 | template <> |
238 | struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig< |
239 | PyOptimizerSelf, |
240 | PyOptimizerCls, |
241 | OptimizerInfo::ParameterInfo> {}; |
242 | |
243 | template <> |
244 | struct Config<CallType::PyCCall> { |
245 | using key_t = PyMethod; |
246 | using ephemeral_t = PyObject*; |
247 | using cache_t = ska::flat_hash_map<key_t, at::StringView>; |
248 | static constexpr EventType event_type = EventType::PyCCall; |
249 | }; |
250 | |
251 | // ============================================================================ |
252 | // == Callsite & ValueCache: Storage during profiling ========================= |
253 | // ============================================================================ |
254 | template <CallType C> |
255 | class Callsite { |
256 | public: |
257 | static constexpr CallType call_type = C; |
258 | using key_t = typename Config<C>::key_t; |
259 | |
260 | static_assert( |
261 | std::is_trivially_copyable<key_t>::value, |
262 | "Key should be trivial, as it is passed by value." ); |
263 | |
264 | template <typename U> |
265 | Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {} |
266 | |
267 | bool operator==(const Callsite<C>& other) const { |
268 | return value_ == other.value_ && caller_ == other.caller_; |
269 | } |
270 | |
271 | key_t value_; |
272 | Config<CallType::PyCall>::key_t caller_; |
273 | }; |
274 | |
275 | // ============================================================================ |
276 | // == Type specific store and load implementations. =========================== |
277 | // ============================================================================ |
278 | using PyCallKey = Config<CallType::PyCall>::key_t; |
279 | using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t; |
280 | using PyCCallKey = Config<CallType::PyCCall>::key_t; |
281 | using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t; |
282 | |
283 | class ValueCache { |
284 | public: |
285 | ValueCache() = default; |
286 | ValueCache(const ValueCache&) = delete; |
287 | |
288 | template <CallType C> |
289 | void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t); |
290 | |
291 | template <CallType C> |
292 | auto load(const Callsite<C>& callsite, size_t python_tid) const { |
293 | auto caller = load<CallType::PyCall>(callsite.caller_); |
294 | TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value()); |
295 | return ExtraFields<Config<C>::event_type>{ |
296 | /*end_time_ns=*/std::numeric_limits<time_t>::min(), |
297 | python_tid, |
298 | caller.frame_state_, |
299 | load<C>(callsite.value_)}; |
300 | } |
301 | |
302 | c10::optional<TensorMetadata> recordIfTensor(py::handle p); |
303 | std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap( |
304 | py::dict tensor_map); |
305 | void trimPrefixes(); |
306 | |
307 | private: |
308 | template <CallType C> |
309 | typename ExtraFields<Config<C>::event_type>::args_t load( |
310 | const typename Config<C>::key_t&) const; |
311 | |
312 | template <CallType C> |
313 | using State = typename Config<C>::cache_t; |
314 | |
315 | CallTypeHelper<State>::tuple_type state_; |
316 | }; |
317 | |
318 | template <CallType C> |
319 | typename Config<C>::cls_t set_class( |
320 | ValueCache* value_cache, |
321 | typename Config<C>::cache_t& cache, |
322 | const typename Config<C>::key_t& key, |
323 | const typename Config<C>::ephemeral_t& frame) { |
324 | if (C10_UNLIKELY(!cache.location_.has_value())) { |
325 | auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); |
326 | TORCH_INTERNAL_ASSERT(code.get() == getCode<C>()); |
327 | cache.location_ = PyCallKey(frame); |
328 | value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t()); |
329 | } |
330 | |
331 | auto cls_handle = py::handle((PyObject*)key).attr("__class__" ); |
332 | auto cls = typename Config<C>::cls_t(cls_handle.ptr()); |
333 | if (cache.cls_names_.find(cls) == cache.cls_names_.end()) { |
334 | cache.cls_names_[cls] = |
335 | at::StringView(py::str(cls_handle.attr("__name__" ))); |
336 | } |
337 | return cls; |
338 | } |
339 | |
340 | TensorMetadata toTensorMetadata(PyObject* self) { |
341 | TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self)); |
342 | const auto& t = THPVariable_Unpack(self); |
343 | RawTensorMetadata m{t}; |
344 | return TensorMetadata{ |
345 | m, |
346 | t.sizes().vec(), |
347 | m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()}; |
348 | } |
349 | |
350 | c10::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) { |
351 | return THPVariable_CheckExact(p.ptr()) |
352 | ? c10::optional<TensorMetadata>{toTensorMetadata(p.ptr())} |
353 | : c10::nullopt; |
354 | } |
355 | |
356 | std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap( |
357 | py::dict tensor_map) { |
358 | std::vector<std::pair<std::string, TensorMetadata>> out; |
359 | for (auto& it : tensor_map) { |
360 | auto* value = it.second.ptr(); |
361 | if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) { |
362 | out.emplace_back( |
363 | py::cast<std::string>(it.first), toTensorMetadata(value)); |
364 | } |
365 | } |
366 | return out; |
367 | } |
368 | |
369 | template <> |
370 | void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) { |
371 | auto& locations = std::get<CallType::PyCall>(state_); |
372 | if (C10_UNLIKELY(locations.find(key) == locations.end())) { |
373 | locations[key] = { |
374 | key.line_number_, |
375 | at::StringView(key.filename_), |
376 | at::StringView(key.name_)}; |
377 | } |
378 | } |
379 | |
380 | template <> |
381 | ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>( |
382 | const PyCallKey& key) const { |
383 | return {std::get<CallType::PyCall>(state_).at(key), c10::nullopt}; |
384 | } |
385 | |
386 | template <> |
387 | void ValueCache::store<CallType::PyModuleCall>( |
388 | const PyModuleCallKey& key, |
389 | Config<CallType::PyModuleCall>::ephemeral_t frame) { |
390 | auto& cache = std::get<CallType::PyModuleCall>(state_); |
391 | if (C10_UNLIKELY( |
392 | cache.cls_and_parameters_.find(key) == |
393 | cache.cls_and_parameters_.end())) { |
394 | auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame); |
395 | |
396 | py::dict params = py::handle((PyObject*)key).attr("_parameters" ); |
397 | std::vector<NNModuleInfo::ParameterInfo> params_; |
398 | for (auto& it : params) { |
399 | auto* p = it.second.ptr(); |
400 | if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) { |
401 | params_.push_back( |
402 | {it.first.cast<std::string>(), |
403 | toTensorMetadata(p), |
404 | recordIfTensor(py::getattr(it.second, "grad" , py::none()))}); |
405 | } |
406 | } |
407 | cache.cls_and_parameters_[key] = {cls, std::move(params_)}; |
408 | } |
409 | } |
410 | |
411 | template <> |
412 | ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>( |
413 | const PyModuleCallKey& key) const { |
414 | auto& cache = std::get<CallType::PyModuleCall>(state_); |
415 | TORCH_INTERNAL_ASSERT(cache.location_.has_value()); |
416 | const auto& cls_and_parameters = cache.cls_and_parameters_.at(key); |
417 | const auto& cls = cls_and_parameters.cls_; |
418 | NNModuleInfo info{ |
419 | key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_}; |
420 | return { |
421 | /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_), |
422 | /*module_info_=*/std::move(info), |
423 | /*optimizer_info_=*/c10::nullopt}; |
424 | } |
425 | |
426 | template <> |
427 | void ValueCache::store<CallType::PyOptimizerCall>( |
428 | const PyOptimizerCallKey& key, |
429 | Config<CallType::PyOptimizerCall>::ephemeral_t frame) { |
430 | auto& cache = std::get<CallType::PyOptimizerCall>(state_); |
431 | if (C10_UNLIKELY( |
432 | cache.cls_and_parameters_.find(key) == |
433 | cache.cls_and_parameters_.end())) { |
434 | auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame); |
435 | const py::handle self{(PyObject*)key}; |
436 | std::vector<OptimizerInfo::ParameterInfo> params; |
437 | |
438 | for (const auto& i : (py::list)self.attr("param_groups" )) { |
439 | for (auto& param : py::cast<py::dict>(i).attr("get" )("params" )) { |
440 | if (THPVariable_CheckExact(param.ptr())) { |
441 | // While `self.state` is permitted to store data in an arbitrary way, |
442 | // all generic optimizers (SGD, Adam, etc) use param as the key since |
443 | // the state in question is tied to particular parameters. We can |
444 | // relax this assumption if the need arises. |
445 | params.push_back( |
446 | {toTensorMetadata(param.ptr()), |
447 | recordIfTensor(py::getattr(param, "grad" , py::none())), |
448 | unpackTensorMap(py::cast<py::dict>(self.attr("state" )) |
449 | .attr("get" )(param, py::dict()))}); |
450 | } |
451 | } |
452 | } |
453 | |
454 | cache.cls_and_parameters_[key] = {cls, std::move(params)}; |
455 | } |
456 | } |
457 | |
458 | template <> |
459 | ExtraFields<EventType::PyCall>::args_t ValueCache::load< |
460 | CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const { |
461 | auto& cache = std::get<CallType::PyOptimizerCall>(state_); |
462 | const auto& cls_and_parameters = cache.cls_and_parameters_.at(key); |
463 | auto cls = cls_and_parameters.cls_; |
464 | OptimizerInfo info{ |
465 | key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_}; |
466 | return { |
467 | /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_), |
468 | /*module_info_=*/c10::nullopt, |
469 | /*optimizer_info_=*/std::move(info)}; |
470 | } |
471 | |
472 | template <> |
473 | void ValueCache::store<CallType::PyCCall>( |
474 | const PyCCallKey& key, |
475 | Config<CallType::PyCCall>::ephemeral_t arg) { |
476 | auto& names = std::get<CallType::PyCCall>(state_); |
477 | if (C10_UNLIKELY(names.find(key) == names.end())) { |
478 | names[key] = at::StringView(py::repr(arg)); |
479 | } |
480 | } |
481 | |
482 | template <> |
483 | ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>( |
484 | const PyCCallKey& key) const { |
485 | return std::get<CallType::PyCCall>(state_).at(key); |
486 | } |
487 | |
488 | // TODO: Use re2. |
489 | void ValueCache::trimPrefixes() { |
490 | static const auto prefixes = []() { |
491 | pybind11::gil_scoped_acquire gil; |
492 | return py::module::import("torch.profiler.python_tracer" ) |
493 | .attr("_prefix_regex" )() |
494 | .cast<std::vector<std::string>>(); |
495 | }(); |
496 | |
497 | for (auto& it : std::get<CallType::PyCall>(state_)) { |
498 | std::string filename = it.second.filename_.str(); |
499 | for (const auto& p : prefixes) { |
500 | if (filename.compare(0, p.size(), p) == 0) { |
501 | filename.erase(0, p.size()); |
502 | it.second.filename_ = at::StringView(filename); |
503 | break; |
504 | } |
505 | } |
506 | } |
507 | } |
508 | |
509 | // ============================================================================ |
510 | // == TraceKey cache ========================================================== |
511 | // ============================================================================ |
512 | using python_tracer::TraceKey; |
513 | |
514 | TraceKey nextKey() { |
515 | static std::atomic<uint64_t> key{0}; |
516 | return TraceKey{++key}; |
517 | } |
518 | |
519 | template <CallType C> |
520 | struct TraceKeyCacheState { |
521 | struct Hash { |
522 | size_t operator()(const Callsite<C>& key) { |
523 | return c10::get_hash(key.value_, key.caller_); |
524 | } |
525 | }; |
526 | |
527 | TraceKey intern( |
528 | Callsite<C> callsite, |
529 | typename Config<C>::ephemeral_t ephemeral, |
530 | ValueCache& value_cache) { |
531 | auto it = state_.find(callsite); |
532 | if (C10_UNLIKELY(it == state_.end())) { |
533 | value_cache.store<C>(callsite.value_, ephemeral); |
534 | value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t()); |
535 | it = state_.insert({callsite, nextKey()}).first; |
536 | } |
537 | return it->second; |
538 | } |
539 | |
540 | auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const { |
541 | return std::make_pair( |
542 | value_cache.load<C>(callsite.value_), |
543 | value_cache.load<CallType::PyCall>(callsite.caller_)); |
544 | } |
545 | |
546 | ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_; |
547 | }; |
548 | |
549 | // ============================================================================ |
550 | // == Core CPython data types ================================================= |
551 | // ============================================================================ |
552 | // PyObject that allows different threads to record events without colliding. |
553 | // It is passed as the second argument when enabling tracing via |
554 | // `PyEval_SetProfile`. |
555 | struct ThreadLocalResults; |
556 | struct TraceContext { |
557 | PyObject_HEAD; |
558 | ThreadLocalResults* thread_local_results_; |
559 | }; |
560 | |
561 | // CPython boilerplate to define `TraceContext` as a proper python object. |
562 | static PyTypeObject TraceContextType = { |
563 | PyVarObject_HEAD_INIT(nullptr, 0) "TraceContext" , /* tp_name */ |
564 | sizeof(TraceContext), /* tp_basicsize */ |
565 | 0, /* tp_itemsize */ |
566 | nullptr, /* tp_dealloc */ |
567 | 0, |
568 | /* tp_vectorcall_offset */ // NOLINT: modernize-use-nullptr |
569 | nullptr, /* tp_getattr */ |
570 | nullptr, /* tp_setattr */ |
571 | nullptr, /* tp_reserved */ |
572 | nullptr, /* tp_repr */ |
573 | nullptr, /* tp_as_number */ |
574 | nullptr, /* tp_as_sequence */ |
575 | nullptr, /* tp_as_mapping */ |
576 | nullptr, /* tp_hash */ |
577 | nullptr, /* tp_call */ |
578 | nullptr, /* tp_str */ |
579 | nullptr, /* tp_getattro */ |
580 | nullptr, /* tp_setattro */ |
581 | nullptr, /* tp_as_buffer */ |
582 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
583 | "Python tracer TLS" , /* tp_doc */ |
584 | nullptr, /* tp_traverse */ |
585 | nullptr, /* tp_clear */ |
586 | nullptr, /* tp_richcompare */ |
587 | 0, /* tp_weaklistoffset */ |
588 | nullptr, /* tp_iter */ |
589 | nullptr, /* tp_iternext */ |
590 | nullptr, /* tp_methods */ |
591 | nullptr, /* tp_members */ |
592 | nullptr, /* tp_getset */ |
593 | nullptr, /* tp_base */ |
594 | nullptr, /* tp_dict */ |
595 | nullptr, /* tp_descr_get */ |
596 | nullptr, /* tp_descr_set */ |
597 | 0, /* tp_dictoffset */ |
598 | nullptr, /* tp_init */ |
599 | nullptr, /* tp_alloc */ |
600 | PyType_GenericNew, /* tp_new */ |
601 | nullptr /* tp_free */ |
602 | }; |
603 | |
604 | class gil_and_restore_thread { |
605 | public: |
606 | gil_and_restore_thread() |
607 | : gil_(), initial_thread_state_{PyThreadState_Get()} {} |
608 | ~gil_and_restore_thread() { |
609 | PyThreadState_Swap(initial_thread_state_); |
610 | |
611 | // `gil_scoped_acquire` is a bit fragile in on-demand mode: |
612 | // https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458 |
613 | if (!Py_IsInitialized()) { |
614 | gil_.disarm(); |
615 | } |
616 | } |
617 | |
618 | PyThreadState* initial_thread_state() const { |
619 | return initial_thread_state_; |
620 | } |
621 | |
622 | private: |
623 | pybind11::gil_scoped_acquire gil_; |
624 | PyThreadState* initial_thread_state_; |
625 | }; |
626 | |
627 | // ============================================================================ |
628 | // == Thread local cache ====================================================== |
629 | // ============================================================================ |
630 | class PythonTracer; |
631 | struct ThreadLocalResults { |
632 | ThreadLocalResults( |
633 | PyThreadState* thread_state, |
634 | ValueCache* value_cache, |
635 | PythonTracer* active_tracer) |
636 | : thread_state_{thread_state}, |
637 | ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)}, |
638 | value_cache_{value_cache}, |
639 | active_tracer_{active_tracer} { |
640 | ctx_->thread_local_results_ = this; |
641 | } |
642 | |
643 | ThreadLocalResults() = delete; |
644 | ThreadLocalResults(const ThreadLocalResults&) = delete; |
645 | ThreadLocalResults(ThreadLocalResults&&) = delete; |
646 | ThreadLocalResults& operator=(const ThreadLocalResults&) = delete; |
647 | ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete; |
648 | |
649 | ~ThreadLocalResults() { |
650 | Py_DECREF((PyObject*)ctx_); |
651 | } |
652 | |
653 | template <CallType C, EventType E, typename Ephemeral, typename... Args> |
654 | TraceKey intern(Ephemeral ephemeral, Args... args) { |
655 | static_assert( |
656 | Config<C>::event_type == E, |
657 | "ThreadLocalResults.intern called from the wrong typed context." ); |
658 | auto callsite = Callsite<C>(std::forward<Args>(args)...); |
659 | return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_); |
660 | } |
661 | |
662 | static constexpr size_t BLOCK_SIZE = 1024; |
663 | |
664 | PyThreadState* thread_state_; |
665 | TraceContext* ctx_; |
666 | ValueCache* value_cache_; |
667 | PythonTracer* active_tracer_; |
668 | CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_; |
669 | AppendOnlyList<approx_time_t, BLOCK_SIZE> exit_times_; |
670 | AppendOnlyList<approx_time_t, BLOCK_SIZE> c_exit_times_; |
671 | }; |
672 | |
673 | // ============================================================================ |
674 | // == Tracing implementation ================================================== |
675 | // ============================================================================ |
676 | class PythonTracer final : public python_tracer::PythonTracerBase { |
677 | public: |
678 | PythonTracer(torch::profiler::impl::RecordQueue* queue); |
679 | ~PythonTracer() override; |
680 | |
681 | static int pyProfileFn( |
682 | PyObject* obj, |
683 | PyFrameObject* frame, |
684 | int what, |
685 | PyObject* arg); |
686 | |
687 | void stop() override; |
688 | std::vector<std::shared_ptr<Result>> getEvents( |
689 | std::function<time_t(approx_time_t)> time_converter, |
690 | std::vector<python_tracer::CompressedEvent>& enters, |
691 | time_t end_time_ns) override; |
692 | |
693 | struct StartFrame { |
694 | TraceKey trace_key_; |
695 | approx_time_t start_time; |
696 | }; |
697 | |
698 | private: |
699 | void recordPyCall( |
700 | ThreadLocalResults& tls, |
701 | PyFrameObject* frame, |
702 | bool is_startup_frame); |
703 | |
704 | void recordCCall( |
705 | ThreadLocalResults& tls, |
706 | PyFrameObject* frame, |
707 | PyObject* arg); |
708 | |
709 | const std::vector<PyThreadState*> interpreterThreads() const; |
710 | |
711 | std::atomic<bool> active_lock_{false}; |
712 | bool active_{false}; |
713 | |
714 | torch::profiler::impl::RecordQueue* queue_; |
715 | PyInterpreterState* interpreter_; |
716 | PyCodeObject* module_call_code_; |
717 | PyCodeObject* optimizer_hook_; |
718 | |
719 | std::vector<StartFrame> start_frames_; |
720 | std::deque<ThreadLocalResults> thread_local_results_; |
721 | ValueCache value_cache_; |
722 | }; |
723 | |
724 | const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const { |
725 | pybind11::gil_scoped_acquire gil; |
726 | std::vector<PyThreadState*> out; |
727 | if (SOFT_ASSERT(interpreter_)) { |
728 | auto* thread_state = PyInterpreterState_ThreadHead(interpreter_); |
729 | while (thread_state != nullptr) { |
730 | out.push_back(thread_state); |
731 | thread_state = PyThreadState_Next(thread_state); |
732 | } |
733 | } |
734 | return out; |
735 | } |
736 | |
737 | PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) |
738 | : queue_(queue), |
739 | interpreter_(nullptr), |
740 | module_call_code_(getCode<CallType::PyModuleCall>()), |
741 | optimizer_hook_(getCode<CallType::PyOptimizerCall>()) { |
742 | TORCH_CHECK(queue_ != nullptr); |
743 | |
744 | bool expected{false}; |
745 | active_ = active_lock_.compare_exchange_strong(expected, true); |
746 | if (!active_) { |
747 | TORCH_WARN( |
748 | "There is already an active Python tracer. " |
749 | "Refusing to register profile functions." ); |
750 | return; |
751 | } |
752 | |
753 | gil_and_restore_thread gil; |
754 | interpreter_ = PyInterpreterState_Get(); |
755 | |
756 | if (!gil.initial_thread_state()) { |
757 | TORCH_WARN("PyThreadState_Get returned NULL" ); |
758 | return; |
759 | } |
760 | |
761 | // Register the tracer in each thread. |
762 | for (const auto thread_state : interpreterThreads()) { |
763 | PyThreadState_Swap(thread_state); |
764 | |
765 | thread_local_results_.emplace_back(thread_state, &value_cache_, this); |
766 | auto* ctx = thread_local_results_.back().ctx_; |
767 | |
768 | // When we begin profiling there are already frames on the Python |
769 | // interpreter stack. To ensure a complete trace, we must push calls |
770 | // to all the prior frames onto our event stack. (We stop at depth=128) |
771 | |
772 | std::vector<THPFrameObjectPtr> current_stack; |
773 | auto frame = PyEval_GetFrame(); |
774 | Py_XINCREF(frame); |
775 | |
776 | size_t depth = 0; // Make sure we can't infinite loop. |
777 | while (frame != nullptr) { |
778 | current_stack.emplace_back(frame); |
779 | if (++depth == 128) { |
780 | break; |
781 | } |
782 | |
783 | // NB: `PyFrame_GetBack` returns a strong reference. |
784 | frame = PyFrame_GetBack(frame); |
785 | } |
786 | |
787 | for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { |
788 | recordPyCall(thread_local_results_.back(), it->get(), true); |
789 | auto frame_refcount = Py_REFCNT(it->get()); |
790 | |
791 | // We hold one reference in `current_stack`, and the interpreter holds |
792 | // another. |
793 | TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount); |
794 | } |
795 | |
796 | // Note: |
797 | // This profile will not compose with other CPython profilers, and |
798 | // cannot be round tripped via `sys.settrace(sys.gettrace())` |
799 | PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx); |
800 | } |
801 | }; |
802 | |
803 | void PythonTracer::stop() { |
804 | gil_and_restore_thread gil; |
805 | if (active_) { |
806 | for (const auto thread_state : interpreterThreads()) { |
807 | if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) { |
808 | PyThreadState_Swap(thread_state); |
809 | PyEval_SetProfile(nullptr, nullptr); |
810 | } |
811 | } |
812 | |
813 | auto lock_returned = active_lock_.compare_exchange_strong(active_, false); |
814 | active_ = false; |
815 | SOFT_ASSERT(lock_returned, "Failed to return python tracer lock." ); |
816 | } |
817 | } |
818 | |
819 | PythonTracer::~PythonTracer() { |
820 | if (active_) { |
821 | TORCH_WARN("`PythonTracer::stop()` was not called." ); |
822 | stop(); |
823 | } |
824 | } |
825 | |
826 | void PythonTracer::recordPyCall( |
827 | ThreadLocalResults& tls, |
828 | PyFrameObject* frame, |
829 | bool is_startup_frame) { |
830 | static constexpr auto E = EventType::PyCall; |
831 | const auto key = [&]() -> TraceKey { |
832 | auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); |
833 | if (code.get() == module_call_code_) { |
834 | // By default, CPython stores locals in a "fast" format, with an array |
835 | // of names and an array of values. Consequently, frame->f_locals is |
836 | // NULL since the interpreter has no need to populate it. |
837 | // |
838 | // If these arrays were part of the public API then we could very |
839 | // quickly access `self`. Unfortunately they are not, and moreover are |
840 | // not stable across versions. As a result, we are forced to call |
841 | // `PyFrame_FastToLocals` which forces the interpreter to materialize |
842 | // the full dict of locals. |
843 | auto locals = THPObjectPtr(PyFrame_GetLocals(frame)); |
844 | auto self = THPObjectPtr(PyDict_GetItemString(locals, "self" )); |
845 | Py_INCREF(self.get()); |
846 | auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); |
847 | TORCH_INTERNAL_ASSERT(back != nullptr); |
848 | return tls.intern<CallType::PyModuleCall, E>( |
849 | frame, self.get(), back.get()); |
850 | } else if (code.get() == optimizer_hook_) { |
851 | auto locals = THPObjectPtr(PyFrame_GetLocals(frame)); |
852 | auto self = THPObjectPtr(PyDict_GetItemString(locals, "self" )); |
853 | Py_INCREF(self.get()); |
854 | auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); |
855 | TORCH_INTERNAL_ASSERT(back != nullptr); |
856 | return tls.intern<CallType::PyOptimizerCall, E>( |
857 | frame, self.get(), back.get()); |
858 | } else { |
859 | auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); |
860 | auto f_back = (back.get() != nullptr) ? back.get() : frame; |
861 | return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back); |
862 | } |
863 | }(); |
864 | const auto time = getApproximateTime(); |
865 | is_startup_frame ? start_frames_.push_back({key, time}) |
866 | : queue_->getSubqueue()->emplace_py_call(key, time); |
867 | } |
868 | |
869 | void PythonTracer::recordCCall( |
870 | ThreadLocalResults& tls, |
871 | PyFrameObject* frame, |
872 | PyObject* arg) { |
873 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(Py_TYPE(arg) == &PyCFunction_Type); |
874 | auto fn = reinterpret_cast<PyCFunctionObject*>(arg); |
875 | |
876 | // NB: For C calls a new frame is not created, so we use `frame` rather than |
877 | // `frame->f_back`. |
878 | auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>( |
879 | arg, (void*)(fn->m_ml), frame); |
880 | queue_->getSubqueue()->emplace_py_call(key, getApproximateTime()); |
881 | } |
882 | |
883 | // ============================================================================ |
884 | // == Post processing ========================================================= |
885 | // ============================================================================ |
886 | struct Exit { |
887 | bool operator>(const Exit& other) const { |
888 | return t_ > other.t_; |
889 | } |
890 | |
891 | time_t t_; |
892 | size_t python_tid_; |
893 | }; |
894 | |
895 | class PostProcess { |
896 | public: |
897 | PostProcess( |
898 | std::function<time_t(approx_time_t)> time_converter, |
899 | std::deque<ThreadLocalResults>& tls, |
900 | const ValueCache& value_cache, |
901 | time_t end_time_ns) |
902 | : end_time_{end_time_ns}, time_converter_{std::move(time_converter)} { |
903 | for (size_t python_tid : c10::irange(tls.size())) { |
904 | CallTypeHelper<TraceKeyCacheState>::map( |
905 | tls[python_tid].trace_keys_, *this, value_cache, python_tid); |
906 | |
907 | addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid); |
908 | addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid); |
909 | } |
910 | } |
911 | |
912 | void set_start_frames( |
913 | const std::vector<PythonTracer::StartFrame>& start_frames, |
914 | std::vector<python_tracer::CompressedEvent>& enters) { |
915 | for (const auto& frame : start_frames) { |
916 | enters.push_back( |
917 | {frame.trace_key_, |
918 | NoTID, // Allows us to detect unhandled start frames |
919 | {}, |
920 | time_converter_(frame.start_time)}); |
921 | } |
922 | } |
923 | |
924 | template <CallType C> |
925 | void operator()( |
926 | const TraceKeyCacheState<C>& trace_cache, |
927 | const ValueCache& value_cache, |
928 | size_t python_tid) { |
929 | for (const auto& it : trace_cache.state_) { |
930 | const auto inserted = get_state<Config<C>::event_type>().fields_.insert( |
931 | {it.second, value_cache.load(it.first, python_tid)}); |
932 | TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: " , it.second); |
933 | } |
934 | } |
935 | |
936 | template <EventType E, size_t N> |
937 | void addExits(AppendOnlyList<approx_time_t, N>& exits, size_t python_tid) { |
938 | for (const auto i : exits) { |
939 | get_state<E>().exits_.push({time_converter_(i), python_tid}); |
940 | } |
941 | } |
942 | |
943 | std::vector<std::shared_ptr<Result>> run( |
944 | std::vector<python_tracer::CompressedEvent>& enters) { |
945 | std::stable_sort( |
946 | enters.begin(), enters.end(), [](const auto a, const auto b) { |
947 | return a.enter_t_ < b.enter_t_; |
948 | }); |
949 | std::vector<std::shared_ptr<Result>> out; |
950 | populate<EventType::PyCall>(enters, out); |
951 | populate<EventType::PyCCall>(enters, out); |
952 | return out; |
953 | } |
954 | |
955 | private: |
956 | template <EventType E> |
957 | void populate( |
958 | std::vector<python_tracer::CompressedEvent>& enters, |
959 | std::vector<std::shared_ptr<Result>>& out) { |
960 | using stack_t = std::vector<std::shared_ptr<Result>>; |
961 | const auto initial_size = out.size(); |
962 | auto pop = [](stack_t& stack, time_t t) { |
963 | TORCH_INTERNAL_ASSERT(stack.size(), "Python replay stack is empty." ); |
964 | c10::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t; |
965 | stack.pop_back(); |
966 | }; |
967 | |
968 | ska::flat_hash_map<size_t, stack_t> stacks; |
969 | auto& state = get_state<E>(); |
970 | for (const auto& enter : enters) { |
971 | auto fields_it = state.fields_.find(enter.key_); |
972 | if (fields_it != state.fields_.end()) { |
973 | while (!state.exits_.empty() && |
974 | state.exits_.top().t_ < enter.enter_t_) { |
975 | auto& exit = state.exits_.top(); |
976 | pop(stacks[exit.python_tid_], exit.t_); |
977 | state.exits_.pop(); |
978 | } |
979 | out.push_back(Result::create( |
980 | enter.enter_t_, |
981 | enter.system_tid_, |
982 | enter.kineto_info_, |
983 | fields_it->second)); |
984 | |
985 | stacks[fields_it->second.python_tid_].push_back(out.back()); |
986 | } |
987 | } |
988 | |
989 | // Handle events which were still running when profiling ended. |
990 | for (auto& i : stacks) { |
991 | while (!i.second.empty()) { |
992 | pop(i.second, end_time_); |
993 | } |
994 | } |
995 | |
996 | // Assign system TIDs to start events based on the system TID of the next |
997 | // observed event with the same Python TID. |
998 | ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>> |
999 | tid_map; |
1000 | auto it = out.rbegin(); |
1001 | for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) { |
1002 | const auto python_tid = |
1003 | c10::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_; |
1004 | if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) { |
1005 | const auto& tid_info = |
1006 | tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}}) |
1007 | .first->second; |
1008 | (*it)->start_tid_ = tid_info.first; |
1009 | (*it)->kineto_info_ = tid_info.second; |
1010 | } |
1011 | tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_}; |
1012 | ++it; |
1013 | } |
1014 | } |
1015 | |
1016 | template <EventType E> |
1017 | struct State { |
1018 | ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_; |
1019 | std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_; |
1020 | }; |
1021 | |
1022 | template <EventType E> |
1023 | auto& get_state() { |
1024 | return std::get < E == EventType::PyCall ? 0 : 1 > (state_); |
1025 | } |
1026 | |
1027 | time_t end_time_; |
1028 | std::function<time_t(approx_time_t)> time_converter_; |
1029 | std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_; |
1030 | }; |
1031 | |
1032 | struct PythonIDVisitor { |
1033 | void (ExtraFields<EventType::PyCall>& py_call) { |
1034 | py_call.id_ = ++current_python_id_; |
1035 | if (py_call.module_.has_value()) { |
1036 | auto& m = py_call.module_; |
1037 | auto& module_ids = module_ids_[m->cls_]; |
1038 | m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second; |
1039 | } |
1040 | } |
1041 | |
1042 | void (ExtraFields<EventType::PyCCall>& py_call) { |
1043 | py_call.id_ = ++current_python_id_; |
1044 | } |
1045 | |
1046 | template <typename T> |
1047 | void operator()(T&) {} |
1048 | |
1049 | size_t current_python_id_{0}; |
1050 | ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>> |
1051 | module_ids_; |
1052 | }; |
1053 | |
1054 | std::vector<std::shared_ptr<Result>> PythonTracer::getEvents( |
1055 | std::function<time_t(approx_time_t)> time_converter, |
1056 | std::vector<python_tracer::CompressedEvent>& enters, |
1057 | time_t end_time_ns) { |
1058 | value_cache_.trimPrefixes(); |
1059 | PostProcess post_process( |
1060 | std::move(time_converter), |
1061 | thread_local_results_, |
1062 | value_cache_, |
1063 | end_time_ns); |
1064 | post_process.set_start_frames(start_frames_, enters); |
1065 | auto out = post_process.run(enters); |
1066 | |
1067 | std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) { |
1068 | return a->start_time_ns_ < b->start_time_ns_; |
1069 | }); |
1070 | |
1071 | PythonIDVisitor id_visitor; |
1072 | for (auto& i : out) { |
1073 | c10::visit(id_visitor, i->extra_fields_); |
1074 | } |
1075 | |
1076 | return out; |
1077 | } |
1078 | |
1079 | // ============================================================================ |
1080 | // == API ===================================================================== |
1081 | // ============================================================================ |
1082 | int PythonTracer::pyProfileFn( |
1083 | PyObject* obj, |
1084 | PyFrameObject* frame, |
1085 | int what, |
1086 | PyObject* arg) { |
1087 | auto& local_results = |
1088 | *reinterpret_cast<TraceContext*>(obj)->thread_local_results_; |
1089 | switch (what) { |
1090 | case PyTrace_CALL: |
1091 | local_results.active_tracer_->recordPyCall(local_results, frame, false); |
1092 | break; |
1093 | |
1094 | case PyTrace_C_CALL: |
1095 | local_results.active_tracer_->recordCCall(local_results, frame, arg); |
1096 | break; |
1097 | |
1098 | case PyTrace_EXCEPTION: |
1099 | case PyTrace_RETURN: |
1100 | local_results.exit_times_.emplace_back(getApproximateTime()); |
1101 | break; |
1102 | |
1103 | case PyTrace_C_EXCEPTION: |
1104 | case PyTrace_C_RETURN: |
1105 | local_results.c_exit_times_.emplace_back(getApproximateTime()); |
1106 | break; |
1107 | } |
1108 | return 0; |
1109 | } |
1110 | |
1111 | std::unique_ptr<python_tracer::PythonTracerBase> getTracer( |
1112 | torch::profiler::impl::RecordQueue* queue) { |
1113 | return std::make_unique<PythonTracer>(queue); |
1114 | } |
1115 | } // namespace |
1116 | } // namespace impl |
1117 | } // namespace profiler |
1118 | } // namespace torch |
1119 | |
1120 | namespace torch { |
1121 | namespace autograd { |
1122 | namespace profiler { |
1123 | namespace python_tracer { |
1124 | |
1125 | void init() { |
1126 | pybind11::gil_scoped_acquire gil; |
1127 | TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0); |
1128 | torch::profiler::impl::python_tracer::registerTracer( |
1129 | &torch::profiler::impl::getTracer); |
1130 | } |
1131 | } // namespace python_tracer |
1132 | } // namespace profiler |
1133 | } // namespace autograd |
1134 | } // namespace torch |
1135 | |