1 | #include <Python.h> |
---|---|
2 | |
3 | #include <vector> |
4 | |
5 | namespace torch { |
6 | namespace autograd { |
7 | |
8 | extern PyObject* THPVariableFunctionsModule; |
9 | |
10 | // Wrapper converts a raised TypeError into returning NotImplemented |
11 | // Used to implement binary arithmetic operators |
12 | template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)> |
13 | inline PyObject* TypeError_to_NotImplemented_( |
14 | PyObject* self, |
15 | PyObject* args, |
16 | PyObject* kwargs) { |
17 | PyObject* ret = Func(self, args, kwargs); |
18 | if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) { |
19 | PyErr_Clear(); |
20 | Py_INCREF(Py_NotImplemented); |
21 | ret = Py_NotImplemented; |
22 | } |
23 | return ret; |
24 | } |
25 | |
26 | void initTorchFunctions(); |
27 | |
28 | } // namespace autograd |
29 | } // namespace torch |
30 |