1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_ |
17 | #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_ |
18 | |
19 | #include <functional> |
20 | #include <iterator> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "absl/container/flat_hash_set.h" |
25 | #include "absl/container/node_hash_map.h" |
26 | #include "absl/strings/string_view.h" |
27 | #include "absl/types/span.h" |
28 | #include "tensorflow/core/framework/graph.pb.h" |
29 | #include "tensorflow/core/framework/node_def.pb.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/types.h" |
32 | #include "tensorflow/core/graph/tensor_id.h" |
33 | #include "tensorflow/core/lib/core/status.h" |
34 | #include "tensorflow/core/lib/core/stringpiece.h" |
35 | #include "tensorflow/core/lib/core/threadpool.h" |
36 | #include "tensorflow/core/lib/gtl/flatmap.h" |
37 | #include "tensorflow/core/lib/gtl/flatset.h" |
38 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
39 | #include "tensorflow/core/platform/types.h" |
40 | |
41 | namespace tensorflow { |
42 | namespace grappler { |
43 | |
44 | // Utilities for manipulating node name and input strings. |
45 | |
46 | // Returns the trailing position number (or zero if no number is present) if |
47 | // NodeName(input_name) is equal to node_name. Returns -1 for control inputs. |
48 | // Returns -2 if input_name is empty or NodeName(input_name) is not equal to |
49 | // node_name. |
50 | inline int NodePositionIfSameNode(absl::string_view input_name, |
51 | absl::string_view node_name) { |
52 | bool is_control = absl::StartsWith(input_name, "^" ); |
53 | if (is_control) input_name.remove_prefix(1); |
54 | if (input_name.empty() || node_name.empty() || |
55 | input_name.size() < node_name.size()) { |
56 | return -2; |
57 | } |
58 | TensorId id = ParseTensorName(input_name); |
59 | if (id.first != node_name) return -2; |
60 | if (is_control) return -1; |
61 | return id.second; |
62 | } |
63 | |
64 | // Returns the node name and position in a single call. |
65 | inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name, |
66 | int* position) { |
67 | const bool is_control = absl::StartsWith(name, "^" ); |
68 | TensorId id = ParseTensorName(name); |
69 | if (position) { |
70 | *position = is_control ? -1 : id.second; |
71 | } |
72 | if (is_control && id.second >= 0) { |
73 | id.first.remove_prefix(1); |
74 | } |
75 | return id.first; |
76 | } |
77 | |
78 | // Returns the node name and position in a single call. |
79 | inline string ParseNodeName(const string& name, int* position) { |
80 | return string(ParseNodeNameAsStringPiece(name, position)); |
81 | } |
82 | |
83 | // Return the node name corresponding to 'name' if name is valid, or the empty |
84 | // string otherwise. |
85 | inline StringPiece NodeNameAsStringPiece(const string& name) { |
86 | return ParseNodeNameAsStringPiece(name, nullptr); |
87 | } |
88 | |
89 | // Return the node name corresponding to 'name' if name is valid, or the empty |
90 | // string otherwise. |
91 | inline string NodeName(const string& name) { |
92 | return string(NodeNameAsStringPiece(name)); |
93 | } |
94 | |
95 | inline int NodePosition(const string& name) { |
96 | int position; |
97 | ParseNodeNameAsStringPiece(name, &position); |
98 | return position; |
99 | } |
100 | |
101 | namespace internal { |
102 | // Base template class for NodeMap and ImmutableNodeMap. |
103 | template <typename GraphDefT, typename NodeDefT> |
104 | class NodeMapInternal { |
105 | public: |
106 | // Note: The NodeMap will store pointers to nodes in graph, which may become |
107 | // invalid if graph is changed. |
108 | explicit NodeMapInternal(GraphDefT* graph) { |
109 | if (graph == nullptr) { |
110 | LOG(WARNING) << "NodeMapInternal constructor is called with a nullptr!" ; |
111 | return; |
112 | } |
113 | nodes_.reserve(graph->node_size()); |
114 | outputs_.reserve(graph->node_size()); |
115 | for (int i = 0; i < graph->node_size(); i++) { |
116 | NodeDefT* node = GetNodeDefFromGraph(graph, i); |
117 | const string& node_name = node->name(); |
118 | auto rslt = nodes_.emplace(node_name, node); |
119 | // Check that the graph doesn't contain multiple nodes with the same name. |
120 | if (!rslt.second) { |
121 | // The first node found with a given name becomes the canonical. |
122 | LOG(WARNING) << "Duplicated node in the graph: " << node_name; |
123 | } |
124 | NodeDefT* canonical = rslt.second ? node : rslt.first->second; |
125 | for (const auto& input : node->input()) { |
126 | outputs_[NodeName(input)].insert(canonical); |
127 | } |
128 | } |
129 | } |
130 | |
131 | // Get unordered list of fanouts from node. Notice, that the order is |
132 | // non-deterministic. |
133 | const absl::flat_hash_set<NodeDefT*>& GetOutputs( |
134 | const string& node_name) const { |
135 | auto it = outputs_.find(node_name); |
136 | if (it == outputs_.end()) { |
137 | return empty_set_; |
138 | } |
139 | return it->second; |
140 | } |
141 | |
142 | // Get fanouts ordered by name. |
143 | std::vector<NodeDefT*> GetOutputsOrderedByNodeName( |
144 | const string& node_name) const { |
145 | std::vector<NodeDefT*> result; |
146 | auto it = outputs_.find(node_name); |
147 | if (it != outputs_.end()) { |
148 | const absl::flat_hash_set<NodeDefT*>& outputs = it->second; |
149 | result.reserve(outputs.size()); |
150 | result.assign(outputs.begin(), outputs.end()); |
151 | std::sort(result.begin(), result.end(), |
152 | [](const NodeDef* n1, const NodeDef* n2) { |
153 | return n1->name() < n2->name(); |
154 | }); |
155 | } |
156 | return result; |
157 | } |
158 | |
159 | // This method doesn't record the outputs of the added node; the outputs need |
160 | // to be explicitly added by the AddOutput method. |
161 | void AddNode(const string& node_name, NodeDefT* node) { |
162 | DCHECK(node != nullptr); |
163 | auto ret = nodes_.emplace(node_name, node); |
164 | DCHECK(ret.second) |
165 | << "Pair (" << node_name << "," << node |
166 | << ") is not inserted because the same key already exists." ; |
167 | } |
168 | |
169 | void RemoveNode(const string& name) { |
170 | nodes_.erase(NodeName(name)); |
171 | outputs_.erase(NodeName(name)); |
172 | } |
173 | |
174 | NodeDefT* GetNode(const string& name) const { |
175 | const string node_name = NodeName(name); |
176 | auto it = nodes_.find(node_name); |
177 | if (it == nodes_.end()) { |
178 | VLOG(1) << "Node could not be found: " << name; |
179 | return nullptr; |
180 | } |
181 | return it->second; |
182 | } |
183 | |
184 | bool NodeExists(const string& name) const { |
185 | const string node_name = NodeName(name); |
186 | return nodes_.find(node_name) != nodes_.end(); |
187 | } |
188 | |
189 | void AddOutput(const string& node_name, const string& output_name) { |
190 | auto output_node = nodes_[NodeName(output_name)]; |
191 | DCHECK(output_node) << "Output node " << output_name |
192 | << " is missing in NodeMap." ; |
193 | outputs_[node_name].insert(output_node); |
194 | } |
195 | |
196 | void RemoveOutput(const string& node_name, const string& output_name) { |
197 | outputs_[node_name].erase(nodes_[NodeName(output_name)]); |
198 | } |
199 | |
200 | void UpdateInput(const string& node_name, const string& old_input_name, |
201 | const string& new_input_name) { |
202 | RemoveOutput(NodeName(old_input_name), node_name); |
203 | AddOutput(NodeName(new_input_name), node_name); |
204 | } |
205 | |
206 | void RemoveInputs(const string& node_name) { |
207 | auto node = nodes_[node_name]; |
208 | for (const auto& input : node->input()) { |
209 | RemoveOutput(NodeName(input), node->name()); |
210 | } |
211 | } |
212 | |
213 | void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); } |
214 | |
215 | void UpdateOutput(const string& node_name, const string& old_output_name, |
216 | const string& new_output_name) { |
217 | absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name]; |
218 | outputs.erase(nodes_[NodeName(old_output_name)]); |
219 | outputs.insert(nodes_[NodeName(new_output_name)]); |
220 | } |
221 | |
222 | private: |
223 | // Helper method to get the NodeDef pointer of i-th node in a graph. |
224 | inline NodeDefT* GetNodeDefFromGraph(GraphDefT* graph, int64_t i) const; |
225 | |
226 | const absl::flat_hash_set<NodeDefT*> empty_set_; |
227 | absl::node_hash_map<string, NodeDefT*> nodes_; |
228 | absl::node_hash_map<string, absl::flat_hash_set<NodeDefT*>> outputs_; |
229 | }; |
230 | |
231 | // Specialized template class method GetNodeDefFromGraph. |
232 | template <> |
233 | inline NodeDef* NodeMapInternal<GraphDef, NodeDef>::GetNodeDefFromGraph( |
234 | GraphDef* graph, int64_t i) const { |
235 | return graph->mutable_node(i); |
236 | } |
237 | |
238 | template <> |
239 | inline const NodeDef* |
240 | NodeMapInternal<const GraphDef, const NodeDef>::GetNodeDefFromGraph( |
241 | const GraphDef* graph, int64_t i) const { |
242 | return &graph->node(i); |
243 | } |
244 | } // namespace internal |
245 | |
246 | // A utility class to lookup a node and its outputs by node name. |
247 | class NodeMap : public internal::NodeMapInternal<GraphDef, NodeDef> { |
248 | public: |
249 | explicit NodeMap(GraphDef* graph) : NodeMapInternal(graph) {} |
250 | }; |
251 | |
252 | // Same to NodeMap, but uses const GraphDef. |
253 | class ImmutableNodeMap |
254 | : public internal::NodeMapInternal<const GraphDef, const NodeDef> { |
255 | public: |
256 | explicit ImmutableNodeMap(const GraphDef* graph) : NodeMapInternal(graph) {} |
257 | }; |
258 | |
259 | // A vector with a set. The set stores the same elements as the vector, and |
260 | // quickly answers whether a value is in the vector. Duplicated elements are not |
261 | // allowed for now. |
262 | template <class T, class Hash = std::hash<T>> |
263 | class SetVector { |
264 | public: |
265 | // Returns false if value already existed in the set, true otherwise. |
266 | bool PushBack(const T& value) { |
267 | if (!set_.insert(value).second) { |
268 | return false; |
269 | } |
270 | vector_.push_back(value); |
271 | return true; |
272 | } |
273 | |
274 | T PopBack() { |
275 | T back = vector_.back(); |
276 | set_.erase(back); |
277 | vector_.pop_back(); |
278 | return back; |
279 | } |
280 | |
281 | bool Exists(const T& value) const { return set_.find(value) != set_.end(); } |
282 | |
283 | bool Empty() const { return vector_.empty(); } |
284 | |
285 | void Reserve(int64_t size) { vector_.reserve(size); } |
286 | |
287 | private: |
288 | gtl::FlatSet<T, Hash> set_; |
289 | std::vector<T> vector_; |
290 | }; |
291 | |
292 | // Returns formatted string from TensorId specific to grappler. Specifically, |
293 | // for the 0 port (first output), only the node name is returned. |
294 | string TensorIdToString(const TensorId& tensor_id); |
295 | |
296 | // Returns formatted string from SafeTensorId specific to grappler. |
297 | // Specifically, for the 0 port (first output), only the node name is returned. |
298 | string SafeTensorIdToString(const SafeTensorId& tensor_id); |
299 | |
300 | // True iff 'name' refers to a control inputs, i.e. a node name prefixed with |
301 | // the ^ character. |
302 | bool IsControlInput(absl::string_view name); |
303 | |
304 | // True iff tensor index refers to a control input. |
305 | bool IsControlInput(const TensorId& tensor_id); |
306 | |
307 | // True iff 'name1' and 'name2' refer to the same input. |
308 | bool IsSameInput(const string& name1, const string& name2); |
309 | |
310 | |
311 | // Add a prefix to a node name with a custom delimiter. |
312 | string AddPrefixToNodeName(const string& name, const string& prefix, |
313 | const string& delimiter); |
314 | |
315 | // Add a prefix to a node name. |
316 | string AddPrefixToNodeName(const string& name, const string& prefix); |
317 | |
318 | // Executes a 'fn' in the 'thread_pool'. The method waits for the configured |
319 | // timeout (in milliseconds) for 'fn' to complete, before returning false. |
320 | // |
321 | // If returning false, the 'fn' may still continue to execute in the |
322 | // thread-pool. It is the responsibility of the caller to reset the thread-pool |
323 | // as appropriate. |
324 | bool ExecuteWithTimeout(std::function<void()> fn, int64_t timeout_in_ms, |
325 | thread::ThreadPool* thread_pool); |
326 | |
327 | // Returns the node name prefixed with conventional symbol '^' |
328 | // for control dependency, given a NodeDef. |
329 | string AsControlDependency(const NodeDef& node); |
330 | |
331 | // Returns the node name prefixed with conventional symbol '^' |
332 | // for control dependency, given a node name |
333 | string AsControlDependency(const string& node); |
334 | |
335 | // Returns true if the node is assigned to run on CPU device. |
336 | bool NodeIsOnCpu(const NodeDef* node); |
337 | |
338 | // Returns true if the node is assigned to run on GPU device. |
339 | bool NodeIsOnGpu(const NodeDef* node); |
340 | |
341 | // Returns the number of outputs of a node according to its OpDef. Note that |
342 | // some of the outputs may be unconnected. |
343 | int NumOutputs(const NodeDef& node, GraphDef* graph); |
344 | |
345 | // Returns true iff the node has at least one control input. |
346 | bool HasControlInputs(const NodeDef& node); |
347 | |
348 | // Returns true iff the node has at least one regular input. |
349 | bool HasRegularInputs(const NodeDef& node); |
350 | |
351 | // Returns true iff the node has at least one regular output. |
352 | bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map); |
353 | |
354 | // Returns true iff the node has at least one control output. |
355 | bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map); |
356 | |
357 | // Number of connected control inputs. |
358 | int NumControlInputs(const NodeDef& node); |
359 | |
360 | // Number of connected non-control inputs. |
361 | int NumNonControlInputs(const NodeDef& node); |
362 | |
363 | // Number of connected control outputs. |
364 | int NumControlOutputs(const NodeDef& node, const NodeMap& node_map); |
365 | |
366 | // Number of connected non-control outputs. |
367 | int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); |
368 | |
369 | // Number of connected non-control data outputs (Ops that consume output tensor |
370 | // data, not just it's shape). |
371 | int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map); |
372 | |
373 | // Removes redundant control inputs from node. |
374 | void DedupControlInputs(NodeDef* node); |
375 | |
376 | // Returns an error if an attribute with the given key does not exist in node. |
377 | Status CheckAttrExists(const NodeDef& node, const string& key); |
378 | |
379 | // Returns an error if attributes with the given keys do not exist in node. |
380 | Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys); |
381 | |
382 | // Returns the data type in attribute `attr_name` of `node`. If that attribute |
383 | // doesn't exist, returns DT_INVALID. |
384 | DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr); |
385 | |
386 | // Returns the last node in the simple chain starting at source and traversing |
387 | // through the input(0) edge from each node as long as the next node satisfies |
388 | // the predicate given in pred_fn. If no nodes satisfy the predicate, &source |
389 | // will be returned. Example: For the chain |
390 | // source <- a <- b <- ... <- y <- z |
391 | // where |
392 | // pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true, |
393 | // pred_fn(z) = false, |
394 | // the return value will be a pointer to y. |
395 | NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, |
396 | bool follow_control_input, |
397 | const std::function<bool(const NodeDef&)>& pred_fn); |
398 | |
399 | // Permute the nodes of graph in place according to the permutation. |
400 | void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation, |
401 | bool invert_permutation); |
402 | |
403 | // Returns OkStatus() if a kernel is registered for node.op() on the device |
404 | // type corresponding to node.device(). |
405 | Status IsKernelRegisteredForNode( |
406 | absl::string_view node_name, bool has_experimental_debug_info, |
407 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info, |
408 | absl::string_view node_op, absl::string_view node_device, |
409 | AttrSlice node_attrs); |
410 | Status IsKernelRegisteredForNode(const NodeDef& node); |
411 | |
412 | Status SetTensorValue(DataType dtype, int value, Tensor* tensor); |
413 | |
414 | void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph); |
415 | |
416 | void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph); |
417 | |
418 | void EraseNodesFromGraph(const std::set<string>& nodes_to_delete, |
419 | GraphDef* graph); |
420 | |
421 | // Erase all attributes without leading underscore. Returns the number of |
422 | // attributes erased. |
423 | int EraseRegularNodeAttributes(NodeDef* node); |
424 | |
425 | // Erase attribute "_xla_inferred_shapes" as well as all attributes starting in |
426 | // "_output_". |
427 | int EraseNodeOutputAttributes(NodeDef* node); |
428 | |
429 | } // end namespace grappler |
430 | } // end namespace tensorflow |
431 | |
432 | #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_ |
433 | |