1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_GRAPH_NODEVALUE_H |
17 | #define GLOW_GRAPH_NODEVALUE_H |
18 | |
19 | #include "glow/Base/Traits.h" |
20 | #include "glow/Base/Type.h" |
21 | #include "llvm/ADT/StringMap.h" |
22 | |
23 | namespace glow { |
24 | |
25 | class Function; |
26 | class Node; |
27 | class NodeWalker; |
28 | struct NodeUse; |
29 | template <bool is_const_iter> class NodeValueIteratorImpl; |
30 | using NodeValueIterator = NodeValueIteratorImpl<false>; |
31 | using NodeValueConstIterator = NodeValueIteratorImpl<true>; |
32 | |
33 | /// Unlike LLVM values, graph nodes may return multiple values as the result of |
34 | /// a computation. Gradient-calculating nodes such as conv-grad return multiple |
35 | /// values. As such, each use of a node computation must indicate the node that |
36 | /// computes it as well as which return value to use from that node. This pair |
37 | /// of information is represented in this class. |
38 | |
39 | /// NodeValue is a simple POD struct that contains a reference to a node and a |
40 | /// result number. |
41 | struct NodeValue { |
42 | protected: |
43 | /// A pointer to the node (owned by the graph). |
44 | Node *node_{nullptr}; |
45 | /// Specifies the node result number to use. |
46 | unsigned resNo_{0}; |
47 | |
48 | public: |
49 | /// Create a new value. |
50 | NodeValue() = default; |
51 | /// Create a new value. |
52 | /*implicit*/ NodeValue(Node *N); |
53 | |
54 | /// Create a new value for result \p resNo. |
55 | NodeValue(Node *N, unsigned resNo); |
56 | |
57 | /// Create a new value from an existing one. |
58 | NodeValue(const NodeValue &that) : node_(that.node_), resNo_{that.resNo_} {} |
59 | |
60 | /// Assignment. |
61 | NodeValue &operator=(const NodeValue &that) { |
62 | node_ = that.node_; |
63 | resNo_ = that.resNo_; |
64 | return *this; |
65 | } |
66 | |
67 | /// Destructor. |
68 | ~NodeValue() {} |
69 | |
70 | /// Get the index which selects a specific result in the SDNode |
71 | unsigned getResNo() const { return resNo_; } |
72 | /// \returns the underlying pointer. |
73 | Node *getNode() const { return node_; } |
74 | |
75 | /// \returns the underlying pointer when casting. |
76 | operator Node *() const { return node_; } |
77 | |
78 | /// Replace all of the uses in \p F of this value with \p v. Types of the node |
79 | /// value and \p v should be exactly the same. |
80 | void replaceAllUsesOfWith(NodeValue v, const Function *F = nullptr, |
81 | Node *skipReplacement = nullptr) const; |
82 | |
83 | /// Replace all of the uses in \p F of this value with \p v. Types of the node |
84 | /// value and \p v can be different. |
85 | void typeUnsafeReplaceAllUsesOfWith(NodeValue v, const Function *F = nullptr, |
86 | Node *skipReplacement = nullptr) const; |
87 | |
88 | /// Return the TypeRef of the referenced return value. |
89 | TypeRef getType() const; |
90 | /// Set the type of the referenced value. |
91 | void setType(TypeRef ty); |
92 | /// Set the type of the referenced value. Does not check that dims() match. |
93 | void setTypeUnsafe(TypeRef ty); |
94 | |
95 | /// Methods that forward to the result type (that must be valid): |
96 | /// @{ |
97 | ElemKind getElementType() const; |
98 | llvm::ArrayRef<dim_t> dims() const; |
99 | float getScale() const; |
100 | int32_t getOffset() const; |
101 | /// @} |
102 | |
103 | bool operator==(const NodeValue &O) const { |
104 | return node_ == O.node_ && resNo_ == O.resNo_; |
105 | } |
106 | |
107 | bool operator<(const NodeValue &O) const { |
108 | if (node_ == O.node_) |
109 | return resNo_ < O.resNo_; |
110 | return (node_ < O.node_); |
111 | } |
112 | |
113 | /// Check if this NodeValue has exactly one use. |
114 | bool hasOneUse() const { return getNumUsers() == 1; } |
115 | /// Get the number of users of this NodeValue. |
116 | unsigned getNumUsers() const; |
117 | |
118 | /// Get the list of users of this NodeValue. |
119 | llvm::iterator_range<NodeValueIterator> getUsers(); |
120 | llvm::iterator_range<NodeValueConstIterator> getUsers() const; |
121 | |
122 | /// Get the full node output name based on the node name and output number. |
123 | /// The following format is used: nodename:outputNumber |
124 | static std::string |
125 | generateNodeOutputName(const std::string &nodeName, unsigned outputNumber = 0, |
126 | bool stripResNoFor0thInput = false) { |
127 | return nodeName + ((stripResNoFor0thInput && outputNumber == 0) |
128 | ? "" |
129 | : ":" + std::to_string(outputNumber)); |
130 | } |
131 | |
132 | /// \returns a unique name for this NodeValue, where the name of the node is |
133 | /// appended with a colon followed by \ref resNo_. |
134 | /// If \p stripResNoFor0thInput then the result number for the 0th input will |
135 | /// not be appended (i.e. no ":0" will be appended). |
136 | std::string generateNodeOutputName(bool stripResNoFor0thInput = false) const; |
137 | }; |
138 | |
139 | /// Struct containing the output name string and node kind for use in the |
140 | /// LoweredInfoMap for keeping track of lowered node info. |
141 | struct NodeNameAndKind : public Named, public Kinded { |
142 | public: |
143 | NodeNameAndKind(llvm::StringRef name, size_t resNo, Kinded::Kind k) |
144 | : Named(NodeValue::generateNodeOutputName(name.str(), resNo)), Kinded(k) { |
145 | } |
146 | }; |
147 | |
148 | /// Overload < operator for NodeNameAndKind to allow for usage with std::set. |
149 | inline bool operator<(const NodeNameAndKind &x, const NodeNameAndKind &y) { |
150 | return x.getName() < y.getName(); |
151 | } |
152 | |
153 | /// Overload == operator for NodeNameAndKind to allow for usage with std::set. |
154 | inline bool operator==(const NodeNameAndKind &x, const NodeNameAndKind &y) { |
155 | return x.getName() == y.getName(); |
156 | } |
157 | |
158 | /// Used to keep track of the origin of lowered Nodes via output names as |
159 | /// determined by NodeValue::generateNodeOutputName(). For example if some |
160 | /// NodeValue X is lowered from some NodeValue Y, then the output name of X is a |
161 | /// key which maps to a set of names which contains the output name of Y. |
162 | using LoweredInfoMap = llvm::StringMap<std::set<NodeNameAndKind>>; |
163 | |
164 | /// A handle type for a NodeValue. This type should be used only by the |
165 | /// class members of Node classes when they need to refer to other nodes! |
166 | /// |
167 | /// This class also manages the node use-def chain, by registering and |
168 | /// removing the address of the value from the use-list. This data structure |
169 | /// is similar to LLVM's SDValue. Only these NodeHandle instances are |
170 | /// registered as users of the nodes they refer to. The is different from the |
171 | /// usual NodeValue instances, which are not registered as users of the nodes |
172 | /// they refer to. |
173 | /// |
174 | /// Instances of NodeHandle should always stay inside the Nodes they are |
175 | /// members of and should never leave it. E.g. they cannot be returned as |
176 | /// results of function calls, etc. |
177 | struct NodeHandle : NodeValue { |
178 | private: |
179 | friend NodeUse; |
180 | /// Parent object which contains this handle. |
181 | Node *parent_{nullptr}; |
182 | |
183 | public: |
184 | /// Create a new value and register the node we reference |
185 | /*implicit*/ NodeHandle(Node *parent, Node *N); |
186 | |
187 | /// Create a new value for result \p resNo and register the node we |
188 | /// reference. |
189 | NodeHandle(Node *parent, Node *N, unsigned resNo); |
190 | |
191 | /// Create a new operand and register it as a new user to the node. |
192 | NodeHandle(Node *parent, const NodeValue &that) |
193 | : NodeValue(nullptr), parent_(parent) { |
194 | setOperand(that.getNode(), that.getResNo()); |
195 | } |
196 | |
197 | /// Create a new NodeHandle from an existing one and register it. |
198 | NodeHandle(Node *parent, const NodeHandle &that) |
199 | : NodeValue(nullptr), parent_(parent) { |
200 | setOperand(that.getNode(), that.getResNo()); |
201 | } |
202 | |
203 | NodeHandle(const NodeHandle &that) : NodeHandle(that.parent_, that) {} |
204 | |
205 | /// Create an empty handle. |
206 | NodeHandle() : NodeValue(nullptr), parent_(nullptr) {} |
207 | |
208 | /// When deleting an operand we need to unregister the operand from the |
209 | /// use-list of the node it used to reference. |
210 | ~NodeHandle() { setOperand(nullptr, 0); } |
211 | |
212 | /// Unregister old value, assign new NodeValue and register it. |
213 | NodeHandle &operator=(const NodeHandle &that) { |
214 | setOperand(that.getNode(), that.getResNo()); |
215 | return *this; |
216 | } |
217 | |
218 | /// Unregister old value, assign new NodeValue and register it. |
219 | NodeHandle &operator=(const NodeValue &that) { |
220 | setOperand(that.getNode(), that.getResNo()); |
221 | return *this; |
222 | } |
223 | /// Sets the operand to point to \p N. This method registers the operand as |
224 | /// a user of \p N. |
225 | void setOperand(Node *v, unsigned resNo); |
226 | |
227 | /// Set the parent object. |
228 | void setParent(Node *parent) { |
229 | assert(!parent_ && "Offset was set already" ); |
230 | parent_ = parent; |
231 | } |
232 | }; |
233 | |
234 | /// A wrapper class to expose a vector of NodeHandles inside an |
235 | /// object as a vector of NodeValues. This is done to avoid leaking of |
236 | /// NodeHandles from Nodes into the user-code. This type can be used as a |
237 | /// return type of e.g. getInputs() and similar functions. |
238 | class NodeValueArrayRef { |
239 | llvm::ArrayRef<NodeHandle> ref_; |
240 | |
241 | public: |
242 | using const_iterator = llvm::ArrayRef<NodeHandle>::const_iterator; |
243 | |
244 | NodeValueArrayRef(llvm::ArrayRef<NodeHandle> ref) : ref_(ref) {} |
245 | NodeValueArrayRef(const std::vector<NodeHandle> &ref) : ref_(ref) {} |
246 | const NodeValue &operator[](std::size_t idx) const { return ref_[idx]; } |
247 | operator std::vector<NodeValue>() { |
248 | return std::vector<NodeValue>(ref_.begin(), ref_.end()); |
249 | } |
250 | size_t size() const { return ref_.size(); } |
251 | bool empty() const { return ref_.empty(); } |
252 | const_iterator begin() { return ref_.begin(); } |
253 | const_iterator end() { return ref_.end(); } |
254 | NodeValue front() { return *begin(); } |
255 | }; |
256 | |
257 | /// A 'Use' is a use-list representation of a Node operand. |
258 | struct NodeUse { |
259 | /// The operand site. This is the address of the operand that points to our |
260 | /// node. |
261 | NodeHandle *site_; |
262 | |
263 | explicit NodeUse(NodeHandle *site) : site_(site) {} |
264 | |
265 | bool operator==(const NodeUse &other) const { return site_ == other.site_; } |
266 | |
267 | /// \returns the instruction that the use refers to. |
268 | NodeHandle *get() const { return site_; } |
269 | /// Get the node containing this use. |
270 | const Node *getUser() const { return site_->parent_; } |
271 | Node *getUser() { return site_->parent_; } |
272 | /// Sets the operand to a new value. |
273 | void setOperand(NodeHandle &site); |
274 | }; |
275 | |
276 | } // namespace glow |
277 | |
278 | #endif // GLOW_GRAPH_NODEVALUE_H |
279 | |