1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/grappler/cost_analyzer.h" |
17 | |
18 | #include <iomanip> |
19 | #include "tensorflow/core/grappler/costs/utils.h" |
20 | #include "tensorflow/core/grappler/grappler_item.h" |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace grappler { |
25 | |
26 | CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster, |
27 | const string& suffix) |
28 | : item_(&item), |
29 | measure_estimator_(cluster, 10, 0), |
30 | analytical_estimator_(cluster, /*use_static_shapes=*/false, |
31 | /*use_aggressive_shape_inference=*/true), |
32 | suffix_(suffix) {} |
33 | |
34 | Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report, |
35 | bool verbose) { |
36 | GatherCosts(); |
37 | PreprocessCosts(); |
38 | AnalyzeCosts(); |
39 | PrintAnalysis(os, per_node_report, verbose); |
40 | return OkStatus(); |
41 | } |
42 | |
43 | void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator, |
44 | CostGraphDef* cost_graph, int64_t* total_time) { |
45 | TF_CHECK_OK(cost_estimator->Initialize(*item_)); |
46 | RunMetadata run_metadata; |
47 | Costs costs; |
48 | const Status status = |
49 | cost_estimator->PredictCosts(item_->graph, &run_metadata, &costs); |
50 | if (cost_graph) { |
51 | cost_graph->Swap(run_metadata.mutable_cost_graph()); |
52 | } |
53 | *total_time = costs.execution_time.count(); |
54 | if (!status.ok()) { |
55 | LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": " |
56 | << status.error_message(); |
57 | return; |
58 | } |
59 | } |
60 | |
61 | void CostAnalyzer::GatherCosts() { |
62 | CostGraphDef cost_graph_measured; |
63 | PredictCosts(&measure_estimator_, &cost_graph_measured, |
64 | &total_time_measured_); |
65 | VLOG(1) << "Graph size: " << item_->graph.node_size(); |
66 | VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size(); |
67 | |
68 | CostGraphDef cost_graph_analytical; |
69 | PredictCosts(&analytical_estimator_, &cost_graph_analytical, |
70 | &total_time_analytical_); |
71 | VLOG(1) << "cost_graph_analytical size: " |
72 | << cost_graph_analytical.node_size(); |
73 | |
74 | CostGraphDef cost_graph_analytical_filtered; |
75 | CostGraphDef cost_graph_measured_filtered; |
76 | std::map<string, const CostGraphDef_Node*> measured_nodes; |
77 | for (const auto& node : cost_graph_measured.node()) { |
78 | measured_nodes[node.name()] = &node; |
79 | } |
80 | for (const auto& node : cost_graph_analytical.node()) { |
81 | auto it = measured_nodes.find(node.name()); |
82 | // Filter the nodes that are not the cost nodes returned by |
83 | // MeasuringCostEstimator. |
84 | if (it == measured_nodes.end()) { |
85 | continue; |
86 | } |
87 | auto added_node_analytical = cost_graph_analytical_filtered.add_node(); |
88 | auto added_node_measured = cost_graph_measured_filtered.add_node(); |
89 | *added_node_analytical = node; |
90 | *added_node_measured = *(it->second); |
91 | } |
92 | VLOG(1) << "cost_graph_analytical_filtered size: " |
93 | << cost_graph_analytical_filtered.node_size(); |
94 | |
95 | // TODO(yaozhang): add a test to make sure that op_perf_analytical_ and |
96 | // op_perf_ cover the same set of nodes. |
97 | op_perf_analytical_ = CostGraphToOpPerformanceData( |
98 | cost_graph_analytical_filtered, item_->graph); |
99 | op_perf_ = |
100 | CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph); |
101 | } |
102 | |
103 | void CostAnalyzer::PreprocessCosts() { |
104 | for (int i = 0; i < op_perf_.op_performance_size(); i++) { |
105 | OpPerformance* perf = op_perf_.mutable_op_performance(i); |
106 | const OpPerformance& analytical = op_perf_analytical_.op_performance(i); |
107 | perf->set_compute_time(analytical.compute_time()); |
108 | perf->set_memory_time(analytical.memory_time()); |
109 | double measured_cost = perf->compute_cost(); |
110 | |
111 | double analytical_compute_cost = analytical.compute_time(); |
112 | if (analytical_compute_cost == 0) { |
113 | // Negative infinity indicates unavailable data. |
114 | perf->set_compute_efficiency(-INFINITY); |
115 | } else { |
116 | perf->set_compute_efficiency(analytical_compute_cost / measured_cost); |
117 | } |
118 | |
119 | double analytical_memory_cost = analytical.memory_time(); |
120 | if (analytical_memory_cost == 0) { |
121 | // Negative infinity indicates unavailable data. |
122 | perf->set_memory_efficiency(-INFINITY); |
123 | } else { |
124 | perf->set_memory_efficiency(analytical_memory_cost / measured_cost); |
125 | } |
126 | } |
127 | } |
128 | |
129 | void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) { |
130 | for (const auto& op : ops) { |
131 | ops_.push_back(op.second); |
132 | } |
133 | struct CompareByTime { |
134 | bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const { |
135 | return a.time > b.time; |
136 | } |
137 | }; |
138 | std::stable_sort(ops_.begin(), ops_.end(), CompareByTime()); |
139 | } |
140 | |
141 | void CostAnalyzer::AnalyzeCosts() { |
142 | std::map<string, OpPerfSummary> ops; |
143 | for (const auto& op_perf : op_perf_.op_performance()) { |
144 | string op_name = op_perf.op().op(); |
145 | ops[op_name].count++; |
146 | ops[op_name].time += op_perf.compute_cost(); |
147 | ops[op_name].compute_time += op_perf.compute_time(); |
148 | ops[op_name].memory_time += op_perf.memory_time(); |
149 | ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time(); |
150 | ops[op_name].time_lower += |
151 | std::max(op_perf.compute_time(), op_perf.memory_time()); |
152 | ops[op_name].name = op_name; |
153 | } |
154 | SortOpsByTime(ops); |
155 | |
156 | total_time_measured_serialized_ = 0; |
157 | total_time_analytical_upper_ = 0; |
158 | total_time_analytical_lower_ = 0; |
159 | for (const auto& op : ops_) { |
160 | total_time_measured_serialized_ += op.time; |
161 | total_time_analytical_upper_ += op.time_upper; |
162 | total_time_analytical_lower_ += op.time_lower; |
163 | } |
164 | } |
165 | |
166 | void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report, |
167 | bool verbose) const { |
168 | os << std::endl; |
169 | os << std::left << std::setw(50) |
170 | << "Total time measured in ns (serialized): " << std::right |
171 | << std::setw(20) << total_time_measured_serialized_ << std::endl; |
172 | os << std::left << std::setw(50) |
173 | << "Total time measured in ns (actual): " << std::right << std::setw(20) |
174 | << total_time_measured_ << std::endl; |
175 | os << std::left << std::setw(50) |
176 | << "Total time analytical in ns (upper bound): " << std::right |
177 | << std::setw(20) << total_time_analytical_upper_ << std::endl; |
178 | os << std::left << std::setw(50) |
179 | << "Total time analytical in ns (lower bound): " << std::right |
180 | << std::setw(20) << total_time_analytical_lower_ << std::endl; |
181 | double efficiency_upper = static_cast<double>(total_time_analytical_upper_) / |
182 | static_cast<double>(total_time_measured_); |
183 | os << std::left << std::setw(50) |
184 | << "Overall efficiency (analytical upper/actual): " << std::right |
185 | << std::setw(20) << efficiency_upper << std::endl; |
186 | double efficiency_lower = static_cast<double>(total_time_analytical_lower_) / |
187 | static_cast<double>(total_time_measured_); |
188 | os << std::left << std::setw(50) |
189 | << "Overall efficiency (analytical lower/actual): " << std::right |
190 | << std::setw(20) << efficiency_lower << std::endl; |
191 | os << std::endl; |
192 | |
193 | int width = 35; |
194 | int width_narrow = 15; |
195 | int width_wide = 20; |
196 | os << std::setw(width + 1) << "Op," ; |
197 | os << std::setw(width_narrow + 1) << "Count," ; |
198 | os << std::setw(width_wide + 1) << "Measured time (ns)," ; |
199 | os << std::setw(width_narrow + 2) << "Time percent," ; |
200 | os << std::setw(width_narrow + 2) << "Acc percent," ; |
201 | os << std::setw(width_wide + 1) << "Analytical upper," ; |
202 | os << std::setw(width_wide + 1) << "Analytical lower," ; |
203 | os << std::setw(width_narrow + 2) << "Overall eff" ; |
204 | os << std::setw(width_narrow + 2) << "Compute eff" ; |
205 | os << std::setw(width_narrow + 2) << "Memory eff" << std::endl; |
206 | float acc_percent = 0; |
207 | for (const auto& op : ops_) { |
208 | double percent = static_cast<double>(op.time) / |
209 | static_cast<double>(total_time_measured_serialized_); |
210 | double eff = |
211 | static_cast<double>(op.time_upper) / static_cast<double>(op.time); |
212 | double compute_eff = |
213 | static_cast<double>(op.compute_time) / static_cast<double>(op.time); |
214 | double memory_eff = |
215 | static_cast<double>(op.memory_time) / static_cast<double>(op.time); |
216 | os << std::setw(width) << op.name << "," ; |
217 | os << std::setw(width_narrow) << op.count << "," ; |
218 | os << std::setw(width_wide) << op.time << "," ; |
219 | os << std::setw(width_narrow) << std::setprecision(2) << percent * 100 |
220 | << "%," ; |
221 | acc_percent += percent; |
222 | os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100 |
223 | << "%," ; |
224 | os << std::setw(width_wide) << op.time_upper << "," ; |
225 | os << std::setw(width_wide) << op.time_lower << "," ; |
226 | os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%," ; |
227 | os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100 |
228 | << "%," ; |
229 | os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100 |
230 | << "%," ; |
231 | os << std::endl; |
232 | } |
233 | os << std::endl; |
234 | |
235 | if (per_node_report) { |
236 | if (verbose) { |
237 | os << "Below is the full per-node report:" << std::endl; |
238 | os << op_perf_.DebugString(); |
239 | } else { |
240 | os << "Below is the per-node report summary:" << std::endl; |
241 | int width = 35; |
242 | int width_narrow = 15; |
243 | int width_wide = 20; |
244 | os << std::setw(width + 1) << "Op," ; |
245 | os << std::setw(width_wide + 1) << "Measured time (ns)," ; |
246 | os << std::setw(width_wide + 1) << "Compute time (ns)," ; |
247 | os << std::setw(width_wide + 1) << "Memory time (ns)," ; |
248 | os << std::setw(width_narrow + 2) << "Compute eff," ; |
249 | os << std::setw(width_narrow + 2) << "Memory eff," ; |
250 | os << " Inputs" << std::endl; |
251 | for (int i = 0; i < op_perf_.op_performance_size(); i++) { |
252 | const auto& perf = op_perf_.op_performance(i); |
253 | string op_name = perf.op().op(); |
254 | os << std::setw(width) << op_name << "," ; |
255 | os << std::setw(width_wide) << perf.compute_cost() << "," ; |
256 | os << std::setw(width_wide) << perf.compute_time() << "," ; |
257 | os << std::setw(width_wide) << perf.memory_time() << "," ; |
258 | os << std::setw(width_narrow) << std::setprecision(2) |
259 | << perf.compute_efficiency() * 100 << "%," ; |
260 | os << std::setw(width_narrow) << std::setprecision(2) |
261 | << perf.memory_efficiency() * 100 << "%," ; |
262 | os << " [" ; |
263 | for (int j = 0; j < perf.op().inputs_size(); j++) { |
264 | const auto& shape = perf.op().inputs(j).shape(); |
265 | if (shape.dim_size() > 0) { |
266 | os << "(" ; |
267 | std::vector<int> dims; |
268 | for (int k = 0; k < shape.dim_size(); k++) { |
269 | os << shape.dim(k).size(); |
270 | if (k < shape.dim_size() - 1) { |
271 | os << ", " ; |
272 | } |
273 | } |
274 | os << ")" ; |
275 | if (j < perf.op().inputs_size() - 1) { |
276 | os << ", " ; |
277 | } |
278 | } |
279 | } |
280 | os << "]" << std::endl; |
281 | } |
282 | os << std::endl; |
283 | } |
284 | } |
285 | } |
286 | } // end namespace grappler |
287 | } // end namespace tensorflow |
288 | |