1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/common_runtime/device_set.h" |
24 | #include "tensorflow/core/framework/function.h" |
25 | #include "tensorflow/core/graph/graph.h" |
26 | #include "tensorflow/core/protobuf/config.pb.h" |
27 | |
28 | // Classes to maintain a static registry of Graph based passes to be applied to |
29 | // a function graph. |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // A pass to be registered with the FunctionOptimizationPassRegistry. This pass |
34 | // takes in a DeviceSet (available devices for executing the Graph), ConfigProto |
35 | // (session configuration parameters), Graph (computation), |
36 | // FunctionLibraryDefinition (mapping between function names and function |
37 | // definitions of the Graph), control ret/target node names (names of nodes that |
38 | // must execute but their data outputs, if they have any, are irrelevant), and |
39 | // whether control ret nodes (via thier name) were updated. Mutations to the |
40 | // Graph and other associated arguments are performed inplace by the pass. |
41 | class FunctionOptimizationPass { |
42 | public: |
43 | virtual ~FunctionOptimizationPass() {} |
44 | virtual Status Run(const DeviceSet& device_set, |
45 | const ConfigProto& config_proto, |
46 | std::unique_ptr<Graph>* graph, |
47 | FunctionLibraryDefinition* flib_def, |
48 | std::vector<std::string>* control_ret_node_names, |
49 | bool* control_rets_updated) = 0; |
50 | }; |
51 | |
52 | // A global function optimization pass registry that is used to hold one |
53 | // FunctionOptimizationPass. Passes registered to this registry will run before |
54 | // passes registered in OptimizationPassRegistry. |
55 | class FunctionOptimizationPassRegistry { |
56 | public: |
57 | // Initializes registry with a pass. Only one pass should be set. An assertion |
58 | // will be triggered if the registry already has a pass set and is being |
59 | // initialized with another pass. |
60 | void Init(std::unique_ptr<FunctionOptimizationPass> pass); |
61 | |
62 | // Runs a pass if the registry contains one. |
63 | Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, |
64 | std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def, |
65 | std::vector<std::string>* control_ret_node_names, |
66 | bool* control_rets_updated); |
67 | |
68 | // Returns the global registry of function graph passes. |
69 | static FunctionOptimizationPassRegistry& Global(); |
70 | |
71 | private: |
72 | std::unique_ptr<FunctionOptimizationPass> pass_; |
73 | }; |
74 | |
75 | namespace function_optimization_registration { |
76 | |
77 | class FunctionOptimizationPassRegistration { |
78 | public: |
79 | explicit FunctionOptimizationPassRegistration( |
80 | std::unique_ptr<FunctionOptimizationPass> pass) { |
81 | FunctionOptimizationPassRegistry::Global().Init(std::move(pass)); |
82 | } |
83 | }; |
84 | |
85 | } // namespace function_optimization_registration |
86 | |
87 | } // namespace tensorflow |
88 | |
89 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ |
90 | |