1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/control_flow.h"
17
18#include <deque>
19#include <vector>
20
21#include "tensorflow/core/framework/node_def_util.h"
22#include "tensorflow/core/framework/types.h"
23#include "tensorflow/core/graph/node_builder.h"
24#include "tensorflow/core/lib/core/errors.h"
25
26namespace tensorflow {
27namespace {
28// Information about a loop frame structure.
29struct Frame {
30 string name;
31
32 // Pointer to the parent frame. The root frame has a pointer to itself.
33 Frame* parent = nullptr;
34
35 // The loop condition of the loop. There should be exactly one loop condition
36 // in every loop.
37 const Node* loop_cond = nullptr;
38};
39
40// Verify that the ControlFlowInfo of the graph has valid loop structure.
41Status ValidateControlFlowInfo(const Graph* graph,
42 const std::vector<ControlFlowInfo>& cf_info) {
43 std::unordered_map<string, Frame> frames;
44 for (const Node* node : graph->op_nodes()) {
45 const ControlFlowInfo& cf = cf_info[node->id()];
46 if (!cf.frame || !cf.parent_frame) {
47 // Skip nodes unreachable from the source node. They might be pruned
48 // later.
49 continue;
50 }
51
52 Frame& frame = frames[cf.frame_name];
53 Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
54 if (frame.parent == nullptr) {
55 frame.parent = parent;
56 frame.name = cf.frame_name;
57 } else if (frame.parent != parent) {
58 return errors::Internal(
59 "Invalid loop structure: Mismatched parent frames for \"",
60 cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name,
61 "\". The node giving this error: ", FormatNodeForError(*node),
62 ". This is an internal bug, please file a bug report with "
63 "instructions on how to reproduce the error.");
64 }
65 if (IsLoopCond(node)) {
66 // ForwardLoopCounter runs in the same frame as the forward loop and
67 // BackPropLoopCounter runs in the same frame as the backprop loop. They
68 // are the only cases that multiple loops share the same frame.
69 if (frame.loop_cond &&
70 !absl::StrContains(frame.loop_cond->name(), "LoopCounter") &&
71 !absl::StrContains(node->name(), "LoopCounter")) {
72 return errors::InvalidArgument(
73 "Invalid loop structure: Loop \"", cf.frame_name,
74 "\" has more than one LoopCond node: ", FormatNodeForError(*node),
75 " and ", FormatNodeForError(*frame.loop_cond),
76 ". This is an internal bug, please file a bug report with "
77 "instructions on how to reproduce the error.");
78 }
79 frame.loop_cond = node;
80 }
81 }
82 return OkStatus();
83}
84} // namespace
85
86Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
87 std::vector<string>* unreachable_nodes) {
88 info->clear();
89 info->resize(g->num_node_ids());
90
91 std::vector<const Node*> parent_nodes;
92 parent_nodes.resize(g->num_node_ids());
93
94 const Node* src_node = g->source_node();
95 ControlFlowInfo& src_info = (*info)[src_node->id()];
96 src_info.frame = src_node;
97 src_info.parent_frame = src_node;
98
99 string frame_name;
100 std::deque<const Node*> ready;
101 ready.push_back(src_node);
102 while (!ready.empty()) {
103 const Node* curr_node = ready.front();
104 ready.pop_front();
105 const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
106 const Node* frame = curr_info.frame;
107 const Node* parent = curr_info.parent_frame;
108 frame_name = curr_info.frame_name;
109
110 if (IsExit(curr_node)) {
111 // Exit to the parent frame.
112 const ControlFlowInfo& parent_info = (*info)[parent->id()];
113 frame = parent_info.frame;
114 parent = parent_info.parent_frame;
115 frame_name = parent_info.frame_name;
116 }
117
118 for (const Edge* out_edge : curr_node->out_edges()) {
119 const Node* out = out_edge->dst();
120 int out_id = out->id();
121 ControlFlowInfo* out_info = &(*info)[out_id];
122 const Node* out_parent = out_info->parent_frame;
123 bool is_visited = (parent_nodes[out_id] != nullptr);
124
125 // Skip Sink/Source nodes.
126 if (!out->IsOp()) continue;
127
128 // Add to ready queue if not seen.
129 if (!is_visited) {
130 parent_nodes[out->id()] = curr_node;
131 ready.push_back(out);
132 }
133
134 // Process the node 'out'.
135 if (IsEnter(out)) {
136 if (is_visited) {
137 const string& parent_frame = (*info)[out_parent->id()].frame_name;
138 if (parent_frame != frame_name) {
139 return errors::InvalidArgument(
140 FormatNodeForError(*out),
141 " has inputs from different frames. The input ",
142 FormatNodeForError(*curr_node), " is in frame '", frame_name,
143 "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
144 " is in frame '", parent_frame, "'.");
145 }
146 } else {
147 out_info->frame = out;
148 out_info->parent_frame = frame;
149 TF_RETURN_IF_ERROR(
150 GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name));
151 if (out_info->frame_name.empty()) {
152 return errors::InvalidArgument("The Enter ",
153 FormatNodeForError(*out),
154 " must have a frame name.");
155 }
156 }
157 } else {
158 if (is_visited) {
159 if (out_info->frame_name != frame_name) {
160 return errors::InvalidArgument(
161 FormatNodeForError(*out),
162 " has inputs from different frames. The input ",
163 FormatNodeForError(*curr_node), " is in frame '", frame_name,
164 "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
165 " is in frame '", out_info->frame_name, "'.");
166 }
167 } else {
168 out_info->frame = frame;
169 out_info->parent_frame = parent;
170 out_info->frame_name = frame_name;
171 }
172 }
173 }
174 }
175 if (unreachable_nodes) {
176 for (const Node* node : g->op_nodes()) {
177 if (!parent_nodes[node->id()]) {
178 unreachable_nodes->push_back(node->name());
179 }
180 }
181 }
182 TF_RETURN_IF_ERROR(ValidateControlFlowInfo(g, *info));
183 return OkStatus();
184}
185
186} // namespace tensorflow
187