1 | #pragma once |
---|---|
2 | |
3 | #include <pybind11/pybind11.h> |
4 | #include <torch/csrc/utils/object_ptr.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | #include <torch/csrc/utils/structseq.h> |
7 | |
8 | namespace six { |
9 | |
10 | // Usually instances of PyStructSequence is also an instance of tuple |
11 | // but in some py2 environment it is not, so we have to manually check |
12 | // the name of the type to determine if it is a namedtupled returned |
13 | // by a pytorch operator. |
14 | |
15 | inline bool isStructSeq(pybind11::handle input) { |
16 | return pybind11::cast<std::string>(input.get_type().attr("__module__")) == |
17 | "torch.return_types"; |
18 | } |
19 | |
20 | inline bool isStructSeq(PyObject* obj) { |
21 | return isStructSeq(pybind11::handle(obj)); |
22 | } |
23 | |
24 | inline bool isTuple(pybind11::handle input) { |
25 | if (PyTuple_Check(input.ptr())) { |
26 | return true; |
27 | } |
28 | return false; |
29 | } |
30 | |
31 | inline bool isTuple(PyObject* obj) { |
32 | return isTuple(pybind11::handle(obj)); |
33 | } |
34 | |
35 | // maybeAsTuple: if the input is a structseq, then convert it to a tuple |
36 | // |
37 | // On Python 3, structseq is a subtype of tuple, so these APIs could be used |
38 | // directly. But on Python 2, structseq is not a subtype of tuple, so we need to |
39 | // manually create a new tuple object from structseq. |
40 | inline THPObjectPtr maybeAsTuple(PyStructSequence* obj) { |
41 | Py_INCREF(obj); |
42 | return THPObjectPtr((PyObject*)obj); |
43 | } |
44 | |
45 | inline THPObjectPtr maybeAsTuple(PyObject* obj) { |
46 | if (isStructSeq(obj)) |
47 | return maybeAsTuple((PyStructSequence*)obj); |
48 | Py_INCREF(obj); |
49 | return THPObjectPtr(obj); |
50 | } |
51 | |
52 | } // namespace six |
53 |