1#pragma once
2
3// This header provides C++ wrappers around commonly used CUDA API functions.
4// The benefit of using C++ here is that we can raise an exception in the
5// event of an error, rather than explicitly pass around error codes. This
6// leads to more natural APIs.
7//
8// The naming convention used here matches the naming convention of torch.cuda
9
10#include <c10/core/Device.h>
11#include <c10/core/impl/GPUTrace.h>
12#include <c10/cuda/CUDAException.h>
13#include <c10/cuda/CUDAMacros.h>
14#include <cuda_runtime_api.h>
15namespace c10 {
16namespace cuda {
17
18// NB: In the past, we were inconsistent about whether or not this reported
19// an error if there were driver problems are not. Based on experience
20// interacting with users, it seems that people basically ~never want this
21// function to fail; it should just return zero if things are not working.
22// Oblige them.
23// It still might log a warning for user first time it's invoked
24C10_CUDA_API DeviceIndex device_count() noexcept;
25
26// Version of device_count that throws is no devices are detected
27C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
28
29C10_CUDA_API DeviceIndex current_device();
30
31C10_CUDA_API void set_device(DeviceIndex device);
32
33C10_CUDA_API void device_synchronize();
34
35C10_CUDA_API void warn_or_error_on_sync();
36
37enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
38
39// this is a holder for c10 global state (similar to at GlobalContext)
40// currently it's used to store cuda synchronization warning state,
41// but can be expanded to hold other related global state, e.g. to
42// record stream usage
43class WarningState {
44 public:
45 void set_sync_debug_mode(SyncDebugMode l) {
46 sync_debug_mode = l;
47 }
48
49 SyncDebugMode get_sync_debug_mode() {
50 return sync_debug_mode;
51 }
52
53 private:
54 SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
55};
56
57C10_CUDA_API __inline__ WarningState& warning_state() {
58 static WarningState warning_state_;
59 return warning_state_;
60}
61// the subsequent functions are defined in the header because for performance
62// reasons we want them to be inline
63C10_CUDA_API void __inline__ memcpy_and_sync(
64 void* dst,
65 void* src,
66 int64_t nbytes,
67 cudaMemcpyKind kind,
68 cudaStream_t stream) {
69 if (C10_UNLIKELY(
70 warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
71 warn_or_error_on_sync();
72 }
73 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
74 if (C10_UNLIKELY(interp)) {
75 (*interp)->trace_gpu_stream_synchronization(
76 reinterpret_cast<uintptr_t>(stream));
77 }
78#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
79 C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
80#else
81 C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
82 C10_CUDA_CHECK(cudaStreamSynchronize(stream));
83#endif
84}
85
86C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
87 if (C10_UNLIKELY(
88 warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
89 warn_or_error_on_sync();
90 }
91 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
92 if (C10_UNLIKELY(interp)) {
93 (*interp)->trace_gpu_stream_synchronization(
94 reinterpret_cast<uintptr_t>(stream));
95 }
96 C10_CUDA_CHECK(cudaStreamSynchronize(stream));
97}
98
99} // namespace cuda
100} // namespace c10
101