1 | #pragma once |
2 | |
3 | // This header provides C++ wrappers around commonly used CUDA API functions. |
4 | // The benefit of using C++ here is that we can raise an exception in the |
5 | // event of an error, rather than explicitly pass around error codes. This |
6 | // leads to more natural APIs. |
7 | // |
8 | // The naming convention used here matches the naming convention of torch.cuda |
9 | |
10 | #include <c10/core/Device.h> |
11 | #include <c10/core/impl/GPUTrace.h> |
12 | #include <c10/cuda/CUDAException.h> |
13 | #include <c10/cuda/CUDAMacros.h> |
14 | #include <cuda_runtime_api.h> |
15 | namespace c10 { |
16 | namespace cuda { |
17 | |
18 | // NB: In the past, we were inconsistent about whether or not this reported |
19 | // an error if there were driver problems are not. Based on experience |
20 | // interacting with users, it seems that people basically ~never want this |
21 | // function to fail; it should just return zero if things are not working. |
22 | // Oblige them. |
23 | // It still might log a warning for user first time it's invoked |
24 | C10_CUDA_API DeviceIndex device_count() noexcept; |
25 | |
26 | // Version of device_count that throws is no devices are detected |
27 | C10_CUDA_API DeviceIndex device_count_ensure_non_zero(); |
28 | |
29 | C10_CUDA_API DeviceIndex current_device(); |
30 | |
31 | C10_CUDA_API void set_device(DeviceIndex device); |
32 | |
33 | C10_CUDA_API void device_synchronize(); |
34 | |
35 | C10_CUDA_API void warn_or_error_on_sync(); |
36 | |
37 | enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; |
38 | |
39 | // this is a holder for c10 global state (similar to at GlobalContext) |
40 | // currently it's used to store cuda synchronization warning state, |
41 | // but can be expanded to hold other related global state, e.g. to |
42 | // record stream usage |
43 | class WarningState { |
44 | public: |
45 | void set_sync_debug_mode(SyncDebugMode l) { |
46 | sync_debug_mode = l; |
47 | } |
48 | |
49 | SyncDebugMode get_sync_debug_mode() { |
50 | return sync_debug_mode; |
51 | } |
52 | |
53 | private: |
54 | SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED; |
55 | }; |
56 | |
57 | C10_CUDA_API __inline__ WarningState& warning_state() { |
58 | static WarningState warning_state_; |
59 | return warning_state_; |
60 | } |
61 | // the subsequent functions are defined in the header because for performance |
62 | // reasons we want them to be inline |
63 | C10_CUDA_API void __inline__ memcpy_and_sync( |
64 | void* dst, |
65 | void* src, |
66 | int64_t nbytes, |
67 | cudaMemcpyKind kind, |
68 | cudaStream_t stream) { |
69 | if (C10_UNLIKELY( |
70 | warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { |
71 | warn_or_error_on_sync(); |
72 | } |
73 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
74 | if (C10_UNLIKELY(interp)) { |
75 | (*interp)->trace_gpu_stream_synchronization( |
76 | reinterpret_cast<uintptr_t>(stream)); |
77 | } |
78 | #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) |
79 | C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); |
80 | #else |
81 | C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); |
82 | C10_CUDA_CHECK(cudaStreamSynchronize(stream)); |
83 | #endif |
84 | } |
85 | |
86 | C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { |
87 | if (C10_UNLIKELY( |
88 | warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { |
89 | warn_or_error_on_sync(); |
90 | } |
91 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
92 | if (C10_UNLIKELY(interp)) { |
93 | (*interp)->trace_gpu_stream_synchronization( |
94 | reinterpret_cast<uintptr_t>(stream)); |
95 | } |
96 | C10_CUDA_CHECK(cudaStreamSynchronize(stream)); |
97 | } |
98 | |
99 | } // namespace cuda |
100 | } // namespace c10 |
101 | |