1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
18
19#include <memory>
20#include <string>
21#include <vector>
22
23#include "tensorflow/core/common_runtime/device_set.h"
24#include "tensorflow/core/framework/function.h"
25#include "tensorflow/core/graph/graph.h"
26#include "tensorflow/core/protobuf/config.pb.h"
27
28// Classes to maintain a static registry of Graph based passes to be applied to
29// a function graph.
30
31namespace tensorflow {
32
33// A pass to be registered with the FunctionOptimizationPassRegistry. This pass
34// takes in a DeviceSet (available devices for executing the Graph), ConfigProto
35// (session configuration parameters), Graph (computation),
36// FunctionLibraryDefinition (mapping between function names and function
37// definitions of the Graph), control ret/target node names (names of nodes that
38// must execute but their data outputs, if they have any, are irrelevant), and
39// whether control ret nodes (via thier name) were updated. Mutations to the
40// Graph and other associated arguments are performed inplace by the pass.
41class FunctionOptimizationPass {
42 public:
43 virtual ~FunctionOptimizationPass() {}
44 virtual Status Run(const DeviceSet& device_set,
45 const ConfigProto& config_proto,
46 std::unique_ptr<Graph>* graph,
47 FunctionLibraryDefinition* flib_def,
48 std::vector<std::string>* control_ret_node_names,
49 bool* control_rets_updated) = 0;
50};
51
52// A global function optimization pass registry that is used to hold one
53// FunctionOptimizationPass. Passes registered to this registry will run before
54// passes registered in OptimizationPassRegistry.
55class FunctionOptimizationPassRegistry {
56 public:
57 // Initializes registry with a pass. Only one pass should be set. An assertion
58 // will be triggered if the registry already has a pass set and is being
59 // initialized with another pass.
60 void Init(std::unique_ptr<FunctionOptimizationPass> pass);
61
62 // Runs a pass if the registry contains one.
63 Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
64 std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
65 std::vector<std::string>* control_ret_node_names,
66 bool* control_rets_updated);
67
68 // Returns the global registry of function graph passes.
69 static FunctionOptimizationPassRegistry& Global();
70
71 private:
72 std::unique_ptr<FunctionOptimizationPass> pass_;
73};
74
75namespace function_optimization_registration {
76
77class FunctionOptimizationPassRegistration {
78 public:
79 explicit FunctionOptimizationPassRegistration(
80 std::unique_ptr<FunctionOptimizationPass> pass) {
81 FunctionOptimizationPassRegistry::Global().Init(std::move(pass));
82 }
83};
84
85} // namespace function_optimization_registration
86
87} // namespace tensorflow
88
89#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_
90