1#pragma once
2#include <torch/csrc/autograd/profiler.h>
3
4namespace c10d {
5constexpr int kUnsetTime = -1;
6
7inline int64_t current_time_in_nanos() {
8 return torch::profiler::impl::getTime();
9}
10
11class TORCH_API Timer {
12 private:
13 // The timestamp of forward call start time in each iteration.
14 int64_t forward_start_time = kUnsetTime;
15 // The timestamp of backward computation start and end time in each
16 // iteration.
17 int64_t backward_compute_start_time = kUnsetTime;
18 int64_t backward_compute_end_time = kUnsetTime;
19 // The timestamp of first communication call start time in each iteration.
20 int64_t backward_comm_start_time = kUnsetTime;
21 // The timestamp of last communication call end time in each iteration.
22 int64_t backward_comm_end_time = kUnsetTime;
23
24 public:
25 enum class Event {
26 kForwardStart,
27 kBackwardComputeStart,
28 kBackwardComputeEnd,
29 kBackwardCommStart,
30 kBackwardCommEnd,
31 };
32
33 // Record the current event, i.e., mark it as having occurred now. Default
34 // CPU implementation.
35 virtual void record(Event event) {
36 getTimeRef(event) = current_time_in_nanos();
37 }
38
39 // Return the difference between when two events occurred, in nanoseconds.
40 // Or nullopt if one of them hasn't been recorded.
41 virtual c10::optional<int64_t> measureDifference(Event start, Event end) = 0;
42
43 virtual ~Timer() = default;
44
45 // Return host-side timestamp, or nullopt if it has not yet been recorded.
46 c10::optional<int64_t> getTimestamp(Event event) {
47 auto time = getTimeRef(event);
48 if (time == kUnsetTime) {
49 return c10::nullopt;
50 } else {
51 return time;
52 }
53 }
54
55 // Return host-side time member variable corresponding to the given event.
56 int64_t& getTimeRef(Event event) {
57 switch (event) {
58 case Event::kForwardStart:
59 return forward_start_time;
60 case Event::kBackwardComputeStart:
61 return backward_compute_start_time;
62 case Event::kBackwardComputeEnd:
63 return backward_compute_end_time;
64 case Event::kBackwardCommStart:
65 return backward_comm_start_time;
66 case Event::kBackwardCommEnd:
67 return backward_comm_end_time;
68 default:
69 TORCH_INTERNAL_ASSERT(false);
70 }
71 }
72};
73
74C10_DECLARE_TYPED_REGISTRY(TimerRegistry, c10::DeviceType, Timer, std::unique_ptr, c10::Device);
75} // namespace c10d
76