1 | #pragma once |
2 | |
3 | // Wrap tensor operation outputs as PyObject* |
4 | |
5 | #include <ATen/ScalarOps.h> |
6 | #include <ATen/core/Tensor.h> |
7 | #include <c10/util/irange.h> |
8 | #include <torch/csrc/python_headers.h> |
9 | #include <initializer_list> |
10 | #include <tuple> |
11 | |
12 | #include <torch/csrc/Dtype.h> |
13 | #include <torch/csrc/DynamicTypes.h> |
14 | #include <torch/csrc/Layout.h> |
15 | #include <torch/csrc/QScheme.h> |
16 | #include <torch/csrc/autograd/python_variable.h> |
17 | #include <torch/csrc/autograd/variable.h> |
18 | #include <torch/csrc/utils/python_numbers.h> |
19 | #include <torch/csrc/utils/tensor_qschemes.h> |
20 | |
21 | namespace torch { |
22 | namespace autograd { |
23 | namespace utils { |
24 | |
25 | inline PyObject* wrap(bool value) { |
26 | if (value) { |
27 | Py_RETURN_TRUE; |
28 | } else { |
29 | Py_RETURN_FALSE; |
30 | } |
31 | } |
32 | |
33 | inline PyObject* wrap(int64_t value) { |
34 | return THPUtils_packInt64(value); |
35 | } |
36 | |
37 | inline PyObject* wrap(double value) { |
38 | return PyFloat_FromDouble(value); |
39 | } |
40 | |
41 | inline PyObject* wrap(c10::complex<double> value) { |
42 | // I could probably also use FromComplex with a reinterpret cast, |
43 | // but... eh. |
44 | return PyComplex_FromDoubles(value.real(), value.imag()); |
45 | } |
46 | |
47 | inline PyObject* wrap(void* value) { |
48 | return THPUtils_packInt64(reinterpret_cast<intptr_t>(value)); |
49 | } |
50 | |
51 | inline PyObject* wrap(THPDtype* dtype) { |
52 | Py_INCREF(dtype); |
53 | return (PyObject*)dtype; |
54 | } |
55 | |
56 | inline PyObject* wrap(at::ScalarType scalarType) { |
57 | return wrap(getTHPDtype(scalarType)); |
58 | } |
59 | |
60 | inline PyObject* wrap(THPLayout* layout) { |
61 | Py_INCREF(layout); |
62 | return (PyObject*)layout; |
63 | } |
64 | |
65 | inline PyObject* wrap(at::Layout layout) { |
66 | return wrap(getTHPLayout(layout)); |
67 | } |
68 | |
69 | inline PyObject* wrap(at::Tensor tensor) { |
70 | return THPVariable_Wrap(Variable(std::move(tensor))); |
71 | } |
72 | |
73 | inline PyObject* wrap(const at::Scalar& scalar) { |
74 | return wrap(scalar_to_tensor(scalar)); |
75 | } |
76 | |
77 | inline PyObject* wrap(at::QScheme qscheme) { |
78 | auto* thp_qscheme = torch::utils::getTHPQScheme(qscheme); |
79 | Py_INCREF(thp_qscheme); |
80 | return thp_qscheme; |
81 | } |
82 | |
83 | inline PyObject* wrap(at::TensorList tl) { |
84 | auto r = THPObjectPtr{PyTuple_New(tl.size())}; |
85 | if (!r) |
86 | throw python_error(); |
87 | for (const auto i : c10::irange(tl.size())) { |
88 | PyTuple_SET_ITEM(r.get(), i, wrap(tl[i])); |
89 | } |
90 | return r.release(); |
91 | } |
92 | |
93 | inline PyObject* wrap(at::IntArrayRef list) { |
94 | auto r = THPObjectPtr{PyTuple_New(list.size())}; |
95 | if (!r) |
96 | throw python_error(); |
97 | for (const auto i : c10::irange(list.size())) { |
98 | PyTuple_SET_ITEM(r.get(), i, wrap(list[i])); |
99 | } |
100 | return r.release(); |
101 | } |
102 | |
103 | namespace detail { |
104 | template <typename F, typename Tuple, size_t... Is> |
105 | void apply_with_idx_impl( |
106 | const F& f, |
107 | Tuple& t, |
108 | std::index_sequence<Is...> /*indices*/) { |
109 | (void)std::initializer_list<int>{(f(std::get<Is>(t), Is), 0)...}; |
110 | } |
111 | |
112 | // For tuple(a, b, c), calls f(a, 0), f(b, 1), f(c, 2) |
113 | template <typename F, typename... Ts> |
114 | void apply_with_idx(const F& f, std::tuple<Ts...>& t) { |
115 | apply_with_idx_impl(f, t, std::index_sequence_for<Ts...>{}); |
116 | } |
117 | } // namespace detail |
118 | |
119 | template <typename... Ts> |
120 | PyObject* wrap(std::tuple<Ts...> values) { |
121 | auto r = THPObjectPtr{PyTuple_New(sizeof...(Ts))}; |
122 | if (!r) |
123 | throw python_error(); |
124 | detail::apply_with_idx( |
125 | [&](auto& value, size_t idx) { |
126 | PyTuple_SET_ITEM(r.get(), idx, wrap(std::move(value))); |
127 | }, |
128 | values); |
129 | return r.release(); |
130 | } |
131 | |
132 | template <typename... Ts> |
133 | PyObject* wrap(PyTypeObject* type, std::tuple<Ts...> values) { |
134 | auto r = THPObjectPtr{PyStructSequence_New(type)}; |
135 | if (!r) |
136 | throw python_error(); |
137 | detail::apply_with_idx( |
138 | [&](auto& value, size_t idx) { |
139 | PyStructSequence_SET_ITEM(r.get(), idx, wrap(std::move(value))); |
140 | }, |
141 | values); |
142 | return r.release(); |
143 | } |
144 | |
145 | } // namespace utils |
146 | } // namespace autograd |
147 | } // namespace torch |
148 | |