1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/function_optimization_registry.h" |
17 | |
18 | #include "tensorflow/core/framework/metrics.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | void FunctionOptimizationPassRegistry::Init( |
23 | std::unique_ptr<FunctionOptimizationPass> pass) { |
24 | DCHECK(!pass_) << "Only one pass should be set." ; |
25 | pass_ = std::move(pass); |
26 | } |
27 | |
28 | Status FunctionOptimizationPassRegistry::Run( |
29 | const DeviceSet& device_set, const ConfigProto& config_proto, |
30 | std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def, |
31 | std::vector<std::string>* control_ret_node_names, |
32 | bool* control_rets_updated) { |
33 | if (!pass_) return OkStatus(); |
34 | |
35 | tensorflow::metrics::ScopedCounter<2> timings( |
36 | tensorflow::metrics::GetGraphOptimizationCounter(), |
37 | {"GraphOptimizationPass" , "FunctionOptimizationPassRegistry" }); |
38 | |
39 | return pass_->Run(device_set, config_proto, graph, flib_def, |
40 | control_ret_node_names, control_rets_updated); |
41 | } |
42 | |
43 | // static |
44 | FunctionOptimizationPassRegistry& FunctionOptimizationPassRegistry::Global() { |
45 | static FunctionOptimizationPassRegistry* kGlobalRegistry = |
46 | new FunctionOptimizationPassRegistry; |
47 | return *kGlobalRegistry; |
48 | } |
49 | |
50 | } // namespace tensorflow |
51 | |