1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FORWARD_TYPE_INFERENCE_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_FORWARD_TYPE_INFERENCE_H_ |
18 | |
19 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | // TODO(mdan): Rename to just type_inference. |
24 | |
25 | // Run a very basic type inference on the graph. It simply propagates type |
26 | // information along edges, until reaching stability. |
27 | // |
28 | // The pass is designed to run as a graph diffusion process, refining type |
29 | // information until it reaches a fixed point. However, the current |
30 | // implementation is a simplification that only ensures that: |
31 | // 1. each node is visited at least once |
32 | // 2. a successful update of a node's type ID prevents future visits |
33 | // 3. each node is visited at most a fixed number of times |
34 | // |
35 | // If needed, we can drop rule #3 and change rule #2 to consider an update to |
36 | // be any deep type change (rather than just the type ID). |
37 | // |
38 | // The state of the diffusion process is the NodeDef.experimental_full_type |
39 | // field, while the diffusion function is the node's corresponding |
40 | // OpRegistrationData.fwd_type_fn function. |
41 | // |
42 | // TODO(mdan): Use a regular union-based algorithm instead? |
43 | class ForwardTypeInferencePass : public GraphOptimizationPass { |
44 | public: |
45 | Status Run(const GraphOptimizationPassOptions& options) override; |
46 | }; |
47 | |
48 | // A version of ForwardTypeInferencePass that prints a warning on error, instead |
49 | // of returning error status. This is done because there are a few graphs |
50 | // currently in the wild which don't actually type check. |
51 | // TODO(mdan): Turn this into an error, once all offenders are clean. |
52 | class WeakForwardTypeInferencePass : public GraphOptimizationPass { |
53 | public: |
54 | Status Run(const GraphOptimizationPassOptions& options) override; |
55 | }; |
56 | |
57 | } // namespace tensorflow |
58 | |
59 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FORWARD_TYPE_INFERENCE_H_ |
60 | |