1 | #include <sstream> |
2 | |
3 | #include <nvToolsExt.h> |
4 | |
5 | #include <c10/cuda/CUDAGuard.h> |
6 | #include <c10/util/irange.h> |
7 | #include <torch/csrc/profiler/stubs/base.h> |
8 | #include <torch/csrc/profiler/util.h> |
9 | |
10 | namespace torch { |
11 | namespace profiler { |
12 | namespace impl { |
13 | namespace { |
14 | |
15 | static inline void cudaCheck(cudaError_t result, const char* file, int line) { |
16 | if (result != cudaSuccess) { |
17 | std::stringstream ss; |
18 | ss << file << ":" << line << ": " ; |
19 | if (result == cudaErrorInitializationError) { |
20 | // It is common for users to use DataLoader with multiple workers |
21 | // and the autograd profiler. Throw a nice error message here. |
22 | ss << "CUDA initialization error. " |
23 | << "This can occur if one runs the profiler in CUDA mode on code " |
24 | << "that creates a DataLoader with num_workers > 0. This operation " |
25 | << "is currently unsupported; potential workarounds are: " |
26 | << "(1) don't use the profiler in CUDA mode or (2) use num_workers=0 " |
27 | << "in the DataLoader or (3) Don't profile the data loading portion " |
28 | << "of your code. https://github.com/pytorch/pytorch/issues/6313 " |
29 | << "tracks profiler support for multi-worker DataLoader." ; |
30 | } else { |
31 | ss << cudaGetErrorString(result); |
32 | } |
33 | throw std::runtime_error(ss.str()); |
34 | } |
35 | } |
36 | #define TORCH_CUDA_CHECK(result) cudaCheck(result, __FILE__, __LINE__); |
37 | |
38 | struct CUDAMethods : public ProfilerStubs { |
39 | void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns) |
40 | const override { |
41 | if (device) { |
42 | TORCH_CUDA_CHECK(cudaGetDevice(device)); |
43 | } |
44 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
45 | CUevent_st* cuda_event_ptr; |
46 | TORCH_CUDA_CHECK(cudaEventCreate(&cuda_event_ptr)); |
47 | *event = std::shared_ptr<CUevent_st>(cuda_event_ptr, [](CUevent_st* ptr) { |
48 | TORCH_CUDA_CHECK(cudaEventDestroy(ptr)); |
49 | }); |
50 | auto stream = at::cuda::getCurrentCUDAStream(); |
51 | if (cpu_ns) { |
52 | *cpu_ns = torch::profiler::impl::getTime(); |
53 | } |
54 | TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream)); |
55 | } |
56 | |
57 | float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2) |
58 | const override { |
59 | TORCH_CUDA_CHECK(cudaEventSynchronize(event->get())); |
60 | TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get())); |
61 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
62 | float ms; |
63 | TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get())); |
64 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) |
65 | return ms * 1000.0; |
66 | } |
67 | |
68 | void mark(const char* name) const override { |
69 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
70 | ::nvtxMark(name); |
71 | } |
72 | |
73 | void rangePush(const char* name) const override { |
74 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
75 | ::nvtxRangePushA(name); |
76 | } |
77 | |
78 | void rangePop() const override { |
79 | ::nvtxRangePop(); |
80 | } |
81 | |
82 | void onEachDevice(std::function<void(int)> op) const override { |
83 | at::cuda::OptionalCUDAGuard device_guard; |
84 | for (const auto i : c10::irange(at::cuda::device_count())) { |
85 | device_guard.set_index(i); |
86 | op(i); |
87 | } |
88 | } |
89 | |
90 | void synchronize() const override { |
91 | TORCH_CUDA_CHECK(cudaDeviceSynchronize()); |
92 | } |
93 | |
94 | bool enabled() const override { |
95 | return true; |
96 | } |
97 | }; |
98 | |
99 | struct RegisterCUDAMethods { |
100 | RegisterCUDAMethods() { |
101 | static CUDAMethods methods; |
102 | registerCUDAMethods(&methods); |
103 | } |
104 | }; |
105 | RegisterCUDAMethods reg; |
106 | |
107 | } // namespace |
108 | } // namespace impl |
109 | } // namespace profiler |
110 | } // namespace torch |
111 | |