1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/graph/validate.h" |
17 | |
18 | #include "absl/container/flat_hash_set.h" |
19 | #include "absl/strings/string_view.h" |
20 | #include "tensorflow/core/framework/graph_def_util.h" |
21 | #include "tensorflow/core/framework/node_def.pb.h" |
22 | #include "tensorflow/core/framework/node_def_util.h" |
23 | #include "tensorflow/core/framework/op_def_util.h" |
24 | #include "tensorflow/core/framework/versions.pb.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace graph { |
30 | |
31 | Status ValidateGraphDef(const GraphDef& graph_def, |
32 | const OpRegistryInterface& op_registry) { |
33 | Status s; |
34 | const int version = graph_def.versions().producer(); |
35 | for (const NodeDef& node_def : graph_def.node()) { |
36 | // Look up the OpDef for the node_def's op name. |
37 | const OpDef* op_def; |
38 | TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def.op(), &op_def)); |
39 | TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def)); |
40 | TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, version)); |
41 | } |
42 | |
43 | return s; |
44 | } |
45 | |
46 | Status ValidateGraphDefAgainstOpRegistry( |
47 | const GraphDef& graph_def, const OpRegistryInterface& op_registry) { |
48 | GraphDef copy(graph_def); |
49 | TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©, op_registry, 0)); |
50 | return ValidateGraphDef(copy, op_registry); |
51 | } |
52 | |
53 | Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, |
54 | const OpList& op_list) { |
55 | OpListOpRegistry registry(&op_list); |
56 | return ValidateGraphDefAgainstOpRegistry(graph_def, registry); |
57 | } |
58 | |
59 | void GetOpListForValidation(OpList* op_list, const OpRegistry& op_registry) { |
60 | op_registry.Export(false, op_list); |
61 | RemoveDescriptionsFromOpList(op_list); |
62 | } |
63 | |
64 | Status ValidateGraphHasNoCycle(const Graph& graph) { |
65 | // A node is ready when all of its inputs have been visited. |
66 | std::vector<const Node*> ready; |
67 | std::vector<int> pending_count(graph.num_node_ids(), 0); |
68 | |
69 | for (int i = 0; i < graph.num_node_ids(); ++i) { |
70 | const Node* n = graph.FindNodeId(i); |
71 | if (n == nullptr) continue; |
72 | pending_count[i] = n->in_edges().size(); |
73 | if (n->IsMerge()) { |
74 | // While-loop cycles are legal cycles so we manually adjust the |
75 | // pending_count to make sure that the loop is visited. |
76 | for (const Edge* e : n->in_edges()) { |
77 | if (!e->IsControlEdge() && e->src()->IsNextIteration()) { |
78 | pending_count[i]--; |
79 | } |
80 | } |
81 | } |
82 | if (pending_count[i] == 0) { |
83 | ready.push_back(n); |
84 | } |
85 | } |
86 | |
87 | int processed = 0; |
88 | while (!ready.empty()) { |
89 | const Node* node = ready.back(); |
90 | ready.pop_back(); |
91 | ++processed; |
92 | |
93 | for (const Edge* out : node->out_edges()) { |
94 | const int output_id = out->dst()->id(); |
95 | pending_count[output_id]--; |
96 | if (pending_count[output_id] == 0) { |
97 | ready.push_back(out->dst()); |
98 | } |
99 | } |
100 | } |
101 | |
102 | if (processed < graph.num_nodes()) { |
103 | std::vector<string> nodes_in_cycle; |
104 | for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; |
105 | ++i) { |
106 | if (pending_count[i] != 0) { |
107 | nodes_in_cycle.push_back(graph.FindNodeId(i)->name()); |
108 | } |
109 | } |
110 | return errors::InvalidArgument( |
111 | "Graph is invalid, contains a cycle with " , |
112 | graph.num_nodes() - processed, |
113 | " nodes, including: " , absl::StrJoin(nodes_in_cycle, ", " )); |
114 | } |
115 | return OkStatus(); |
116 | } |
117 | |
118 | Status VerifyNoDuplicateNodeNames(const GraphDef& graph) { |
119 | absl::flat_hash_set<absl::string_view> nodes; |
120 | for (const auto& node : graph.node()) { |
121 | if (nodes.contains(node.name())) { |
122 | return errors::AlreadyExists("Node already exists: " , node.name()); |
123 | } |
124 | nodes.insert(node.name()); |
125 | } |
126 | return OkStatus(); |
127 | } |
128 | |
129 | } // namespace graph |
130 | } // namespace tensorflow |
131 | |