1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
17#define TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
18
19#include <functional>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/graph/graph.h"
24#include "tensorflow/core/lib/gtl/array_slice.h"
25
26namespace tensorflow {
27
28// Comparator for two nodes. This is used in order to get a stable ording.
29using NodeComparator = std::function<bool(const Node*, const Node*)>;
30
31using EdgeFilter = std::function<bool(const Edge&)>;
32
33// Compares two node based on their ids.
34struct NodeComparatorID {
35 bool operator()(const Node* n1, const Node* n2) const {
36 return n1->id() < n2->id();
37 }
38};
39
40// Compare two nodes based on their names.
41struct NodeComparatorName {
42 bool operator()(const Node* n1, const Node* n2) const {
43 return n1->name() < n2->name();
44 }
45};
46
47// Perform a depth-first-search on g starting at the source node.
48// If enter is not empty, calls enter(n) before visiting any children of n.
49// If leave is not empty, calls leave(n) after visiting all children of n.
50// If stable_comparator is set, a stable ordering of visit is achieved by
51// sorting a node's neighbors first before visiting them.
52// If edge_filter is set then ignores edges for which edge_filter returns false.
53extern void DFS(const Graph& g, const std::function<void(Node*)>& enter,
54 const std::function<void(Node*)>& leave,
55 const NodeComparator& stable_comparator = {},
56 const EdgeFilter& edge_filter = {});
57
58// Perform a depth-first-search on g starting at the 'start' nodes.
59// If enter is not empty, calls enter(n) before visiting any children of n.
60// If leave is not empty, calls leave(n) after visiting all children of n.
61// If stable_comparator is set, a stable ordering of visit is achieved by
62// sorting a node's neighbors first before visiting them.
63// If edge_filter is set then ignores edges for which edge_filter returns false.
64extern void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
65 const std::function<void(Node*)>& enter,
66 const std::function<void(Node*)>& leave,
67 const NodeComparator& stable_comparator = {},
68 const EdgeFilter& edge_filter = {});
69extern void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
70 const std::function<void(const Node*)>& enter,
71 const std::function<void(const Node*)>& leave,
72 const NodeComparator& stable_comparator = {},
73 const EdgeFilter& edge_filter = {});
74
75// Perform a reverse depth-first-search on g starting at the sink node.
76// If enter is not empty, calls enter(n) before visiting any parents of n.
77// If leave is not empty, calls leave(n) after visiting all parents of n.
78// If stable_comparator is set, a stable ordering of visit is achieved by
79// sorting a node's neighbors first before visiting them.
80// If edge_filter is set then ignores edges for which edge_filter returns false.
81extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
82 const std::function<void(Node*)>& leave,
83 const NodeComparator& stable_comparator = {},
84 const EdgeFilter& edge_filter = {});
85
86// Perform a reverse depth-first-search on g starting at the 'start' nodes.
87// If enter is not empty, calls enter(n) before visiting any parents of n.
88// If leave is not empty, calls leave(n) after visiting all parents of n.
89// If stable_comparator is set, a stable ordering of visit is achieved by
90// sorting a node's neighbors first before visiting them.
91// If edge_filter is set then ignores edges for which edge_filter returns false.
92extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
93 const std::function<void(Node*)>& enter,
94 const std::function<void(Node*)>& leave,
95 const NodeComparator& stable_comparator = {},
96 const EdgeFilter& edge_filter = {});
97extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
98 const std::function<void(const Node*)>& enter,
99 const std::function<void(const Node*)>& leave,
100 const NodeComparator& stable_comparator = {},
101 const EdgeFilter& edge_filter = {});
102
103// Stores in *order the post-order numbering of all nodes
104// in graph found via a depth first search starting at the source node.
105//
106// Note that this is equivalent to reverse topological sorting when the
107// graph does not have cycles.
108//
109// If stable_comparator is set, a stable ordering of visit is achieved by
110// sorting a node's neighbors first before visiting them.
111//
112// If edge_filter is set then ignores edges for which edge_filter returns false.
113//
114// REQUIRES: order is not NULL.
115void GetPostOrder(const Graph& g, std::vector<Node*>* order,
116 const NodeComparator& stable_comparator = {},
117 const EdgeFilter& edge_filter = {});
118
119// Stores in *order the reverse post-order numbering of all nodes
120// If stable_comparator is set, a stable ordering of visit is achieved by
121// sorting a node's neighbors first before visiting them.
122//
123// If edge_filter is set then ignores edges for which edge_filter returns false.
124void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
125 const NodeComparator& stable_comparator = {},
126 const EdgeFilter& edge_filter = {});
127
128// Prune nodes in "g" that are not in some path from the source node
129// to any node in 'nodes'. Returns true if changes were made to the graph.
130// Does not fix up source and sink edges.
131bool PruneForReverseReachability(Graph* g,
132 std::unordered_set<const Node*> nodes);
133
134// Connect all nodes with no incoming edges to source.
135// Connect all nodes with no outgoing edges to sink.
136//
137// Returns true if and only if 'g' is mutated.
138bool FixupSourceAndSinkEdges(Graph* g);
139
140} // namespace tensorflow
141
142#endif // TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
143