1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This module implements a common subexpression elimination pass. We
17// process the nodes in the graph in reverse postorder
18// (i.e. inputs before their downstream dependencies). The rough algorithm is
19// as follows:
20//
21// std::unordered_map<size_t, Node*> available
22// for each node n in forward topological order:
23// h = NodeHash(n)
24// if available[h] exists and Equivalent(available(h), h)
25// redirect downstream uses of outputs of n to available[h]
26// remove n from graph
27// else
28// if available[h] does not exist
29// available[h] = n
30//
31// This is similar to the global value number algorithm describe in this
32// paper:
33//
34// "Global code motion/global value numbering", Cliff Click, PLDI '95
35// Proceedings of the ACM SIGPLAN 1995 conference on Programming
36// language design and implementation, Pages 246-257
37// http://dl.acm.org/citation.cfm?id=207154
38
39#include "tensorflow/core/graph/optimizer_cse.h"
40
41#include <iostream>
42#include <unordered_map>
43#include <utility>
44#include <vector>
45
46#include "tensorflow/core/framework/node_def.pb.h"
47#include "tensorflow/core/framework/node_def_util.h"
48#include "tensorflow/core/graph/algorithm.h"
49#include "tensorflow/core/graph/graph_node_util.h"
50#include "tensorflow/core/lib/gtl/map_util.h"
51#include "tensorflow/core/lib/hash/hash.h"
52#include "tensorflow/core/platform/logging.h"
53#include "tensorflow/core/platform/protobuf.h"
54
55namespace tensorflow {
56
57class OptimizerCSE {
58 public:
59 explicit OptimizerCSE(Graph* g) : g_(g) {}
60
61 bool Optimize(const std::function<bool(const Node*)>& consider_fn);
62
63 private:
64 static size_t NodeHash(const Node* n);
65 static bool Equivalent(const Node* a, const Node* b,
66 AttrSlice::Scratch* scratch);
67
68 Graph* g_;
69};
70
71static void FillInputs(const Node* n,
72 gtl::InlinedVector<const Node*, 4>* control_edges,
73 gtl::InlinedVector<std::pair<const Node*, int>, 4>* in) {
74 DCHECK_EQ(in->size(), n->num_inputs());
75 control_edges->clear();
76 for (const Edge* e : n->in_edges()) {
77 if (e->IsControlEdge()) {
78 control_edges->push_back(e->src());
79 } else {
80 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
81 }
82 }
83 std::sort(control_edges->begin(), control_edges->end());
84 if (n->op_def().is_commutative()) {
85 // For commutative inputs, we sort the input by the input Node*
86 // to get a canonical ordering (so that add(a,b) and add(b, a) will
87 // hash to the same value if is_commutative is true for 'add').
88 std::sort(in->begin(), in->end());
89 }
90}
91
92static size_t kIllegalNodeHash = 0;
93
94class Hasher {
95 public:
96 uint64 hash() { return h_ == kIllegalNodeHash ? kIllegalNodeHash + 1 : h_; }
97
98 void MixString(const string& s) { h_ = Hash64(s.data(), s.size(), h_); }
99
100 void MixInteger(size_t z) { h_ = Hash64Combine(h_, z); }
101
102 void MixProto(const protobuf::MessageLite& msg) {
103 msg.ByteSizeLong(); // Ensure sizes are cached accurately.
104 HashingOutputStream hasher;
105 {
106 // CodedOutputStream doesn't call BackUp until it's destroyed, so we need
107 // it to be destroyed before we call hasher.hash().
108 protobuf::io::CodedOutputStream stream(&hasher);
109 stream.EnableAliasing(true);
110 stream.SetSerializationDeterministic(true);
111 msg.SerializeWithCachedSizes(&stream);
112 }
113 h_ = Hash64Combine(h_, hasher.hash());
114 }
115
116 private:
117 // HashingOutputStream produces the same exact hash as if you serialized the
118 // proto and hashed it sequentially in kBufSize chunks, except it doesn't
119 // manifest the entire proto into memory at any point.
120 class HashingOutputStream : public protobuf::io::ZeroCopyOutputStream {
121 public:
122 // This kBufSize makes sizeof(HashingOutputStream) == 256. It's not chosen
123 // for any particular reason except it's a nice even number of cache lines.
124 static constexpr size_t kBufSize = 228;
125 static constexpr uint64 kDefaultSeed = 2570847921467975139ULL;
126 bool Next(void** data, int* size) override {
127 if (i_ == kBufSize) {
128 // Mix the chunk in.
129 Mix(buf_, kBufSize);
130 *data = buf_;
131 *size = kBufSize;
132 } else {
133 *data = buf_ + i_;
134 *size = kBufSize - i_;
135 }
136 // We always set i_ to be past the end, since we've given the rest of buf_
137 // out.
138 i_ = kBufSize;
139 return true;
140 }
141
142 void BackUp(int count) override { i_ -= count; }
143
144 int64_t ByteCount() const override { return byte_count_; }
145
146 bool WriteAliasedRaw(const void* void_data, int size) override {
147 // We can't do math on void*.
148 const char* data = static_cast<const char*>(void_data);
149 const auto remaining = kBufSize - i_;
150 if (remaining > 0) {
151 if (size < remaining) {
152 memcpy(buf_ + i_, data, size);
153 i_ += size;
154 return true;
155 }
156 memcpy(buf_ + i_, data, remaining);
157 i_ = kBufSize;
158 data += remaining;
159 size -= remaining;
160 }
161 if (i_ == kBufSize) {
162 Mix(buf_, kBufSize);
163 i_ = 0;
164 }
165 while (size >= kBufSize) {
166 Mix(data, kBufSize);
167 data += kBufSize;
168 size -= kBufSize;
169 }
170 memcpy(buf_, data, size);
171 i_ = size;
172 return true;
173 }
174
175 bool AllowsAliasing() const override { return true; }
176
177 uint64 hash() {
178 if (i_ != 0) {
179 Mix(buf_, i_);
180 i_ = 0;
181 }
182 return h_;
183 }
184
185 private:
186 void Mix(const char* p, size_t n) {
187 byte_count_ += n;
188 h_ = Hash64(p, n, h_);
189 }
190 char buf_[kBufSize];
191 int i_ = 0;
192 int64_t byte_count_ = 0;
193 uint64 h_ = kDefaultSeed;
194 };
195
196 uint64 h_ = HashingOutputStream::kDefaultSeed;
197};
198
199size_t OptimizerCSE::NodeHash(const Node* n) {
200 Hasher hasher;
201 hasher.MixString(n->type_string());
202 hasher.MixInteger(n->output_types().size());
203 for (DataType dt : n->output_types()) {
204 hasher.MixInteger(dt);
205 }
206
207 hasher.MixInteger(n->num_inputs());
208 gtl::InlinedVector<const Node*, 4> control_edges;
209 gtl::InlinedVector<std::pair<const Node*, int>, 4> in(n->num_inputs());
210 FillInputs(n, &control_edges, &in);
211 for (const auto& edge : in) {
212 hasher.MixInteger(edge.first->id());
213 hasher.MixInteger(edge.second);
214 }
215
216#if !defined(__ANDROID__)
217 // Hash the attrs. For example, this makes sure different constants
218 // end up in different hash buckets.
219 size_t attr_hashes = 0;
220 for (const auto& attr : n->attrs()) {
221 Hasher h;
222 h.MixString(attr.first);
223 h.MixProto(attr.second);
224 attr_hashes = Hash64CombineUnordered(attr_hashes, h.hash());
225 }
226 hasher.MixInteger(attr_hashes);
227#endif
228
229 return hasher.hash();
230}
231
232static bool HasRefInput(const Node* n) {
233 for (auto dt : n->input_types()) {
234 if (IsRefType(dt)) return true;
235 }
236 return false;
237}
238
239bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
240 AttrSlice::Scratch* scratch) {
241 // Different op names are different
242 if (a->type_string() != b->type_string()) return false;
243
244 // Never consider stateful nodes (such as non-const inputs) equivalent.
245 if (a->op_def().is_stateful()) return false;
246
247 // For now, we consider any node that takes a ref input to not be
248 // equivalent to any other node.
249 if (HasRefInput(a) || HasRefInput(b)) return false;
250
251 // Compare attrs. Note that equal attrs implies equal input and
252 // output types.
253 if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false;
254
255 // Compare input sources
256 if (a->num_inputs() != b->num_inputs()) return false;
257 const int N_in = a->num_inputs();
258 gtl::InlinedVector<const Node*, 4> a_control_edges;
259 gtl::InlinedVector<const Node*, 4> b_control_edges;
260 gtl::InlinedVector<std::pair<const Node*, int>, 4> a_in(N_in);
261 gtl::InlinedVector<std::pair<const Node*, int>, 4> b_in(N_in);
262 FillInputs(a, &a_control_edges, &a_in);
263 FillInputs(b, &b_control_edges, &b_in);
264 if (a_in != b_in) return false;
265 if (a_control_edges != b_control_edges) return false;
266
267 return true;
268}
269
270bool OptimizerCSE::Optimize(
271 const std::function<bool(const Node*)>& consider_fn) {
272 // This very simple implementation works if the whole graph is one
273 // giant basic block (because we just traverse nodes in a
274 // topological order). This simple implementation works well
275 // with control flow/loops/etc. But we need to be careful about
276 // control flow if we want to add more sophisticated CSE optimizations.
277
278 // TODO(jeff): We need to handle Update nodes specially, but dealing
279 // with more general control flow will also solve this issue, and for
280 // now, our updates are almost always the most downstream nodes in
281 // the graph.
282 std::vector<Node*> order;
283 GetReversePostOrder(*g_, &order, NodeComparatorID());
284
285 // Our value is just a single Node*, meaning we keep just a single
286 // candidate for a given node hash value. This may cause us to
287 // (rarely) lose some optimization opportunities if there are
288 // hash collisions, but it allows us to avoid having the value
289 // be a set<Node*> (or equivalent).
290 std::unordered_map<size_t, Node*> available;
291
292 // Scratch space for Equivalent calls. Allocated here and passed in to
293 // Equivalent to avoid allocation inside the loop below.
294 bool changed = false;
295 AttrSlice::Scratch scratch;
296 for (Node* n : order) {
297 if (!n->IsOp()) continue;
298
299 // Don't prune placeholder nodes.
300 if (n->type_string() == "Placeholder" ||
301 n->type_string() == "PlaceholderV2" ||
302 n->type_string() == "PlaceholderWithDefault") {
303 continue;
304 }
305
306 // See if we should consider this node at all
307 if (consider_fn != nullptr && !consider_fn(n)) continue;
308
309 size_t h = NodeHash(n);
310 Node** candidate = &available[h];
311 if (*candidate == nullptr) {
312 // No existing match: insert "n" into the hash table under "h"
313 *candidate = n;
314 } else if (Equivalent(*candidate, n, &scratch)) {
315 VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and "
316 << n->name();
317 // *candidate and n are equivalent. Therefore, we can replace
318 // n with *candidate by fixing up outgoing edges from "n" to instead
319 // come from "*candidate", and then delete n from the graph
320 for (const Edge* e : n->out_edges()) {
321 g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input());
322 }
323
324 MergeDebugInfo(NodeDebugInfo(*n), *candidate);
325 g_->RemoveNode(n);
326 changed = true;
327 }
328 }
329 return changed;
330}
331
332bool OptimizeCSE(Graph* g,
333 const std::function<bool(const Node*)>& consider_fn) {
334 OptimizerCSE opt(g);
335 return opt.Optimize(consider_fn);
336}
337
338} // namespace tensorflow
339