1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_GRAPH_UTILS_H
17#define GLOW_GRAPH_UTILS_H
18
19#include "glow/Graph/Graph.h"
20#include "glow/Graph/Node.h"
21
22#include "llvm/ADT/ArrayRef.h"
23
24#include <unordered_set>
25#include <vector>
26
27namespace glow {
28
29/// A helper class for ordering the nodes in a post-order order.
30struct PostOrderVisitor : NodeWalker {
31 /// A post-order list of nodes.
32 std::vector<Node *> postOrder_;
33 /// A set of visited nodes.
34 std::unordered_set<const Node *> visited_;
35
36public:
37 bool shouldVisit(Node *parent, Node *N) override {
38 // Don't revisit nodes that we've already processed.
39 return !visited_.count(N);
40 }
41
42 bool shouldVisit(const Node *parent, const Node *N) override {
43 // Don't revisit nodes that we've already processed.
44 return !visited_.count(N);
45 }
46
47 explicit PostOrderVisitor() = default;
48
49 void post(Node *parent, Node *N) override {
50 visited_.insert(N);
51 postOrder_.push_back(N);
52 }
53
54 /// \returns the order.
55 llvm::ArrayRef<Node *> getPostOrder() { return postOrder_; }
56};
57
58/// A helper class for ordering Graph nodes in a post-order order.
59class GraphPostOrderVisitor : public PostOrderVisitor {
60 Function &G;
61 void visit() {
62 for (const auto *V : G.getParent()->getConstants()) {
63 V->visit(nullptr, this);
64 }
65 // Start visiting all root nodes, i.e. nodes that do not have any users.
66 for (auto &N : G.getNodes()) {
67 if (N.getNumUsers() == 0)
68 N.visit(nullptr, this);
69 }
70 }
71
72public:
73 explicit GraphPostOrderVisitor(Function &G) : G(G) {}
74 /// \returns the order.
75 llvm::ArrayRef<Node *> getPostOrder() {
76 if (postOrder_.empty())
77 visit();
78 return postOrder_;
79 }
80};
81
82/// A helper class for ordering the nodes in a pre-order order.
83struct PreOrderVisitor : NodeWalker {
84 /// A pre-order list of nodes.
85 std::vector<Node *> preOrder_;
86 /// A set of visited nodes.
87 std::unordered_set<const Node *> visited_;
88
89public:
90 bool shouldVisit(Node *parent, Node *N) override {
91 // Don't revisit nodes that we've already processed.
92 return !visited_.count(N);
93 }
94
95 bool shouldVisit(const Node *parent, const Node *N) override {
96 // Don't revisit nodes that we've already processed.
97 return !visited_.count(N);
98 }
99
100 explicit PreOrderVisitor() = default;
101
102 void pre(Node *parent, Node *N) override {
103 visited_.insert(N);
104 preOrder_.push_back(N);
105 }
106
107 /// \returns the order.
108 llvm::ArrayRef<Node *> getPreOrder() { return preOrder_; }
109};
110
111/// A helper class for ordering Graph nodes in a pre-order order.
112class GraphPreOrderVisitor : public PreOrderVisitor {
113 Function &G;
114 void visit() {
115 for (const auto *V : G.getParent()->getConstants()) {
116 V->visit(nullptr, this);
117 }
118 // Start visiting all root nodes, i.e. nodes that do not have any users.
119 for (auto &N : G.getNodes()) {
120 if (N.getNumUsers() == 0)
121 N.visit(nullptr, this);
122 }
123 }
124
125public:
126 explicit GraphPreOrderVisitor(Function &G) : G(G) {}
127 /// \returns the order.
128 llvm::ArrayRef<Node *> getPreOrder() {
129 if (preOrder_.empty())
130 visit();
131 return preOrder_;
132 }
133};
134
135/// Check if a Constant node \p CN consists of a single repeating value of type
136/// \p ElemTy. If so, return the value in \p val.
137template <typename ElemTy = float>
138static bool isUniformConstant(const Constant &CN, ElemTy &val) {
139 if (!CN.getType()->isType<ElemTy>()) {
140 return false;
141 }
142 const auto handle = CN.getPayload().getHandle<ElemTy>();
143 for (size_t i = 1; i < handle.size(); i++) {
144 if (handle.raw(i) != handle.raw(0)) {
145 return false;
146 }
147 }
148 val = handle.raw(0);
149 return true;
150}
151
152#ifdef WIN32
153static int
154#else
155static int __attribute__((unused))
156#endif
157getMaxDimOtherThanBatch(const llvm::ArrayRef<dim_t> &dims) {
158 return std::distance(dims.begin(),
159 std::max_element(dims.begin() + 1, dims.end()));
160}
161
162} // namespace glow
163
164#endif // GLOW_GRAPH_UTILS_H
165