1 | #pragma once |
2 | |
3 | #include <c10/cuda/CUDAMacros.h> |
4 | |
5 | #include <memory> |
6 | #include <mutex> |
7 | #include <string> |
8 | #include <vector> |
9 | |
10 | #ifdef USE_CUDA |
11 | #define TORCH_USE_CUDA_DSA |
12 | #endif |
13 | |
14 | /// Number of assertion failure messages we can store. If this is too small |
15 | /// threads will fail silently. |
16 | constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10; |
17 | constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512; |
18 | |
19 | namespace c10 { |
20 | namespace cuda { |
21 | |
22 | /// Holds information about any device-side assertions that fail. |
23 | /// Held in managed memory and access by both the CPU and the GPU. |
24 | struct DeviceAssertionData { |
25 | /// Stringification of the assertion |
26 | char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN]; |
27 | /// File the assertion was in |
28 | char filename[C10_CUDA_DSA_MAX_STR_LEN]; |
29 | /// Name of the function the assertion was in |
30 | char function_name[C10_CUDA_DSA_MAX_STR_LEN]; |
31 | /// Line number the assertion was at |
32 | int line_number; |
33 | /// Number uniquely identifying the kernel launch that triggered the assertion |
34 | uint32_t caller; |
35 | /// block_id of the thread that failed the assertion |
36 | int32_t block_id[3]; |
37 | /// third_id of the thread that failed the assertion |
38 | int32_t thread_id[3]; |
39 | }; |
40 | |
41 | /// Used to hold assertions generated by the device |
42 | /// Held in managed memory and access by both the CPU and the GPU. |
43 | struct DeviceAssertionsData { |
44 | /// Total number of assertions found; a subset of thse will be recorded |
45 | /// in `assertions` |
46 | int32_t assertion_count; |
47 | /// An array of assertions that will be written to in a race-free manner |
48 | DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT]; |
49 | }; |
50 | |
51 | /// Use to hold info about kernel launches so that we can run kernels |
52 | /// asynchronously and still associate launches with device-side |
53 | /// assertion failures |
54 | struct CUDAKernelLaunchInfo { |
55 | /// Filename of the code where the kernel was launched from |
56 | const char* launch_filename; |
57 | /// Function from which the kernel was launched |
58 | const char* launch_function; |
59 | /// Line number of where the code was launched from |
60 | uint32_t launch_linenum; |
61 | /// Backtrace of where the kernel was launched from, only populated if |
62 | /// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True |
63 | std::string launch_stacktrace; |
64 | /// Kernel that was launched |
65 | const char* kernel_name; |
66 | /// Device the kernel was launched on |
67 | int device; |
68 | /// Stream the kernel was launched on |
69 | int32_t stream; |
70 | /// A number that uniquely identifies the kernel launch |
71 | uint64_t generation_number; |
72 | }; |
73 | |
74 | /// Circular buffer used to hold information about kernel launches |
75 | /// this is later used to reconstruct how a device-side kernel assertion failure |
76 | /// occurred CUDAKernelLaunchRegistry is used as a singleton |
77 | class C10_CUDA_API CUDAKernelLaunchRegistry { |
78 | private: |
79 | /// Assume that this is the max number of kernel launches that might ever be |
80 | /// enqueued across all streams on a single device |
81 | static constexpr int max_kernel_launches = 1024; |
82 | /// How many kernel launch infos we've inserted. Used to ensure that circular |
83 | /// queue doesn't provide false information by always increasing, but also to |
84 | /// mark where we are inserting into the queue |
85 | #ifdef TORCH_USE_CUDA_DSA |
86 | uint64_t generation_number = 0; |
87 | #endif |
88 | /// Shared mutex between writer and accessor to ensure multi-threaded safety. |
89 | mutable std::mutex read_write_mutex; |
90 | /// Used to ensure prevent race conditions in GPU memory allocation |
91 | mutable std::mutex gpu_alloc_mutex; |
92 | /// Pointer to managed memory keeping track of device-side assertions. There |
93 | /// is one entry for each possible device the process might work with. Unused |
94 | /// entries are nullptrs. We could also use an unordered_set here, but this |
95 | /// vector design will be faster and the wasted memory is small since we |
96 | /// expect the number of GPUs per node will always be small |
97 | std::vector< |
98 | std::unique_ptr<DeviceAssertionsData, void (*)(DeviceAssertionsData*)>> |
99 | uvm_assertions; |
100 | /// A single circular buffer holds information about every kernel launch the |
101 | /// process makes across all devices. |
102 | std::vector<CUDAKernelLaunchInfo> kernel_launches; |
103 | bool check_env_for_enable_launch_stacktracing() const; |
104 | bool check_env_for_dsa_enabled() const; |
105 | |
106 | public: |
107 | CUDAKernelLaunchRegistry(); |
108 | /// Register a new kernel launch and obtain a generation number back to be |
109 | /// passed to the kernel |
110 | uint32_t insert( |
111 | const char* launch_filename, |
112 | const char* launch_function, |
113 | const uint32_t launch_linenum, |
114 | const char* kernel_name, |
115 | const int32_t stream_id); |
116 | /// Get copies of the kernel launch registry and each device's assertion |
117 | /// failure buffer so they can be inspected without raising race conditions |
118 | std:: |
119 | pair<std::vector<DeviceAssertionsData>, std::vector<CUDAKernelLaunchInfo>> |
120 | snapshot() const; |
121 | /// Get a pointer to the current device's assertion failure buffer. If no such |
122 | /// buffer exists then one is created. This means that the first kernel launch |
123 | /// made on each device will be slightly slower because memory allocations are |
124 | /// required |
125 | DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device(); |
126 | /// Gets the global singleton of the registry |
127 | static CUDAKernelLaunchRegistry& get_singleton_ref(); |
128 | /// If not all devices support DSA, we disable it |
129 | const bool do_all_devices_support_managed_memory = false; |
130 | /// Whether or not to gather stack traces when launching kernels |
131 | bool gather_launch_stacktrace = false; |
132 | /// Whether or not host-side DSA is enabled or disabled at run-time |
133 | /// Note: Device-side code cannot be enabled/disabled at run-time |
134 | bool enabled_at_runtime = false; |
135 | /// Whether or not a device has indicated a failure |
136 | bool has_failed() const; |
137 | #ifdef TORCH_USE_CUDA_DSA |
138 | const bool enabled_at_compile_time = true; |
139 | #else |
140 | const bool enabled_at_compile_time = false; |
141 | #endif |
142 | }; |
143 | |
144 | std::string c10_retrieve_device_side_assertion_info(); |
145 | |
146 | } // namespace cuda |
147 | } // namespace c10 |
148 | |
149 | // Each kernel launched with TORCH_DSA_KERNEL_LAUNCH |
150 | // requires the same input arguments. We introduce the following macro to |
151 | // standardize these. |
152 | #define TORCH_DSA_KERNEL_ARGS \ |
153 | [[maybe_unused]] c10::cuda::DeviceAssertionsData *const assertions_data, \ |
154 | [[maybe_unused]] uint32_t assertion_caller_id |
155 | |
156 | // This macro can be used to pass the DSA arguments onward to another |
157 | // function |
158 | #define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id |
159 | |