1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_GRAPH_UTILS_H |
17 | #define GLOW_GRAPH_UTILS_H |
18 | |
19 | #include "glow/Graph/Graph.h" |
20 | #include "glow/Graph/Node.h" |
21 | |
22 | #include "llvm/ADT/ArrayRef.h" |
23 | |
24 | #include <unordered_set> |
25 | #include <vector> |
26 | |
27 | namespace glow { |
28 | |
29 | /// A helper class for ordering the nodes in a post-order order. |
30 | struct PostOrderVisitor : NodeWalker { |
31 | /// A post-order list of nodes. |
32 | std::vector<Node *> postOrder_; |
33 | /// A set of visited nodes. |
34 | std::unordered_set<const Node *> visited_; |
35 | |
36 | public: |
37 | bool shouldVisit(Node *parent, Node *N) override { |
38 | // Don't revisit nodes that we've already processed. |
39 | return !visited_.count(N); |
40 | } |
41 | |
42 | bool shouldVisit(const Node *parent, const Node *N) override { |
43 | // Don't revisit nodes that we've already processed. |
44 | return !visited_.count(N); |
45 | } |
46 | |
47 | explicit PostOrderVisitor() = default; |
48 | |
49 | void post(Node *parent, Node *N) override { |
50 | visited_.insert(N); |
51 | postOrder_.push_back(N); |
52 | } |
53 | |
54 | /// \returns the order. |
55 | llvm::ArrayRef<Node *> getPostOrder() { return postOrder_; } |
56 | }; |
57 | |
58 | /// A helper class for ordering Graph nodes in a post-order order. |
59 | class GraphPostOrderVisitor : public PostOrderVisitor { |
60 | Function &G; |
61 | void visit() { |
62 | for (const auto *V : G.getParent()->getConstants()) { |
63 | V->visit(nullptr, this); |
64 | } |
65 | // Start visiting all root nodes, i.e. nodes that do not have any users. |
66 | for (auto &N : G.getNodes()) { |
67 | if (N.getNumUsers() == 0) |
68 | N.visit(nullptr, this); |
69 | } |
70 | } |
71 | |
72 | public: |
73 | explicit GraphPostOrderVisitor(Function &G) : G(G) {} |
74 | /// \returns the order. |
75 | llvm::ArrayRef<Node *> getPostOrder() { |
76 | if (postOrder_.empty()) |
77 | visit(); |
78 | return postOrder_; |
79 | } |
80 | }; |
81 | |
82 | /// A helper class for ordering the nodes in a pre-order order. |
83 | struct PreOrderVisitor : NodeWalker { |
84 | /// A pre-order list of nodes. |
85 | std::vector<Node *> preOrder_; |
86 | /// A set of visited nodes. |
87 | std::unordered_set<const Node *> visited_; |
88 | |
89 | public: |
90 | bool shouldVisit(Node *parent, Node *N) override { |
91 | // Don't revisit nodes that we've already processed. |
92 | return !visited_.count(N); |
93 | } |
94 | |
95 | bool shouldVisit(const Node *parent, const Node *N) override { |
96 | // Don't revisit nodes that we've already processed. |
97 | return !visited_.count(N); |
98 | } |
99 | |
100 | explicit PreOrderVisitor() = default; |
101 | |
102 | void pre(Node *parent, Node *N) override { |
103 | visited_.insert(N); |
104 | preOrder_.push_back(N); |
105 | } |
106 | |
107 | /// \returns the order. |
108 | llvm::ArrayRef<Node *> getPreOrder() { return preOrder_; } |
109 | }; |
110 | |
111 | /// A helper class for ordering Graph nodes in a pre-order order. |
112 | class GraphPreOrderVisitor : public PreOrderVisitor { |
113 | Function &G; |
114 | void visit() { |
115 | for (const auto *V : G.getParent()->getConstants()) { |
116 | V->visit(nullptr, this); |
117 | } |
118 | // Start visiting all root nodes, i.e. nodes that do not have any users. |
119 | for (auto &N : G.getNodes()) { |
120 | if (N.getNumUsers() == 0) |
121 | N.visit(nullptr, this); |
122 | } |
123 | } |
124 | |
125 | public: |
126 | explicit GraphPreOrderVisitor(Function &G) : G(G) {} |
127 | /// \returns the order. |
128 | llvm::ArrayRef<Node *> getPreOrder() { |
129 | if (preOrder_.empty()) |
130 | visit(); |
131 | return preOrder_; |
132 | } |
133 | }; |
134 | |
135 | /// Check if a Constant node \p CN consists of a single repeating value of type |
136 | /// \p ElemTy. If so, return the value in \p val. |
137 | template <typename ElemTy = float> |
138 | static bool isUniformConstant(const Constant &CN, ElemTy &val) { |
139 | if (!CN.getType()->isType<ElemTy>()) { |
140 | return false; |
141 | } |
142 | const auto handle = CN.getPayload().getHandle<ElemTy>(); |
143 | for (size_t i = 1; i < handle.size(); i++) { |
144 | if (handle.raw(i) != handle.raw(0)) { |
145 | return false; |
146 | } |
147 | } |
148 | val = handle.raw(0); |
149 | return true; |
150 | } |
151 | |
152 | #ifdef WIN32 |
153 | static int |
154 | #else |
155 | static int __attribute__((unused)) |
156 | #endif |
157 | getMaxDimOtherThanBatch(const llvm::ArrayRef<dim_t> &dims) { |
158 | return std::distance(dims.begin(), |
159 | std::max_element(dims.begin() + 1, dims.end())); |
160 | } |
161 | |
162 | } // namespace glow |
163 | |
164 | #endif // GLOW_GRAPH_UTILS_H |
165 | |