1#include <c10/util/Logging.h>
2#include <torch/csrc/distributed/c10d/reducer.hpp>
3
4#include <mutex>
5
6namespace c10d {
7
8class TORCH_API Logger {
9 public:
10 explicit Logger(std::shared_ptr<c10d::Reducer> reducer);
11 // Set logging data that can be got during DistributedDataParallel
12 // construction time.
13 void set_construction_data_and_log(
14 const std::string& module_name,
15 const std::vector<int>& device_ids,
16 int output_device,
17 bool broadcast_buffers,
18 bool has_sync_bn,
19 bool static_graph
20 );
21
22 void set_static_graph();
23
24 // An interface for users to get DDPLoggingData and log them
25 // in the applications. Explanation of logging fields are in
26 // "struct DDPLoggingData" of "torch/c10/util/Logging.h".
27 at::DDPLoggingData get_ddp_logging_data();
28
29 // Stream insertion operator for logging data to stream under
30 // TORCH_DISTRIBUTED_DEBUG.
31 friend std::ostream& operator<<(std::ostream& output, const Logger& logger);
32
33 ~Logger() noexcept(false) {
34 // Log if DDP graph is static in Logger dtor instead of Reducer dtor since
35 // Logger is deleted before Reducer.
36 log_if_graph_static(reducer_->ddp_graph_static());
37 }
38
39 // Set environment variables.
40 void set_env_variables();
41 // Set parameters stats.
42 void set_parameter_stats();
43 // Get size of each bucket (Bytes).
44 std::vector<int64_t> get_bucket_sizes();
45 // Get variable indices for each bucket.
46 std::vector<std::vector<size_t>> get_per_bucket_variable_indices();
47 // Set comm. hook, if used
48 void set_comm_hook(const std::string& hook);
49 // Set running with uneven input detection (model.join() context manager)
50 void set_uneven_input_join();
51
52 // Reset performance stats at current iteration
53 void reset_performance_stats();
54
55 // Calculate avg stats using cpu timer and gpu timer
56 // that has been recorded in reducer.
57 void calculate_avg_time(
58 int64_t& avg_time,
59 int64_t& time_duration,
60 Timer& timer,
61 Timer::Event start_event,
62 Timer::Event end_event);
63
64 // Set the absolute time of the event that has been recorded in reducer.
65 void set_event_time(
66 int64_t& event_time,
67 Timer& timer,
68 Timer::Event event
69 );
70 // Set stats that can be collected only during
71 // training loop. It is called at the beginning of forward call
72 // to record the run time stats of sampled iterations that previouly ran.
73 // GPU performance stats are collected only for single process
74 // single device program and single device module right now.
75 // TODO to support single process multiple devices and multi device modules,
76 // events need to be created and recorded on multiple devices.
77 void set_runtime_stats_and_log();
78
79 // Called when DDP/reducer is failing with an error. The
80 // logging data structure will have two fields filled: "has_error" indicating
81 // that this iteration encountered an error and other fields are not valid,
82 // and "error", a string which contains the error message that DDP failed
83 // with.
84 template <typename... Args>
85 void set_error_and_log(const std::string& ddp_error, const Args&... args) {
86 ddp_logging_data_->ints_map["has_error"] = 1;
87 auto err = c10::str(ddp_error, args...);
88 ddp_logging_data_->strs_map["error"] = err;
89 // Report the iteration we are erroring at so user knows how many examples
90 // successfully processed before this error was hit.
91 ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
92 at::LogPyTorchDDPUsage(*ddp_logging_data_);
93 }
94
95 // When running without static graph, called when reducer is destroyed to log
96 // if graph was actually static and is a candidate for static graph
97 // optimization.
98 void log_if_graph_static(bool is_static);
99
100
101 private:
102 // ddp_logging_data_ is used to hold all the ddp related logging
103 // data fields.
104 std::unique_ptr<at::DDPLoggingData> ddp_logging_data_;
105 std::shared_ptr<c10d::Reducer> reducer_;
106 // track the number of iterations when runtime stats are collected so far.
107 long num_iterations_stats_recorded_ = 0;
108};
109
110} // namespace c10d
111