1#pragma once
2#include <ATen/core/ivalue.h>
3#include <pybind11/pybind11.h>
4#include <torch/csrc/jit/python/pybind_utils.h>
5#include <torch/csrc/python_headers.h>
6#include <torch/csrc/utils/pybind.h>
7
8namespace py = pybind11;
9
10namespace c10 {
11namespace ivalue {
12
13// concrete ivalue Holder that hold a py::object
14struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
15 public:
16 static c10::intrusive_ptr<PyObjectHolder> create(py::object py_obj) {
17 return c10::make_intrusive<ConcretePyObjectHolder>(std::move(py_obj));
18 }
19
20 static c10::intrusive_ptr<PyObjectHolder> create(const py::handle& handle) {
21 py::gil_scoped_acquire ag;
22 return c10::make_intrusive<ConcretePyObjectHolder>(
23 handle.cast<py::object>());
24 }
25
26 PyObject* getPyObject() override {
27 return py_obj_.ptr();
28 }
29
30 InferredType tryToInferType() override {
31 pybind11::gil_scoped_acquire ag;
32 return torch::jit::tryToInferType(py_obj_);
33 }
34
35 IValue toIValue(const TypePtr& type, c10::optional<int32_t> N = c10::nullopt)
36 override {
37 pybind11::gil_scoped_acquire ag;
38 return torch::jit::toIValue(py_obj_, type, N);
39 }
40
41 std::string toStr() override {
42 pybind11::gil_scoped_acquire ag;
43 return py::str(py_obj_);
44 }
45
46 std::vector<at::Tensor> extractTensors() override {
47 // We could implement this entirely in C++ via pybind11 but it turns out to
48 // be substantially slower. Namely, the total time taken by markCompleted on
49 // a CUDAFuture is 21.5us with this implementation, but goes up to 58.7us
50 // when using C++. The reason is unclear.
51 try {
52 pybind11::gil_scoped_acquire ag;
53 static py::object& extractorFn = *new py::object(
54 py::module::import("torch._jit_internal").attr("_extract_tensors"));
55 return extractorFn(py_obj_).cast<std::vector<at::Tensor>>();
56 } catch (py::error_already_set& e) {
57 auto err = std::runtime_error(
58 c10::str("Cannot extract tensors from value: ", e.what()));
59 {
60 pybind11::gil_scoped_acquire ag;
61 e.restore();
62 PyErr_Clear();
63 }
64 throw err;
65 }
66 }
67
68 // Note [Destructing py::object]
69 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
70 //
71 // (1) Why py_obj_ = py::none(); does not work. Because we also need to
72 // acquire GIL when destructing py::object of None that de-references None.
73 // https://docs.python.org/3/c-api/none.html#c.Py_RETURN_NONE
74 //
75 // https://stackoverflow.com/questions/15287590/why-should-py-increfpy-none-be-required-before-returning-py-none-in-c
76 //
77 // (2) Why we need to call dec_ref() explicitly. Because py::object of
78 // nullptr, on destruction, effectively does nothing because of it calls
79 // Py_XDECREF(NULL) underlying.
80 // https://docs.python.org/3/c-api/refcounting.html#c.Py_XDECREF
81 ~ConcretePyObjectHolder() override {
82 pybind11::gil_scoped_acquire ag;
83 py_obj_.dec_ref();
84 // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
85 // decref on the PyObject again.
86 py_obj_.ptr() = nullptr;
87 }
88
89 // explicit construction to avoid errornous implicit conversion and
90 // copy-initialization
91 explicit ConcretePyObjectHolder(py::object py_obj)
92 : py_obj_(std::move(py_obj)) {}
93
94 private:
95 py::object py_obj_;
96};
97
98} // namespace ivalue
99} // namespace c10
100