1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/graph/control_flow.h" |
17 | |
18 | #include <deque> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/node_def_util.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/graph/node_builder.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace { |
28 | // Information about a loop frame structure. |
29 | struct Frame { |
30 | string name; |
31 | |
32 | // Pointer to the parent frame. The root frame has a pointer to itself. |
33 | Frame* parent = nullptr; |
34 | |
35 | // The loop condition of the loop. There should be exactly one loop condition |
36 | // in every loop. |
37 | const Node* loop_cond = nullptr; |
38 | }; |
39 | |
40 | // Verify that the ControlFlowInfo of the graph has valid loop structure. |
41 | Status ValidateControlFlowInfo(const Graph* graph, |
42 | const std::vector<ControlFlowInfo>& cf_info) { |
43 | std::unordered_map<string, Frame> frames; |
44 | for (const Node* node : graph->op_nodes()) { |
45 | const ControlFlowInfo& cf = cf_info[node->id()]; |
46 | if (!cf.frame || !cf.parent_frame) { |
47 | // Skip nodes unreachable from the source node. They might be pruned |
48 | // later. |
49 | continue; |
50 | } |
51 | |
52 | Frame& frame = frames[cf.frame_name]; |
53 | Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; |
54 | if (frame.parent == nullptr) { |
55 | frame.parent = parent; |
56 | frame.name = cf.frame_name; |
57 | } else if (frame.parent != parent) { |
58 | return errors::Internal( |
59 | "Invalid loop structure: Mismatched parent frames for \"" , |
60 | cf.frame_name, "\": \"" , parent->name, "\" vs \"" , frame.parent->name, |
61 | "\". The node giving this error: " , FormatNodeForError(*node), |
62 | ". This is an internal bug, please file a bug report with " |
63 | "instructions on how to reproduce the error." ); |
64 | } |
65 | if (IsLoopCond(node)) { |
66 | // ForwardLoopCounter runs in the same frame as the forward loop and |
67 | // BackPropLoopCounter runs in the same frame as the backprop loop. They |
68 | // are the only cases that multiple loops share the same frame. |
69 | if (frame.loop_cond && |
70 | !absl::StrContains(frame.loop_cond->name(), "LoopCounter" ) && |
71 | !absl::StrContains(node->name(), "LoopCounter" )) { |
72 | return errors::InvalidArgument( |
73 | "Invalid loop structure: Loop \"" , cf.frame_name, |
74 | "\" has more than one LoopCond node: " , FormatNodeForError(*node), |
75 | " and " , FormatNodeForError(*frame.loop_cond), |
76 | ". This is an internal bug, please file a bug report with " |
77 | "instructions on how to reproduce the error." ); |
78 | } |
79 | frame.loop_cond = node; |
80 | } |
81 | } |
82 | return OkStatus(); |
83 | } |
84 | } // namespace |
85 | |
86 | Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info, |
87 | std::vector<string>* unreachable_nodes) { |
88 | info->clear(); |
89 | info->resize(g->num_node_ids()); |
90 | |
91 | std::vector<const Node*> parent_nodes; |
92 | parent_nodes.resize(g->num_node_ids()); |
93 | |
94 | const Node* src_node = g->source_node(); |
95 | ControlFlowInfo& src_info = (*info)[src_node->id()]; |
96 | src_info.frame = src_node; |
97 | src_info.parent_frame = src_node; |
98 | |
99 | string frame_name; |
100 | std::deque<const Node*> ready; |
101 | ready.push_back(src_node); |
102 | while (!ready.empty()) { |
103 | const Node* curr_node = ready.front(); |
104 | ready.pop_front(); |
105 | const ControlFlowInfo& curr_info = (*info)[curr_node->id()]; |
106 | const Node* frame = curr_info.frame; |
107 | const Node* parent = curr_info.parent_frame; |
108 | frame_name = curr_info.frame_name; |
109 | |
110 | if (IsExit(curr_node)) { |
111 | // Exit to the parent frame. |
112 | const ControlFlowInfo& parent_info = (*info)[parent->id()]; |
113 | frame = parent_info.frame; |
114 | parent = parent_info.parent_frame; |
115 | frame_name = parent_info.frame_name; |
116 | } |
117 | |
118 | for (const Edge* out_edge : curr_node->out_edges()) { |
119 | const Node* out = out_edge->dst(); |
120 | int out_id = out->id(); |
121 | ControlFlowInfo* out_info = &(*info)[out_id]; |
122 | const Node* out_parent = out_info->parent_frame; |
123 | bool is_visited = (parent_nodes[out_id] != nullptr); |
124 | |
125 | // Skip Sink/Source nodes. |
126 | if (!out->IsOp()) continue; |
127 | |
128 | // Add to ready queue if not seen. |
129 | if (!is_visited) { |
130 | parent_nodes[out->id()] = curr_node; |
131 | ready.push_back(out); |
132 | } |
133 | |
134 | // Process the node 'out'. |
135 | if (IsEnter(out)) { |
136 | if (is_visited) { |
137 | const string& parent_frame = (*info)[out_parent->id()].frame_name; |
138 | if (parent_frame != frame_name) { |
139 | return errors::InvalidArgument( |
140 | FormatNodeForError(*out), |
141 | " has inputs from different frames. The input " , |
142 | FormatNodeForError(*curr_node), " is in frame '" , frame_name, |
143 | "'. The input " , FormatNodeForError(*parent_nodes[out->id()]), |
144 | " is in frame '" , parent_frame, "'." ); |
145 | } |
146 | } else { |
147 | out_info->frame = out; |
148 | out_info->parent_frame = frame; |
149 | TF_RETURN_IF_ERROR( |
150 | GetNodeAttr(out->attrs(), "frame_name" , &out_info->frame_name)); |
151 | if (out_info->frame_name.empty()) { |
152 | return errors::InvalidArgument("The Enter " , |
153 | FormatNodeForError(*out), |
154 | " must have a frame name." ); |
155 | } |
156 | } |
157 | } else { |
158 | if (is_visited) { |
159 | if (out_info->frame_name != frame_name) { |
160 | return errors::InvalidArgument( |
161 | FormatNodeForError(*out), |
162 | " has inputs from different frames. The input " , |
163 | FormatNodeForError(*curr_node), " is in frame '" , frame_name, |
164 | "'. The input " , FormatNodeForError(*parent_nodes[out->id()]), |
165 | " is in frame '" , out_info->frame_name, "'." ); |
166 | } |
167 | } else { |
168 | out_info->frame = frame; |
169 | out_info->parent_frame = parent; |
170 | out_info->frame_name = frame_name; |
171 | } |
172 | } |
173 | } |
174 | } |
175 | if (unreachable_nodes) { |
176 | for (const Node* node : g->op_nodes()) { |
177 | if (!parent_nodes[node->id()]) { |
178 | unreachable_nodes->push_back(node->name()); |
179 | } |
180 | } |
181 | } |
182 | TF_RETURN_IF_ERROR(ValidateControlFlowInfo(g, *info)); |
183 | return OkStatus(); |
184 | } |
185 | |
186 | } // namespace tensorflow |
187 | |