1 | #pragma once |
---|---|
2 | #include <ATen/core/ivalue.h> |
3 | #include <pybind11/pybind11.h> |
4 | #include <torch/csrc/jit/python/pybind_utils.h> |
5 | #include <torch/csrc/python_headers.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | |
8 | namespace py = pybind11; |
9 | |
10 | namespace c10 { |
11 | namespace ivalue { |
12 | |
13 | // concrete ivalue Holder that hold a py::object |
14 | struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { |
15 | public: |
16 | static c10::intrusive_ptr<PyObjectHolder> create(py::object py_obj) { |
17 | return c10::make_intrusive<ConcretePyObjectHolder>(std::move(py_obj)); |
18 | } |
19 | |
20 | static c10::intrusive_ptr<PyObjectHolder> create(const py::handle& handle) { |
21 | py::gil_scoped_acquire ag; |
22 | return c10::make_intrusive<ConcretePyObjectHolder>( |
23 | handle.cast<py::object>()); |
24 | } |
25 | |
26 | PyObject* getPyObject() override { |
27 | return py_obj_.ptr(); |
28 | } |
29 | |
30 | InferredType tryToInferType() override { |
31 | pybind11::gil_scoped_acquire ag; |
32 | return torch::jit::tryToInferType(py_obj_); |
33 | } |
34 | |
35 | IValue toIValue(const TypePtr& type, c10::optional<int32_t> N = c10::nullopt) |
36 | override { |
37 | pybind11::gil_scoped_acquire ag; |
38 | return torch::jit::toIValue(py_obj_, type, N); |
39 | } |
40 | |
41 | std::string toStr() override { |
42 | pybind11::gil_scoped_acquire ag; |
43 | return py::str(py_obj_); |
44 | } |
45 | |
46 | std::vector<at::Tensor> extractTensors() override { |
47 | // We could implement this entirely in C++ via pybind11 but it turns out to |
48 | // be substantially slower. Namely, the total time taken by markCompleted on |
49 | // a CUDAFuture is 21.5us with this implementation, but goes up to 58.7us |
50 | // when using C++. The reason is unclear. |
51 | try { |
52 | pybind11::gil_scoped_acquire ag; |
53 | static py::object& extractorFn = *new py::object( |
54 | py::module::import("torch._jit_internal").attr( "_extract_tensors")); |
55 | return extractorFn(py_obj_).cast<std::vector<at::Tensor>>(); |
56 | } catch (py::error_already_set& e) { |
57 | auto err = std::runtime_error( |
58 | c10::str("Cannot extract tensors from value: ", e.what())); |
59 | { |
60 | pybind11::gil_scoped_acquire ag; |
61 | e.restore(); |
62 | PyErr_Clear(); |
63 | } |
64 | throw err; |
65 | } |
66 | } |
67 | |
68 | // Note [Destructing py::object] |
69 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
70 | // |
71 | // (1) Why py_obj_ = py::none(); does not work. Because we also need to |
72 | // acquire GIL when destructing py::object of None that de-references None. |
73 | // https://docs.python.org/3/c-api/none.html#c.Py_RETURN_NONE |
74 | // |
75 | // https://stackoverflow.com/questions/15287590/why-should-py-increfpy-none-be-required-before-returning-py-none-in-c |
76 | // |
77 | // (2) Why we need to call dec_ref() explicitly. Because py::object of |
78 | // nullptr, on destruction, effectively does nothing because of it calls |
79 | // Py_XDECREF(NULL) underlying. |
80 | // https://docs.python.org/3/c-api/refcounting.html#c.Py_XDECREF |
81 | ~ConcretePyObjectHolder() override { |
82 | pybind11::gil_scoped_acquire ag; |
83 | py_obj_.dec_ref(); |
84 | // explicitly setting PyObject* to nullptr to prevent py::object's dtor to |
85 | // decref on the PyObject again. |
86 | py_obj_.ptr() = nullptr; |
87 | } |
88 | |
89 | // explicit construction to avoid errornous implicit conversion and |
90 | // copy-initialization |
91 | explicit ConcretePyObjectHolder(py::object py_obj) |
92 | : py_obj_(std::move(py_obj)) {} |
93 | |
94 | private: |
95 | py::object py_obj_; |
96 | }; |
97 | |
98 | } // namespace ivalue |
99 | } // namespace c10 |
100 |