1 | #include <c10/util/StringUtil.h> |
2 | #include <fmt/format.h> |
3 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
4 | #include <torch/csrc/distributed/c10d/debug.h> |
5 | #include <torch/csrc/distributed/c10d/logger.hpp> |
6 | #include <string> |
7 | |
8 | #include <c10/util/CallOnce.h> |
9 | |
10 | #ifdef USE_C10D_GLOO |
11 | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
12 | #endif |
13 | |
14 | namespace c10d { |
15 | |
16 | // Logs runtime stats to configured destination. Note that since data collection |
17 | // only runs every ddp_runtime_logging_sample_rate iterations, the actual |
18 | // training iterations recorded will be like 10, |
19 | // (20-10) * ddp_runtime_logging_sample_rate, |
20 | // (50-10) * ddp_runtime_logging_sample_rate and so on. |
21 | const int LoggingIterations[] = {10, 20, 50, 100, 500, 800, 1000}; // NOLINT |
22 | |
23 | std::ostream& operator<<(std::ostream& output, const Logger& logger) { |
24 | auto& ddp_logging_data = (*logger.ddp_logging_data_); |
25 | |
26 | std::string loggerInfo = fmt::format( |
27 | "[Rank {} / {}] [before iteration {}] Training {} unused_parameter_size={} \n " |
28 | "Avg forward compute time: {} \n Avg backward compute time: {} \n" |
29 | "Avg backward comm. time: {} \n Avg backward comm/comp overlap time: {}" , |
30 | ddp_logging_data.ints_map["rank" ], |
31 | ddp_logging_data.ints_map["world_size" ], |
32 | ddp_logging_data.ints_map["iteration" ], |
33 | ddp_logging_data.strs_map["module_name" ], |
34 | ddp_logging_data.ints_map["unused_parameter_size" ], |
35 | ddp_logging_data.ints_map["avg_forward_compute_time" ], |
36 | ddp_logging_data.ints_map["avg_backward_compute_time" ], |
37 | ddp_logging_data.ints_map["avg_backward_comm_time" ], |
38 | ddp_logging_data.ints_map["avg_backward_compute_comm_overlap_time" ]); |
39 | |
40 | if (!ddp_logging_data.strs_map["comm_hook" ].empty()) { |
41 | loggerInfo += fmt::format( |
42 | "\n Gradient comm. hook: {}" , ddp_logging_data.strs_map["comm_hook" ]); |
43 | } |
44 | |
45 | if (ddp_logging_data.ints_map["join_uneven_inputs" ]) { |
46 | loggerInfo += "\n Uneven input detection with join() enabled." ; |
47 | } |
48 | |
49 | return output << loggerInfo; |
50 | } |
51 | |
52 | Logger::Logger(std::shared_ptr<c10d::Reducer> reducer) |
53 | : reducer_(std::move(reducer)) { |
54 | ddp_logging_data_ = std::make_unique<at::DDPLoggingData>(); |
55 | } |
56 | |
57 | c10::once_flag log_graph_static_flag; |
58 | |
59 | void Logger::log_if_graph_static(bool is_static) { |
60 | c10::call_once(log_graph_static_flag, [this, is_static]() { |
61 | ddp_logging_data_->ints_map["can_set_static_graph" ] = is_static; |
62 | // It is useful to report the iteration that training finished at. |
63 | ddp_logging_data_->ints_map["iteration" ] = reducer_->num_iterations_; |
64 | at::LogPyTorchDDPUsage(*ddp_logging_data_); |
65 | }); |
66 | } |
67 | |
68 | // Environment variables |
69 | void Logger::set_env_variables() { |
70 | ddp_logging_data_->strs_map["master_port" ] = parse_env("MASTER_PORT" ); |
71 | ddp_logging_data_->strs_map["master_addr" ] = parse_env("MASTER_ADDR" ); |
72 | ddp_logging_data_->strs_map["torch_distributed_debug" ] = |
73 | parse_env("TORCH_DISTRIBUTED_DEBUG" ); |
74 | ddp_logging_data_->strs_map["cuda_visible_devices" ] = |
75 | parse_env("CUDA_VISIBLE_DEVICES" ); |
76 | if (reducer_->process_group_->getBackendName() == "nccl" ) { |
77 | ddp_logging_data_->strs_map["nccl_socket_ifname" ] = |
78 | parse_env("NCCL_SOCKET_IFNAME" ); |
79 | ddp_logging_data_->strs_map["nccl_blocking_wait" ] = |
80 | parse_env("NCCL_BLOCKING_WAIT" ); |
81 | ddp_logging_data_->strs_map["nccl_async_error_handling" ] = |
82 | parse_env("NCCL_ASYNC_ERROR_HANDLING" ); |
83 | ddp_logging_data_->strs_map["nccl_debug" ] = parse_env("NCCL_DEBUG" ); |
84 | ddp_logging_data_->strs_map["nccl_nthreads" ] = parse_env("NCCL_NTHREADS" ); |
85 | ddp_logging_data_->strs_map["nccl_ib_timeout" ] = |
86 | parse_env("NCCL_IB_TIMEOUT" ); |
87 | } |
88 | if (reducer_->process_group_->getBackendName() == "gloo" ) { |
89 | ddp_logging_data_->strs_map["gloo_socket_ifname" ] = |
90 | parse_env("GLOO_SOCKET_IFNAME" ); |
91 | ddp_logging_data_->strs_map["gloo_device_transport" ] = |
92 | parse_env("GLOO_DEVICE_TRANSPORT" ); |
93 | |
94 | #ifdef USE_C10D_GLOO |
95 | auto gloo_pg = static_cast<c10d::ProcessGroupGloo*>( |
96 | reducer_->process_group_ |
97 | ->getBackend(c10d::ProcessGroup::BackendType::GLOO) |
98 | .get()); |
99 | auto n_threads = gloo_pg->getNumThreads(); |
100 | ddp_logging_data_->ints_map["gloo_num_threads" ] = n_threads; |
101 | #endif |
102 | } |
103 | } |
104 | |
105 | void Logger::set_parameter_stats() { |
106 | // The number of parameter tensors |
107 | ddp_logging_data_->ints_map["num_parameter_tensors" ] = |
108 | reducer_->params_.size(); |
109 | // Total parameters size (Bytes) |
110 | ddp_logging_data_->ints_map["total_parameter_size_bytes" ] = 0; |
111 | // Parameters' data types, there may be multiple data |
112 | // types for mixed precision training. |
113 | std::set<std::string> unique_dtypes; |
114 | for (const auto& t : reducer_->params_) { |
115 | ddp_logging_data_->ints_map["total_parameter_size_bytes" ] += |
116 | t.numel() * t.element_size(); |
117 | unique_dtypes.insert(std::string(t.dtype().name())); |
118 | } |
119 | ddp_logging_data_->strs_map["dtypes" ] = c10::Join(", " , unique_dtypes); |
120 | } |
121 | |
122 | std::vector<std::vector<size_t>> Logger::get_per_bucket_variable_indices() { |
123 | std::vector<std::vector<size_t>> per_bucket_variable_indices; |
124 | per_bucket_variable_indices.reserve(reducer_->buckets_.size()); |
125 | for (const auto& bucket : reducer_->buckets_) { |
126 | const auto& indices = bucket.variable_indices; |
127 | per_bucket_variable_indices.push_back(indices); |
128 | } |
129 | return per_bucket_variable_indices; |
130 | } |
131 | |
132 | std::vector<int64_t> Logger::get_bucket_sizes() { |
133 | std::vector<int64_t> bucket_sizes; |
134 | for (const auto& bucket : reducer_->buckets_) { |
135 | const auto& variables = bucket.variables; |
136 | int64_t bucket_size = 0; |
137 | for (const auto& v : variables) { |
138 | bucket_size += v.numel() * v.element_size(); |
139 | } |
140 | bucket_sizes.push_back(bucket_size); |
141 | } |
142 | return bucket_sizes; |
143 | } |
144 | |
145 | // Communication hook. Empty string if not set, in which case it will not be |
146 | // logged. |
147 | void Logger::set_comm_hook(const std::string& hook) { |
148 | ddp_logging_data_->strs_map["comm_hook" ] = hook; |
149 | } |
150 | |
151 | // Whether we are running under model.join() context manager for DDP uneven |
152 | // inputs. |
153 | void Logger::set_uneven_input_join() { |
154 | ddp_logging_data_->ints_map["join_uneven_inputs" ] = true; |
155 | } |
156 | |
157 | void Logger::set_static_graph() { |
158 | ddp_logging_data_->ints_map["static_graph" ] = reducer_->static_graph_; |
159 | } |
160 | |
161 | // Data that can be got during DistributedDataParallel construction time |
162 | void Logger::set_construction_data_and_log( |
163 | const std::string& module_name, |
164 | const std::vector<int>& device_ids, |
165 | int output_device, |
166 | bool broadcast_buffers, |
167 | bool has_sync_bn, |
168 | bool static_graph) { |
169 | // No lock is needed, as it will be called in DistributedDataParallel |
170 | // constructor. |
171 | if (static_graph) { |
172 | set_static_graph(); |
173 | } |
174 | ddp_logging_data_->strs_map["module_name" ] = module_name; |
175 | ddp_logging_data_->ints_map["world_size" ] = |
176 | reducer_->process_group_->getSize(); |
177 | ddp_logging_data_->ints_map["rank" ] = reducer_->process_group_->getRank(); |
178 | // In which iteration of the training loop the get_ddp_logging_data() |
179 | // is called to fetch the DDPLoggingData, 0 if the data is fetched |
180 | // before training loop. |
181 | ddp_logging_data_->ints_map["iteration" ] = 0; |
182 | ddp_logging_data_->ints_map["is_multi_device_module" ] = |
183 | reducer_->is_multi_device_module_; |
184 | |
185 | set_parameter_stats(); |
186 | // A list of bucket sizes (Bytes) calculated during construction time |
187 | ddp_logging_data_->strs_map["bucket_sizes" ] = |
188 | c10::Join(", " , get_bucket_sizes()); |
189 | set_env_variables(); |
190 | |
191 | // DistributedDataParallel constructor input parameters |
192 | ddp_logging_data_->strs_map["device_ids" ] = c10::Join(", " , device_ids); |
193 | ddp_logging_data_->ints_map["output_device" ] = output_device; |
194 | ddp_logging_data_->ints_map["broadcast_buffers" ] = broadcast_buffers; |
195 | ddp_logging_data_->ints_map["has_sync_bn" ] = has_sync_bn; |
196 | ddp_logging_data_->ints_map["bucket_cap_bytes" ] = reducer_->bucket_bytes_cap_; |
197 | ddp_logging_data_->ints_map["find_unused_parameters" ] = |
198 | reducer_->find_unused_parameters_; |
199 | ddp_logging_data_->ints_map["gradient_as_bucket_view" ] = |
200 | reducer_->gradient_as_bucket_view_; |
201 | ddp_logging_data_->strs_map["backend_name" ] = |
202 | reducer_->process_group_->getBackendName(); |
203 | |
204 | if (debug_level() != DebugLevel::Off) { |
205 | std::string initInfo = fmt::format( |
206 | "[Rank {}]: DDP Initialized with: \n" , |
207 | ddp_logging_data_->ints_map["rank" ]); |
208 | std::stringstream ddpLoggingDataInfo; |
209 | for (const auto& intItem : ddp_logging_data_->ints_map) { |
210 | ddpLoggingDataInfo << intItem.first << ": " << intItem.second << "\n" ; |
211 | } |
212 | for (const auto& strItem : ddp_logging_data_->strs_map) { |
213 | ddpLoggingDataInfo << strItem.first << ": " << strItem.second << "\n" ; |
214 | } |
215 | LOG(INFO) << initInfo << ddpLoggingDataInfo.str(); |
216 | } |
217 | |
218 | at::LogPyTorchDDPUsage(*ddp_logging_data_); |
219 | } |
220 | |
221 | void Logger::set_event_time( |
222 | int64_t& event_time, |
223 | Timer& timer, |
224 | Timer::Event event) { |
225 | auto timestamp = timer.getTimestamp(event); |
226 | if (timestamp != c10::nullopt) { |
227 | // TODO: should we set this as human-readable time instead of unixtime? |
228 | event_time = *timestamp; |
229 | } |
230 | } |
231 | |
232 | void Logger::calculate_avg_time( |
233 | int64_t& avg_time, |
234 | int64_t& time_duration, |
235 | Timer& timer, |
236 | Timer::Event start_event, |
237 | Timer::Event end_event) { |
238 | TORCH_CHECK(num_iterations_stats_recorded_ > 0); |
239 | c10::optional<int64_t> maybe_time_duration = |
240 | timer.measureDifference(start_event, end_event); |
241 | if (!maybe_time_duration.has_value()) { |
242 | return; |
243 | } |
244 | time_duration = maybe_time_duration.value(); |
245 | avg_time = (time_duration + avg_time * (num_iterations_stats_recorded_ - 1)) / |
246 | num_iterations_stats_recorded_; |
247 | } |
248 | |
249 | void Logger::reset_performance_stats() { |
250 | ddp_logging_data_->ints_map["forward_compute_time" ] = 0; |
251 | ddp_logging_data_->ints_map["backward_comm_time" ] = 0; |
252 | ddp_logging_data_->ints_map["backward_compute_time" ] = 0; |
253 | ddp_logging_data_->ints_map["backward_compute_comm_overlap_time" ] = 0; |
254 | ddp_logging_data_->ints_map["forward_compute_time_start" ] = 0; |
255 | ddp_logging_data_->ints_map["backward_compute_time_start" ] = 0; |
256 | ddp_logging_data_->ints_map["backward_comm_time_start" ] = 0; |
257 | ddp_logging_data_->ints_map["backward_compute_time_end" ] = 0; |
258 | ddp_logging_data_->ints_map["backward_comm_time_end" ] = 0; |
259 | } |
260 | |
261 | void Logger::set_runtime_stats_and_log() { |
262 | // Sync with reducer's data |
263 | std::lock_guard<std::mutex> lock(reducer_->mutex_); |
264 | // Set runtime stats at the sampling iterations. |
265 | if (!reducer_->should_collect_runtime_stats()) { |
266 | return; |
267 | } |
268 | num_iterations_stats_recorded_++; |
269 | // Set ith iteration when the runtime stats are set. |
270 | ddp_logging_data_->ints_map["iteration" ] = reducer_->num_iterations_; |
271 | // When get_ddp_logging_data() is called, "unused_parameter_size", |
272 | // "has_rebuilt_buckets" and "rebuilt_bucket_sizes" are updated in the latest |
273 | // sampling iteration. |
274 | // If unused_parameters_ is not empty, calculate its sizes. |
275 | // unused_parameters_ is calculated in forward call of |
276 | // each iteration. |
277 | if (reducer_->unused_parameters_.empty() && |
278 | reducer_->find_unused_parameters_) { |
279 | // No unused params in this iteration |
280 | ddp_logging_data_->ints_map["unused_parameter_size" ] = 0; |
281 | } |
282 | for (const auto& unused_index : reducer_->unused_parameters_) { |
283 | const auto& v = reducer_->params_[unused_index]; |
284 | ddp_logging_data_->ints_map["unused_parameter_size" ] += |
285 | v.numel() * v.element_size(); |
286 | } |
287 | // rebuilt_bucket_sizes will not change once buckets are rebuilt, |
288 | // so it only needs to set once during whole training loop. |
289 | // Rebuild buckets stats after 1st iteration |
290 | if (ddp_logging_data_->ints_map["has_rebuilt_buckets" ] != |
291 | reducer_->has_rebuilt_bucket_) { |
292 | ddp_logging_data_->ints_map["has_rebuilt_buckets" ] = |
293 | reducer_->has_rebuilt_bucket_; |
294 | ddp_logging_data_->strs_map["rebuilt_bucket_sizes" ] = |
295 | c10::Join(", " , get_bucket_sizes()); |
296 | // Log per-bucket variable indices |
297 | std::vector<std::string> per_bucket_variable_indices; |
298 | auto indices = get_per_bucket_variable_indices(); |
299 | per_bucket_variable_indices.reserve(indices.size()); |
300 | for (const auto& bucket_indices : indices) { |
301 | per_bucket_variable_indices.push_back(c10::Join(" " , bucket_indices)); |
302 | } |
303 | ddp_logging_data_->strs_map["rebuilt_per_bucket_param_indices" ] = |
304 | c10::Join(", " , per_bucket_variable_indices); |
305 | } |
306 | // Log gradient ready order |
307 | if (!reducer_->grad_ready_order_indices_.empty()) { |
308 | // Note that the indices are for the previous iteration as |
309 | // this function is called in forward pass, and we last computed gradient |
310 | // ready order in the last backward pass. |
311 | ddp_logging_data_->strs_map["prev_iteration_grad_ready_order_indices" ] = |
312 | c10::Join(", " , reducer_->grad_ready_order_indices_); |
313 | } |
314 | |
315 | reset_performance_stats(); |
316 | |
317 | // Cuda time stats are only collected for single device modules. |
318 | if (reducer_->params_[0].is_cuda() && reducer_->is_multi_device_module_) { |
319 | TORCH_WARN_ONCE( |
320 | "Cuda time stats are not collected for multi-device modules." ); |
321 | return; |
322 | } |
323 | if (!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu()) { |
324 | TORCH_WARN_ONCE( |
325 | "Time stats are currently only collected for CPU and CUDA devices. " |
326 | "Please refer to CpuTimer or CudaTimer for how to register timer " |
327 | "for other device type." ); |
328 | return; |
329 | } |
330 | TORCH_INTERNAL_ASSERT(reducer_->timer_); |
331 | calculate_avg_time( |
332 | ddp_logging_data_->ints_map["avg_forward_compute_time" ], |
333 | ddp_logging_data_->ints_map["forward_compute_time" ], |
334 | *reducer_->timer_, |
335 | Timer::Event::kForwardStart, |
336 | Timer::Event::kBackwardComputeStart); |
337 | calculate_avg_time( |
338 | ddp_logging_data_->ints_map["avg_backward_compute_time" ], |
339 | ddp_logging_data_->ints_map["backward_compute_time" ], |
340 | *reducer_->timer_, |
341 | Timer::Event::kBackwardComputeStart, |
342 | Timer::Event::kBackwardComputeEnd); |
343 | calculate_avg_time( |
344 | ddp_logging_data_->ints_map["avg_backward_comm_time" ], |
345 | ddp_logging_data_->ints_map["backward_comm_time" ], |
346 | *reducer_->timer_, |
347 | Timer::Event::kBackwardCommStart, |
348 | Timer::Event::kBackwardCommEnd); |
349 | calculate_avg_time( |
350 | ddp_logging_data_->ints_map["avg_backward_compute_comm_overlap_time" ], |
351 | ddp_logging_data_->ints_map["backward_compute_comm_overlap_time" ], |
352 | *reducer_->timer_, |
353 | Timer::Event::kBackwardCommStart, |
354 | Timer::Event::kBackwardComputeEnd); |
355 | |
356 | set_event_time( |
357 | ddp_logging_data_->ints_map["forward_compute_time_start" ], |
358 | *reducer_->timer_, |
359 | Timer::Event::kForwardStart); |
360 | set_event_time( |
361 | ddp_logging_data_->ints_map["backward_compute_time_start" ], |
362 | *reducer_->timer_, |
363 | Timer::Event::kBackwardComputeStart); |
364 | set_event_time( |
365 | ddp_logging_data_->ints_map["backward_comm_time_start" ], |
366 | *reducer_->timer_, |
367 | Timer::Event::kBackwardCommStart); |
368 | set_event_time( |
369 | ddp_logging_data_->ints_map["backward_compute_time_end" ], |
370 | *reducer_->timer_, |
371 | Timer::Event::kBackwardComputeEnd); |
372 | set_event_time( |
373 | ddp_logging_data_->ints_map["backward_comm_time_end" ], |
374 | *reducer_->timer_, |
375 | Timer::Event::kBackwardCommEnd); |
376 | |
377 | // Log runtime stats to stderr if TORCH_DISTRIBUTED_DEBUG=DETAIL is enabled. |
378 | if (debug_level() == DebugLevel::Detail) { |
379 | LOG(INFO) << *this; |
380 | } |
381 | |
382 | // Log runtime (e.g. avg performance) stats at the beginning and also |
383 | // after a larger number of iterations. Choosing 10/1000/10000 is |
384 | // not scientific here, it assumes most of applications will run |
385 | // at least 10 iterations. stats could have smaller variance if |
386 | // selected num_iterations_ is larger. |
387 | if (std::find( |
388 | std::begin(LoggingIterations), |
389 | std::end(LoggingIterations), |
390 | num_iterations_stats_recorded_) != std::end(LoggingIterations)) { |
391 | at::LogPyTorchDDPUsage(*ddp_logging_data_); |
392 | } |
393 | } |
394 | |
395 | at::DDPLoggingData Logger::get_ddp_logging_data() { |
396 | std::lock_guard<std::mutex> lock(reducer_->mutex_); |
397 | return *ddp_logging_data_; |
398 | } |
399 | |
400 | } // namespace c10d |
401 | |