1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/forward_type_inference.h" |
17 | |
18 | #include <functional> |
19 | #include <queue> |
20 | #include <string> |
21 | #include <string_view> |
22 | |
23 | #include "absl/container/flat_hash_set.h" |
24 | #include "tensorflow/core/framework/full_type.pb.h" |
25 | #include "tensorflow/core/framework/full_type_util.h" |
26 | #include "tensorflow/core/framework/op_def_builder.h" |
27 | #include "tensorflow/core/platform/errors.h" |
28 | #include "tensorflow/core/util/dump_graph.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | namespace { |
33 | |
34 | int MAX_VISITS_PER_NODE = 3; |
35 | |
36 | typedef absl::flat_hash_map< |
37 | int, std::reference_wrapper<ForwardTypeInferenceFn const>> |
38 | ForwardInferMap; |
39 | typedef absl::flat_hash_map< |
40 | int, std::pair<int, std::reference_wrapper<ForwardTypeInferenceFn const>>> |
41 | ReverseInferMap; |
42 | |
43 | bool all_sources_closed(const Node& n, const absl::flat_hash_set<int>& closed, |
44 | const ForwardInferMap& forward, |
45 | const ReverseInferMap& reverse) { |
46 | for (const auto& e : n.out_edges()) { |
47 | if (e->IsControlEdge()) { |
48 | continue; |
49 | } |
50 | int dst_id = e->dst()->id(); |
51 | if (reverse.contains(dst_id) && !closed.contains(dst_id)) { |
52 | return false; |
53 | } |
54 | } |
55 | if (forward.contains(n.id())) { |
56 | for (const auto& e : n.in_edges()) { |
57 | if (e->IsControlEdge()) { |
58 | continue; |
59 | } |
60 | if (!closed.contains(e->src()->id())) { |
61 | return false; |
62 | } |
63 | } |
64 | } |
65 | return true; |
66 | } |
67 | |
68 | std::vector<std::reference_wrapper<const FullTypeDef>> input_types( |
69 | const Node& n) { |
70 | static FullTypeDef* no_type = new FullTypeDef(); |
71 | |
72 | std::vector<std::reference_wrapper<const FullTypeDef>> input_types; |
73 | for (const auto& in_edge : n.in_edges()) { |
74 | if (in_edge->IsControlEdge()) { |
75 | continue; |
76 | } |
77 | input_types.push_back(*no_type); |
78 | } |
79 | for (const auto& in_edge : n.in_edges()) { |
80 | if (in_edge->IsControlEdge()) { |
81 | continue; |
82 | } |
83 | VLOG(5) << " in edge: " << in_edge->DebugString(); |
84 | NodeDef* ndef = in_edge->src()->mutable_def(); |
85 | if (ndef->has_experimental_type()) { |
86 | const auto& t = ndef->experimental_type(); |
87 | if (t.type_id() != TFT_UNSET) { |
88 | DCHECK(t.type_id() == TFT_PRODUCT) << ndef->DebugString(); |
89 | DCHECK(t.args_size() > in_edge->src_output()) << ndef->DebugString(); |
90 | input_types.at(in_edge->dst_input()) = t.args(in_edge->src_output()); |
91 | } |
92 | } |
93 | } |
94 | return input_types; |
95 | } |
96 | |
97 | Status updated_inferred_type(Node* target, const FullTypeDef& t, |
98 | bool& updated) { |
99 | if (t.type_id() == TFT_UNSET) { |
100 | VLOG(3) << " " << target->name() << " no inferred type" ; |
101 | return OkStatus(); |
102 | } |
103 | |
104 | if (target->def().has_experimental_type()) { |
105 | const auto existing = target->def().experimental_type(); |
106 | if (full_type::IsSubtype(existing, t)) { |
107 | VLOG(3) << " " << target->name() << " no new type info" ; |
108 | return OkStatus(); |
109 | } else if (!full_type::IsSubtype(t, existing)) { |
110 | // The only allowable type mismatches are those which would further |
111 | // specialize the existing type. |
112 | return Status( |
113 | error::INVALID_ARGUMENT, |
114 | absl::StrCat("type mismatch for node '" , target->name(), |
115 | "': expected a subtype of:\n" , existing.DebugString(), |
116 | "\n got:\n" , t.DebugString(), "\n " )); |
117 | } |
118 | } |
119 | |
120 | *(target->mutable_def()->mutable_experimental_type()) = t; |
121 | updated = true; |
122 | VLOG(3) << " " << target->name() << " updated" ; |
123 | return OkStatus(); |
124 | } |
125 | |
126 | } // namespace |
127 | |
128 | Status ForwardTypeInferencePass::Run( |
129 | const GraphOptimizationPassOptions& options) { |
130 | VLOG(1) << "ForwardTypeInferencePass::Run" ; |
131 | |
132 | DCHECK(options.graph != nullptr); |
133 | Graph* g = options.graph->get(); |
134 | DCHECK(g != nullptr); |
135 | FunctionLibraryDefinition* flib_def = options.flib_def; |
136 | DCHECK(flib_def != nullptr); |
137 | |
138 | if (VLOG_IS_ON(1)) { |
139 | DumpGraphToFile("forward_type_inference_before" , *g, flib_def); |
140 | } |
141 | |
142 | for (Node* n : g->nodes()) { |
143 | // TODO(mdan): Needed? |
144 | n->UpdateProperties(); |
145 | } |
146 | |
147 | // Cache type inference functions, to avoid repeated flib_def lookups. |
148 | ForwardInferMap forward; |
149 | ReverseInferMap reverse; |
150 | for (Node* n : g->nodes()) { |
151 | VLOG(4) << "\n node: " << n->def().DebugString() |
152 | << "\n op def: " << n->op_def().DebugString(); |
153 | const OpRegistrationData* reg; |
154 | TF_RETURN_IF_ERROR(flib_def->LookUp(n->op_def().name(), ®)); |
155 | if (reg->fwd_type_fn != nullptr) { |
156 | forward.emplace(n->id(), reg->fwd_type_fn); |
157 | } |
158 | if (reg->rev_type_fn != nullptr) { |
159 | reverse.emplace(n->id(), std::make_pair(reg->rev_type_input, |
160 | std::cref(reg->rev_type_fn))); |
161 | } |
162 | } |
163 | |
164 | auto infer_forward = [&forward](Node* n, bool& updated) { |
165 | if (!forward.contains(n->id())) { |
166 | return OkStatus(); |
167 | } |
168 | VLOG(4) << " " << n->name() << " has forward function" ; |
169 | |
170 | // TODO(b/224775462): Populate with types from function references. |
171 | TypeRefMap type_vars; |
172 | auto in_types = input_types(*n); |
173 | const auto& infer_ret = forward.at(n->id())(in_types, type_vars); |
174 | |
175 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
176 | infer_ret.status(), |
177 | absl::StrCat("while inferring type of node '" , n->name(), "'" )); |
178 | |
179 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
180 | updated_inferred_type(n, *infer_ret, updated), |
181 | "while updating its output type." ); |
182 | |
183 | return OkStatus(); |
184 | }; |
185 | |
186 | auto infer_reverse = [&reverse](Node* n, bool& updated) { |
187 | if (!reverse.contains(n->id())) { |
188 | return OkStatus(); |
189 | } |
190 | VLOG(4) << " " << n->name() << " has reverse function" ; |
191 | |
192 | // TODO(b/224775462): Populate with types from function references. |
193 | TypeRefMap type_vars; |
194 | auto in_types = input_types(*n); |
195 | auto rev_idx_and_fn = reverse.at(n->id()); |
196 | const auto& infer_ret = rev_idx_and_fn.second(in_types, type_vars); |
197 | |
198 | const Edge* e; |
199 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
200 | n->input_edge(rev_idx_and_fn.first, &e), |
201 | absl::StrCat("while querying input " , rev_idx_and_fn.first, " of '" , |
202 | n->name(), "'" )); |
203 | |
204 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
205 | infer_ret.status(), |
206 | absl::StrCat("while inferring type of node '" , e->src()->name(), |
207 | "' via '" , n->name(), "'" )); |
208 | |
209 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
210 | updated_inferred_type(e->src(), *infer_ret, updated), |
211 | absl::StrCat("while updating its output type inferred from '" , |
212 | n->name(), "," )); |
213 | |
214 | return OkStatus(); |
215 | }; |
216 | |
217 | std::list<int> queue; |
218 | absl::flat_hash_set<int> in_queue; |
219 | absl::flat_hash_map<int, int> visit_count; |
220 | // Open nodes. A node is open if it has never been visited. |
221 | absl::flat_hash_set<int> open; |
222 | // Closed nodes. A closed node will never be visited again. |
223 | absl::flat_hash_set<int> closed; |
224 | |
225 | // Upper bound. Worst-case is a cycle in which no nodes have type info, |
226 | // case in which there will be max_passes iterations, each visiting one node. |
227 | int max_passes = g->num_nodes(); |
228 | |
229 | int visits = 0; |
230 | |
231 | // Start with niladic nodes. If none exist, a random one will be selected at |
232 | // the end of first iteration. |
233 | for (Node* n : g->nodes()) { |
234 | const int nid = n->id(); |
235 | bool niladic = true; |
236 | for (const auto& e : n->in_edges()) { |
237 | if (!e->IsControlEdge()) { |
238 | niladic = false; |
239 | break; |
240 | } |
241 | } |
242 | if (niladic) { |
243 | queue.emplace_back(nid); |
244 | in_queue.emplace(nid); |
245 | } |
246 | open.emplace(nid); |
247 | visit_count.emplace(nid, 0); |
248 | } |
249 | |
250 | for (int i = 0; i < max_passes; i++) { |
251 | VLOG(2) << "Iteration " << i << ", " << queue.size() << " nodes in queue" ; |
252 | |
253 | while (!queue.empty()) { |
254 | int nid = queue.front(); |
255 | Node* n = g->FindNodeId(nid); |
256 | VLOG(3) << " visiting " << n->name(); |
257 | visits++; |
258 | visit_count[nid]++; |
259 | DCHECK(!closed.contains(nid)); |
260 | |
261 | bool updated = false; |
262 | TF_RETURN_IF_ERROR(infer_forward(n, updated)); |
263 | TF_RETURN_IF_ERROR(infer_reverse(n, updated)); |
264 | |
265 | VLOG(4) << " done " << n->def().DebugString(); |
266 | |
267 | queue.pop_front(); |
268 | in_queue.erase(nid); |
269 | open.erase(nid); |
270 | |
271 | // Update the graph to fixed point, with iterations limited |
272 | // by MAX_VISITS_PER_NODE. |
273 | if (visit_count.at(nid) >= MAX_VISITS_PER_NODE) { |
274 | VLOG(3) << " closing " << n->name() << " - visit limit reached" ; |
275 | closed.emplace(nid); |
276 | } else if (all_sources_closed(*n, closed, forward, reverse)) { |
277 | VLOG(3) << " closing " << n->name() << " - all sources closed" ; |
278 | closed.emplace(nid); |
279 | } |
280 | |
281 | for (const auto& out_edge : n->out_edges()) { |
282 | if (out_edge->IsControlEdge()) { |
283 | continue; |
284 | } |
285 | Node* c = out_edge->dst(); |
286 | int cid = c->id(); |
287 | if (closed.contains(cid) || in_queue.contains(cid)) { |
288 | continue; |
289 | } |
290 | if (updated || all_sources_closed(*c, closed, forward, reverse)) { |
291 | queue.emplace_back(cid); |
292 | in_queue.emplace(cid); |
293 | } |
294 | } |
295 | if (updated && reverse.contains(nid)) { |
296 | const Edge* e; |
297 | TF_RETURN_IF_ERROR(n->input_edge(reverse.at(nid).first, &e)); |
298 | Node* c = e->src(); |
299 | int cid = c->id(); |
300 | if (!closed.contains(cid) && !in_queue.contains(cid)) { |
301 | queue.emplace_back(cid); |
302 | in_queue.emplace(cid); |
303 | } |
304 | } |
305 | } |
306 | |
307 | VLOG(2) << "Done iteration " << i << ", " << closed.size() |
308 | << " nodes closed" ; |
309 | |
310 | if (open.empty()) { |
311 | VLOG(1) << "Finished after " << i + 1 << " iterations; done " |
312 | << closed.size() << " of " << g->num_nodes() << " nodes in " |
313 | << visits << " visits" ; |
314 | break; |
315 | } else { |
316 | queue.emplace_back(*(open.begin())); |
317 | } |
318 | } |
319 | |
320 | if (VLOG_IS_ON(1)) { |
321 | DumpGraphToFile("forward_type_inference_after" , *g, flib_def); |
322 | } |
323 | |
324 | return OkStatus(); |
325 | } |
326 | |
327 | Status WeakForwardTypeInferencePass::Run( |
328 | const GraphOptimizationPassOptions& options) { |
329 | ForwardTypeInferencePass pass; |
330 | const auto& pass_status = pass.Run(options); |
331 | if (!pass_status.ok()) { |
332 | LOG_FIRST_N(WARNING, 1) |
333 | << "Type inference failed. This indicates an " |
334 | "invalid graph that escaped type checking. Error message: " |
335 | << pass_status.ToString(); |
336 | } |
337 | return OkStatus(); |
338 | } |
339 | |
340 | // Note: This needs to run last because Placer needs it. |
341 | REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 99999, |
342 | WeakForwardTypeInferencePass); |
343 | |
344 | } // namespace tensorflow |
345 | |