1 | #pragma once |
2 | #include <torch/csrc/autograd/profiler.h> |
3 | |
4 | namespace c10d { |
5 | constexpr int kUnsetTime = -1; |
6 | |
7 | inline int64_t current_time_in_nanos() { |
8 | return torch::profiler::impl::getTime(); |
9 | } |
10 | |
11 | class TORCH_API Timer { |
12 | private: |
13 | // The timestamp of forward call start time in each iteration. |
14 | int64_t forward_start_time = kUnsetTime; |
15 | // The timestamp of backward computation start and end time in each |
16 | // iteration. |
17 | int64_t backward_compute_start_time = kUnsetTime; |
18 | int64_t backward_compute_end_time = kUnsetTime; |
19 | // The timestamp of first communication call start time in each iteration. |
20 | int64_t backward_comm_start_time = kUnsetTime; |
21 | // The timestamp of last communication call end time in each iteration. |
22 | int64_t backward_comm_end_time = kUnsetTime; |
23 | |
24 | public: |
25 | enum class Event { |
26 | kForwardStart, |
27 | kBackwardComputeStart, |
28 | kBackwardComputeEnd, |
29 | kBackwardCommStart, |
30 | kBackwardCommEnd, |
31 | }; |
32 | |
33 | // Record the current event, i.e., mark it as having occurred now. Default |
34 | // CPU implementation. |
35 | virtual void record(Event event) { |
36 | getTimeRef(event) = current_time_in_nanos(); |
37 | } |
38 | |
39 | // Return the difference between when two events occurred, in nanoseconds. |
40 | // Or nullopt if one of them hasn't been recorded. |
41 | virtual c10::optional<int64_t> measureDifference(Event start, Event end) = 0; |
42 | |
43 | virtual ~Timer() = default; |
44 | |
45 | // Return host-side timestamp, or nullopt if it has not yet been recorded. |
46 | c10::optional<int64_t> getTimestamp(Event event) { |
47 | auto time = getTimeRef(event); |
48 | if (time == kUnsetTime) { |
49 | return c10::nullopt; |
50 | } else { |
51 | return time; |
52 | } |
53 | } |
54 | |
55 | // Return host-side time member variable corresponding to the given event. |
56 | int64_t& getTimeRef(Event event) { |
57 | switch (event) { |
58 | case Event::kForwardStart: |
59 | return forward_start_time; |
60 | case Event::kBackwardComputeStart: |
61 | return backward_compute_start_time; |
62 | case Event::kBackwardComputeEnd: |
63 | return backward_compute_end_time; |
64 | case Event::kBackwardCommStart: |
65 | return backward_comm_start_time; |
66 | case Event::kBackwardCommEnd: |
67 | return backward_comm_end_time; |
68 | default: |
69 | TORCH_INTERNAL_ASSERT(false); |
70 | } |
71 | } |
72 | }; |
73 | |
74 | C10_DECLARE_TYPED_REGISTRY(TimerRegistry, c10::DeviceType, Timer, std::unique_ptr, c10::Device); |
75 | } // namespace c10d |
76 | |