1#pragma once
2
3#include <pybind11/pybind11.h>
4#include <torch/csrc/autograd/anomaly_mode.h>
5#include <torch/csrc/python_headers.h>
6#include <torch/csrc/utils/auto_gil.h>
7#include <torch/csrc/utils/pybind.h>
8
9namespace torch {
10namespace autograd {
11
12struct PyAnomalyMetadata : public AnomalyMetadata {
13 static constexpr const char* ANOMALY_TRACE_KEY = "traceback_";
14 static constexpr const char* ANOMALY_PARENT_KEY = "parent_";
15
16 PyAnomalyMetadata() {
17 pybind11::gil_scoped_acquire gil;
18 dict_ = PyDict_New();
19 }
20 ~PyAnomalyMetadata() override {
21 // If python is already dead, leak the wrapped python objects
22 if (Py_IsInitialized()) {
23 pybind11::gil_scoped_acquire gil;
24 Py_DECREF(dict_);
25 }
26 }
27 void store_stack() override;
28 void print_stack(const std::string& current_node_name) override;
29 void assign_parent(const std::shared_ptr<Node>& parent_node) override;
30
31 PyObject* dict() {
32 return dict_;
33 }
34
35 private:
36 PyObject* dict_;
37};
38void _print_stack(
39 PyObject* trace_stack,
40 const std::string& current_node_name,
41 bool is_parent);
42
43} // namespace autograd
44} // namespace torch
45