1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_GRAPH_NODE_H |
17 | #define GLOW_GRAPH_NODE_H |
18 | |
19 | #include "llvm/ADT/StringRef.h" |
20 | #include "llvm/ADT/ilist.h" |
21 | #include "llvm/ADT/ilist_node.h" |
22 | #include "llvm/ADT/iterator_range.h" |
23 | #include "llvm/Support/Casting.h" |
24 | |
25 | #include "glow/Base/Traits.h" |
26 | #include "glow/Base/Type.h" |
27 | #include "glow/Graph/NodeValue.h" |
28 | #include "glow/Graph/UseDef.h" |
29 | #include "glow/Support/Support.h" |
30 | |
31 | #include <list> |
32 | #include <unordered_set> |
33 | |
34 | namespace glow { |
35 | |
36 | class Function; |
37 | class Node; |
38 | class NodeWalker; |
39 | struct NodeUse; |
40 | template <bool is_const_iter> class NodeValueIteratorImpl; |
41 | using NodeValueIterator = NodeValueIteratorImpl<false>; |
42 | using NodeValueConstIterator = NodeValueIteratorImpl<true>; |
43 | |
44 | /// Represents a node in the compute graph. |
45 | class Node : public Named, |
46 | public Kinded, |
47 | public UseDef<Node, NodeUse>, |
48 | public llvm::ilist_node<Node> { |
49 | friend llvm::ilist_traits<Node>; |
50 | |
51 | protected: |
52 | /// The output types for the results of the node. |
53 | llvm::SmallVector<TypeRef, 6> types_; |
54 | /// A nullable reference to some tensor value that may predicate the execution |
55 | /// of the current node. |
56 | NodeHandle predicate_; |
57 | |
58 | /// Link to the function holding this node. |
59 | Function *parent_{nullptr}; |
60 | |
61 | public: |
62 | Node(Kinded::Kind k, llvm::StringRef name) |
63 | : Named(name), Kinded(k), predicate_(this, nullptr), parent_(nullptr) {} |
64 | |
65 | /// \returns the nullable predicate of the current node. |
66 | const NodeValue getPredicate() const; |
67 | /// Assigns a nullable predicate to the current node. |
68 | void setPredicate(const NodeValue &P); |
69 | /// Checks if a predicate is assigned to the current node. |
70 | bool hasPredicate() const; |
71 | |
72 | /// \returns the number of results that the node has. |
73 | unsigned getNumResults() const { return types_.size(); } |
74 | /// \returns the \p idx result of the node. |
75 | NodeValue getNthResult(unsigned idx); |
76 | /// \returns the n'th result of the node. |
77 | const NodeValue getNthResult(unsigned idx) const; |
78 | |
79 | /// \returns the function holding this node. |
80 | /// If that node does not belong to any function, this |
81 | /// is nullptr. |
82 | const Function *getParent() const { return parent_; } |
83 | Function *getParent() { return parent_; } |
84 | /// Set the link to the function that holds this node. |
85 | void setParent(Function *parent) { parent_ = parent; } |
86 | |
87 | /// Getters/setters to access Node's inputs and outputs. |
88 | unsigned getNumInputs() const; |
89 | std::string getInputName(unsigned idx) const; |
90 | NodeValue getNthInput(unsigned idx); |
91 | const NodeValue getNthInput(unsigned idx) const; |
92 | void setNthInput(unsigned idx, NodeValue val); |
93 | llvm::StringRef getOutputName(unsigned idx) const; |
94 | bool hasSideEffects() const; |
95 | bool isArithmetic() const; |
96 | bool isCanonical() const; |
97 | bool isDataParallel() const; |
98 | |
99 | /// \returns true if this input is being overwritten by the node. |
100 | bool isOverwrittenNthInput(unsigned idx) const; |
101 | |
102 | /// \returns a textual description of the node. |
103 | std::string getDebugDesc() const; |
104 | |
105 | /// Dump a textual representation of the Node into provided output stream. |
106 | void dump(llvm::raw_ostream &out) const; |
107 | |
108 | /// Dump a textual representation of the Node into default output stream. |
109 | void dump() const; |
110 | |
111 | /// Dump a textual representation of the Node to std::string. |
112 | std::string toString() const; |
113 | |
114 | /// \returns the total memory size (in bytes) of the node as the sum of sizes |
115 | /// for all the inputs and outputs. |
116 | size_t getTotMemSize() const; |
117 | |
118 | /// \returns copy of the current node. Notice that the new node is not |
119 | /// inserted into any DAG. The caller of this method should add it to some |
120 | /// node-list. |
121 | Node *clone() const; |
122 | |
123 | /// \returns true if the node is equal to the other node. |
124 | bool isEqual(const Node &other) const; |
125 | |
126 | /// \returns true if the node is equal to the other node. |
127 | bool operator==(const Node &O) const { return isEqual(O); } |
128 | |
129 | /// \returns a hash code of the node. |
130 | llvm::hash_code getHash() const; |
131 | |
132 | /// This method implements the visitor pattern that scans the compute DAG top |
133 | /// to bottom. The visitor \p visitor is sent by the parent node \p parent, |
134 | /// or nullptr if this is the first node to be visited. |
135 | void visit(Node *parent, NodeWalker *visitor); |
136 | |
137 | void visit(const Node *parent, NodeWalker *visitor) const; |
138 | |
139 | /// Verify node. |
140 | /// \returns True if the node is valid. False otherwise. |
141 | bool verify() const; |
142 | |
143 | /// Replace all uses of this node with null. This method is used by the |
144 | /// destruction sequence. When the node is deleted we need to unregister all |
145 | /// users. This allows us to deconstruct the graph in an arbitrary order. |
146 | void releaseUsers() { |
147 | NodeValue nop(nullptr); |
148 | for (unsigned i = 0; i < getNumResults(); i++) { |
149 | NodeValue(this, i).replaceAllUsesOfWith(nop); |
150 | } |
151 | } |
152 | |
153 | ~Node() { releaseUsers(); } |
154 | |
155 | /// Destroys a node and deallocates the memory. This method is implicitly |
156 | /// invoked by the parent function when a node is being removed from the |
157 | /// intrusive list of nodes. You can also invoke this method explicitly to |
158 | /// destroy a node which has no parent function (orphan node). |
159 | static void destroyNode(Node *N); |
160 | |
161 | /// \returns the n'th result type of the node. |
162 | TypeRef getType(unsigned idx) const; |
163 | /// Set the \p idx'th result type of the node. |
164 | /// \note This setter only changes the type of this one |
165 | /// result. If that type is incompatible with |
166 | /// the inputs of the node, the caller is |
167 | /// responsible to update these if need be. |
168 | void setType(unsigned idx, TypeRef ty); |
169 | |
170 | /// Set the \p idx'th result type of the node, without checking if the dims of |
171 | /// the old type match the dims of the new one. |
172 | /// \note This setter only changes the type of this one |
173 | /// result. If that type is incompatible with |
174 | /// the inputs of the node, the caller is |
175 | /// responsible to update these if need be. |
176 | /// This function does not check for validity |
177 | /// of input dims and whether the result exists. |
178 | void setTypeUnsafe(unsigned idx, TypeRef ty); |
179 | |
180 | /// Methods that forward to the result type (that must be valid): |
181 | /// @{ |
182 | ElemKind getElementType(unsigned resNo) const; |
183 | llvm::ArrayRef<dim_t> dims(unsigned resNo) const; |
184 | /// @} |
185 | |
186 | protected: |
187 | /// When constructing the node, add a new result of type \p T. |
188 | void addResult(TypeRef T); |
189 | }; |
190 | |
191 | /// A walker that recursively visits a node and its children. |
192 | class NodeWalker { |
193 | public: |
194 | /// This callback is called before visiting the children of \p N. |
195 | virtual void pre(Node *parent, Node *N) {} |
196 | virtual void pre(const Node *parent, const Node *N) {} |
197 | |
198 | /// This callback is called after visiting the children of \p N. |
199 | virtual void post(Node *parent, Node *N) {} |
200 | virtual void post(const Node *parent, const Node *N) {} |
201 | |
202 | /// This callback is called before processing the graph. If the method returns |
203 | /// false then we skip this node. |
204 | virtual bool shouldVisit(Node *parent, Node *N) { return true; } |
205 | virtual bool shouldVisit(const Node *parent, const Node *N) { return true; } |
206 | |
207 | /// Dtor. |
208 | virtual ~NodeWalker() = default; |
209 | }; |
210 | |
211 | using IndicesSet = std::unordered_set<unsigned>; |
212 | |
213 | /// Helper class to hold info of a Node, containing its \p opKind, \p inTypes, |
214 | /// and \p outTypes |
215 | class NodeInfo : public Kinded { |
216 | private: |
217 | /// The input types of the NodeInfo. |
218 | std::vector<TypeRef> inTypes_; |
219 | /// The output types of the NodeInfo. |
220 | std::vector<TypeRef> outTypes_; |
221 | /// The name of the node. |
222 | llvm::StringRef name_; |
223 | |
224 | /// Helper function for checking if all of the ElemKinds contained in \p types |
225 | /// are equal to \p allowedElemKind. Indices in \p ignore are ignored when |
226 | /// checking from \p types. |
227 | bool allSameElemKind(const ElemKind allowedElemKind, |
228 | llvm::ArrayRef<TypeRef> types, |
229 | const IndicesSet &ignore) const { |
230 | for (size_t i = 0; i < types.size(); i++) { |
231 | if (ignore.count(i)) { |
232 | continue; |
233 | } |
234 | const TypeRef currType = types[i]; |
235 | if (currType->getElementType() != allowedElemKind) { |
236 | return false; |
237 | } |
238 | } |
239 | return true; |
240 | } |
241 | |
242 | public: |
243 | NodeInfo(Kinded::Kind kind, llvm::ArrayRef<TypeRef> inTypes, |
244 | llvm::ArrayRef<TypeRef> outTypes) |
245 | : Kinded(kind), inTypes_(inTypes), outTypes_(outTypes) {} |
246 | |
247 | NodeInfo(const Node &N) : Kinded(N.getKind()) { |
248 | for (unsigned i = 0, e = N.getNumResults(); i < e; ++i) { |
249 | outTypes_.push_back(N.getType(i)); |
250 | } |
251 | for (unsigned idx = 0, end = N.getNumInputs(); idx != end; ++idx) { |
252 | inTypes_.push_back(N.getNthInput(idx).getType()); |
253 | } |
254 | name_ = N.getName(); |
255 | } |
256 | |
257 | /// \returns the input types. |
258 | llvm::ArrayRef<TypeRef> getInTypes() const { return inTypes_; } |
259 | |
260 | /// \returns the output types. |
261 | llvm::ArrayRef<TypeRef> getOutTypes() const { return outTypes_; } |
262 | |
263 | /// \returns the input type located at \p idx. |
264 | const TypeRef getInTy(size_t idx) const { |
265 | assert(idx < inTypes_.size()); |
266 | return inTypes_[idx]; |
267 | } |
268 | |
269 | /// \returns the output type located at \p idx. |
270 | const TypeRef getOutTy(size_t idx) const { |
271 | assert(idx < outTypes_.size()); |
272 | return outTypes_[idx]; |
273 | } |
274 | |
275 | /// \returns the input type located at \p idx. |
276 | const ElemKind getInElemTy(size_t idx) const { |
277 | assert(idx < inTypes_.size()); |
278 | return inTypes_[idx]->getElementType(); |
279 | } |
280 | |
281 | /// \returns the output type located at \p idx. |
282 | const ElemKind getOutElemTy(size_t idx) const { |
283 | assert(idx < outTypes_.size()); |
284 | return outTypes_[idx]->getElementType(); |
285 | } |
286 | |
287 | /// \returns the name of the node. |
288 | llvm::StringRef getName() const { return name_; } |
289 | |
290 | /// \returns whether all of the element types of inTypes_ and outTypes_ are |
291 | /// all the same and one of those found in \p allowedElemKinds. \p ignoreIn |
292 | /// and \p ignoreOut represent indices that can be skipped in inTypes_ and |
293 | /// outTypes_ respectively. |
294 | bool |
295 | allInputsAndOutputsHaveSameElemKind(llvm::ArrayRef<ElemKind> allowedElemKinds, |
296 | const IndicesSet &ignoreIn = {}, |
297 | const IndicesSet &ignoreOut = {}) const { |
298 | for (const ElemKind elemKind : allowedElemKinds) { |
299 | if (allSameElemKind(elemKind, inTypes_, ignoreIn) && |
300 | allSameElemKind(elemKind, outTypes_, ignoreOut)) { |
301 | return true; |
302 | } |
303 | } |
304 | return false; |
305 | } |
306 | |
307 | /// Helper for debugging which \returns a string representation for the |
308 | /// NodeInfo. |
309 | std::string getDebugDesc() const { |
310 | DescriptionBuilder db(getKindName()); |
311 | for (size_t i = 0; i < inTypes_.size(); ++i) { |
312 | db.addParam("inType" + std::to_string(i), *inTypes_[i]); |
313 | } |
314 | for (size_t i = 0; i < outTypes_.size(); ++i) { |
315 | db.addParam("outType" + std::to_string(i), *outTypes_[i]); |
316 | } |
317 | return db; |
318 | } |
319 | }; |
320 | |
321 | /// Helper class to walk through the specific uses of a NodeValue. |
322 | /// This class is built on top of the regular users-list (Node::getUsers) |
323 | /// but filters out the uses that don't affect the desired NodeValue. |
324 | template <bool is_const_iter = false> |
325 | class NodeValueIteratorImpl |
326 | : public std::iterator<std::forward_iterator_tag, NodeUse> { |
327 | public: |
328 | /// Base type of the iterator. |
329 | using iterator = |
330 | typename std::conditional<is_const_iter, |
331 | std::list<NodeUse>::const_iterator, |
332 | std::list<NodeUse>::iterator>::type; |
333 | /// Type of the NodeValue that this iterator is filtering for. |
334 | using NodeValueTy = typename std::conditional<is_const_iter, const NodeValue, |
335 | NodeValue>::type; |
336 | /// Type of the NodeUse that this iterator should return when dereferenced. |
337 | using NodeUseTy = |
338 | typename std::conditional<is_const_iter, const NodeUse, NodeUse>::type; |
339 | |
340 | private: |
341 | /// NodeValue that this iterator tracks. |
342 | NodeValueTy &parent_; |
343 | /// Actual iterator on the users-list. |
344 | /// \invariant if it_ points to a valid iterator, then the NodeValue it |
345 | /// references (via the NodeUse) is equal to parent_. |
346 | iterator it_; |
347 | |
348 | /// Convenient method to get the end iterator of the users list that this |
349 | /// iterator walks. |
350 | iterator getEnd() const { return parent_.getNode()->getUsers().end(); } |
351 | |
352 | /// Check if \p it_ points to a NodeUse that references \p parents_. |
353 | bool hasSameParent() const { |
354 | assert(it_ != getEnd() && "Cannot check invalid iterator" ); |
355 | // A users-list should be for one node. |
356 | // If this assert breaks, that means the input list is broken, |
357 | // or this iterator is not used as it was intended: to walk |
358 | // through a users-list. |
359 | assert(it_->get()->getNode() == parent_.getNode() && |
360 | "Iterator points to a list with different parent?!" ); |
361 | return it_->get()->getResNo() == parent_.getResNo(); |
362 | } |
363 | |
364 | public: |
365 | NodeValueIteratorImpl(NodeValueTy &parent, iterator it) |
366 | : parent_(parent), it_(it) { |
367 | if (it_ != getEnd() && !hasSameParent()) { |
368 | ++(*this); |
369 | } |
370 | assert((it_ == getEnd() || hasSameParent()) && |
371 | "operator++ should return the next valid iterator" ); |
372 | } |
373 | |
374 | /// Move to the next use of parent_. |
375 | NodeValueIteratorImpl &operator++() { |
376 | auto endIt = getEnd(); |
377 | while (++it_ != endIt && !hasSameParent()) { |
378 | } |
379 | return *this; |
380 | } |
381 | |
382 | NodeUseTy &operator*() { |
383 | assert(hasSameParent() && "Invalid iterator" ); |
384 | return *it_; |
385 | } |
386 | |
387 | const NodeUseTy &operator*() const { |
388 | assert(hasSameParent() && "Invalid iterator" ); |
389 | return *it_; |
390 | } |
391 | |
392 | bool operator!=(const NodeValueIteratorImpl &other) const { |
393 | return it_ != other.it_; |
394 | } |
395 | }; |
396 | |
397 | /// This enum is expected to match the indices order of any Arithmetic node as |
398 | /// defined in Node::isArithmetic(). |
399 | namespace ArithmeticNode { |
400 | constexpr unsigned LHSIdx = 0; |
401 | constexpr unsigned RHSIdx = 1; |
402 | constexpr unsigned ResultIdx = 0; |
403 | } // namespace ArithmeticNode |
404 | |
405 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node &node); |
406 | |
407 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node *node); |
408 | |
409 | /// Helper to get the Kind of a Node (e.g. Kinded::Kind::AddNodeKind) given its |
410 | /// \p nodeName (e.g. Add). |
411 | inline Kinded::Kind getKindFromNodeName(llvm::StringRef nodeName) { |
412 | #define DEF_NODE(CLASS, NAME) \ |
413 | if (nodeName == #NAME) { \ |
414 | return Kinded::Kind::CLASS##Kind; \ |
415 | } |
416 | #include "glow/AutoGenNodes.def" |
417 | LOG(FATAL) << "Unknown node name: " << nodeName.str(); |
418 | } |
419 | |
420 | } // namespace glow |
421 | |
422 | namespace llvm { |
423 | /// Allow casting NodeValue into Node*. |
424 | template <> struct simplify_type<glow::NodeValue> { |
425 | typedef glow::Node *SimpleType; |
426 | static SimpleType getSimplifiedValue(glow::NodeValue &val) { |
427 | return val.getNode(); |
428 | } |
429 | }; |
430 | |
431 | /// Allow casting NodeValue into Node*. |
432 | template <> struct simplify_type<const glow::NodeValue> { |
433 | typedef glow::Node *SimpleType; |
434 | static SimpleType getSimplifiedValue(const glow::NodeValue &val) { |
435 | return val.getNode(); |
436 | } |
437 | }; |
438 | |
439 | /// Allow casting NodeHandle into Node*. |
440 | template <> struct simplify_type<glow::NodeHandle> { |
441 | typedef glow::Node *SimpleType; |
442 | static SimpleType getSimplifiedValue(glow::NodeHandle &val) { |
443 | return val.getNode(); |
444 | } |
445 | }; |
446 | /// Allow casting const NodeHandle into Node*. |
447 | template <> struct simplify_type<const glow::NodeHandle> { |
448 | typedef glow::Node *SimpleType; |
449 | static SimpleType getSimplifiedValue(const glow::NodeHandle &val) { |
450 | return val.getNode(); |
451 | } |
452 | }; |
453 | |
454 | //===----------------------------------------------------------------------===// |
455 | // ilist_traits for glow::Node |
456 | //===----------------------------------------------------------------------===// |
457 | |
458 | template <> |
459 | struct ilist_traits<glow::Node> : public ilist_node_traits<glow::Node> { |
460 | using Node = glow::Node; |
461 | |
462 | glow::Function *getContainingFunction(); |
463 | |
464 | private: |
465 | using node_iterator = simple_ilist<Node>::iterator; |
466 | |
467 | public: |
468 | static void deleteNode(Node *N) { glow::Node::destroyNode(N); } |
469 | |
470 | void addNodeToList(Node *N); |
471 | void removeNodeFromList(Node *N); |
472 | void transferNodesFromList(ilist_traits<Node> &L2, node_iterator first, |
473 | node_iterator last); |
474 | |
475 | private: |
476 | void createNode(const Node &); |
477 | }; |
478 | |
479 | } // namespace llvm |
480 | |
481 | // custom specialization of std::hash for NodeValue. |
482 | namespace std { |
483 | template <> struct hash<glow::NodeValue> { |
484 | typedef glow::NodeValue argument_type; |
485 | typedef std::size_t result_type; |
486 | result_type operator()(argument_type const &s) const noexcept { |
487 | auto name = s.getNode()->getName(); |
488 | result_type const h1(std::hash<std::string>{}(name.str())); |
489 | result_type const h2(std::hash<unsigned>{}(s.getResNo())); |
490 | return h1 ^ (h2 << 8); |
491 | } |
492 | }; |
493 | } // namespace std |
494 | |
495 | #endif // GLOW_GRAPH_NODE_H |
496 | |