1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
17 | |
18 | #include "tensorflow/core/framework/metrics.h" |
19 | #include "tensorflow/core/util/dump_graph.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | // static |
24 | OptimizationPassRegistry* OptimizationPassRegistry::Global() { |
25 | static OptimizationPassRegistry* global_optimization_registry = |
26 | new OptimizationPassRegistry; |
27 | return global_optimization_registry; |
28 | } |
29 | |
30 | void OptimizationPassRegistry::Register( |
31 | Grouping grouping, int phase, std::unique_ptr<GraphOptimizationPass> pass) { |
32 | groups_[grouping][phase].push_back(std::move(pass)); |
33 | } |
34 | |
35 | Status OptimizationPassRegistry::RunGrouping( |
36 | Grouping grouping, const GraphOptimizationPassOptions& options) { |
37 | auto dump_graph = [&](std::string& prefix) { |
38 | if (options.graph) { |
39 | DumpGraphToFile( |
40 | strings::StrCat(prefix, "_" , |
41 | reinterpret_cast<uintptr_t>((*options.graph).get())), |
42 | **options.graph, options.flib_def); |
43 | } |
44 | if (options.partition_graphs) { |
45 | for (auto& part : *options.partition_graphs) { |
46 | DumpGraphToFile( |
47 | strings::StrCat(prefix, "_partition_" , part.first, "_" , |
48 | reinterpret_cast<uintptr_t>(part.second.get())), |
49 | *part.second, options.flib_def); |
50 | } |
51 | } |
52 | }; |
53 | |
54 | VLOG(1) << "Starting optimization of a group " << grouping; |
55 | if (VLOG_IS_ON(3)) { |
56 | std::string prefix = strings::StrCat("before_grouping_" , grouping); |
57 | dump_graph(prefix); |
58 | } |
59 | auto group = groups_.find(grouping); |
60 | if (group != groups_.end()) { |
61 | static const char* kGraphOptimizationCategory = "GraphOptimizationPass" ; |
62 | tensorflow::metrics::ScopedCounter<2> group_timings( |
63 | tensorflow::metrics::GetGraphOptimizationCounter(), |
64 | {kGraphOptimizationCategory, "*" }); |
65 | for (auto& phase : group->second) { |
66 | VLOG(1) << "Running optimization phase " << phase.first; |
67 | for (auto& pass : phase.second) { |
68 | VLOG(1) << "Running optimization pass: " << pass->name(); |
69 | |
70 | tensorflow::metrics::ScopedCounter<2> pass_timings( |
71 | tensorflow::metrics::GetGraphOptimizationCounter(), |
72 | {kGraphOptimizationCategory, pass->name()}); |
73 | Status s = pass->Run(options); |
74 | |
75 | if (!s.ok()) return s; |
76 | pass_timings.ReportAndStop(); |
77 | if (VLOG_IS_ON(5)) { |
78 | std::string prefix = |
79 | strings::StrCat("after_group_" , grouping, "_phase_" , phase.first, |
80 | "_" , pass->name()); |
81 | dump_graph(prefix); |
82 | } |
83 | } |
84 | } |
85 | group_timings.ReportAndStop(); |
86 | } |
87 | VLOG(1) << "Finished optimization of a group " << grouping; |
88 | if (VLOG_IS_ON(3) || |
89 | (VLOG_IS_ON(2) && grouping == Grouping::POST_REWRITE_FOR_EXEC)) { |
90 | std::string prefix = strings::StrCat("after_grouping_" , grouping); |
91 | dump_graph(prefix); |
92 | } |
93 | return OkStatus(); |
94 | } |
95 | |
96 | void OptimizationPassRegistry::LogGrouping(Grouping grouping, int vlog_level) { |
97 | auto group = groups_.find(grouping); |
98 | if (group != groups_.end()) { |
99 | for (auto& phase : group->second) { |
100 | for (auto& pass : phase.second) { |
101 | VLOG(vlog_level) << "Registered optimization pass grouping " << grouping |
102 | << " phase " << phase.first << ": " << pass->name(); |
103 | } |
104 | } |
105 | } |
106 | } |
107 | |
108 | void OptimizationPassRegistry::LogAllGroupings(int vlog_level) { |
109 | for (auto group = groups_.begin(); group != groups_.end(); ++group) { |
110 | LogGrouping(group->first, vlog_level); |
111 | } |
112 | } |
113 | |
114 | } // namespace tensorflow |
115 | |