1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
17#define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
18
19#include <functional>
20#include <iterator>
21#include <utility>
22#include <vector>
23
24#include "absl/container/flat_hash_set.h"
25#include "absl/container/node_hash_map.h"
26#include "absl/strings/string_view.h"
27#include "absl/types/span.h"
28#include "tensorflow/core/framework/graph.pb.h"
29#include "tensorflow/core/framework/node_def.pb.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/graph/tensor_id.h"
33#include "tensorflow/core/lib/core/status.h"
34#include "tensorflow/core/lib/core/stringpiece.h"
35#include "tensorflow/core/lib/core/threadpool.h"
36#include "tensorflow/core/lib/gtl/flatmap.h"
37#include "tensorflow/core/lib/gtl/flatset.h"
38#include "tensorflow/core/lib/gtl/inlined_vector.h"
39#include "tensorflow/core/platform/types.h"
40
41namespace tensorflow {
42namespace grappler {
43
44// Utilities for manipulating node name and input strings.
45
46// Returns the trailing position number (or zero if no number is present) if
47// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
48// Returns -2 if input_name is empty or NodeName(input_name) is not equal to
49// node_name.
50inline int NodePositionIfSameNode(absl::string_view input_name,
51 absl::string_view node_name) {
52 bool is_control = absl::StartsWith(input_name, "^");
53 if (is_control) input_name.remove_prefix(1);
54 if (input_name.empty() || node_name.empty() ||
55 input_name.size() < node_name.size()) {
56 return -2;
57 }
58 TensorId id = ParseTensorName(input_name);
59 if (id.first != node_name) return -2;
60 if (is_control) return -1;
61 return id.second;
62}
63
64// Returns the node name and position in a single call.
65inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name,
66 int* position) {
67 const bool is_control = absl::StartsWith(name, "^");
68 TensorId id = ParseTensorName(name);
69 if (position) {
70 *position = is_control ? -1 : id.second;
71 }
72 if (is_control && id.second >= 0) {
73 id.first.remove_prefix(1);
74 }
75 return id.first;
76}
77
78// Returns the node name and position in a single call.
79inline string ParseNodeName(const string& name, int* position) {
80 return string(ParseNodeNameAsStringPiece(name, position));
81}
82
83// Return the node name corresponding to 'name' if name is valid, or the empty
84// string otherwise.
85inline StringPiece NodeNameAsStringPiece(const string& name) {
86 return ParseNodeNameAsStringPiece(name, nullptr);
87}
88
89// Return the node name corresponding to 'name' if name is valid, or the empty
90// string otherwise.
91inline string NodeName(const string& name) {
92 return string(NodeNameAsStringPiece(name));
93}
94
95inline int NodePosition(const string& name) {
96 int position;
97 ParseNodeNameAsStringPiece(name, &position);
98 return position;
99}
100
101namespace internal {
102// Base template class for NodeMap and ImmutableNodeMap.
103template <typename GraphDefT, typename NodeDefT>
104class NodeMapInternal {
105 public:
106 // Note: The NodeMap will store pointers to nodes in graph, which may become
107 // invalid if graph is changed.
108 explicit NodeMapInternal(GraphDefT* graph) {
109 if (graph == nullptr) {
110 LOG(WARNING) << "NodeMapInternal constructor is called with a nullptr!";
111 return;
112 }
113 nodes_.reserve(graph->node_size());
114 outputs_.reserve(graph->node_size());
115 for (int i = 0; i < graph->node_size(); i++) {
116 NodeDefT* node = GetNodeDefFromGraph(graph, i);
117 const string& node_name = node->name();
118 auto rslt = nodes_.emplace(node_name, node);
119 // Check that the graph doesn't contain multiple nodes with the same name.
120 if (!rslt.second) {
121 // The first node found with a given name becomes the canonical.
122 LOG(WARNING) << "Duplicated node in the graph: " << node_name;
123 }
124 NodeDefT* canonical = rslt.second ? node : rslt.first->second;
125 for (const auto& input : node->input()) {
126 outputs_[NodeName(input)].insert(canonical);
127 }
128 }
129 }
130
131 // Get unordered list of fanouts from node. Notice, that the order is
132 // non-deterministic.
133 const absl::flat_hash_set<NodeDefT*>& GetOutputs(
134 const string& node_name) const {
135 auto it = outputs_.find(node_name);
136 if (it == outputs_.end()) {
137 return empty_set_;
138 }
139 return it->second;
140 }
141
142 // Get fanouts ordered by name.
143 std::vector<NodeDefT*> GetOutputsOrderedByNodeName(
144 const string& node_name) const {
145 std::vector<NodeDefT*> result;
146 auto it = outputs_.find(node_name);
147 if (it != outputs_.end()) {
148 const absl::flat_hash_set<NodeDefT*>& outputs = it->second;
149 result.reserve(outputs.size());
150 result.assign(outputs.begin(), outputs.end());
151 std::sort(result.begin(), result.end(),
152 [](const NodeDef* n1, const NodeDef* n2) {
153 return n1->name() < n2->name();
154 });
155 }
156 return result;
157 }
158
159 // This method doesn't record the outputs of the added node; the outputs need
160 // to be explicitly added by the AddOutput method.
161 void AddNode(const string& node_name, NodeDefT* node) {
162 DCHECK(node != nullptr);
163 auto ret = nodes_.emplace(node_name, node);
164 DCHECK(ret.second)
165 << "Pair (" << node_name << "," << node
166 << ") is not inserted because the same key already exists.";
167 }
168
169 void RemoveNode(const string& name) {
170 nodes_.erase(NodeName(name));
171 outputs_.erase(NodeName(name));
172 }
173
174 NodeDefT* GetNode(const string& name) const {
175 const string node_name = NodeName(name);
176 auto it = nodes_.find(node_name);
177 if (it == nodes_.end()) {
178 VLOG(1) << "Node could not be found: " << name;
179 return nullptr;
180 }
181 return it->second;
182 }
183
184 bool NodeExists(const string& name) const {
185 const string node_name = NodeName(name);
186 return nodes_.find(node_name) != nodes_.end();
187 }
188
189 void AddOutput(const string& node_name, const string& output_name) {
190 auto output_node = nodes_[NodeName(output_name)];
191 DCHECK(output_node) << "Output node " << output_name
192 << " is missing in NodeMap.";
193 outputs_[node_name].insert(output_node);
194 }
195
196 void RemoveOutput(const string& node_name, const string& output_name) {
197 outputs_[node_name].erase(nodes_[NodeName(output_name)]);
198 }
199
200 void UpdateInput(const string& node_name, const string& old_input_name,
201 const string& new_input_name) {
202 RemoveOutput(NodeName(old_input_name), node_name);
203 AddOutput(NodeName(new_input_name), node_name);
204 }
205
206 void RemoveInputs(const string& node_name) {
207 auto node = nodes_[node_name];
208 for (const auto& input : node->input()) {
209 RemoveOutput(NodeName(input), node->name());
210 }
211 }
212
213 void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
214
215 void UpdateOutput(const string& node_name, const string& old_output_name,
216 const string& new_output_name) {
217 absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
218 outputs.erase(nodes_[NodeName(old_output_name)]);
219 outputs.insert(nodes_[NodeName(new_output_name)]);
220 }
221
222 private:
223 // Helper method to get the NodeDef pointer of i-th node in a graph.
224 inline NodeDefT* GetNodeDefFromGraph(GraphDefT* graph, int64_t i) const;
225
226 const absl::flat_hash_set<NodeDefT*> empty_set_;
227 absl::node_hash_map<string, NodeDefT*> nodes_;
228 absl::node_hash_map<string, absl::flat_hash_set<NodeDefT*>> outputs_;
229};
230
231// Specialized template class method GetNodeDefFromGraph.
232template <>
233inline NodeDef* NodeMapInternal<GraphDef, NodeDef>::GetNodeDefFromGraph(
234 GraphDef* graph, int64_t i) const {
235 return graph->mutable_node(i);
236}
237
238template <>
239inline const NodeDef*
240NodeMapInternal<const GraphDef, const NodeDef>::GetNodeDefFromGraph(
241 const GraphDef* graph, int64_t i) const {
242 return &graph->node(i);
243}
244} // namespace internal
245
246// A utility class to lookup a node and its outputs by node name.
247class NodeMap : public internal::NodeMapInternal<GraphDef, NodeDef> {
248 public:
249 explicit NodeMap(GraphDef* graph) : NodeMapInternal(graph) {}
250};
251
252// Same to NodeMap, but uses const GraphDef.
253class ImmutableNodeMap
254 : public internal::NodeMapInternal<const GraphDef, const NodeDef> {
255 public:
256 explicit ImmutableNodeMap(const GraphDef* graph) : NodeMapInternal(graph) {}
257};
258
259// A vector with a set. The set stores the same elements as the vector, and
260// quickly answers whether a value is in the vector. Duplicated elements are not
261// allowed for now.
262template <class T, class Hash = std::hash<T>>
263class SetVector {
264 public:
265 // Returns false if value already existed in the set, true otherwise.
266 bool PushBack(const T& value) {
267 if (!set_.insert(value).second) {
268 return false;
269 }
270 vector_.push_back(value);
271 return true;
272 }
273
274 T PopBack() {
275 T back = vector_.back();
276 set_.erase(back);
277 vector_.pop_back();
278 return back;
279 }
280
281 bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
282
283 bool Empty() const { return vector_.empty(); }
284
285 void Reserve(int64_t size) { vector_.reserve(size); }
286
287 private:
288 gtl::FlatSet<T, Hash> set_;
289 std::vector<T> vector_;
290};
291
292// Returns formatted string from TensorId specific to grappler. Specifically,
293// for the 0 port (first output), only the node name is returned.
294string TensorIdToString(const TensorId& tensor_id);
295
296// Returns formatted string from SafeTensorId specific to grappler.
297// Specifically, for the 0 port (first output), only the node name is returned.
298string SafeTensorIdToString(const SafeTensorId& tensor_id);
299
300// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
301// the ^ character.
302bool IsControlInput(absl::string_view name);
303
304// True iff tensor index refers to a control input.
305bool IsControlInput(const TensorId& tensor_id);
306
307// True iff 'name1' and 'name2' refer to the same input.
308bool IsSameInput(const string& name1, const string& name2);
309
310
311// Add a prefix to a node name with a custom delimiter.
312string AddPrefixToNodeName(const string& name, const string& prefix,
313 const string& delimiter);
314
315// Add a prefix to a node name.
316string AddPrefixToNodeName(const string& name, const string& prefix);
317
318// Executes a 'fn' in the 'thread_pool'. The method waits for the configured
319// timeout (in milliseconds) for 'fn' to complete, before returning false.
320//
321// If returning false, the 'fn' may still continue to execute in the
322// thread-pool. It is the responsibility of the caller to reset the thread-pool
323// as appropriate.
324bool ExecuteWithTimeout(std::function<void()> fn, int64_t timeout_in_ms,
325 thread::ThreadPool* thread_pool);
326
327// Returns the node name prefixed with conventional symbol '^'
328// for control dependency, given a NodeDef.
329string AsControlDependency(const NodeDef& node);
330
331// Returns the node name prefixed with conventional symbol '^'
332// for control dependency, given a node name
333string AsControlDependency(const string& node);
334
335// Returns true if the node is assigned to run on CPU device.
336bool NodeIsOnCpu(const NodeDef* node);
337
338// Returns true if the node is assigned to run on GPU device.
339bool NodeIsOnGpu(const NodeDef* node);
340
341// Returns the number of outputs of a node according to its OpDef. Note that
342// some of the outputs may be unconnected.
343int NumOutputs(const NodeDef& node, GraphDef* graph);
344
345// Returns true iff the node has at least one control input.
346bool HasControlInputs(const NodeDef& node);
347
348// Returns true iff the node has at least one regular input.
349bool HasRegularInputs(const NodeDef& node);
350
351// Returns true iff the node has at least one regular output.
352bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
353
354// Returns true iff the node has at least one control output.
355bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
356
357// Number of connected control inputs.
358int NumControlInputs(const NodeDef& node);
359
360// Number of connected non-control inputs.
361int NumNonControlInputs(const NodeDef& node);
362
363// Number of connected control outputs.
364int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
365
366// Number of connected non-control outputs.
367int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
368
369// Number of connected non-control data outputs (Ops that consume output tensor
370// data, not just it's shape).
371int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
372
373// Removes redundant control inputs from node.
374void DedupControlInputs(NodeDef* node);
375
376// Returns an error if an attribute with the given key does not exist in node.
377Status CheckAttrExists(const NodeDef& node, const string& key);
378
379// Returns an error if attributes with the given keys do not exist in node.
380Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
381
382// Returns the data type in attribute `attr_name` of `node`. If that attribute
383// doesn't exist, returns DT_INVALID.
384DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
385
386// Returns the last node in the simple chain starting at source and traversing
387// through the input(0) edge from each node as long as the next node satisfies
388// the predicate given in pred_fn. If no nodes satisfy the predicate, &source
389// will be returned. Example: For the chain
390// source <- a <- b <- ... <- y <- z
391// where
392// pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
393// pred_fn(z) = false,
394// the return value will be a pointer to y.
395NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
396 bool follow_control_input,
397 const std::function<bool(const NodeDef&)>& pred_fn);
398
399// Permute the nodes of graph in place according to the permutation.
400void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
401 bool invert_permutation);
402
403// Returns OkStatus() if a kernel is registered for node.op() on the device
404// type corresponding to node.device().
405Status IsKernelRegisteredForNode(
406 absl::string_view node_name, bool has_experimental_debug_info,
407 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
408 absl::string_view node_op, absl::string_view node_device,
409 AttrSlice node_attrs);
410Status IsKernelRegisteredForNode(const NodeDef& node);
411
412Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
413
414void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
415
416void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
417
418void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
419 GraphDef* graph);
420
421// Erase all attributes without leading underscore. Returns the number of
422// attributes erased.
423int EraseRegularNodeAttributes(NodeDef* node);
424
425// Erase attribute "_xla_inferred_shapes" as well as all attributes starting in
426// "_output_".
427int EraseNodeOutputAttributes(NodeDef* node);
428
429} // end namespace grappler
430} // end namespace tensorflow
431
432#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_
433