1 | #include <torch/csrc/distributed/c10d/reducer_timer.hpp> |
2 | |
3 | #include <ATen/cuda/CUDAEvent.h> |
4 | #include <c10/core/DeviceGuard.h> |
5 | |
6 | namespace c10d { |
7 | namespace { |
8 | |
9 | const int kMilliSecondToNanosSecond = 1000000; |
10 | |
11 | class CudaTimer : public Timer { |
12 | private: |
13 | c10::Device device; |
14 | |
15 | at::cuda::CUDAEvent forward_start = at::cuda::CUDAEvent(cudaEventDefault); |
16 | at::cuda::CUDAEvent backward_compute_start = |
17 | at::cuda::CUDAEvent(cudaEventDefault); |
18 | at::cuda::CUDAEvent backward_compute_end = |
19 | at::cuda::CUDAEvent(cudaEventDefault); |
20 | at::cuda::CUDAEvent backward_comm_start = |
21 | at::cuda::CUDAEvent(cudaEventDefault); |
22 | at::cuda::CUDAEvent backward_comm_end = at::cuda::CUDAEvent(cudaEventDefault); |
23 | |
24 | at::cuda::CUDAEvent& getEvent(Event event) { |
25 | switch (event) { |
26 | case Event::kForwardStart: |
27 | return forward_start; |
28 | case Event::kBackwardComputeStart: |
29 | return backward_compute_start; |
30 | case Event::kBackwardComputeEnd: |
31 | return backward_compute_end; |
32 | case Event::kBackwardCommStart: |
33 | return backward_comm_start; |
34 | case Event::kBackwardCommEnd: |
35 | return backward_comm_end; |
36 | default: |
37 | TORCH_INTERNAL_ASSERT(false); |
38 | } |
39 | } |
40 | |
41 | public: |
42 | explicit CudaTimer(c10::Device dev) : device(dev) {} |
43 | |
44 | void record(Event event) override { |
45 | // Parent class sets the host-side time |
46 | Timer::record(event); |
47 | c10::DeviceGuard g(device); |
48 | getEvent(event).record(); |
49 | } |
50 | |
51 | c10::optional<int64_t> measureDifference(Event start, Event end) override { |
52 | c10::DeviceGuard g(device); |
53 | at::cuda::CUDAEvent& start_event = getEvent(start); |
54 | at::cuda::CUDAEvent& end_event = getEvent(end); |
55 | // It is possible users did not call backward or run codes in |
56 | // no-sync mode, in this case, some cudaEvents like "backward_compute_end" |
57 | // or "backward_comm_start" or "backward_comm_end" will not be recorded. |
58 | // cudaEvent is created when it is first time to be recorded. |
59 | // If it is never recorded/created, skip synchronize and calculation. |
60 | // Otherwise it will throw cuda errors. |
61 | if (!start_event.isCreated() || !end_event.isCreated()) { |
62 | return c10::nullopt; |
63 | } |
64 | // set_runtime_stats_and_log is called at the beginning of forward call, |
65 | // when it is cheap to synchronize the cuda events of previous iteration, |
66 | // as mostly all cuda operations are finished in previous iteration. |
67 | start_event.synchronize(); |
68 | end_event.synchronize(); |
69 | float milliseconds = start_event.elapsed_time(end_event); |
70 | // If gpu_end is not recorded in this iteration, |
71 | // milliseconds will have invalid value. |
72 | // For some cases like DDP runs on non-sync mode, |
73 | // gpu_end can not be recorded in this iteration and thus can not |
74 | // calculate the valid avg_time. |
75 | // In this case, skip calculating the avg_time and return. |
76 | if (milliseconds < 0) { |
77 | return c10::nullopt; |
78 | } |
79 | return int64_t(milliseconds * kMilliSecondToNanosSecond); |
80 | } |
81 | }; |
82 | |
83 | C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCUDA, CudaTimer); |
84 | |
85 | } // namespace |
86 | } // namespace c10d |
87 | |