1#pragma once
2
3// Wrap tensor operation outputs as PyObject*
4
5#include <ATen/ScalarOps.h>
6#include <ATen/core/Tensor.h>
7#include <c10/util/irange.h>
8#include <torch/csrc/python_headers.h>
9#include <initializer_list>
10#include <tuple>
11
12#include <torch/csrc/Dtype.h>
13#include <torch/csrc/DynamicTypes.h>
14#include <torch/csrc/Layout.h>
15#include <torch/csrc/QScheme.h>
16#include <torch/csrc/autograd/python_variable.h>
17#include <torch/csrc/autograd/variable.h>
18#include <torch/csrc/utils/python_numbers.h>
19#include <torch/csrc/utils/tensor_qschemes.h>
20
21namespace torch {
22namespace autograd {
23namespace utils {
24
25inline PyObject* wrap(bool value) {
26 if (value) {
27 Py_RETURN_TRUE;
28 } else {
29 Py_RETURN_FALSE;
30 }
31}
32
33inline PyObject* wrap(int64_t value) {
34 return THPUtils_packInt64(value);
35}
36
37inline PyObject* wrap(double value) {
38 return PyFloat_FromDouble(value);
39}
40
41inline PyObject* wrap(c10::complex<double> value) {
42 // I could probably also use FromComplex with a reinterpret cast,
43 // but... eh.
44 return PyComplex_FromDoubles(value.real(), value.imag());
45}
46
47inline PyObject* wrap(void* value) {
48 return THPUtils_packInt64(reinterpret_cast<intptr_t>(value));
49}
50
51inline PyObject* wrap(THPDtype* dtype) {
52 Py_INCREF(dtype);
53 return (PyObject*)dtype;
54}
55
56inline PyObject* wrap(at::ScalarType scalarType) {
57 return wrap(getTHPDtype(scalarType));
58}
59
60inline PyObject* wrap(THPLayout* layout) {
61 Py_INCREF(layout);
62 return (PyObject*)layout;
63}
64
65inline PyObject* wrap(at::Layout layout) {
66 return wrap(getTHPLayout(layout));
67}
68
69inline PyObject* wrap(at::Tensor tensor) {
70 return THPVariable_Wrap(Variable(std::move(tensor)));
71}
72
73inline PyObject* wrap(const at::Scalar& scalar) {
74 return wrap(scalar_to_tensor(scalar));
75}
76
77inline PyObject* wrap(at::QScheme qscheme) {
78 auto* thp_qscheme = torch::utils::getTHPQScheme(qscheme);
79 Py_INCREF(thp_qscheme);
80 return thp_qscheme;
81}
82
83inline PyObject* wrap(at::TensorList tl) {
84 auto r = THPObjectPtr{PyTuple_New(tl.size())};
85 if (!r)
86 throw python_error();
87 for (const auto i : c10::irange(tl.size())) {
88 PyTuple_SET_ITEM(r.get(), i, wrap(tl[i]));
89 }
90 return r.release();
91}
92
93inline PyObject* wrap(at::IntArrayRef list) {
94 auto r = THPObjectPtr{PyTuple_New(list.size())};
95 if (!r)
96 throw python_error();
97 for (const auto i : c10::irange(list.size())) {
98 PyTuple_SET_ITEM(r.get(), i, wrap(list[i]));
99 }
100 return r.release();
101}
102
103namespace detail {
104template <typename F, typename Tuple, size_t... Is>
105void apply_with_idx_impl(
106 const F& f,
107 Tuple& t,
108 std::index_sequence<Is...> /*indices*/) {
109 (void)std::initializer_list<int>{(f(std::get<Is>(t), Is), 0)...};
110}
111
112// For tuple(a, b, c), calls f(a, 0), f(b, 1), f(c, 2)
113template <typename F, typename... Ts>
114void apply_with_idx(const F& f, std::tuple<Ts...>& t) {
115 apply_with_idx_impl(f, t, std::index_sequence_for<Ts...>{});
116}
117} // namespace detail
118
119template <typename... Ts>
120PyObject* wrap(std::tuple<Ts...> values) {
121 auto r = THPObjectPtr{PyTuple_New(sizeof...(Ts))};
122 if (!r)
123 throw python_error();
124 detail::apply_with_idx(
125 [&](auto& value, size_t idx) {
126 PyTuple_SET_ITEM(r.get(), idx, wrap(std::move(value)));
127 },
128 values);
129 return r.release();
130}
131
132template <typename... Ts>
133PyObject* wrap(PyTypeObject* type, std::tuple<Ts...> values) {
134 auto r = THPObjectPtr{PyStructSequence_New(type)};
135 if (!r)
136 throw python_error();
137 detail::apply_with_idx(
138 [&](auto& value, size_t idx) {
139 PyStructSequence_SET_ITEM(r.get(), idx, wrap(std::move(value)));
140 },
141 values);
142 return r.release();
143}
144
145} // namespace utils
146} // namespace autograd
147} // namespace torch
148