1#pragma once
2
3#include <torch/csrc/Types.h>
4#include <functional>
5#include <vector>
6
7typedef std::function<void(PyObject*, PyObject*, bool)> THPCopyFunction;
8// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
9struct THPCopyInfo {
10 PyTypeObject* srcType; // Python type of src tensor/storage
11 THPCopyFunction copy; // copy function
12 bool non_blocking; // true if copy implements an 'non_blocking' copy
13 bool broadcast; // true if the copy implements a broadcast copy
14};
15typedef std::vector<THPCopyInfo> THPCopyList;
16
17inline bool tryTHPCopy(
18 const THPCopyList& v,
19 PyObject* dst,
20 PyObject* src,
21 bool non_blocking,
22 bool broadcast) {
23 for (auto& i : v) {
24 if (i.non_blocking == non_blocking &&
25 PyType_IsSubtype(Py_TYPE(src), i.srcType)) {
26 (i.copy)(dst, src, broadcast);
27 return true;
28 }
29 }
30 return false;
31}
32
33inline bool THPCopy(
34 const THPCopyList& v,
35 PyObject* dst,
36 PyObject* src,
37 bool non_blocking,
38 bool broadcast) {
39 // NOLINTNEXTLINE(bugprone-branch-clone)
40 if (tryTHPCopy(v, dst, src, non_blocking, broadcast)) {
41 return true;
42 } else if (non_blocking && tryTHPCopy(v, dst, src, false, broadcast)) {
43 return true;
44 }
45 THPUtils_setError(
46 "copy from %s to %s isn't implemented",
47 THPUtils_typename(src),
48 THPUtils_typename(dst));
49 return false;
50}
51