1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/grappler/model_analyzer.h" |
17 | |
18 | #include <iomanip> |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
21 | #include "tensorflow/core/grappler/costs/graph_properties.h" |
22 | #include "tensorflow/core/grappler/grappler_item.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace grappler { |
26 | |
27 | ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {} |
28 | |
29 | Status ModelAnalyzer::GenerateReport(bool debug, bool assume_valid_feeds, |
30 | std::ostream& os) { |
31 | GraphProperties properties(item_); |
32 | TF_RETURN_IF_ERROR(properties.InferStatically(assume_valid_feeds)); |
33 | |
34 | for (const auto& node : item_.MainOpsFanin()) { |
35 | PrintNodeInfo(node, properties, debug, os); |
36 | } |
37 | for (const auto& node : item_.EnqueueOpsFanin()) { |
38 | PrintNodeInfo(node, properties, debug, os); |
39 | } |
40 | |
41 | return OkStatus(); |
42 | } |
43 | |
44 | void ModelAnalyzer::PrintNodeInfo(const NodeDef* node, |
45 | const GraphProperties& properties, bool debug, |
46 | std::ostream& os) const { |
47 | os << node->name() << " [" << node->op() << "]" << std::endl; |
48 | if (properties.HasOutputProperties(node->name())) { |
49 | const std::vector<OpInfo::TensorProperties>& props = |
50 | properties.GetOutputProperties(node->name()); |
51 | for (int i = 0, props_size = props.size(); i < props_size; ++i) { |
52 | const OpInfo::TensorProperties& prop = props[i]; |
53 | os << "\t" |
54 | << "output " << i << " (" << DataTypeString(prop.dtype()) |
55 | << ") has shape " ; |
56 | if (prop.shape().unknown_rank()) { |
57 | os << "?" ; |
58 | } else { |
59 | os << "[" ; |
60 | for (int i = 0; i < prop.shape().dim_size(); ++i) { |
61 | if (i > 0) { |
62 | os << ", " ; |
63 | } |
64 | if (prop.shape().dim(i).size() >= 0) { |
65 | // Print the actual dimension. |
66 | os << prop.shape().dim(i).size(); |
67 | } else if (prop.shape().dim(i).size() == -1) { |
68 | // We don't know anything about the dimension. |
69 | os << "?" ; |
70 | } else { |
71 | // Symbolic dimension. |
72 | os << "x" << -prop.shape().dim(i).size(); |
73 | } |
74 | } |
75 | os << "]" ; |
76 | } |
77 | os << std::endl; |
78 | } |
79 | } |
80 | |
81 | if (debug) { |
82 | const OpRegistrationData* op_reg_data; |
83 | Status status = OpRegistry::Global()->LookUp(node->op(), &op_reg_data); |
84 | if (!status.ok()) { |
85 | os << "\tCouldn't find op registration for " << node->op() << std::endl; |
86 | } else if (!op_reg_data->shape_inference_fn) { |
87 | os << "\tCouldn't find shape function for op " << node->op() << std::endl; |
88 | } else if (properties.HasInputProperties(node->name())) { |
89 | const std::vector<OpInfo::TensorProperties>& props = |
90 | properties.GetInputProperties(node->name()); |
91 | for (int i = 0, props_size = props.size(); i < props_size; ++i) { |
92 | const OpInfo::TensorProperties& prop = props[i]; |
93 | if (prop.has_value()) { |
94 | os << "\t" |
95 | << "input " << i << " (" << DataTypeString(prop.dtype()) |
96 | << ") has known value" << std::endl; |
97 | } |
98 | } |
99 | } |
100 | } |
101 | } |
102 | |
103 | } // end namespace grappler |
104 | } // end namespace tensorflow |
105 | |