1 | #include <c10/util/Logging.h> |
2 | #include <torch/csrc/distributed/c10d/reducer.hpp> |
3 | |
4 | #include <mutex> |
5 | |
6 | namespace c10d { |
7 | |
8 | class TORCH_API Logger { |
9 | public: |
10 | explicit Logger(std::shared_ptr<c10d::Reducer> reducer); |
11 | // Set logging data that can be got during DistributedDataParallel |
12 | // construction time. |
13 | void set_construction_data_and_log( |
14 | const std::string& module_name, |
15 | const std::vector<int>& device_ids, |
16 | int output_device, |
17 | bool broadcast_buffers, |
18 | bool has_sync_bn, |
19 | bool static_graph |
20 | ); |
21 | |
22 | void set_static_graph(); |
23 | |
24 | // An interface for users to get DDPLoggingData and log them |
25 | // in the applications. Explanation of logging fields are in |
26 | // "struct DDPLoggingData" of "torch/c10/util/Logging.h". |
27 | at::DDPLoggingData get_ddp_logging_data(); |
28 | |
29 | // Stream insertion operator for logging data to stream under |
30 | // TORCH_DISTRIBUTED_DEBUG. |
31 | friend std::ostream& operator<<(std::ostream& output, const Logger& logger); |
32 | |
33 | ~Logger() noexcept(false) { |
34 | // Log if DDP graph is static in Logger dtor instead of Reducer dtor since |
35 | // Logger is deleted before Reducer. |
36 | log_if_graph_static(reducer_->ddp_graph_static()); |
37 | } |
38 | |
39 | // Set environment variables. |
40 | void set_env_variables(); |
41 | // Set parameters stats. |
42 | void set_parameter_stats(); |
43 | // Get size of each bucket (Bytes). |
44 | std::vector<int64_t> get_bucket_sizes(); |
45 | // Get variable indices for each bucket. |
46 | std::vector<std::vector<size_t>> get_per_bucket_variable_indices(); |
47 | // Set comm. hook, if used |
48 | void set_comm_hook(const std::string& hook); |
49 | // Set running with uneven input detection (model.join() context manager) |
50 | void set_uneven_input_join(); |
51 | |
52 | // Reset performance stats at current iteration |
53 | void reset_performance_stats(); |
54 | |
55 | // Calculate avg stats using cpu timer and gpu timer |
56 | // that has been recorded in reducer. |
57 | void calculate_avg_time( |
58 | int64_t& avg_time, |
59 | int64_t& time_duration, |
60 | Timer& timer, |
61 | Timer::Event start_event, |
62 | Timer::Event end_event); |
63 | |
64 | // Set the absolute time of the event that has been recorded in reducer. |
65 | void set_event_time( |
66 | int64_t& event_time, |
67 | Timer& timer, |
68 | Timer::Event event |
69 | ); |
70 | // Set stats that can be collected only during |
71 | // training loop. It is called at the beginning of forward call |
72 | // to record the run time stats of sampled iterations that previouly ran. |
73 | // GPU performance stats are collected only for single process |
74 | // single device program and single device module right now. |
75 | // TODO to support single process multiple devices and multi device modules, |
76 | // events need to be created and recorded on multiple devices. |
77 | void set_runtime_stats_and_log(); |
78 | |
79 | // Called when DDP/reducer is failing with an error. The |
80 | // logging data structure will have two fields filled: "has_error" indicating |
81 | // that this iteration encountered an error and other fields are not valid, |
82 | // and "error", a string which contains the error message that DDP failed |
83 | // with. |
84 | template <typename... Args> |
85 | void set_error_and_log(const std::string& ddp_error, const Args&... args) { |
86 | ddp_logging_data_->ints_map["has_error" ] = 1; |
87 | auto err = c10::str(ddp_error, args...); |
88 | ddp_logging_data_->strs_map["error" ] = err; |
89 | // Report the iteration we are erroring at so user knows how many examples |
90 | // successfully processed before this error was hit. |
91 | ddp_logging_data_->ints_map["iteration" ] = reducer_->num_iterations_; |
92 | at::LogPyTorchDDPUsage(*ddp_logging_data_); |
93 | } |
94 | |
95 | // When running without static graph, called when reducer is destroyed to log |
96 | // if graph was actually static and is a candidate for static graph |
97 | // optimization. |
98 | void log_if_graph_static(bool is_static); |
99 | |
100 | |
101 | private: |
102 | // ddp_logging_data_ is used to hold all the ddp related logging |
103 | // data fields. |
104 | std::unique_ptr<at::DDPLoggingData> ddp_logging_data_; |
105 | std::shared_ptr<c10d::Reducer> reducer_; |
106 | // track the number of iterations when runtime stats are collected so far. |
107 | long num_iterations_stats_recorded_ = 0; |
108 | }; |
109 | |
110 | } // namespace c10d |
111 | |