1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ |
17 | #define TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ |
18 | |
19 | #include <functional> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/graph/graph.h" |
24 | #include "tensorflow/core/lib/gtl/array_slice.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | // Comparator for two nodes. This is used in order to get a stable ording. |
29 | using NodeComparator = std::function<bool(const Node*, const Node*)>; |
30 | |
31 | using EdgeFilter = std::function<bool(const Edge&)>; |
32 | |
33 | // Compares two node based on their ids. |
34 | struct NodeComparatorID { |
35 | bool operator()(const Node* n1, const Node* n2) const { |
36 | return n1->id() < n2->id(); |
37 | } |
38 | }; |
39 | |
40 | // Compare two nodes based on their names. |
41 | struct NodeComparatorName { |
42 | bool operator()(const Node* n1, const Node* n2) const { |
43 | return n1->name() < n2->name(); |
44 | } |
45 | }; |
46 | |
47 | // Perform a depth-first-search on g starting at the source node. |
48 | // If enter is not empty, calls enter(n) before visiting any children of n. |
49 | // If leave is not empty, calls leave(n) after visiting all children of n. |
50 | // If stable_comparator is set, a stable ordering of visit is achieved by |
51 | // sorting a node's neighbors first before visiting them. |
52 | // If edge_filter is set then ignores edges for which edge_filter returns false. |
53 | extern void DFS(const Graph& g, const std::function<void(Node*)>& enter, |
54 | const std::function<void(Node*)>& leave, |
55 | const NodeComparator& stable_comparator = {}, |
56 | const EdgeFilter& edge_filter = {}); |
57 | |
58 | // Perform a depth-first-search on g starting at the 'start' nodes. |
59 | // If enter is not empty, calls enter(n) before visiting any children of n. |
60 | // If leave is not empty, calls leave(n) after visiting all children of n. |
61 | // If stable_comparator is set, a stable ordering of visit is achieved by |
62 | // sorting a node's neighbors first before visiting them. |
63 | // If edge_filter is set then ignores edges for which edge_filter returns false. |
64 | extern void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, |
65 | const std::function<void(Node*)>& enter, |
66 | const std::function<void(Node*)>& leave, |
67 | const NodeComparator& stable_comparator = {}, |
68 | const EdgeFilter& edge_filter = {}); |
69 | extern void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, |
70 | const std::function<void(const Node*)>& enter, |
71 | const std::function<void(const Node*)>& leave, |
72 | const NodeComparator& stable_comparator = {}, |
73 | const EdgeFilter& edge_filter = {}); |
74 | |
75 | // Perform a reverse depth-first-search on g starting at the sink node. |
76 | // If enter is not empty, calls enter(n) before visiting any parents of n. |
77 | // If leave is not empty, calls leave(n) after visiting all parents of n. |
78 | // If stable_comparator is set, a stable ordering of visit is achieved by |
79 | // sorting a node's neighbors first before visiting them. |
80 | // If edge_filter is set then ignores edges for which edge_filter returns false. |
81 | extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter, |
82 | const std::function<void(Node*)>& leave, |
83 | const NodeComparator& stable_comparator = {}, |
84 | const EdgeFilter& edge_filter = {}); |
85 | |
86 | // Perform a reverse depth-first-search on g starting at the 'start' nodes. |
87 | // If enter is not empty, calls enter(n) before visiting any parents of n. |
88 | // If leave is not empty, calls leave(n) after visiting all parents of n. |
89 | // If stable_comparator is set, a stable ordering of visit is achieved by |
90 | // sorting a node's neighbors first before visiting them. |
91 | // If edge_filter is set then ignores edges for which edge_filter returns false. |
92 | extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, |
93 | const std::function<void(Node*)>& enter, |
94 | const std::function<void(Node*)>& leave, |
95 | const NodeComparator& stable_comparator = {}, |
96 | const EdgeFilter& edge_filter = {}); |
97 | extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, |
98 | const std::function<void(const Node*)>& enter, |
99 | const std::function<void(const Node*)>& leave, |
100 | const NodeComparator& stable_comparator = {}, |
101 | const EdgeFilter& edge_filter = {}); |
102 | |
103 | // Stores in *order the post-order numbering of all nodes |
104 | // in graph found via a depth first search starting at the source node. |
105 | // |
106 | // Note that this is equivalent to reverse topological sorting when the |
107 | // graph does not have cycles. |
108 | // |
109 | // If stable_comparator is set, a stable ordering of visit is achieved by |
110 | // sorting a node's neighbors first before visiting them. |
111 | // |
112 | // If edge_filter is set then ignores edges for which edge_filter returns false. |
113 | // |
114 | // REQUIRES: order is not NULL. |
115 | void GetPostOrder(const Graph& g, std::vector<Node*>* order, |
116 | const NodeComparator& stable_comparator = {}, |
117 | const EdgeFilter& edge_filter = {}); |
118 | |
119 | // Stores in *order the reverse post-order numbering of all nodes |
120 | // If stable_comparator is set, a stable ordering of visit is achieved by |
121 | // sorting a node's neighbors first before visiting them. |
122 | // |
123 | // If edge_filter is set then ignores edges for which edge_filter returns false. |
124 | void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, |
125 | const NodeComparator& stable_comparator = {}, |
126 | const EdgeFilter& edge_filter = {}); |
127 | |
128 | // Prune nodes in "g" that are not in some path from the source node |
129 | // to any node in 'nodes'. Returns true if changes were made to the graph. |
130 | // Does not fix up source and sink edges. |
131 | bool PruneForReverseReachability(Graph* g, |
132 | std::unordered_set<const Node*> nodes); |
133 | |
134 | // Connect all nodes with no incoming edges to source. |
135 | // Connect all nodes with no outgoing edges to sink. |
136 | // |
137 | // Returns true if and only if 'g' is mutated. |
138 | bool FixupSourceAndSinkEdges(Graph* g); |
139 | |
140 | } // namespace tensorflow |
141 | |
142 | #endif // TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ |
143 | |