1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/python/grappler/cost_analyzer.h"
17
18#include <iomanip>
19#include "tensorflow/core/grappler/costs/utils.h"
20#include "tensorflow/core/grappler/grappler_item.h"
21#include "tensorflow/core/lib/core/status.h"
22
23namespace tensorflow {
24namespace grappler {
25
26CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
27 const string& suffix)
28 : item_(&item),
29 measure_estimator_(cluster, 10, 0),
30 analytical_estimator_(cluster, /*use_static_shapes=*/false,
31 /*use_aggressive_shape_inference=*/true),
32 suffix_(suffix) {}
33
34Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report,
35 bool verbose) {
36 GatherCosts();
37 PreprocessCosts();
38 AnalyzeCosts();
39 PrintAnalysis(os, per_node_report, verbose);
40 return OkStatus();
41}
42
43void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator,
44 CostGraphDef* cost_graph, int64_t* total_time) {
45 TF_CHECK_OK(cost_estimator->Initialize(*item_));
46 RunMetadata run_metadata;
47 Costs costs;
48 const Status status =
49 cost_estimator->PredictCosts(item_->graph, &run_metadata, &costs);
50 if (cost_graph) {
51 cost_graph->Swap(run_metadata.mutable_cost_graph());
52 }
53 *total_time = costs.execution_time.count();
54 if (!status.ok()) {
55 LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": "
56 << status.error_message();
57 return;
58 }
59}
60
61void CostAnalyzer::GatherCosts() {
62 CostGraphDef cost_graph_measured;
63 PredictCosts(&measure_estimator_, &cost_graph_measured,
64 &total_time_measured_);
65 VLOG(1) << "Graph size: " << item_->graph.node_size();
66 VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size();
67
68 CostGraphDef cost_graph_analytical;
69 PredictCosts(&analytical_estimator_, &cost_graph_analytical,
70 &total_time_analytical_);
71 VLOG(1) << "cost_graph_analytical size: "
72 << cost_graph_analytical.node_size();
73
74 CostGraphDef cost_graph_analytical_filtered;
75 CostGraphDef cost_graph_measured_filtered;
76 std::map<string, const CostGraphDef_Node*> measured_nodes;
77 for (const auto& node : cost_graph_measured.node()) {
78 measured_nodes[node.name()] = &node;
79 }
80 for (const auto& node : cost_graph_analytical.node()) {
81 auto it = measured_nodes.find(node.name());
82 // Filter the nodes that are not the cost nodes returned by
83 // MeasuringCostEstimator.
84 if (it == measured_nodes.end()) {
85 continue;
86 }
87 auto added_node_analytical = cost_graph_analytical_filtered.add_node();
88 auto added_node_measured = cost_graph_measured_filtered.add_node();
89 *added_node_analytical = node;
90 *added_node_measured = *(it->second);
91 }
92 VLOG(1) << "cost_graph_analytical_filtered size: "
93 << cost_graph_analytical_filtered.node_size();
94
95 // TODO(yaozhang): add a test to make sure that op_perf_analytical_ and
96 // op_perf_ cover the same set of nodes.
97 op_perf_analytical_ = CostGraphToOpPerformanceData(
98 cost_graph_analytical_filtered, item_->graph);
99 op_perf_ =
100 CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph);
101}
102
103void CostAnalyzer::PreprocessCosts() {
104 for (int i = 0; i < op_perf_.op_performance_size(); i++) {
105 OpPerformance* perf = op_perf_.mutable_op_performance(i);
106 const OpPerformance& analytical = op_perf_analytical_.op_performance(i);
107 perf->set_compute_time(analytical.compute_time());
108 perf->set_memory_time(analytical.memory_time());
109 double measured_cost = perf->compute_cost();
110
111 double analytical_compute_cost = analytical.compute_time();
112 if (analytical_compute_cost == 0) {
113 // Negative infinity indicates unavailable data.
114 perf->set_compute_efficiency(-INFINITY);
115 } else {
116 perf->set_compute_efficiency(analytical_compute_cost / measured_cost);
117 }
118
119 double analytical_memory_cost = analytical.memory_time();
120 if (analytical_memory_cost == 0) {
121 // Negative infinity indicates unavailable data.
122 perf->set_memory_efficiency(-INFINITY);
123 } else {
124 perf->set_memory_efficiency(analytical_memory_cost / measured_cost);
125 }
126 }
127}
128
129void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) {
130 for (const auto& op : ops) {
131 ops_.push_back(op.second);
132 }
133 struct CompareByTime {
134 bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const {
135 return a.time > b.time;
136 }
137 };
138 std::stable_sort(ops_.begin(), ops_.end(), CompareByTime());
139}
140
141void CostAnalyzer::AnalyzeCosts() {
142 std::map<string, OpPerfSummary> ops;
143 for (const auto& op_perf : op_perf_.op_performance()) {
144 string op_name = op_perf.op().op();
145 ops[op_name].count++;
146 ops[op_name].time += op_perf.compute_cost();
147 ops[op_name].compute_time += op_perf.compute_time();
148 ops[op_name].memory_time += op_perf.memory_time();
149 ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time();
150 ops[op_name].time_lower +=
151 std::max(op_perf.compute_time(), op_perf.memory_time());
152 ops[op_name].name = op_name;
153 }
154 SortOpsByTime(ops);
155
156 total_time_measured_serialized_ = 0;
157 total_time_analytical_upper_ = 0;
158 total_time_analytical_lower_ = 0;
159 for (const auto& op : ops_) {
160 total_time_measured_serialized_ += op.time;
161 total_time_analytical_upper_ += op.time_upper;
162 total_time_analytical_lower_ += op.time_lower;
163 }
164}
165
166void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report,
167 bool verbose) const {
168 os << std::endl;
169 os << std::left << std::setw(50)
170 << "Total time measured in ns (serialized): " << std::right
171 << std::setw(20) << total_time_measured_serialized_ << std::endl;
172 os << std::left << std::setw(50)
173 << "Total time measured in ns (actual): " << std::right << std::setw(20)
174 << total_time_measured_ << std::endl;
175 os << std::left << std::setw(50)
176 << "Total time analytical in ns (upper bound): " << std::right
177 << std::setw(20) << total_time_analytical_upper_ << std::endl;
178 os << std::left << std::setw(50)
179 << "Total time analytical in ns (lower bound): " << std::right
180 << std::setw(20) << total_time_analytical_lower_ << std::endl;
181 double efficiency_upper = static_cast<double>(total_time_analytical_upper_) /
182 static_cast<double>(total_time_measured_);
183 os << std::left << std::setw(50)
184 << "Overall efficiency (analytical upper/actual): " << std::right
185 << std::setw(20) << efficiency_upper << std::endl;
186 double efficiency_lower = static_cast<double>(total_time_analytical_lower_) /
187 static_cast<double>(total_time_measured_);
188 os << std::left << std::setw(50)
189 << "Overall efficiency (analytical lower/actual): " << std::right
190 << std::setw(20) << efficiency_lower << std::endl;
191 os << std::endl;
192
193 int width = 35;
194 int width_narrow = 15;
195 int width_wide = 20;
196 os << std::setw(width + 1) << "Op,";
197 os << std::setw(width_narrow + 1) << "Count,";
198 os << std::setw(width_wide + 1) << "Measured time (ns),";
199 os << std::setw(width_narrow + 2) << "Time percent,";
200 os << std::setw(width_narrow + 2) << "Acc percent,";
201 os << std::setw(width_wide + 1) << "Analytical upper,";
202 os << std::setw(width_wide + 1) << "Analytical lower,";
203 os << std::setw(width_narrow + 2) << "Overall eff";
204 os << std::setw(width_narrow + 2) << "Compute eff";
205 os << std::setw(width_narrow + 2) << "Memory eff" << std::endl;
206 float acc_percent = 0;
207 for (const auto& op : ops_) {
208 double percent = static_cast<double>(op.time) /
209 static_cast<double>(total_time_measured_serialized_);
210 double eff =
211 static_cast<double>(op.time_upper) / static_cast<double>(op.time);
212 double compute_eff =
213 static_cast<double>(op.compute_time) / static_cast<double>(op.time);
214 double memory_eff =
215 static_cast<double>(op.memory_time) / static_cast<double>(op.time);
216 os << std::setw(width) << op.name << ",";
217 os << std::setw(width_narrow) << op.count << ",";
218 os << std::setw(width_wide) << op.time << ",";
219 os << std::setw(width_narrow) << std::setprecision(2) << percent * 100
220 << "%,";
221 acc_percent += percent;
222 os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100
223 << "%,";
224 os << std::setw(width_wide) << op.time_upper << ",";
225 os << std::setw(width_wide) << op.time_lower << ",";
226 os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%,";
227 os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100
228 << "%,";
229 os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100
230 << "%,";
231 os << std::endl;
232 }
233 os << std::endl;
234
235 if (per_node_report) {
236 if (verbose) {
237 os << "Below is the full per-node report:" << std::endl;
238 os << op_perf_.DebugString();
239 } else {
240 os << "Below is the per-node report summary:" << std::endl;
241 int width = 35;
242 int width_narrow = 15;
243 int width_wide = 20;
244 os << std::setw(width + 1) << "Op,";
245 os << std::setw(width_wide + 1) << "Measured time (ns),";
246 os << std::setw(width_wide + 1) << "Compute time (ns),";
247 os << std::setw(width_wide + 1) << "Memory time (ns),";
248 os << std::setw(width_narrow + 2) << "Compute eff,";
249 os << std::setw(width_narrow + 2) << "Memory eff,";
250 os << " Inputs" << std::endl;
251 for (int i = 0; i < op_perf_.op_performance_size(); i++) {
252 const auto& perf = op_perf_.op_performance(i);
253 string op_name = perf.op().op();
254 os << std::setw(width) << op_name << ",";
255 os << std::setw(width_wide) << perf.compute_cost() << ",";
256 os << std::setw(width_wide) << perf.compute_time() << ",";
257 os << std::setw(width_wide) << perf.memory_time() << ",";
258 os << std::setw(width_narrow) << std::setprecision(2)
259 << perf.compute_efficiency() * 100 << "%,";
260 os << std::setw(width_narrow) << std::setprecision(2)
261 << perf.memory_efficiency() * 100 << "%,";
262 os << " [";
263 for (int j = 0; j < perf.op().inputs_size(); j++) {
264 const auto& shape = perf.op().inputs(j).shape();
265 if (shape.dim_size() > 0) {
266 os << "(";
267 std::vector<int> dims;
268 for (int k = 0; k < shape.dim_size(); k++) {
269 os << shape.dim(k).size();
270 if (k < shape.dim_size() - 1) {
271 os << ", ";
272 }
273 }
274 os << ")";
275 if (j < perf.op().inputs_size() - 1) {
276 os << ", ";
277 }
278 }
279 }
280 os << "]" << std::endl;
281 }
282 os << std::endl;
283 }
284 }
285}
286} // end namespace grappler
287} // end namespace tensorflow
288