1 | #pragma once |
---|---|
2 | |
3 | #include <pybind11/pybind11.h> |
4 | #include <torch/csrc/autograd/anomaly_mode.h> |
5 | #include <torch/csrc/python_headers.h> |
6 | #include <torch/csrc/utils/auto_gil.h> |
7 | #include <torch/csrc/utils/pybind.h> |
8 | |
9 | namespace torch { |
10 | namespace autograd { |
11 | |
12 | struct PyAnomalyMetadata : public AnomalyMetadata { |
13 | static constexpr const char* ANOMALY_TRACE_KEY = "traceback_"; |
14 | static constexpr const char* ANOMALY_PARENT_KEY = "parent_"; |
15 | |
16 | PyAnomalyMetadata() { |
17 | pybind11::gil_scoped_acquire gil; |
18 | dict_ = PyDict_New(); |
19 | } |
20 | ~PyAnomalyMetadata() override { |
21 | // If python is already dead, leak the wrapped python objects |
22 | if (Py_IsInitialized()) { |
23 | pybind11::gil_scoped_acquire gil; |
24 | Py_DECREF(dict_); |
25 | } |
26 | } |
27 | void store_stack() override; |
28 | void print_stack(const std::string& current_node_name) override; |
29 | void assign_parent(const std::shared_ptr<Node>& parent_node) override; |
30 | |
31 | PyObject* dict() { |
32 | return dict_; |
33 | } |
34 | |
35 | private: |
36 | PyObject* dict_; |
37 | }; |
38 | void _print_stack( |
39 | PyObject* trace_stack, |
40 | const std::string& current_node_name, |
41 | bool is_parent); |
42 | |
43 | } // namespace autograd |
44 | } // namespace torch |
45 |