1#include <torch/csrc/distributed/c10d/reducer_timer.hpp>
2
3#include <ATen/cuda/CUDAEvent.h>
4#include <c10/core/DeviceGuard.h>
5
6namespace c10d {
7namespace {
8
9const int kMilliSecondToNanosSecond = 1000000;
10
11class CudaTimer : public Timer {
12 private:
13 c10::Device device;
14
15 at::cuda::CUDAEvent forward_start = at::cuda::CUDAEvent(cudaEventDefault);
16 at::cuda::CUDAEvent backward_compute_start =
17 at::cuda::CUDAEvent(cudaEventDefault);
18 at::cuda::CUDAEvent backward_compute_end =
19 at::cuda::CUDAEvent(cudaEventDefault);
20 at::cuda::CUDAEvent backward_comm_start =
21 at::cuda::CUDAEvent(cudaEventDefault);
22 at::cuda::CUDAEvent backward_comm_end = at::cuda::CUDAEvent(cudaEventDefault);
23
24 at::cuda::CUDAEvent& getEvent(Event event) {
25 switch (event) {
26 case Event::kForwardStart:
27 return forward_start;
28 case Event::kBackwardComputeStart:
29 return backward_compute_start;
30 case Event::kBackwardComputeEnd:
31 return backward_compute_end;
32 case Event::kBackwardCommStart:
33 return backward_comm_start;
34 case Event::kBackwardCommEnd:
35 return backward_comm_end;
36 default:
37 TORCH_INTERNAL_ASSERT(false);
38 }
39 }
40
41 public:
42 explicit CudaTimer(c10::Device dev) : device(dev) {}
43
44 void record(Event event) override {
45 // Parent class sets the host-side time
46 Timer::record(event);
47 c10::DeviceGuard g(device);
48 getEvent(event).record();
49 }
50
51 c10::optional<int64_t> measureDifference(Event start, Event end) override {
52 c10::DeviceGuard g(device);
53 at::cuda::CUDAEvent& start_event = getEvent(start);
54 at::cuda::CUDAEvent& end_event = getEvent(end);
55 // It is possible users did not call backward or run codes in
56 // no-sync mode, in this case, some cudaEvents like "backward_compute_end"
57 // or "backward_comm_start" or "backward_comm_end" will not be recorded.
58 // cudaEvent is created when it is first time to be recorded.
59 // If it is never recorded/created, skip synchronize and calculation.
60 // Otherwise it will throw cuda errors.
61 if (!start_event.isCreated() || !end_event.isCreated()) {
62 return c10::nullopt;
63 }
64 // set_runtime_stats_and_log is called at the beginning of forward call,
65 // when it is cheap to synchronize the cuda events of previous iteration,
66 // as mostly all cuda operations are finished in previous iteration.
67 start_event.synchronize();
68 end_event.synchronize();
69 float milliseconds = start_event.elapsed_time(end_event);
70 // If gpu_end is not recorded in this iteration,
71 // milliseconds will have invalid value.
72 // For some cases like DDP runs on non-sync mode,
73 // gpu_end can not be recorded in this iteration and thus can not
74 // calculate the valid avg_time.
75 // In this case, skip calculating the avg_time and return.
76 if (milliseconds < 0) {
77 return c10::nullopt;
78 }
79 return int64_t(milliseconds * kMilliSecondToNanosSecond);
80 }
81};
82
83C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCUDA, CudaTimer);
84
85} // namespace
86} // namespace c10d
87