1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // This module implements a common subexpression elimination pass. We |
17 | // process the nodes in the graph in reverse postorder |
18 | // (i.e. inputs before their downstream dependencies). The rough algorithm is |
19 | // as follows: |
20 | // |
21 | // std::unordered_map<size_t, Node*> available |
22 | // for each node n in forward topological order: |
23 | // h = NodeHash(n) |
24 | // if available[h] exists and Equivalent(available(h), h) |
25 | // redirect downstream uses of outputs of n to available[h] |
26 | // remove n from graph |
27 | // else |
28 | // if available[h] does not exist |
29 | // available[h] = n |
30 | // |
31 | // This is similar to the global value number algorithm describe in this |
32 | // paper: |
33 | // |
34 | // "Global code motion/global value numbering", Cliff Click, PLDI '95 |
35 | // Proceedings of the ACM SIGPLAN 1995 conference on Programming |
36 | // language design and implementation, Pages 246-257 |
37 | // http://dl.acm.org/citation.cfm?id=207154 |
38 | |
39 | #include "tensorflow/core/graph/optimizer_cse.h" |
40 | |
41 | #include <iostream> |
42 | #include <unordered_map> |
43 | #include <utility> |
44 | #include <vector> |
45 | |
46 | #include "tensorflow/core/framework/node_def.pb.h" |
47 | #include "tensorflow/core/framework/node_def_util.h" |
48 | #include "tensorflow/core/graph/algorithm.h" |
49 | #include "tensorflow/core/graph/graph_node_util.h" |
50 | #include "tensorflow/core/lib/gtl/map_util.h" |
51 | #include "tensorflow/core/lib/hash/hash.h" |
52 | #include "tensorflow/core/platform/logging.h" |
53 | #include "tensorflow/core/platform/protobuf.h" |
54 | |
55 | namespace tensorflow { |
56 | |
57 | class OptimizerCSE { |
58 | public: |
59 | explicit OptimizerCSE(Graph* g) : g_(g) {} |
60 | |
61 | bool Optimize(const std::function<bool(const Node*)>& consider_fn); |
62 | |
63 | private: |
64 | static size_t NodeHash(const Node* n); |
65 | static bool Equivalent(const Node* a, const Node* b, |
66 | AttrSlice::Scratch* scratch); |
67 | |
68 | Graph* g_; |
69 | }; |
70 | |
71 | static void FillInputs(const Node* n, |
72 | gtl::InlinedVector<const Node*, 4>* control_edges, |
73 | gtl::InlinedVector<std::pair<const Node*, int>, 4>* in) { |
74 | DCHECK_EQ(in->size(), n->num_inputs()); |
75 | control_edges->clear(); |
76 | for (const Edge* e : n->in_edges()) { |
77 | if (e->IsControlEdge()) { |
78 | control_edges->push_back(e->src()); |
79 | } else { |
80 | (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); |
81 | } |
82 | } |
83 | std::sort(control_edges->begin(), control_edges->end()); |
84 | if (n->op_def().is_commutative()) { |
85 | // For commutative inputs, we sort the input by the input Node* |
86 | // to get a canonical ordering (so that add(a,b) and add(b, a) will |
87 | // hash to the same value if is_commutative is true for 'add'). |
88 | std::sort(in->begin(), in->end()); |
89 | } |
90 | } |
91 | |
92 | static size_t kIllegalNodeHash = 0; |
93 | |
94 | class Hasher { |
95 | public: |
96 | uint64 hash() { return h_ == kIllegalNodeHash ? kIllegalNodeHash + 1 : h_; } |
97 | |
98 | void MixString(const string& s) { h_ = Hash64(s.data(), s.size(), h_); } |
99 | |
100 | void MixInteger(size_t z) { h_ = Hash64Combine(h_, z); } |
101 | |
102 | void MixProto(const protobuf::MessageLite& msg) { |
103 | msg.ByteSizeLong(); // Ensure sizes are cached accurately. |
104 | HashingOutputStream hasher; |
105 | { |
106 | // CodedOutputStream doesn't call BackUp until it's destroyed, so we need |
107 | // it to be destroyed before we call hasher.hash(). |
108 | protobuf::io::CodedOutputStream stream(&hasher); |
109 | stream.EnableAliasing(true); |
110 | stream.SetSerializationDeterministic(true); |
111 | msg.SerializeWithCachedSizes(&stream); |
112 | } |
113 | h_ = Hash64Combine(h_, hasher.hash()); |
114 | } |
115 | |
116 | private: |
117 | // HashingOutputStream produces the same exact hash as if you serialized the |
118 | // proto and hashed it sequentially in kBufSize chunks, except it doesn't |
119 | // manifest the entire proto into memory at any point. |
120 | class HashingOutputStream : public protobuf::io::ZeroCopyOutputStream { |
121 | public: |
122 | // This kBufSize makes sizeof(HashingOutputStream) == 256. It's not chosen |
123 | // for any particular reason except it's a nice even number of cache lines. |
124 | static constexpr size_t kBufSize = 228; |
125 | static constexpr uint64 kDefaultSeed = 2570847921467975139ULL; |
126 | bool Next(void** data, int* size) override { |
127 | if (i_ == kBufSize) { |
128 | // Mix the chunk in. |
129 | Mix(buf_, kBufSize); |
130 | *data = buf_; |
131 | *size = kBufSize; |
132 | } else { |
133 | *data = buf_ + i_; |
134 | *size = kBufSize - i_; |
135 | } |
136 | // We always set i_ to be past the end, since we've given the rest of buf_ |
137 | // out. |
138 | i_ = kBufSize; |
139 | return true; |
140 | } |
141 | |
142 | void BackUp(int count) override { i_ -= count; } |
143 | |
144 | int64_t ByteCount() const override { return byte_count_; } |
145 | |
146 | bool WriteAliasedRaw(const void* void_data, int size) override { |
147 | // We can't do math on void*. |
148 | const char* data = static_cast<const char*>(void_data); |
149 | const auto remaining = kBufSize - i_; |
150 | if (remaining > 0) { |
151 | if (size < remaining) { |
152 | memcpy(buf_ + i_, data, size); |
153 | i_ += size; |
154 | return true; |
155 | } |
156 | memcpy(buf_ + i_, data, remaining); |
157 | i_ = kBufSize; |
158 | data += remaining; |
159 | size -= remaining; |
160 | } |
161 | if (i_ == kBufSize) { |
162 | Mix(buf_, kBufSize); |
163 | i_ = 0; |
164 | } |
165 | while (size >= kBufSize) { |
166 | Mix(data, kBufSize); |
167 | data += kBufSize; |
168 | size -= kBufSize; |
169 | } |
170 | memcpy(buf_, data, size); |
171 | i_ = size; |
172 | return true; |
173 | } |
174 | |
175 | bool AllowsAliasing() const override { return true; } |
176 | |
177 | uint64 hash() { |
178 | if (i_ != 0) { |
179 | Mix(buf_, i_); |
180 | i_ = 0; |
181 | } |
182 | return h_; |
183 | } |
184 | |
185 | private: |
186 | void Mix(const char* p, size_t n) { |
187 | byte_count_ += n; |
188 | h_ = Hash64(p, n, h_); |
189 | } |
190 | char buf_[kBufSize]; |
191 | int i_ = 0; |
192 | int64_t byte_count_ = 0; |
193 | uint64 h_ = kDefaultSeed; |
194 | }; |
195 | |
196 | uint64 h_ = HashingOutputStream::kDefaultSeed; |
197 | }; |
198 | |
199 | size_t OptimizerCSE::NodeHash(const Node* n) { |
200 | Hasher hasher; |
201 | hasher.MixString(n->type_string()); |
202 | hasher.MixInteger(n->output_types().size()); |
203 | for (DataType dt : n->output_types()) { |
204 | hasher.MixInteger(dt); |
205 | } |
206 | |
207 | hasher.MixInteger(n->num_inputs()); |
208 | gtl::InlinedVector<const Node*, 4> control_edges; |
209 | gtl::InlinedVector<std::pair<const Node*, int>, 4> in(n->num_inputs()); |
210 | FillInputs(n, &control_edges, &in); |
211 | for (const auto& edge : in) { |
212 | hasher.MixInteger(edge.first->id()); |
213 | hasher.MixInteger(edge.second); |
214 | } |
215 | |
216 | #if !defined(__ANDROID__) |
217 | // Hash the attrs. For example, this makes sure different constants |
218 | // end up in different hash buckets. |
219 | size_t attr_hashes = 0; |
220 | for (const auto& attr : n->attrs()) { |
221 | Hasher h; |
222 | h.MixString(attr.first); |
223 | h.MixProto(attr.second); |
224 | attr_hashes = Hash64CombineUnordered(attr_hashes, h.hash()); |
225 | } |
226 | hasher.MixInteger(attr_hashes); |
227 | #endif |
228 | |
229 | return hasher.hash(); |
230 | } |
231 | |
232 | static bool HasRefInput(const Node* n) { |
233 | for (auto dt : n->input_types()) { |
234 | if (IsRefType(dt)) return true; |
235 | } |
236 | return false; |
237 | } |
238 | |
239 | bool OptimizerCSE::Equivalent(const Node* a, const Node* b, |
240 | AttrSlice::Scratch* scratch) { |
241 | // Different op names are different |
242 | if (a->type_string() != b->type_string()) return false; |
243 | |
244 | // Never consider stateful nodes (such as non-const inputs) equivalent. |
245 | if (a->op_def().is_stateful()) return false; |
246 | |
247 | // For now, we consider any node that takes a ref input to not be |
248 | // equivalent to any other node. |
249 | if (HasRefInput(a) || HasRefInput(b)) return false; |
250 | |
251 | // Compare attrs. Note that equal attrs implies equal input and |
252 | // output types. |
253 | if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false; |
254 | |
255 | // Compare input sources |
256 | if (a->num_inputs() != b->num_inputs()) return false; |
257 | const int N_in = a->num_inputs(); |
258 | gtl::InlinedVector<const Node*, 4> a_control_edges; |
259 | gtl::InlinedVector<const Node*, 4> b_control_edges; |
260 | gtl::InlinedVector<std::pair<const Node*, int>, 4> a_in(N_in); |
261 | gtl::InlinedVector<std::pair<const Node*, int>, 4> b_in(N_in); |
262 | FillInputs(a, &a_control_edges, &a_in); |
263 | FillInputs(b, &b_control_edges, &b_in); |
264 | if (a_in != b_in) return false; |
265 | if (a_control_edges != b_control_edges) return false; |
266 | |
267 | return true; |
268 | } |
269 | |
270 | bool OptimizerCSE::Optimize( |
271 | const std::function<bool(const Node*)>& consider_fn) { |
272 | // This very simple implementation works if the whole graph is one |
273 | // giant basic block (because we just traverse nodes in a |
274 | // topological order). This simple implementation works well |
275 | // with control flow/loops/etc. But we need to be careful about |
276 | // control flow if we want to add more sophisticated CSE optimizations. |
277 | |
278 | // TODO(jeff): We need to handle Update nodes specially, but dealing |
279 | // with more general control flow will also solve this issue, and for |
280 | // now, our updates are almost always the most downstream nodes in |
281 | // the graph. |
282 | std::vector<Node*> order; |
283 | GetReversePostOrder(*g_, &order, NodeComparatorID()); |
284 | |
285 | // Our value is just a single Node*, meaning we keep just a single |
286 | // candidate for a given node hash value. This may cause us to |
287 | // (rarely) lose some optimization opportunities if there are |
288 | // hash collisions, but it allows us to avoid having the value |
289 | // be a set<Node*> (or equivalent). |
290 | std::unordered_map<size_t, Node*> available; |
291 | |
292 | // Scratch space for Equivalent calls. Allocated here and passed in to |
293 | // Equivalent to avoid allocation inside the loop below. |
294 | bool changed = false; |
295 | AttrSlice::Scratch scratch; |
296 | for (Node* n : order) { |
297 | if (!n->IsOp()) continue; |
298 | |
299 | // Don't prune placeholder nodes. |
300 | if (n->type_string() == "Placeholder" || |
301 | n->type_string() == "PlaceholderV2" || |
302 | n->type_string() == "PlaceholderWithDefault" ) { |
303 | continue; |
304 | } |
305 | |
306 | // See if we should consider this node at all |
307 | if (consider_fn != nullptr && !consider_fn(n)) continue; |
308 | |
309 | size_t h = NodeHash(n); |
310 | Node** candidate = &available[h]; |
311 | if (*candidate == nullptr) { |
312 | // No existing match: insert "n" into the hash table under "h" |
313 | *candidate = n; |
314 | } else if (Equivalent(*candidate, n, &scratch)) { |
315 | VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and " |
316 | << n->name(); |
317 | // *candidate and n are equivalent. Therefore, we can replace |
318 | // n with *candidate by fixing up outgoing edges from "n" to instead |
319 | // come from "*candidate", and then delete n from the graph |
320 | for (const Edge* e : n->out_edges()) { |
321 | g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input()); |
322 | } |
323 | |
324 | MergeDebugInfo(NodeDebugInfo(*n), *candidate); |
325 | g_->RemoveNode(n); |
326 | changed = true; |
327 | } |
328 | } |
329 | return changed; |
330 | } |
331 | |
332 | bool OptimizeCSE(Graph* g, |
333 | const std::function<bool(const Node*)>& consider_fn) { |
334 | OptimizerCSE opt(g); |
335 | return opt.Optimize(consider_fn); |
336 | } |
337 | |
338 | } // namespace tensorflow |
339 | |