1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/step_stats_collector.h" |
17 | #include "tensorflow/core/common_runtime/costmodel_manager.h" |
18 | #include "tensorflow/core/framework/allocation_description.pb.h" |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/tensor_description.pb.h" |
22 | #include "tensorflow/core/framework/tracking_allocator.h" |
23 | #include "tensorflow/core/graph/costmodel.h" |
24 | #include "tensorflow/core/graph/graph.h" |
25 | #include "tensorflow/core/lib/core/stringpiece.h" |
26 | #include "tensorflow/core/lib/strings/numbers.h" |
27 | #include "tensorflow/core/lib/strings/scanner.h" |
28 | #include "tensorflow/core/lib/strings/stringprintf.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/core/util/ptr_util.h" |
31 | |
32 | namespace tensorflow { |
33 | namespace { |
34 | const int kMaxAllocReportNodes = 100; |
35 | const float kMaxAllocReportFraction = 0.99; |
36 | |
37 | struct AllocStats { |
38 | std::map<int64_t, std::vector<string>> nodes_by_size; |
39 | int64_t total_bytes = 0; |
40 | int64_t total_nodes = 0; |
41 | }; |
42 | |
43 | bool IsRecv(const NodeDef* node) { |
44 | return node->op() == "_Recv" || node->op() == "_HostRecv" ; |
45 | } |
46 | |
47 | bool IsSend(const NodeDef* node) { |
48 | return node->op() == "_Send" || node->op() == "_HostSend" ; |
49 | } |
50 | |
51 | } // namespace |
52 | |
53 | NodeExecStatsWrapper::NodeExecStatsWrapper( |
54 | const NodeDef* node, StepStatsCollector* step_stats_collector) |
55 | : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node, |
56 | step_stats_collector) { |
57 | stats_->set_node_name(node->name()); |
58 | } |
59 | |
60 | NodeExecStatsWrapper::NodeExecStatsWrapper( |
61 | std::unique_ptr<NodeExecStats> stats, const NodeDef* node, |
62 | StepStatsCollector* step_stats_collector) |
63 | : stats_(std::move(stats)), |
64 | node_(node), |
65 | step_stats_collector_(step_stats_collector) {} |
66 | |
67 | void NodeExecStatsWrapper::Done(const string& device) { |
68 | // TODO(tucker): merge with the DetailText function in session.cc in a common |
69 | // location. |
70 | DCHECK(node_); |
71 | string memory; |
72 | for (auto& all : stats_->memory()) { |
73 | int64_t tot = all.total_bytes(); |
74 | if (tot >= 0.1 * 1048576.0) { |
75 | int64_t peak = all.peak_bytes(); |
76 | if (peak > 0) { |
77 | memory = |
78 | strings::StrCat(memory, "[" , all.allocator_name(), |
79 | strings::Printf(" %.1fMB %.1fMB] " , tot / 1048576.0, |
80 | peak / 1048576.0)); |
81 | } else { |
82 | memory = strings::StrCat(memory, "[" , all.allocator_name(), |
83 | strings::Printf(" %.1fMB] " , tot / 1048576.0)); |
84 | } |
85 | } |
86 | } |
87 | const AttrSlice attrs(*node_); |
88 | string text; |
89 | if (IsSend(node_)) { |
90 | string tensor_name; |
91 | TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name" , &tensor_name)); |
92 | string recv_device; |
93 | TF_CHECK_OK(GetNodeAttr(attrs, "recv_device" , &recv_device)); |
94 | text = strings::StrCat(memory, node_->name(), " = " , node_->op(), "(" , |
95 | tensor_name, " @" , recv_device, ")" ); |
96 | } else if (IsRecv(node_)) { |
97 | string tensor_name; |
98 | TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name" , &tensor_name)); |
99 | string send_device; |
100 | TF_CHECK_OK(GetNodeAttr(attrs, "send_device" , &send_device)); |
101 | text = strings::StrCat(memory, node_->name(), " = " , node_->op(), "(" , |
102 | tensor_name, " @" , send_device, ")" ); |
103 | } else { |
104 | text = strings::StrCat(memory, node_->name(), " = " , node_->op(), "(" , |
105 | absl::StrJoin(node_->input(), ", " ), ")" ); |
106 | } |
107 | stats_->set_timeline_label(text); |
108 | step_stats_collector_->Save(device, this); |
109 | } |
110 | |
111 | void NodeExecStatsWrapper::RecordExecutorStarted() { |
112 | int64_t now_nanos = Env::Default()->NowNanos(); |
113 | stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos); |
114 | stats_->set_all_start_nanos(now_nanos); |
115 | } |
116 | |
117 | void NodeExecStatsWrapper::RecordComputeStarted() { |
118 | int64_t now_nanos = Env::Default()->NowNanos(); |
119 | DCHECK_NE(stats_->all_start_micros(), 0); |
120 | DCHECK_NE(stats_->all_start_nanos(), 0); |
121 | stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos - |
122 | stats_->all_start_micros()); |
123 | stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos()); |
124 | } |
125 | |
126 | void NodeExecStatsWrapper::RecordComputeEnded() { |
127 | int64_t now_nanos = Env::Default()->NowNanos(); |
128 | DCHECK_NE(stats_->all_start_micros(), 0); |
129 | DCHECK_NE(stats_->all_start_nanos(), 0); |
130 | stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - |
131 | stats_->all_start_micros()); |
132 | stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos()); |
133 | } |
134 | |
135 | void NodeExecStatsWrapper::RecordExecutorEnded() { |
136 | int64_t now_nanos = Env::Default()->NowNanos(); |
137 | DCHECK_NE(stats_->all_start_micros(), 0); |
138 | DCHECK_NE(stats_->all_start_nanos(), 0); |
139 | stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - |
140 | stats_->all_start_micros()); |
141 | stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos()); |
142 | } |
143 | |
144 | void NodeExecStatsWrapper::SetScheduled(int64_t nanos) { |
145 | stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos); |
146 | stats_->set_scheduled_nanos(nanos); |
147 | } |
148 | |
149 | void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { |
150 | for (const auto& allocator_pair : ctx->ConsumeWrappedAllocators()) { |
151 | AddAllocation(allocator_pair.first, allocator_pair.second); |
152 | } |
153 | auto* ms = stats_->mutable_memory_stats(); |
154 | ms->set_temp_memory_size(ctx->temp_memory_allocated()); |
155 | for (const auto& alloc_id : ctx->persistent_alloc_ids()) { |
156 | ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); |
157 | } |
158 | ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); |
159 | } |
160 | |
161 | void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) { |
162 | DCHECK(tensor); |
163 | NodeOutput* node_output = stats_->add_output(); |
164 | node_output->set_slot(slot); |
165 | tensor->FillDescription(node_output->mutable_tensor_description()); |
166 | } |
167 | |
168 | void NodeExecStatsWrapper::AddAllocation( |
169 | Allocator* allocator, TrackingAllocator* tracking_allocator) { |
170 | AllocatorMemoryUsed* memory = stats_->add_memory(); |
171 | memory->set_allocator_name(allocator->Name()); |
172 | auto sizes = tracking_allocator->GetSizes(); |
173 | memory->set_total_bytes(std::get<0>(sizes)); |
174 | memory->set_peak_bytes(std::get<1>(sizes)); |
175 | memory->set_live_bytes(std::get<2>(sizes)); |
176 | |
177 | absl::optional<AllocatorStats> stats = allocator->GetStats(); |
178 | if (stats) { |
179 | memory->set_allocator_bytes_in_use(stats->bytes_in_use); |
180 | } |
181 | allocations_.push_back(std::make_pair(memory, tracking_allocator)); |
182 | } |
183 | |
184 | void NodeExecStatsWrapper::Finalize() { |
185 | for (auto& alloc : allocations_) { |
186 | AllocatorMemoryUsed* memory = alloc.first; |
187 | for (auto& record : alloc.second->GetRecordsAndUnRef()) { |
188 | auto* r = memory->add_allocation_records(); |
189 | r->set_alloc_bytes(record.alloc_bytes); |
190 | r->set_alloc_micros(record.alloc_micros); |
191 | } |
192 | } |
193 | allocations_.clear(); |
194 | } |
195 | |
196 | StepStatsCollector::StepStatsCollector(StepStats* step_stats) |
197 | : finalized_(false), step_stats_(step_stats) {} |
198 | |
199 | static int (string device_name) { |
200 | // Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp, |
201 | // and if it does return the stream index (always positive). If it doesn't |
202 | // return -1. |
203 | |
204 | // The best way to parse this regexp using a scanner is to parse it in |
205 | // reverse starting from the end. |
206 | std::reverse(device_name.begin(), device_name.end()); |
207 | strings::Scanner scanner(device_name); |
208 | // Check that the string end with '/stream:all' |
209 | scanner.OneLiteral("lla:maerts/" ); |
210 | // Capture the digits if present |
211 | scanner.RestartCapture().Many(strings::Scanner::DIGIT).StopCapture(); |
212 | // Check that the digits are preceded by the 'device:GPU:' string |
213 | scanner.OneLiteral(":UPG:ecived" ); |
214 | StringPiece capture; |
215 | bool matched = scanner.GetResult(nullptr, &capture); |
216 | |
217 | if (!matched) { |
218 | return -1; |
219 | } else { |
220 | // Convert the captured string into an integer. But first we need to put |
221 | // the digits back in order |
222 | string ordered_capture(capture); |
223 | std::reverse(ordered_capture.begin(), ordered_capture.end()); |
224 | int gpu_id; |
225 | CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); |
226 | return gpu_id; |
227 | } |
228 | } |
229 | |
230 | static int (string device_name) { |
231 | // Check if the device name matches the ".*gpu:(\\d+)$" regexp, |
232 | // and if it does return the stream index (always positive). If it doesn't |
233 | // return -1. |
234 | |
235 | // The best way to parse this regexp using a scanner is to parse it in |
236 | // reverse starting from the end. |
237 | std::reverse(device_name.begin(), device_name.end()); |
238 | strings::Scanner scanner(device_name); |
239 | // Capture the trailing digits if present |
240 | scanner.RestartCapture().Many(strings::Scanner::DIGIT).StopCapture(); |
241 | // Check that the digits are preceded by the 'device:GPU:' string |
242 | scanner.OneLiteral(":UPG:ecived" ); |
243 | StringPiece capture; |
244 | bool matched = scanner.GetResult(nullptr, &capture); |
245 | |
246 | if (!matched) { |
247 | return -1; |
248 | } else { |
249 | // Convert the captured string into an integer. But first we need to put |
250 | // the digits back in order |
251 | string ordered_capture(capture); |
252 | std::reverse(ordered_capture.begin(), ordered_capture.end()); |
253 | int gpu_id; |
254 | CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); |
255 | return gpu_id; |
256 | } |
257 | } |
258 | |
259 | void StepStatsCollector::BuildCostModel( |
260 | CostModelManager* cost_model_manager, |
261 | const std::unordered_map<string, const Graph*>& device_map) { |
262 | mutex_lock lock(mu_); |
263 | |
264 | if (!finalized_) { |
265 | FinalizeInternal(); |
266 | } |
267 | // Hardware stats for gpu are available under a fake device named |
268 | // "gpu:<id>/stream::all. |
269 | // Use them instead of regular stats whenever they're available to extract |
270 | // the execution stats of a particular node since they're more accurate. |
271 | // However hardware traces don't record memory usage, so we still have to |
272 | // rely on regular traces to track memory usage. |
273 | struct DeviceStats { |
274 | const DeviceStepStats* regular_stats; |
275 | const DeviceStepStats* hardware_stats; |
276 | }; |
277 | |
278 | std::unordered_map<StringPiece, DeviceStats, StringPieceHasher> |
279 | per_device_stats; |
280 | std::unordered_map<int, const DeviceStepStats*> gpu_hardware_stats; |
281 | |
282 | for (int i = 0; i < step_stats_->dev_stats_size(); ++i) { |
283 | const DeviceStepStats& device_stats = step_stats_->dev_stats(i); |
284 | const string& device_name = device_stats.device(); |
285 | const int gpu_id = ExtractGpuWithStreamAll(device_name); |
286 | if (gpu_id >= 0) { |
287 | // These are gpu hardware stats |
288 | gpu_hardware_stats.emplace(gpu_id, &device_stats); |
289 | } else { |
290 | // These are regular stats. |
291 | per_device_stats.emplace(device_name, |
292 | DeviceStats{&device_stats, nullptr}); |
293 | } |
294 | } |
295 | |
296 | for (auto& itr : per_device_stats) { |
297 | const StringPiece device_name = itr.first; |
298 | const int gpu_id = ExtractGpuWithoutStream(string(device_name)); |
299 | if (gpu_id >= 0) { |
300 | // Reference the gpu hardware stats in addition to the regular stats |
301 | // for this gpu device if they're available. |
302 | if (gpu_hardware_stats.find(gpu_id) != gpu_hardware_stats.end()) { |
303 | itr.second.hardware_stats = gpu_hardware_stats.find(gpu_id)->second; |
304 | } |
305 | } |
306 | } |
307 | |
308 | for (const auto& itr : device_map) { |
309 | const StringPiece device = itr.first; |
310 | if (per_device_stats.find(device) == per_device_stats.end()) { |
311 | continue; |
312 | } |
313 | |
314 | const Graph* graph = itr.second; |
315 | CostModel* cm = cost_model_manager->FindOrCreateCostModel(graph); |
316 | cm->IncrementUpdateTimes(); |
317 | |
318 | std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node; |
319 | for (Node* n : graph->nodes()) { |
320 | name_to_node.emplace(n->name(), n); |
321 | } |
322 | |
323 | const DeviceStats& dev_stats = per_device_stats.find(device)->second; |
324 | |
325 | std::unordered_map<string, NodeExecStats> name_to_hw_node_stats; |
326 | if (dev_stats.hardware_stats) { |
327 | for (const auto& node_stats : dev_stats.hardware_stats->node_stats()) { |
328 | string node_name = node_stats.node_name(); |
329 | // Remove the part of op name (e.g. :Conv2D) in the end of a node name. |
330 | size_t pos = node_name.find_first_of(':'); |
331 | if (pos != std::string::npos) { |
332 | node_name = node_name.substr(0, pos); |
333 | } |
334 | // Certain ops (e.g. Conv2D) are implemented with multiple GPU kernels, |
335 | // which results in multiple NodeExecStats with the same node name. For |
336 | // such ops, we sum up the time for all its GPU kernels. |
337 | if (name_to_hw_node_stats.find(node_name) != |
338 | name_to_hw_node_stats.end()) { |
339 | int64_t time = name_to_hw_node_stats[node_name].op_end_rel_micros(); |
340 | name_to_hw_node_stats[node_name].set_op_end_rel_micros( |
341 | time + node_stats.op_end_rel_micros()); |
342 | } else { |
343 | name_to_hw_node_stats.emplace(node_name, node_stats); |
344 | } |
345 | } |
346 | } |
347 | |
348 | for (int i = 0; i < dev_stats.regular_stats->node_stats_size(); ++i) { |
349 | const NodeExecStats& stats = dev_stats.regular_stats->node_stats(i); |
350 | const Node* node = name_to_node[stats.node_name()]; |
351 | if (node) { |
352 | for (int i = 0; i < stats.output_size(); ++i) { |
353 | const auto& output = stats.output(i); |
354 | int output_slot = output.slot(); |
355 | cm->RecordMaxMemorySize(node, output_slot, |
356 | Bytes(output.tensor_description() |
357 | .allocation_description() |
358 | .allocated_bytes()), |
359 | output.tensor_description().shape(), |
360 | node->output_types()[output_slot]); |
361 | cm->RecordAllocationId(node, output_slot, |
362 | output.tensor_description() |
363 | .allocation_description() |
364 | .allocation_id()); |
365 | } |
366 | cm->RecordMemoryStats(node, stats.memory_stats()); |
367 | // Use hardware stats to record the execution time if they're available, |
368 | // otherwise use the regular (less accurate) stats |
369 | string node_name = dev_stats.regular_stats->node_stats(i).node_name(); |
370 | if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) != |
371 | name_to_hw_node_stats.end()) { |
372 | const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name]; |
373 | cm->RecordMaxExecutionTime( |
374 | node, Microseconds(hw_stats.op_end_rel_micros())); |
375 | } else { |
376 | cm->RecordMaxExecutionTime(node, |
377 | Microseconds(stats.op_end_rel_micros())); |
378 | } |
379 | } |
380 | } |
381 | } |
382 | } |
383 | |
384 | void StepStatsCollector::Save(const string& device, |
385 | NodeExecStats* node_stats_pb) { |
386 | Save(device, |
387 | new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb), |
388 | nullptr, this)); |
389 | } |
390 | |
391 | void StepStatsCollector::Save(const string& device, |
392 | NodeExecStatsWrapper* node_stats) { |
393 | if (!node_stats) return; |
394 | VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats(); |
395 | { |
396 | mutex_lock l(mu_); |
397 | if (finalized_) { |
398 | LOG(WARNING) << "stats saved after finalize will not be collected." ; |
399 | } |
400 | if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) { |
401 | VLOG(1) << "step_stats_ nullptr or already collected too many nodes." ; |
402 | delete node_stats; |
403 | return; |
404 | } |
405 | auto& device_stats = dev_stats_[device]; |
406 | device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats)); |
407 | collected_nodes_++; |
408 | } |
409 | } |
410 | |
411 | void StepStatsCollector::SaveThreadName(const string& device, |
412 | const uint32 thread_id, |
413 | const string& thread_name) { |
414 | VLOG(1) << "Save dev " << device << " thread id " << thread_id << " name " |
415 | << thread_name; |
416 | { |
417 | mutex_lock l(mu_); |
418 | if (finalized_) { |
419 | LOG(WARNING) << "thread_name saved after finalize will not be collected." ; |
420 | } |
421 | auto& thread_names_map = thread_names_[device]; |
422 | thread_names_map[thread_id] = thread_name; |
423 | } |
424 | } |
425 | |
426 | NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats( |
427 | const NodeDef* node) { |
428 | // Only collect statistics for non-transfer nodes. |
429 | if (IsSend(node) || IsRecv(node)) { |
430 | return nullptr; |
431 | } |
432 | return new NodeExecStatsWrapper(node, this); |
433 | } |
434 | |
435 | string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) { |
436 | mutex_lock l(mu_); |
437 | if (err.find("OOM" ) == err.npos) { |
438 | return "" ; |
439 | } |
440 | // <device, allocator> -> AllocStats |
441 | std::map<std::pair<string, string>, AllocStats> allocs_map; |
442 | string report = "\n" ; |
443 | for (const auto& dev_stat : dev_stats_) { |
444 | const string& device = dev_stat.first; |
445 | // Only print the device that has OOM. |
446 | // TODO(xpan): Extract device from err first to speed it up. |
447 | if (err.find(device) == err.npos) { |
448 | continue; |
449 | } |
450 | // NodeExecStatsWrapper* |
451 | for (const auto& stats : dev_stat.second) { |
452 | // std::pair<AllocatorMemoryUsed*, TrackingAllocator*> |
453 | for (const auto& alloc : stats->allocations_) { |
454 | // Only print the allocator that has OOM. |
455 | // TODO(xpan): Extract device from err first to speed it up. |
456 | if (err.find(alloc.first->allocator_name()) == err.npos) { |
457 | continue; |
458 | } |
459 | auto dev_allocator = |
460 | std::make_pair(dev_stat.first, alloc.first->allocator_name()); |
461 | AllocStats& dev_allocs_stats = allocs_map[dev_allocator]; |
462 | TrackingAllocator* tracking_alloc = alloc.second; |
463 | gtl::InlinedVector<AllocRecord, 4> cur_records = |
464 | tracking_alloc->GetCurrentRecords(); |
465 | int64_t cur_bytes = 0; |
466 | for (const auto& r : cur_records) { |
467 | cur_bytes += r.alloc_bytes; |
468 | } |
469 | if (cur_bytes > 0) { |
470 | dev_allocs_stats.total_bytes += cur_bytes; |
471 | dev_allocs_stats.total_nodes++; |
472 | dev_allocs_stats.nodes_by_size[cur_bytes].push_back( |
473 | stats->stats()->node_name()); |
474 | } |
475 | } |
476 | } |
477 | } |
478 | |
479 | for (const auto& dev_allocs_it : allocs_map) { |
480 | const auto& dev = dev_allocs_it.first; |
481 | const AllocStats& dev_allocs_stats = dev_allocs_it.second; |
482 | int64_t reported_bytes = 0; |
483 | int64_t reported_nodes = 0; |
484 | bool done = false; |
485 | strings::StrAppend(&report, "\nCurrent usage from device: " , dev.first, |
486 | ", allocator: " , dev.second, "\n" ); |
487 | // Print allocations stats of the <device, allocator> pair. |
488 | for (auto it = dev_allocs_stats.nodes_by_size.rbegin(); |
489 | it != dev_allocs_stats.nodes_by_size.rend(); ++it) { |
490 | for (const string& node_name : it->second) { |
491 | reported_bytes += it->first; |
492 | strings::StrAppend(&report, " " , |
493 | strings::HumanReadableNumBytes(it->first), " from " , |
494 | node_name, "\n" ); |
495 | if (++reported_nodes > kMaxAllocReportNodes || |
496 | reported_bytes >= |
497 | dev_allocs_stats.total_bytes * kMaxAllocReportFraction) { |
498 | done = true; |
499 | break; |
500 | } |
501 | } |
502 | if (done) break; |
503 | } |
504 | int64_t remain_nodes = dev_allocs_stats.total_nodes - reported_nodes; |
505 | int64_t remain_bytes = dev_allocs_stats.total_bytes - reported_bytes; |
506 | if (remain_nodes > 0) { |
507 | strings::StrAppend(&report, " Remaining " , remain_nodes, " nodes with " , |
508 | strings::HumanReadableNumBytes(remain_bytes), "\n" ); |
509 | } |
510 | } |
511 | return report; |
512 | } |
513 | |
514 | void StepStatsCollector::Finalize() { |
515 | mutex_lock l(mu_); |
516 | FinalizeInternal(); |
517 | } |
518 | |
519 | void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) { |
520 | mutex_lock l(mu_); |
521 | CHECK(step_stats_); |
522 | FinalizeInternal(); |
523 | step_stats->Swap(step_stats_); |
524 | collected_nodes_ = 0; |
525 | } |
526 | |
527 | void StepStatsCollector::FinalizeInternal() { |
528 | if (!step_stats_ || finalized_) { |
529 | return; |
530 | } |
531 | finalized_ = true; |
532 | std::map<string, DeviceStepStats*> dev_stats_pb; |
533 | for (auto& ds : *step_stats_->mutable_dev_stats()) { |
534 | dev_stats_pb[ds.device()] = &ds; |
535 | } |
536 | for (const auto& dev_stat : dev_stats_) { |
537 | if (dev_stats_pb.find(dev_stat.first) == dev_stats_pb.end()) { |
538 | DeviceStepStats* ndev_stat = step_stats_->add_dev_stats(); |
539 | ndev_stat->set_device(dev_stat.first); |
540 | dev_stats_pb[dev_stat.first] = ndev_stat; |
541 | } |
542 | DeviceStepStats* dss = dev_stats_pb.at(dev_stat.first); |
543 | for (auto& stats : dev_stat.second) { |
544 | stats->Finalize(); |
545 | stats->stats()->Swap(dss->add_node_stats()); |
546 | } |
547 | } |
548 | for (const auto& device_thread : thread_names_) { |
549 | if (dev_stats_pb.find(device_thread.first) == dev_stats_pb.end()) { |
550 | // skip device without DeviceStepStats. |
551 | continue; |
552 | } |
553 | DeviceStepStats* dss = dev_stats_pb.at(device_thread.first); |
554 | for (const auto& thread_name : device_thread.second) { |
555 | (*dss->mutable_thread_names())[thread_name.first] = thread_name.second; |
556 | } |
557 | } |
558 | } |
559 | } // namespace tensorflow |
560 | |