1#include <c10/core/CopyBytes.h>
2#include <c10/util/Logging.h>
3
4namespace c10 {
5
6// First dimension of the array is `bool async`: 0 is sync,
7// 1 is async (non-blocking)
8// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
9static CopyBytesFunction g_copy_bytes[2][COMPILE_TIME_MAX_DEVICE_TYPES]
10 [COMPILE_TIME_MAX_DEVICE_TYPES];
11
12_CopyBytesFunctionRegisterer::_CopyBytesFunctionRegisterer(
13 DeviceType fromType,
14 DeviceType toType,
15 CopyBytesFunction func_sync,
16 CopyBytesFunction func_async) {
17 auto from = static_cast<int>(fromType);
18 auto to = static_cast<int>(toType);
19 if (!func_async) {
20 // default to the sync function
21 func_async = func_sync;
22 }
23 CHECK(
24 g_copy_bytes[0][from][to] == nullptr &&
25 g_copy_bytes[1][from][to] == nullptr)
26 << "Duplicate registration for device type pair "
27 << c10::DeviceTypeName(fromType) << ", " << c10::DeviceTypeName(toType);
28 g_copy_bytes[0][from][to] = func_sync;
29 g_copy_bytes[1][from][to] = func_async;
30}
31
32void CopyBytes(
33 size_t nbytes,
34 const void* src,
35 Device src_device,
36 void* dst,
37 Device dst_device,
38 bool async) {
39 auto ptr = g_copy_bytes[async ? 1 : 0][static_cast<int>(src_device.type())]
40 [static_cast<int>(dst_device.type())];
41 CAFFE_ENFORCE(
42 ptr,
43 "No function found for copying from ",
44 c10::DeviceTypeName(src_device.type()),
45 " to ",
46 c10::DeviceTypeName(dst_device.type()));
47 ptr(nbytes, src, src_device, dst, dst_device);
48}
49
50} // namespace c10
51