1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_GRAPH_NODE_H
17#define GLOW_GRAPH_NODE_H
18
19#include "llvm/ADT/StringRef.h"
20#include "llvm/ADT/ilist.h"
21#include "llvm/ADT/ilist_node.h"
22#include "llvm/ADT/iterator_range.h"
23#include "llvm/Support/Casting.h"
24
25#include "glow/Base/Traits.h"
26#include "glow/Base/Type.h"
27#include "glow/Graph/NodeValue.h"
28#include "glow/Graph/UseDef.h"
29#include "glow/Support/Support.h"
30
31#include <list>
32#include <unordered_set>
33
34namespace glow {
35
36class Function;
37class Node;
38class NodeWalker;
39struct NodeUse;
40template <bool is_const_iter> class NodeValueIteratorImpl;
41using NodeValueIterator = NodeValueIteratorImpl<false>;
42using NodeValueConstIterator = NodeValueIteratorImpl<true>;
43
44/// Represents a node in the compute graph.
45class Node : public Named,
46 public Kinded,
47 public UseDef<Node, NodeUse>,
48 public llvm::ilist_node<Node> {
49 friend llvm::ilist_traits<Node>;
50
51protected:
52 /// The output types for the results of the node.
53 llvm::SmallVector<TypeRef, 6> types_;
54 /// A nullable reference to some tensor value that may predicate the execution
55 /// of the current node.
56 NodeHandle predicate_;
57
58 /// Link to the function holding this node.
59 Function *parent_{nullptr};
60
61public:
62 Node(Kinded::Kind k, llvm::StringRef name)
63 : Named(name), Kinded(k), predicate_(this, nullptr), parent_(nullptr) {}
64
65 /// \returns the nullable predicate of the current node.
66 const NodeValue getPredicate() const;
67 /// Assigns a nullable predicate to the current node.
68 void setPredicate(const NodeValue &P);
69 /// Checks if a predicate is assigned to the current node.
70 bool hasPredicate() const;
71
72 /// \returns the number of results that the node has.
73 unsigned getNumResults() const { return types_.size(); }
74 /// \returns the \p idx result of the node.
75 NodeValue getNthResult(unsigned idx);
76 /// \returns the n'th result of the node.
77 const NodeValue getNthResult(unsigned idx) const;
78
79 /// \returns the function holding this node.
80 /// If that node does not belong to any function, this
81 /// is nullptr.
82 const Function *getParent() const { return parent_; }
83 Function *getParent() { return parent_; }
84 /// Set the link to the function that holds this node.
85 void setParent(Function *parent) { parent_ = parent; }
86
87 /// Getters/setters to access Node's inputs and outputs.
88 unsigned getNumInputs() const;
89 std::string getInputName(unsigned idx) const;
90 NodeValue getNthInput(unsigned idx);
91 const NodeValue getNthInput(unsigned idx) const;
92 void setNthInput(unsigned idx, NodeValue val);
93 llvm::StringRef getOutputName(unsigned idx) const;
94 bool hasSideEffects() const;
95 bool isArithmetic() const;
96 bool isCanonical() const;
97 bool isDataParallel() const;
98
99 /// \returns true if this input is being overwritten by the node.
100 bool isOverwrittenNthInput(unsigned idx) const;
101
102 /// \returns a textual description of the node.
103 std::string getDebugDesc() const;
104
105 /// Dump a textual representation of the Node into provided output stream.
106 void dump(llvm::raw_ostream &out) const;
107
108 /// Dump a textual representation of the Node into default output stream.
109 void dump() const;
110
111 /// Dump a textual representation of the Node to std::string.
112 std::string toString() const;
113
114 /// \returns the total memory size (in bytes) of the node as the sum of sizes
115 /// for all the inputs and outputs.
116 size_t getTotMemSize() const;
117
118 /// \returns copy of the current node. Notice that the new node is not
119 /// inserted into any DAG. The caller of this method should add it to some
120 /// node-list.
121 Node *clone() const;
122
123 /// \returns true if the node is equal to the other node.
124 bool isEqual(const Node &other) const;
125
126 /// \returns true if the node is equal to the other node.
127 bool operator==(const Node &O) const { return isEqual(O); }
128
129 /// \returns a hash code of the node.
130 llvm::hash_code getHash() const;
131
132 /// This method implements the visitor pattern that scans the compute DAG top
133 /// to bottom. The visitor \p visitor is sent by the parent node \p parent,
134 /// or nullptr if this is the first node to be visited.
135 void visit(Node *parent, NodeWalker *visitor);
136
137 void visit(const Node *parent, NodeWalker *visitor) const;
138
139 /// Verify node.
140 /// \returns True if the node is valid. False otherwise.
141 bool verify() const;
142
143 /// Replace all uses of this node with null. This method is used by the
144 /// destruction sequence. When the node is deleted we need to unregister all
145 /// users. This allows us to deconstruct the graph in an arbitrary order.
146 void releaseUsers() {
147 NodeValue nop(nullptr);
148 for (unsigned i = 0; i < getNumResults(); i++) {
149 NodeValue(this, i).replaceAllUsesOfWith(nop);
150 }
151 }
152
153 ~Node() { releaseUsers(); }
154
155 /// Destroys a node and deallocates the memory. This method is implicitly
156 /// invoked by the parent function when a node is being removed from the
157 /// intrusive list of nodes. You can also invoke this method explicitly to
158 /// destroy a node which has no parent function (orphan node).
159 static void destroyNode(Node *N);
160
161 /// \returns the n'th result type of the node.
162 TypeRef getType(unsigned idx) const;
163 /// Set the \p idx'th result type of the node.
164 /// \note This setter only changes the type of this one
165 /// result. If that type is incompatible with
166 /// the inputs of the node, the caller is
167 /// responsible to update these if need be.
168 void setType(unsigned idx, TypeRef ty);
169
170 /// Set the \p idx'th result type of the node, without checking if the dims of
171 /// the old type match the dims of the new one.
172 /// \note This setter only changes the type of this one
173 /// result. If that type is incompatible with
174 /// the inputs of the node, the caller is
175 /// responsible to update these if need be.
176 /// This function does not check for validity
177 /// of input dims and whether the result exists.
178 void setTypeUnsafe(unsigned idx, TypeRef ty);
179
180 /// Methods that forward to the result type (that must be valid):
181 /// @{
182 ElemKind getElementType(unsigned resNo) const;
183 llvm::ArrayRef<dim_t> dims(unsigned resNo) const;
184 /// @}
185
186protected:
187 /// When constructing the node, add a new result of type \p T.
188 void addResult(TypeRef T);
189};
190
191/// A walker that recursively visits a node and its children.
192class NodeWalker {
193public:
194 /// This callback is called before visiting the children of \p N.
195 virtual void pre(Node *parent, Node *N) {}
196 virtual void pre(const Node *parent, const Node *N) {}
197
198 /// This callback is called after visiting the children of \p N.
199 virtual void post(Node *parent, Node *N) {}
200 virtual void post(const Node *parent, const Node *N) {}
201
202 /// This callback is called before processing the graph. If the method returns
203 /// false then we skip this node.
204 virtual bool shouldVisit(Node *parent, Node *N) { return true; }
205 virtual bool shouldVisit(const Node *parent, const Node *N) { return true; }
206
207 /// Dtor.
208 virtual ~NodeWalker() = default;
209};
210
211using IndicesSet = std::unordered_set<unsigned>;
212
213/// Helper class to hold info of a Node, containing its \p opKind, \p inTypes,
214/// and \p outTypes
215class NodeInfo : public Kinded {
216private:
217 /// The input types of the NodeInfo.
218 std::vector<TypeRef> inTypes_;
219 /// The output types of the NodeInfo.
220 std::vector<TypeRef> outTypes_;
221 /// The name of the node.
222 llvm::StringRef name_;
223
224 /// Helper function for checking if all of the ElemKinds contained in \p types
225 /// are equal to \p allowedElemKind. Indices in \p ignore are ignored when
226 /// checking from \p types.
227 bool allSameElemKind(const ElemKind allowedElemKind,
228 llvm::ArrayRef<TypeRef> types,
229 const IndicesSet &ignore) const {
230 for (size_t i = 0; i < types.size(); i++) {
231 if (ignore.count(i)) {
232 continue;
233 }
234 const TypeRef currType = types[i];
235 if (currType->getElementType() != allowedElemKind) {
236 return false;
237 }
238 }
239 return true;
240 }
241
242public:
243 NodeInfo(Kinded::Kind kind, llvm::ArrayRef<TypeRef> inTypes,
244 llvm::ArrayRef<TypeRef> outTypes)
245 : Kinded(kind), inTypes_(inTypes), outTypes_(outTypes) {}
246
247 NodeInfo(const Node &N) : Kinded(N.getKind()) {
248 for (unsigned i = 0, e = N.getNumResults(); i < e; ++i) {
249 outTypes_.push_back(N.getType(i));
250 }
251 for (unsigned idx = 0, end = N.getNumInputs(); idx != end; ++idx) {
252 inTypes_.push_back(N.getNthInput(idx).getType());
253 }
254 name_ = N.getName();
255 }
256
257 /// \returns the input types.
258 llvm::ArrayRef<TypeRef> getInTypes() const { return inTypes_; }
259
260 /// \returns the output types.
261 llvm::ArrayRef<TypeRef> getOutTypes() const { return outTypes_; }
262
263 /// \returns the input type located at \p idx.
264 const TypeRef getInTy(size_t idx) const {
265 assert(idx < inTypes_.size());
266 return inTypes_[idx];
267 }
268
269 /// \returns the output type located at \p idx.
270 const TypeRef getOutTy(size_t idx) const {
271 assert(idx < outTypes_.size());
272 return outTypes_[idx];
273 }
274
275 /// \returns the input type located at \p idx.
276 const ElemKind getInElemTy(size_t idx) const {
277 assert(idx < inTypes_.size());
278 return inTypes_[idx]->getElementType();
279 }
280
281 /// \returns the output type located at \p idx.
282 const ElemKind getOutElemTy(size_t idx) const {
283 assert(idx < outTypes_.size());
284 return outTypes_[idx]->getElementType();
285 }
286
287 /// \returns the name of the node.
288 llvm::StringRef getName() const { return name_; }
289
290 /// \returns whether all of the element types of inTypes_ and outTypes_ are
291 /// all the same and one of those found in \p allowedElemKinds. \p ignoreIn
292 /// and \p ignoreOut represent indices that can be skipped in inTypes_ and
293 /// outTypes_ respectively.
294 bool
295 allInputsAndOutputsHaveSameElemKind(llvm::ArrayRef<ElemKind> allowedElemKinds,
296 const IndicesSet &ignoreIn = {},
297 const IndicesSet &ignoreOut = {}) const {
298 for (const ElemKind elemKind : allowedElemKinds) {
299 if (allSameElemKind(elemKind, inTypes_, ignoreIn) &&
300 allSameElemKind(elemKind, outTypes_, ignoreOut)) {
301 return true;
302 }
303 }
304 return false;
305 }
306
307 /// Helper for debugging which \returns a string representation for the
308 /// NodeInfo.
309 std::string getDebugDesc() const {
310 DescriptionBuilder db(getKindName());
311 for (size_t i = 0; i < inTypes_.size(); ++i) {
312 db.addParam("inType" + std::to_string(i), *inTypes_[i]);
313 }
314 for (size_t i = 0; i < outTypes_.size(); ++i) {
315 db.addParam("outType" + std::to_string(i), *outTypes_[i]);
316 }
317 return db;
318 }
319};
320
321/// Helper class to walk through the specific uses of a NodeValue.
322/// This class is built on top of the regular users-list (Node::getUsers)
323/// but filters out the uses that don't affect the desired NodeValue.
324template <bool is_const_iter = false>
325class NodeValueIteratorImpl
326 : public std::iterator<std::forward_iterator_tag, NodeUse> {
327public:
328 /// Base type of the iterator.
329 using iterator =
330 typename std::conditional<is_const_iter,
331 std::list<NodeUse>::const_iterator,
332 std::list<NodeUse>::iterator>::type;
333 /// Type of the NodeValue that this iterator is filtering for.
334 using NodeValueTy = typename std::conditional<is_const_iter, const NodeValue,
335 NodeValue>::type;
336 /// Type of the NodeUse that this iterator should return when dereferenced.
337 using NodeUseTy =
338 typename std::conditional<is_const_iter, const NodeUse, NodeUse>::type;
339
340private:
341 /// NodeValue that this iterator tracks.
342 NodeValueTy &parent_;
343 /// Actual iterator on the users-list.
344 /// \invariant if it_ points to a valid iterator, then the NodeValue it
345 /// references (via the NodeUse) is equal to parent_.
346 iterator it_;
347
348 /// Convenient method to get the end iterator of the users list that this
349 /// iterator walks.
350 iterator getEnd() const { return parent_.getNode()->getUsers().end(); }
351
352 /// Check if \p it_ points to a NodeUse that references \p parents_.
353 bool hasSameParent() const {
354 assert(it_ != getEnd() && "Cannot check invalid iterator");
355 // A users-list should be for one node.
356 // If this assert breaks, that means the input list is broken,
357 // or this iterator is not used as it was intended: to walk
358 // through a users-list.
359 assert(it_->get()->getNode() == parent_.getNode() &&
360 "Iterator points to a list with different parent?!");
361 return it_->get()->getResNo() == parent_.getResNo();
362 }
363
364public:
365 NodeValueIteratorImpl(NodeValueTy &parent, iterator it)
366 : parent_(parent), it_(it) {
367 if (it_ != getEnd() && !hasSameParent()) {
368 ++(*this);
369 }
370 assert((it_ == getEnd() || hasSameParent()) &&
371 "operator++ should return the next valid iterator");
372 }
373
374 /// Move to the next use of parent_.
375 NodeValueIteratorImpl &operator++() {
376 auto endIt = getEnd();
377 while (++it_ != endIt && !hasSameParent()) {
378 }
379 return *this;
380 }
381
382 NodeUseTy &operator*() {
383 assert(hasSameParent() && "Invalid iterator");
384 return *it_;
385 }
386
387 const NodeUseTy &operator*() const {
388 assert(hasSameParent() && "Invalid iterator");
389 return *it_;
390 }
391
392 bool operator!=(const NodeValueIteratorImpl &other) const {
393 return it_ != other.it_;
394 }
395};
396
397/// This enum is expected to match the indices order of any Arithmetic node as
398/// defined in Node::isArithmetic().
399namespace ArithmeticNode {
400constexpr unsigned LHSIdx = 0;
401constexpr unsigned RHSIdx = 1;
402constexpr unsigned ResultIdx = 0;
403} // namespace ArithmeticNode
404
405llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node &node);
406
407llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node *node);
408
409/// Helper to get the Kind of a Node (e.g. Kinded::Kind::AddNodeKind) given its
410/// \p nodeName (e.g. Add).
411inline Kinded::Kind getKindFromNodeName(llvm::StringRef nodeName) {
412#define DEF_NODE(CLASS, NAME) \
413 if (nodeName == #NAME) { \
414 return Kinded::Kind::CLASS##Kind; \
415 }
416#include "glow/AutoGenNodes.def"
417 LOG(FATAL) << "Unknown node name: " << nodeName.str();
418}
419
420} // namespace glow
421
422namespace llvm {
423/// Allow casting NodeValue into Node*.
424template <> struct simplify_type<glow::NodeValue> {
425 typedef glow::Node *SimpleType;
426 static SimpleType getSimplifiedValue(glow::NodeValue &val) {
427 return val.getNode();
428 }
429};
430
431/// Allow casting NodeValue into Node*.
432template <> struct simplify_type<const glow::NodeValue> {
433 typedef glow::Node *SimpleType;
434 static SimpleType getSimplifiedValue(const glow::NodeValue &val) {
435 return val.getNode();
436 }
437};
438
439/// Allow casting NodeHandle into Node*.
440template <> struct simplify_type<glow::NodeHandle> {
441 typedef glow::Node *SimpleType;
442 static SimpleType getSimplifiedValue(glow::NodeHandle &val) {
443 return val.getNode();
444 }
445};
446/// Allow casting const NodeHandle into Node*.
447template <> struct simplify_type<const glow::NodeHandle> {
448 typedef glow::Node *SimpleType;
449 static SimpleType getSimplifiedValue(const glow::NodeHandle &val) {
450 return val.getNode();
451 }
452};
453
454//===----------------------------------------------------------------------===//
455// ilist_traits for glow::Node
456//===----------------------------------------------------------------------===//
457
458template <>
459struct ilist_traits<glow::Node> : public ilist_node_traits<glow::Node> {
460 using Node = glow::Node;
461
462 glow::Function *getContainingFunction();
463
464private:
465 using node_iterator = simple_ilist<Node>::iterator;
466
467public:
468 static void deleteNode(Node *N) { glow::Node::destroyNode(N); }
469
470 void addNodeToList(Node *N);
471 void removeNodeFromList(Node *N);
472 void transferNodesFromList(ilist_traits<Node> &L2, node_iterator first,
473 node_iterator last);
474
475private:
476 void createNode(const Node &);
477};
478
479} // namespace llvm
480
481// custom specialization of std::hash for NodeValue.
482namespace std {
483template <> struct hash<glow::NodeValue> {
484 typedef glow::NodeValue argument_type;
485 typedef std::size_t result_type;
486 result_type operator()(argument_type const &s) const noexcept {
487 auto name = s.getNode()->getName();
488 result_type const h1(std::hash<std::string>{}(name.str()));
489 result_type const h2(std::hash<unsigned>{}(s.getResNo()));
490 return h1 ^ (h2 << 8);
491 }
492};
493} // namespace std
494
495#endif // GLOW_GRAPH_NODE_H
496