1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/graph_optimizer.h"
17
18#include "tensorflow/core/common_runtime/constant_folding.h"
19#include "tensorflow/core/common_runtime/function_utils.h"
20#include "tensorflow/core/common_runtime/graph_constructor.h"
21#include "tensorflow/core/common_runtime/inline_function_utils.h"
22#include "tensorflow/core/framework/metrics.h"
23#include "tensorflow/core/graph/algorithm.h"
24#include "tensorflow/core/graph/node_builder.h"
25#include "tensorflow/core/graph/optimizer_cse.h"
26
27namespace tensorflow {
28
29GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
30 if (opts_.opt_level() >= OptimizerOptions::L1) {
31 opts_.set_do_common_subexpression_elimination(true);
32 opts_.set_do_constant_folding(true);
33 }
34}
35
36GraphOptimizer::~GraphOptimizer() {}
37
38void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
39 const Device* device,
40 std::unique_ptr<Graph>* graph,
41 const Options& options) {
42 static const char* kGraphOptimizerCategory = "GraphOptimizerPass";
43
44 Graph* g = graph->get();
45 DumpGraph("Initial", g);
46 bool changed = true;
47 const int kMaxRounds = 10;
48 for (int rounds = 0; rounds < kMaxRounds; ++rounds) {
49 changed = false;
50 if (RemoveListArrayConverter(g)) {
51 DumpGraph("RemoveListArrayConverter", g);
52 changed = true;
53 }
54
55 tensorflow::metrics::ScopedCounter<2> inlining_timings(
56 tensorflow::metrics::GetGraphOptimizationCounter(),
57 {kGraphOptimizerCategory, "function_inlining"});
58 if (opts_.do_function_inlining() && RemoveDeadNodes(g)) {
59 DumpGraph("RemoveDeadNodes", g);
60 changed = true;
61 }
62 if (opts_.do_function_inlining() && RemoveIdentityNodes(g)) {
63 DumpGraph("RemoveIdentityNodes", g);
64 changed = true;
65 }
66 if (opts_.do_function_inlining()) {
67 inlining_timings.AccumulateAndStop();
68 }
69
70 if (opts_.do_constant_folding()) {
71 tensorflow::metrics::ScopedCounter<2> timings(
72 tensorflow::metrics::GetGraphOptimizationCounter(),
73 {kGraphOptimizerCategory, "constant_folding"});
74
75 ConstantFoldingOptions cf_opts;
76 cf_opts.shape_map = options.shape_map;
77 cf_opts.consider = options.cf_consider_fn;
78 if (opts_.max_folded_constant_in_bytes() > 0) {
79 cf_opts.max_constant_size_in_bytes =
80 opts_.max_folded_constant_in_bytes();
81 }
82 bool was_mutated;
83 ConstantFold(cf_opts, runtime, env, device, g, &was_mutated)
84 .IgnoreError();
85 if (was_mutated) {
86 RemoveDeadNodes(g);
87 DumpGraph("ConstFolding", g);
88 changed = true;
89 }
90 }
91
92 if (opts_.do_function_inlining()) {
93 inlining_timings.Start();
94 if (FixupSourceAndSinkEdges(g)) {
95 DumpGraph("FixupSourceAndSinkEdges", g);
96 changed = true;
97 }
98 inlining_timings.AccumulateAndStop();
99 }
100
101 if (opts_.do_common_subexpression_elimination()) {
102 tensorflow::metrics::ScopedCounter<2> timings(
103 tensorflow::metrics::GetGraphOptimizationCounter(),
104 {kGraphOptimizerCategory, "common_subexpression_elimination"});
105 if (OptimizeCSE(g, options.cse_consider_fn)) {
106 DumpGraph("OptimizeCSE", g);
107 changed = true;
108 }
109 }
110 if (opts_.do_function_inlining()) {
111 inlining_timings.Start();
112 ExpandInlineFunctionsOptions expand_inline_opts;
113 expand_inline_opts.native_options.inlined_function_body_placer =
114 InlinedFunctionBodyPlacer::SingleDevice();
115
116 // Force single device placement strategy for multi-device function body.
117 if (options.inline_with_single_device_body_placer) {
118 expand_inline_opts.multi_device_options.inlined_function_body_placer =
119 InlinedFunctionBodyPlacer::SingleDevice();
120 }
121
122 if (!options.inline_multi_device_functions) {
123 // GraphOptimizer is running:
124 // (1) After partitioning when executing with a Session API.
125 // (2) For a single device function body after instantiation.
126 // We can't inline multi-device functions in these cases, because it
127 // might lead to multiple device assignments.
128 expand_inline_opts.multi_device_options.disable_inlining = true;
129 }
130 if (options.inline_impl_selection_group_functions) {
131 expand_inline_opts.native_options
132 .inline_impl_selection_group_functions = true;
133 expand_inline_opts.multi_device_options
134 .inline_impl_selection_group_functions = true;
135 }
136
137 if (options.ignore_noinline) {
138 expand_inline_opts.multi_device_options.ignore_noinline = true;
139 expand_inline_opts.native_options.ignore_noinline = true;
140 }
141
142 bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts);
143 if (was_mutated) {
144 DumpGraph("ExpandInlineFunctions", g);
145 changed = true;
146 }
147
148 inlining_timings.ReportAndStop();
149 }
150 if (!changed) break;
151 }
152
153 // Clone the graph to copy the input FunctionLibraryDefinition, since the
154 // original lib def will go out of scope.
155 *graph = g->Clone();
156
157 DumpGraph("ReCopy", graph->get());
158}
159
160void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
161 const GraphOptimizer::Options& graph_optimizer_options) {
162 OptimizerOptions opts;
163 opts.set_do_common_subexpression_elimination(true);
164 opts.set_do_function_inlining(true);
165 opts.set_do_constant_folding(true);
166 GraphOptimizer optimizer(opts);
167 optimizer.Optimize(lib, lib->env(), lib->device(), g,
168 graph_optimizer_options);
169}
170
171void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
172 OptimizeGraph(lib, g, GraphOptimizer::Options());
173}
174
175} // end namespace tensorflow
176