1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_GRAPH_NODEVALUE_H
17#define GLOW_GRAPH_NODEVALUE_H
18
19#include "glow/Base/Traits.h"
20#include "glow/Base/Type.h"
21#include "llvm/ADT/StringMap.h"
22
23namespace glow {
24
25class Function;
26class Node;
27class NodeWalker;
28struct NodeUse;
29template <bool is_const_iter> class NodeValueIteratorImpl;
30using NodeValueIterator = NodeValueIteratorImpl<false>;
31using NodeValueConstIterator = NodeValueIteratorImpl<true>;
32
33/// Unlike LLVM values, graph nodes may return multiple values as the result of
34/// a computation. Gradient-calculating nodes such as conv-grad return multiple
35/// values. As such, each use of a node computation must indicate the node that
36/// computes it as well as which return value to use from that node. This pair
37/// of information is represented in this class.
38
39/// NodeValue is a simple POD struct that contains a reference to a node and a
40/// result number.
41struct NodeValue {
42protected:
43 /// A pointer to the node (owned by the graph).
44 Node *node_{nullptr};
45 /// Specifies the node result number to use.
46 unsigned resNo_{0};
47
48public:
49 /// Create a new value.
50 NodeValue() = default;
51 /// Create a new value.
52 /*implicit*/ NodeValue(Node *N);
53
54 /// Create a new value for result \p resNo.
55 NodeValue(Node *N, unsigned resNo);
56
57 /// Create a new value from an existing one.
58 NodeValue(const NodeValue &that) : node_(that.node_), resNo_{that.resNo_} {}
59
60 /// Assignment.
61 NodeValue &operator=(const NodeValue &that) {
62 node_ = that.node_;
63 resNo_ = that.resNo_;
64 return *this;
65 }
66
67 /// Destructor.
68 ~NodeValue() {}
69
70 /// Get the index which selects a specific result in the SDNode
71 unsigned getResNo() const { return resNo_; }
72 /// \returns the underlying pointer.
73 Node *getNode() const { return node_; }
74
75 /// \returns the underlying pointer when casting.
76 operator Node *() const { return node_; }
77
78 /// Replace all of the uses in \p F of this value with \p v. Types of the node
79 /// value and \p v should be exactly the same.
80 void replaceAllUsesOfWith(NodeValue v, const Function *F = nullptr,
81 Node *skipReplacement = nullptr) const;
82
83 /// Replace all of the uses in \p F of this value with \p v. Types of the node
84 /// value and \p v can be different.
85 void typeUnsafeReplaceAllUsesOfWith(NodeValue v, const Function *F = nullptr,
86 Node *skipReplacement = nullptr) const;
87
88 /// Return the TypeRef of the referenced return value.
89 TypeRef getType() const;
90 /// Set the type of the referenced value.
91 void setType(TypeRef ty);
92 /// Set the type of the referenced value. Does not check that dims() match.
93 void setTypeUnsafe(TypeRef ty);
94
95 /// Methods that forward to the result type (that must be valid):
96 /// @{
97 ElemKind getElementType() const;
98 llvm::ArrayRef<dim_t> dims() const;
99 float getScale() const;
100 int32_t getOffset() const;
101 /// @}
102
103 bool operator==(const NodeValue &O) const {
104 return node_ == O.node_ && resNo_ == O.resNo_;
105 }
106
107 bool operator<(const NodeValue &O) const {
108 if (node_ == O.node_)
109 return resNo_ < O.resNo_;
110 return (node_ < O.node_);
111 }
112
113 /// Check if this NodeValue has exactly one use.
114 bool hasOneUse() const { return getNumUsers() == 1; }
115 /// Get the number of users of this NodeValue.
116 unsigned getNumUsers() const;
117
118 /// Get the list of users of this NodeValue.
119 llvm::iterator_range<NodeValueIterator> getUsers();
120 llvm::iterator_range<NodeValueConstIterator> getUsers() const;
121
122 /// Get the full node output name based on the node name and output number.
123 /// The following format is used: nodename:outputNumber
124 static std::string
125 generateNodeOutputName(const std::string &nodeName, unsigned outputNumber = 0,
126 bool stripResNoFor0thInput = false) {
127 return nodeName + ((stripResNoFor0thInput && outputNumber == 0)
128 ? ""
129 : ":" + std::to_string(outputNumber));
130 }
131
132 /// \returns a unique name for this NodeValue, where the name of the node is
133 /// appended with a colon followed by \ref resNo_.
134 /// If \p stripResNoFor0thInput then the result number for the 0th input will
135 /// not be appended (i.e. no ":0" will be appended).
136 std::string generateNodeOutputName(bool stripResNoFor0thInput = false) const;
137};
138
139/// Struct containing the output name string and node kind for use in the
140/// LoweredInfoMap for keeping track of lowered node info.
141struct NodeNameAndKind : public Named, public Kinded {
142public:
143 NodeNameAndKind(llvm::StringRef name, size_t resNo, Kinded::Kind k)
144 : Named(NodeValue::generateNodeOutputName(name.str(), resNo)), Kinded(k) {
145 }
146};
147
148/// Overload < operator for NodeNameAndKind to allow for usage with std::set.
149inline bool operator<(const NodeNameAndKind &x, const NodeNameAndKind &y) {
150 return x.getName() < y.getName();
151}
152
153/// Overload == operator for NodeNameAndKind to allow for usage with std::set.
154inline bool operator==(const NodeNameAndKind &x, const NodeNameAndKind &y) {
155 return x.getName() == y.getName();
156}
157
158/// Used to keep track of the origin of lowered Nodes via output names as
159/// determined by NodeValue::generateNodeOutputName(). For example if some
160/// NodeValue X is lowered from some NodeValue Y, then the output name of X is a
161/// key which maps to a set of names which contains the output name of Y.
162using LoweredInfoMap = llvm::StringMap<std::set<NodeNameAndKind>>;
163
164/// A handle type for a NodeValue. This type should be used only by the
165/// class members of Node classes when they need to refer to other nodes!
166///
167/// This class also manages the node use-def chain, by registering and
168/// removing the address of the value from the use-list. This data structure
169/// is similar to LLVM's SDValue. Only these NodeHandle instances are
170/// registered as users of the nodes they refer to. The is different from the
171/// usual NodeValue instances, which are not registered as users of the nodes
172/// they refer to.
173///
174/// Instances of NodeHandle should always stay inside the Nodes they are
175/// members of and should never leave it. E.g. they cannot be returned as
176/// results of function calls, etc.
177struct NodeHandle : NodeValue {
178private:
179 friend NodeUse;
180 /// Parent object which contains this handle.
181 Node *parent_{nullptr};
182
183public:
184 /// Create a new value and register the node we reference
185 /*implicit*/ NodeHandle(Node *parent, Node *N);
186
187 /// Create a new value for result \p resNo and register the node we
188 /// reference.
189 NodeHandle(Node *parent, Node *N, unsigned resNo);
190
191 /// Create a new operand and register it as a new user to the node.
192 NodeHandle(Node *parent, const NodeValue &that)
193 : NodeValue(nullptr), parent_(parent) {
194 setOperand(that.getNode(), that.getResNo());
195 }
196
197 /// Create a new NodeHandle from an existing one and register it.
198 NodeHandle(Node *parent, const NodeHandle &that)
199 : NodeValue(nullptr), parent_(parent) {
200 setOperand(that.getNode(), that.getResNo());
201 }
202
203 NodeHandle(const NodeHandle &that) : NodeHandle(that.parent_, that) {}
204
205 /// Create an empty handle.
206 NodeHandle() : NodeValue(nullptr), parent_(nullptr) {}
207
208 /// When deleting an operand we need to unregister the operand from the
209 /// use-list of the node it used to reference.
210 ~NodeHandle() { setOperand(nullptr, 0); }
211
212 /// Unregister old value, assign new NodeValue and register it.
213 NodeHandle &operator=(const NodeHandle &that) {
214 setOperand(that.getNode(), that.getResNo());
215 return *this;
216 }
217
218 /// Unregister old value, assign new NodeValue and register it.
219 NodeHandle &operator=(const NodeValue &that) {
220 setOperand(that.getNode(), that.getResNo());
221 return *this;
222 }
223 /// Sets the operand to point to \p N. This method registers the operand as
224 /// a user of \p N.
225 void setOperand(Node *v, unsigned resNo);
226
227 /// Set the parent object.
228 void setParent(Node *parent) {
229 assert(!parent_ && "Offset was set already");
230 parent_ = parent;
231 }
232};
233
234/// A wrapper class to expose a vector of NodeHandles inside an
235/// object as a vector of NodeValues. This is done to avoid leaking of
236/// NodeHandles from Nodes into the user-code. This type can be used as a
237/// return type of e.g. getInputs() and similar functions.
238class NodeValueArrayRef {
239 llvm::ArrayRef<NodeHandle> ref_;
240
241public:
242 using const_iterator = llvm::ArrayRef<NodeHandle>::const_iterator;
243
244 NodeValueArrayRef(llvm::ArrayRef<NodeHandle> ref) : ref_(ref) {}
245 NodeValueArrayRef(const std::vector<NodeHandle> &ref) : ref_(ref) {}
246 const NodeValue &operator[](std::size_t idx) const { return ref_[idx]; }
247 operator std::vector<NodeValue>() {
248 return std::vector<NodeValue>(ref_.begin(), ref_.end());
249 }
250 size_t size() const { return ref_.size(); }
251 bool empty() const { return ref_.empty(); }
252 const_iterator begin() { return ref_.begin(); }
253 const_iterator end() { return ref_.end(); }
254 NodeValue front() { return *begin(); }
255};
256
257/// A 'Use' is a use-list representation of a Node operand.
258struct NodeUse {
259 /// The operand site. This is the address of the operand that points to our
260 /// node.
261 NodeHandle *site_;
262
263 explicit NodeUse(NodeHandle *site) : site_(site) {}
264
265 bool operator==(const NodeUse &other) const { return site_ == other.site_; }
266
267 /// \returns the instruction that the use refers to.
268 NodeHandle *get() const { return site_; }
269 /// Get the node containing this use.
270 const Node *getUser() const { return site_->parent_; }
271 Node *getUser() { return site_->parent_; }
272 /// Sets the operand to a new value.
273 void setOperand(NodeHandle &site);
274};
275
276} // namespace glow
277
278#endif // GLOW_GRAPH_NODEVALUE_H
279