1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/graph_optimizer.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/constant_folding.h" |
19 | #include "tensorflow/core/common_runtime/function_utils.h" |
20 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
21 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
22 | #include "tensorflow/core/framework/metrics.h" |
23 | #include "tensorflow/core/graph/algorithm.h" |
24 | #include "tensorflow/core/graph/node_builder.h" |
25 | #include "tensorflow/core/graph/optimizer_cse.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) { |
30 | if (opts_.opt_level() >= OptimizerOptions::L1) { |
31 | opts_.set_do_common_subexpression_elimination(true); |
32 | opts_.set_do_constant_folding(true); |
33 | } |
34 | } |
35 | |
36 | GraphOptimizer::~GraphOptimizer() {} |
37 | |
38 | void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, |
39 | const Device* device, |
40 | std::unique_ptr<Graph>* graph, |
41 | const Options& options) { |
42 | static const char* kGraphOptimizerCategory = "GraphOptimizerPass" ; |
43 | |
44 | Graph* g = graph->get(); |
45 | DumpGraph("Initial" , g); |
46 | bool changed = true; |
47 | const int kMaxRounds = 10; |
48 | for (int rounds = 0; rounds < kMaxRounds; ++rounds) { |
49 | changed = false; |
50 | if (RemoveListArrayConverter(g)) { |
51 | DumpGraph("RemoveListArrayConverter" , g); |
52 | changed = true; |
53 | } |
54 | |
55 | tensorflow::metrics::ScopedCounter<2> inlining_timings( |
56 | tensorflow::metrics::GetGraphOptimizationCounter(), |
57 | {kGraphOptimizerCategory, "function_inlining" }); |
58 | if (opts_.do_function_inlining() && RemoveDeadNodes(g)) { |
59 | DumpGraph("RemoveDeadNodes" , g); |
60 | changed = true; |
61 | } |
62 | if (opts_.do_function_inlining() && RemoveIdentityNodes(g)) { |
63 | DumpGraph("RemoveIdentityNodes" , g); |
64 | changed = true; |
65 | } |
66 | if (opts_.do_function_inlining()) { |
67 | inlining_timings.AccumulateAndStop(); |
68 | } |
69 | |
70 | if (opts_.do_constant_folding()) { |
71 | tensorflow::metrics::ScopedCounter<2> timings( |
72 | tensorflow::metrics::GetGraphOptimizationCounter(), |
73 | {kGraphOptimizerCategory, "constant_folding" }); |
74 | |
75 | ConstantFoldingOptions cf_opts; |
76 | cf_opts.shape_map = options.shape_map; |
77 | cf_opts.consider = options.cf_consider_fn; |
78 | if (opts_.max_folded_constant_in_bytes() > 0) { |
79 | cf_opts.max_constant_size_in_bytes = |
80 | opts_.max_folded_constant_in_bytes(); |
81 | } |
82 | bool was_mutated; |
83 | ConstantFold(cf_opts, runtime, env, device, g, &was_mutated) |
84 | .IgnoreError(); |
85 | if (was_mutated) { |
86 | RemoveDeadNodes(g); |
87 | DumpGraph("ConstFolding" , g); |
88 | changed = true; |
89 | } |
90 | } |
91 | |
92 | if (opts_.do_function_inlining()) { |
93 | inlining_timings.Start(); |
94 | if (FixupSourceAndSinkEdges(g)) { |
95 | DumpGraph("FixupSourceAndSinkEdges" , g); |
96 | changed = true; |
97 | } |
98 | inlining_timings.AccumulateAndStop(); |
99 | } |
100 | |
101 | if (opts_.do_common_subexpression_elimination()) { |
102 | tensorflow::metrics::ScopedCounter<2> timings( |
103 | tensorflow::metrics::GetGraphOptimizationCounter(), |
104 | {kGraphOptimizerCategory, "common_subexpression_elimination" }); |
105 | if (OptimizeCSE(g, options.cse_consider_fn)) { |
106 | DumpGraph("OptimizeCSE" , g); |
107 | changed = true; |
108 | } |
109 | } |
110 | if (opts_.do_function_inlining()) { |
111 | inlining_timings.Start(); |
112 | ExpandInlineFunctionsOptions expand_inline_opts; |
113 | expand_inline_opts.native_options.inlined_function_body_placer = |
114 | InlinedFunctionBodyPlacer::SingleDevice(); |
115 | |
116 | // Force single device placement strategy for multi-device function body. |
117 | if (options.inline_with_single_device_body_placer) { |
118 | expand_inline_opts.multi_device_options.inlined_function_body_placer = |
119 | InlinedFunctionBodyPlacer::SingleDevice(); |
120 | } |
121 | |
122 | if (!options.inline_multi_device_functions) { |
123 | // GraphOptimizer is running: |
124 | // (1) After partitioning when executing with a Session API. |
125 | // (2) For a single device function body after instantiation. |
126 | // We can't inline multi-device functions in these cases, because it |
127 | // might lead to multiple device assignments. |
128 | expand_inline_opts.multi_device_options.disable_inlining = true; |
129 | } |
130 | if (options.inline_impl_selection_group_functions) { |
131 | expand_inline_opts.native_options |
132 | .inline_impl_selection_group_functions = true; |
133 | expand_inline_opts.multi_device_options |
134 | .inline_impl_selection_group_functions = true; |
135 | } |
136 | |
137 | if (options.ignore_noinline) { |
138 | expand_inline_opts.multi_device_options.ignore_noinline = true; |
139 | expand_inline_opts.native_options.ignore_noinline = true; |
140 | } |
141 | |
142 | bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts); |
143 | if (was_mutated) { |
144 | DumpGraph("ExpandInlineFunctions" , g); |
145 | changed = true; |
146 | } |
147 | |
148 | inlining_timings.ReportAndStop(); |
149 | } |
150 | if (!changed) break; |
151 | } |
152 | |
153 | // Clone the graph to copy the input FunctionLibraryDefinition, since the |
154 | // original lib def will go out of scope. |
155 | *graph = g->Clone(); |
156 | |
157 | DumpGraph("ReCopy" , graph->get()); |
158 | } |
159 | |
160 | void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g, |
161 | const GraphOptimizer::Options& graph_optimizer_options) { |
162 | OptimizerOptions opts; |
163 | opts.set_do_common_subexpression_elimination(true); |
164 | opts.set_do_function_inlining(true); |
165 | opts.set_do_constant_folding(true); |
166 | GraphOptimizer optimizer(opts); |
167 | optimizer.Optimize(lib, lib->env(), lib->device(), g, |
168 | graph_optimizer_options); |
169 | } |
170 | |
171 | void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) { |
172 | OptimizeGraph(lib, g, GraphOptimizer::Options()); |
173 | } |
174 | |
175 | } // end namespace tensorflow |
176 | |