1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/forward_type_inference.h"
17
18#include <functional>
19#include <queue>
20#include <string>
21#include <string_view>
22
23#include "absl/container/flat_hash_set.h"
24#include "tensorflow/core/framework/full_type.pb.h"
25#include "tensorflow/core/framework/full_type_util.h"
26#include "tensorflow/core/framework/op_def_builder.h"
27#include "tensorflow/core/platform/errors.h"
28#include "tensorflow/core/util/dump_graph.h"
29
30namespace tensorflow {
31
32namespace {
33
34int MAX_VISITS_PER_NODE = 3;
35
36typedef absl::flat_hash_map<
37 int, std::reference_wrapper<ForwardTypeInferenceFn const>>
38 ForwardInferMap;
39typedef absl::flat_hash_map<
40 int, std::pair<int, std::reference_wrapper<ForwardTypeInferenceFn const>>>
41 ReverseInferMap;
42
43bool all_sources_closed(const Node& n, const absl::flat_hash_set<int>& closed,
44 const ForwardInferMap& forward,
45 const ReverseInferMap& reverse) {
46 for (const auto& e : n.out_edges()) {
47 if (e->IsControlEdge()) {
48 continue;
49 }
50 int dst_id = e->dst()->id();
51 if (reverse.contains(dst_id) && !closed.contains(dst_id)) {
52 return false;
53 }
54 }
55 if (forward.contains(n.id())) {
56 for (const auto& e : n.in_edges()) {
57 if (e->IsControlEdge()) {
58 continue;
59 }
60 if (!closed.contains(e->src()->id())) {
61 return false;
62 }
63 }
64 }
65 return true;
66}
67
68std::vector<std::reference_wrapper<const FullTypeDef>> input_types(
69 const Node& n) {
70 static FullTypeDef* no_type = new FullTypeDef();
71
72 std::vector<std::reference_wrapper<const FullTypeDef>> input_types;
73 for (const auto& in_edge : n.in_edges()) {
74 if (in_edge->IsControlEdge()) {
75 continue;
76 }
77 input_types.push_back(*no_type);
78 }
79 for (const auto& in_edge : n.in_edges()) {
80 if (in_edge->IsControlEdge()) {
81 continue;
82 }
83 VLOG(5) << " in edge: " << in_edge->DebugString();
84 NodeDef* ndef = in_edge->src()->mutable_def();
85 if (ndef->has_experimental_type()) {
86 const auto& t = ndef->experimental_type();
87 if (t.type_id() != TFT_UNSET) {
88 DCHECK(t.type_id() == TFT_PRODUCT) << ndef->DebugString();
89 DCHECK(t.args_size() > in_edge->src_output()) << ndef->DebugString();
90 input_types.at(in_edge->dst_input()) = t.args(in_edge->src_output());
91 }
92 }
93 }
94 return input_types;
95}
96
97Status updated_inferred_type(Node* target, const FullTypeDef& t,
98 bool& updated) {
99 if (t.type_id() == TFT_UNSET) {
100 VLOG(3) << " " << target->name() << " no inferred type";
101 return OkStatus();
102 }
103
104 if (target->def().has_experimental_type()) {
105 const auto existing = target->def().experimental_type();
106 if (full_type::IsSubtype(existing, t)) {
107 VLOG(3) << " " << target->name() << " no new type info";
108 return OkStatus();
109 } else if (!full_type::IsSubtype(t, existing)) {
110 // The only allowable type mismatches are those which would further
111 // specialize the existing type.
112 return Status(
113 error::INVALID_ARGUMENT,
114 absl::StrCat("type mismatch for node '", target->name(),
115 "': expected a subtype of:\n", existing.DebugString(),
116 "\n got:\n", t.DebugString(), "\n "));
117 }
118 }
119
120 *(target->mutable_def()->mutable_experimental_type()) = t;
121 updated = true;
122 VLOG(3) << " " << target->name() << " updated";
123 return OkStatus();
124}
125
126} // namespace
127
128Status ForwardTypeInferencePass::Run(
129 const GraphOptimizationPassOptions& options) {
130 VLOG(1) << "ForwardTypeInferencePass::Run";
131
132 DCHECK(options.graph != nullptr);
133 Graph* g = options.graph->get();
134 DCHECK(g != nullptr);
135 FunctionLibraryDefinition* flib_def = options.flib_def;
136 DCHECK(flib_def != nullptr);
137
138 if (VLOG_IS_ON(1)) {
139 DumpGraphToFile("forward_type_inference_before", *g, flib_def);
140 }
141
142 for (Node* n : g->nodes()) {
143 // TODO(mdan): Needed?
144 n->UpdateProperties();
145 }
146
147 // Cache type inference functions, to avoid repeated flib_def lookups.
148 ForwardInferMap forward;
149 ReverseInferMap reverse;
150 for (Node* n : g->nodes()) {
151 VLOG(4) << "\n node: " << n->def().DebugString()
152 << "\n op def: " << n->op_def().DebugString();
153 const OpRegistrationData* reg;
154 TF_RETURN_IF_ERROR(flib_def->LookUp(n->op_def().name(), &reg));
155 if (reg->fwd_type_fn != nullptr) {
156 forward.emplace(n->id(), reg->fwd_type_fn);
157 }
158 if (reg->rev_type_fn != nullptr) {
159 reverse.emplace(n->id(), std::make_pair(reg->rev_type_input,
160 std::cref(reg->rev_type_fn)));
161 }
162 }
163
164 auto infer_forward = [&forward](Node* n, bool& updated) {
165 if (!forward.contains(n->id())) {
166 return OkStatus();
167 }
168 VLOG(4) << " " << n->name() << " has forward function";
169
170 // TODO(b/224775462): Populate with types from function references.
171 TypeRefMap type_vars;
172 auto in_types = input_types(*n);
173 const auto& infer_ret = forward.at(n->id())(in_types, type_vars);
174
175 TF_RETURN_WITH_CONTEXT_IF_ERROR(
176 infer_ret.status(),
177 absl::StrCat("while inferring type of node '", n->name(), "'"));
178
179 TF_RETURN_WITH_CONTEXT_IF_ERROR(
180 updated_inferred_type(n, *infer_ret, updated),
181 "while updating its output type.");
182
183 return OkStatus();
184 };
185
186 auto infer_reverse = [&reverse](Node* n, bool& updated) {
187 if (!reverse.contains(n->id())) {
188 return OkStatus();
189 }
190 VLOG(4) << " " << n->name() << " has reverse function";
191
192 // TODO(b/224775462): Populate with types from function references.
193 TypeRefMap type_vars;
194 auto in_types = input_types(*n);
195 auto rev_idx_and_fn = reverse.at(n->id());
196 const auto& infer_ret = rev_idx_and_fn.second(in_types, type_vars);
197
198 const Edge* e;
199 TF_RETURN_WITH_CONTEXT_IF_ERROR(
200 n->input_edge(rev_idx_and_fn.first, &e),
201 absl::StrCat("while querying input ", rev_idx_and_fn.first, " of '",
202 n->name(), "'"));
203
204 TF_RETURN_WITH_CONTEXT_IF_ERROR(
205 infer_ret.status(),
206 absl::StrCat("while inferring type of node '", e->src()->name(),
207 "' via '", n->name(), "'"));
208
209 TF_RETURN_WITH_CONTEXT_IF_ERROR(
210 updated_inferred_type(e->src(), *infer_ret, updated),
211 absl::StrCat("while updating its output type inferred from '",
212 n->name(), ","));
213
214 return OkStatus();
215 };
216
217 std::list<int> queue;
218 absl::flat_hash_set<int> in_queue;
219 absl::flat_hash_map<int, int> visit_count;
220 // Open nodes. A node is open if it has never been visited.
221 absl::flat_hash_set<int> open;
222 // Closed nodes. A closed node will never be visited again.
223 absl::flat_hash_set<int> closed;
224
225 // Upper bound. Worst-case is a cycle in which no nodes have type info,
226 // case in which there will be max_passes iterations, each visiting one node.
227 int max_passes = g->num_nodes();
228
229 int visits = 0;
230
231 // Start with niladic nodes. If none exist, a random one will be selected at
232 // the end of first iteration.
233 for (Node* n : g->nodes()) {
234 const int nid = n->id();
235 bool niladic = true;
236 for (const auto& e : n->in_edges()) {
237 if (!e->IsControlEdge()) {
238 niladic = false;
239 break;
240 }
241 }
242 if (niladic) {
243 queue.emplace_back(nid);
244 in_queue.emplace(nid);
245 }
246 open.emplace(nid);
247 visit_count.emplace(nid, 0);
248 }
249
250 for (int i = 0; i < max_passes; i++) {
251 VLOG(2) << "Iteration " << i << ", " << queue.size() << " nodes in queue";
252
253 while (!queue.empty()) {
254 int nid = queue.front();
255 Node* n = g->FindNodeId(nid);
256 VLOG(3) << " visiting " << n->name();
257 visits++;
258 visit_count[nid]++;
259 DCHECK(!closed.contains(nid));
260
261 bool updated = false;
262 TF_RETURN_IF_ERROR(infer_forward(n, updated));
263 TF_RETURN_IF_ERROR(infer_reverse(n, updated));
264
265 VLOG(4) << " done " << n->def().DebugString();
266
267 queue.pop_front();
268 in_queue.erase(nid);
269 open.erase(nid);
270
271 // Update the graph to fixed point, with iterations limited
272 // by MAX_VISITS_PER_NODE.
273 if (visit_count.at(nid) >= MAX_VISITS_PER_NODE) {
274 VLOG(3) << " closing " << n->name() << " - visit limit reached";
275 closed.emplace(nid);
276 } else if (all_sources_closed(*n, closed, forward, reverse)) {
277 VLOG(3) << " closing " << n->name() << " - all sources closed";
278 closed.emplace(nid);
279 }
280
281 for (const auto& out_edge : n->out_edges()) {
282 if (out_edge->IsControlEdge()) {
283 continue;
284 }
285 Node* c = out_edge->dst();
286 int cid = c->id();
287 if (closed.contains(cid) || in_queue.contains(cid)) {
288 continue;
289 }
290 if (updated || all_sources_closed(*c, closed, forward, reverse)) {
291 queue.emplace_back(cid);
292 in_queue.emplace(cid);
293 }
294 }
295 if (updated && reverse.contains(nid)) {
296 const Edge* e;
297 TF_RETURN_IF_ERROR(n->input_edge(reverse.at(nid).first, &e));
298 Node* c = e->src();
299 int cid = c->id();
300 if (!closed.contains(cid) && !in_queue.contains(cid)) {
301 queue.emplace_back(cid);
302 in_queue.emplace(cid);
303 }
304 }
305 }
306
307 VLOG(2) << "Done iteration " << i << ", " << closed.size()
308 << " nodes closed";
309
310 if (open.empty()) {
311 VLOG(1) << "Finished after " << i + 1 << " iterations; done "
312 << closed.size() << " of " << g->num_nodes() << " nodes in "
313 << visits << " visits";
314 break;
315 } else {
316 queue.emplace_back(*(open.begin()));
317 }
318 }
319
320 if (VLOG_IS_ON(1)) {
321 DumpGraphToFile("forward_type_inference_after", *g, flib_def);
322 }
323
324 return OkStatus();
325}
326
327Status WeakForwardTypeInferencePass::Run(
328 const GraphOptimizationPassOptions& options) {
329 ForwardTypeInferencePass pass;
330 const auto& pass_status = pass.Run(options);
331 if (!pass_status.ok()) {
332 LOG_FIRST_N(WARNING, 1)
333 << "Type inference failed. This indicates an "
334 "invalid graph that escaped type checking. Error message: "
335 << pass_status.ToString();
336 }
337 return OkStatus();
338}
339
340// Note: This needs to run last because Placer needs it.
341REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 99999,
342 WeakForwardTypeInferencePass);
343
344} // namespace tensorflow
345