1 | #pragma once |
2 | |
3 | #include <c10/core/SymInt.h> |
4 | #include <torch/csrc/autograd/python_variable.h> |
5 | #include <torch/csrc/python_headers.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | #include <torch/csrc/utils/python_symnode.h> |
8 | |
9 | namespace torch { |
10 | namespace autograd { |
11 | |
12 | struct UnpackedSlice { |
13 | c10::SymInt start; |
14 | c10::SymInt stop; |
15 | c10::SymInt step; |
16 | }; |
17 | |
18 | // This mirrors Cpython's PySlice_Unpack method |
19 | static inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { |
20 | PySliceObject* r = (PySliceObject*)_r; |
21 | /* this is harder to get right than you might think */ |
22 | |
23 | c10::SymInt start_sym, stop_sym, step_sym; |
24 | |
25 | auto clip_val = [](Py_ssize_t val) { |
26 | if (val < c10::SymInt::min_representable_int()) { |
27 | auto r = PyErr_WarnEx( |
28 | PyExc_UserWarning, |
29 | "Truncating the start/stop/step " |
30 | "of slice. This is likely because of " |
31 | "saved old models when the start/stop/step were larger." , |
32 | 1); |
33 | if (r != 0) { |
34 | throw python_error(); |
35 | } |
36 | return (Py_ssize_t)(c10::SymInt::min_representable_int()); |
37 | } |
38 | return val; |
39 | }; |
40 | |
41 | if (r->step == Py_None) { |
42 | step_sym = c10::SymInt(1); |
43 | } else { |
44 | if (torch::is_symint(r->step)) { |
45 | auto step_sym = py::handle(r->step).cast<c10::SymInt>(); |
46 | } else { |
47 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
48 | Py_ssize_t step; |
49 | if (!_PyEval_SliceIndex(r->step, &step)) { |
50 | throw python_error(); |
51 | } |
52 | if (step == 0) { |
53 | PyErr_SetString(PyExc_ValueError, "slice step cannot be zero" ); |
54 | } |
55 | |
56 | step = clip_val(step); |
57 | step_sym = c10::SymInt(step); |
58 | } |
59 | } |
60 | |
61 | if (torch::is_symint(r->start)) { |
62 | start_sym = py::handle(r->start).cast<c10::SymInt>(); |
63 | } else if (r->start == Py_None) { |
64 | start_sym = c10::SymInt(step_sym < 0 ? PY_SSIZE_T_MAX : 0); |
65 | } else { |
66 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
67 | Py_ssize_t start; |
68 | if (!_PyEval_SliceIndex(r->start, &start)) { |
69 | throw python_error(); |
70 | } |
71 | start = clip_val(start); |
72 | start_sym = c10::SymInt(start); |
73 | } |
74 | |
75 | if (torch::is_symint(r->stop)) { |
76 | stop_sym = py::handle(r->stop).cast<c10::SymInt>(); |
77 | } else if (r->stop == Py_None) { |
78 | stop_sym = c10::SymInt( |
79 | step_sym < 0 ? c10::SymInt::min_representable_int() : PY_SSIZE_T_MAX); |
80 | } else { |
81 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
82 | Py_ssize_t stop; |
83 | if (!_PyEval_SliceIndex(r->stop, &stop)) { |
84 | throw python_error(); |
85 | } |
86 | stop = clip_val(stop); |
87 | stop_sym = c10::SymInt(stop); |
88 | } |
89 | |
90 | return UnpackedSlice{ |
91 | std::move(start_sym), std::move(stop_sym), std::move(step_sym)}; |
92 | } |
93 | |
94 | Py_ssize_t THPVariable_length(PyObject* self); |
95 | PyObject* THPVariable_getitem(PyObject* self, PyObject* index); |
96 | int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value); |
97 | |
98 | Variable valueToTensor( |
99 | c10::TensorOptions options, |
100 | PyObject* value, |
101 | const at::Device& device); |
102 | |
103 | } // namespace autograd |
104 | } // namespace torch |
105 | |