1 | #include <c10/core/CopyBytes.h> |
2 | #include <c10/util/Logging.h> |
3 | |
4 | namespace c10 { |
5 | |
6 | // First dimension of the array is `bool async`: 0 is sync, |
7 | // 1 is async (non-blocking) |
8 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
9 | static CopyBytesFunction g_copy_bytes[2][COMPILE_TIME_MAX_DEVICE_TYPES] |
10 | [COMPILE_TIME_MAX_DEVICE_TYPES]; |
11 | |
12 | _CopyBytesFunctionRegisterer::_CopyBytesFunctionRegisterer( |
13 | DeviceType fromType, |
14 | DeviceType toType, |
15 | CopyBytesFunction func_sync, |
16 | CopyBytesFunction func_async) { |
17 | auto from = static_cast<int>(fromType); |
18 | auto to = static_cast<int>(toType); |
19 | if (!func_async) { |
20 | // default to the sync function |
21 | func_async = func_sync; |
22 | } |
23 | CHECK( |
24 | g_copy_bytes[0][from][to] == nullptr && |
25 | g_copy_bytes[1][from][to] == nullptr) |
26 | << "Duplicate registration for device type pair " |
27 | << c10::DeviceTypeName(fromType) << ", " << c10::DeviceTypeName(toType); |
28 | g_copy_bytes[0][from][to] = func_sync; |
29 | g_copy_bytes[1][from][to] = func_async; |
30 | } |
31 | |
32 | void CopyBytes( |
33 | size_t nbytes, |
34 | const void* src, |
35 | Device src_device, |
36 | void* dst, |
37 | Device dst_device, |
38 | bool async) { |
39 | auto ptr = g_copy_bytes[async ? 1 : 0][static_cast<int>(src_device.type())] |
40 | [static_cast<int>(dst_device.type())]; |
41 | CAFFE_ENFORCE( |
42 | ptr, |
43 | "No function found for copying from " , |
44 | c10::DeviceTypeName(src_device.type()), |
45 | " to " , |
46 | c10::DeviceTypeName(dst_device.type())); |
47 | ptr(nbytes, src, src_device, dst, dst_device); |
48 | } |
49 | |
50 | } // namespace c10 |
51 | |