1 | #pragma once |
2 | |
3 | #include <torch/csrc/Types.h> |
4 | #include <functional> |
5 | #include <vector> |
6 | |
7 | typedef std::function<void(PyObject*, PyObject*, bool)> THPCopyFunction; |
8 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
9 | struct THPCopyInfo { |
10 | PyTypeObject* srcType; // Python type of src tensor/storage |
11 | THPCopyFunction copy; // copy function |
12 | bool non_blocking; // true if copy implements an 'non_blocking' copy |
13 | bool broadcast; // true if the copy implements a broadcast copy |
14 | }; |
15 | typedef std::vector<THPCopyInfo> THPCopyList; |
16 | |
17 | inline bool tryTHPCopy( |
18 | const THPCopyList& v, |
19 | PyObject* dst, |
20 | PyObject* src, |
21 | bool non_blocking, |
22 | bool broadcast) { |
23 | for (auto& i : v) { |
24 | if (i.non_blocking == non_blocking && |
25 | PyType_IsSubtype(Py_TYPE(src), i.srcType)) { |
26 | (i.copy)(dst, src, broadcast); |
27 | return true; |
28 | } |
29 | } |
30 | return false; |
31 | } |
32 | |
33 | inline bool THPCopy( |
34 | const THPCopyList& v, |
35 | PyObject* dst, |
36 | PyObject* src, |
37 | bool non_blocking, |
38 | bool broadcast) { |
39 | // NOLINTNEXTLINE(bugprone-branch-clone) |
40 | if (tryTHPCopy(v, dst, src, non_blocking, broadcast)) { |
41 | return true; |
42 | } else if (non_blocking && tryTHPCopy(v, dst, src, false, broadcast)) { |
43 | return true; |
44 | } |
45 | THPUtils_setError( |
46 | "copy from %s to %s isn't implemented" , |
47 | THPUtils_typename(src), |
48 | THPUtils_typename(dst)); |
49 | return false; |
50 | } |
51 | |