1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FORWARD_TYPE_INFERENCE_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_FORWARD_TYPE_INFERENCE_H_
18
19#include "tensorflow/core/common_runtime/optimization_registry.h"
20
21namespace tensorflow {
22
23// TODO(mdan): Rename to just type_inference.
24
25// Run a very basic type inference on the graph. It simply propagates type
26// information along edges, until reaching stability.
27//
28// The pass is designed to run as a graph diffusion process, refining type
29// information until it reaches a fixed point. However, the current
30// implementation is a simplification that only ensures that:
31// 1. each node is visited at least once
32// 2. a successful update of a node's type ID prevents future visits
33// 3. each node is visited at most a fixed number of times
34//
35// If needed, we can drop rule #3 and change rule #2 to consider an update to
36// be any deep type change (rather than just the type ID).
37//
38// The state of the diffusion process is the NodeDef.experimental_full_type
39// field, while the diffusion function is the node's corresponding
40// OpRegistrationData.fwd_type_fn function.
41//
42// TODO(mdan): Use a regular union-based algorithm instead?
43class ForwardTypeInferencePass : public GraphOptimizationPass {
44 public:
45 Status Run(const GraphOptimizationPassOptions& options) override;
46};
47
48// A version of ForwardTypeInferencePass that prints a warning on error, instead
49// of returning error status. This is done because there are a few graphs
50// currently in the wild which don't actually type check.
51// TODO(mdan): Turn this into an error, once all offenders are clean.
52class WeakForwardTypeInferencePass : public GraphOptimizationPass {
53 public:
54 Status Run(const GraphOptimizationPassOptions& options) override;
55};
56
57} // namespace tensorflow
58
59#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FORWARD_TYPE_INFERENCE_H_
60