1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Base/Type.h"
18#include "glow/Graph/Graph.h"
19#include "glow/Graph/Nodes.h"
20#include "glow/Graph/VerifierHelper.h"
21#include "glow/Support/Support.h"
22
23using namespace glow;
24
25void Node::setPredicate(const NodeValue &P) { predicate_ = P; }
26
27bool Node::hasPredicate() const { return predicate_.getNode(); }
28
29TypeRef Node::getType(unsigned idx) const {
30 assert(idx < getNumResults() && "Result number does not exist.");
31 return types_[idx];
32}
33
34void Node::setType(unsigned idx, TypeRef ty) {
35 assert(types_[idx]->dims() == ty->dims() &&
36 "Better create a new node at this point");
37 setTypeUnsafe(idx, ty);
38}
39
40void Node::setTypeUnsafe(unsigned idx, TypeRef ty) {
41 assert(idx < getNumResults() && "Result number does not exist.");
42 types_[idx] = ty;
43}
44
45ElemKind Node::getElementType(unsigned resNo) const {
46 TypeRef TR = getType(resNo);
47 return TR->getElementType();
48}
49
50llvm::ArrayRef<dim_t> Node::dims(unsigned resNo) const {
51 TypeRef TR = getType(resNo);
52 return TR->dims();
53}
54
55void Node::addResult(TypeRef T) { types_.push_back(T); }
56
57bool Node::isEqual(const Node &other) const {
58 if (this == &other)
59 return true;
60
61 if (getKind() != other.getKind())
62 return false;
63
64 switch (getKind()) {
65#define DEF_NODE(CLASS, NAME) \
66 case glow::Kinded::Kind::CLASS##Kind: \
67 return static_cast<const CLASS *>(this)->isEqual( \
68 *static_cast<const CLASS *>(&other));
69#include "glow/AutoGenNodes.def"
70
71#define DEF_INSTR(CLASS, NAME) case glow::Kinded::Kind::CLASS##Kind:
72#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) DEF_INSTR(CLASS, NAME)
73#define DEF_VALUE(CLASS, NAME) DEF_INSTR(CLASS, NAME)
74#include "glow/AutoGenInstr.def"
75
76 llvm_unreachable(
77 "Not reachable, values and instructions are not handled here");
78 }
79 return false;
80}
81
82const NodeValue Node::getPredicate() const { return predicate_; }
83
84namespace {
85class HashNodeVisitor : public NodeVisitor<HashNodeVisitor, llvm::hash_code> {
86 using hash_code = llvm::hash_code;
87 using super = NodeVisitor;
88
89public:
90#define DEF_NODE(CLASS, NAME) \
91 hash_code visit##CLASS(const CLASS *N) const { return N->getHash(); }
92#include "glow/AutoGenNodes.def"
93
94 hash_code visit(const Node *N) const {
95 return const_cast<HashNodeVisitor *>(this)->super::visit(
96 const_cast<Node *>(N));
97 }
98};
99
100} // namespace
101
102llvm::hash_code Node::getHash() const { return HashNodeVisitor().visit(this); }
103
104void Node::visit(Node *parent, NodeWalker *visitor) {
105 if (hasPredicate()) {
106 getPredicate().getNode()->visit(this, visitor);
107 }
108
109 switch (getKind()) {
110#define DEF_NODE(CLASS, NAME) \
111 case glow::Kinded::Kind::CLASS##Kind: \
112 return static_cast<CLASS *>(this)->visit(parent, visitor);
113#include "glow/AutoGenNodes.def"
114 default:
115 llvm_unreachable("Unhandled node");
116 }
117}
118
119//===----------------------------------------------------------------------===//
120// Debug description methods
121//===----------------------------------------------------------------------===//
122
123unsigned Node::getNumInputs() const {
124 switch (getKind()) {
125#define DEF_NODE(CLASS, NAME) \
126 case glow::Kinded::Kind::CLASS##Kind: \
127 return static_cast<const CLASS *>(this)->getNumInputs();
128#include "glow/AutoGenNodes.def"
129 default:
130 llvm_unreachable("Unhandled node");
131 }
132}
133
134std::string Node::getInputName(unsigned idx) const {
135 switch (getKind()) {
136#define DEF_NODE(CLASS, NAME) \
137 case glow::Kinded::Kind::CLASS##Kind: \
138 return static_cast<const CLASS *>(this)->getInputName(idx);
139#include "glow/AutoGenNodes.def"
140 default:
141 llvm_unreachable("Unhandled node");
142 }
143}
144
145NodeValue Node::getNthInput(unsigned idx) {
146 switch (getKind()) {
147#define DEF_NODE(CLASS, NAME) \
148 case glow::Kinded::Kind::CLASS##Kind: \
149 return static_cast<CLASS *>(this)->getNthInput(idx);
150#include "glow/AutoGenNodes.def"
151 default:
152 llvm_unreachable("Unhandled node");
153 }
154}
155
156const NodeValue Node::getNthInput(unsigned idx) const {
157 switch (getKind()) {
158#define DEF_NODE(CLASS, NAME) \
159 case glow::Kinded::Kind::CLASS##Kind: \
160 return static_cast<CLASS *>(const_cast<Node *>(this))->getNthInput(idx);
161#include "glow/AutoGenNodes.def"
162 default:
163 llvm_unreachable("Unhandled node");
164 }
165}
166
167void Node::setNthInput(unsigned idx, NodeValue val) {
168 switch (getKind()) {
169#define DEF_NODE(CLASS, NAME) \
170 case glow::Kinded::Kind::CLASS##Kind: \
171 if (getParent()) { \
172 getParent()->getLogContext()->logNodeInputChange( \
173 *this, this->getNthInput(idx), val); \
174 } \
175 return static_cast<CLASS *>(this)->setNthInput(idx, val);
176#include "glow/AutoGenNodes.def"
177 default:
178 llvm_unreachable("Unhandled node");
179 }
180}
181
182NodeValue Node::getNthResult(unsigned idx) {
183 assert(idx < getNumResults());
184 return NodeValue(this, idx);
185}
186
187const NodeValue Node::getNthResult(unsigned idx) const {
188 assert(idx < getNumResults());
189 return NodeValue(const_cast<Node *>(this), idx);
190}
191
192llvm::StringRef Node::getOutputName(unsigned idx) const {
193 switch (getKind()) {
194#define DEF_NODE(CLASS, NAME) \
195 case glow::Kinded::Kind::CLASS##Kind: \
196 return static_cast<const CLASS *>(this)->getOutputName(idx);
197#include "glow/AutoGenNodes.def"
198 default:
199 llvm_unreachable("Unhandled node");
200 }
201}
202
203bool Node::hasSideEffects() const {
204 switch (getKind()) {
205#define DEF_NODE(CLASS, NAME) \
206 case glow::Kinded::Kind::CLASS##Kind: \
207 return static_cast<const CLASS *>(this)->hasSideEffects();
208#include "glow/AutoGenNodes.def"
209 default:
210 llvm_unreachable("Unhandled node");
211 }
212}
213
214bool Node::isCanonical() const {
215 switch (getKind()) {
216#define DEF_NODE(CLASS, NAME) \
217 case glow::Kinded::Kind::CLASS##Kind: \
218 return static_cast<const CLASS *>(this)->isCanonical();
219#include "glow/AutoGenNodes.def"
220 default:
221 llvm_unreachable("Unhandled node");
222 }
223}
224
225bool Node::isDataParallel() const {
226 switch (getKind()) {
227#define DEF_NODE(CLASS, NAME) \
228 case glow::Kinded::Kind::CLASS##Kind: \
229 return static_cast<const CLASS *>(this)->isDataParallel();
230#include "glow/AutoGenNodes.def"
231 default:
232 llvm_unreachable("Unhandled node");
233 }
234}
235
236// NOTE: This is used in conjunction with assuming the 1st input is LHS, and 2nd
237// input is RHS, and 1st result is Result.
238bool Node::isArithmetic() const {
239 // Each case includes a static assert that the generated nodes that we
240 // consider arithmetic have the expected format/order of LHS, RHS, Result.
241#define ARITHMETIC_NODE_CASE(NODE_NAME_) \
242 static_assert((NODE_NAME_##Node::LHSIdx == ArithmeticNode::LHSIdx && \
243 NODE_NAME_##Node::RHSIdx == ArithmeticNode::RHSIdx && \
244 NODE_NAME_##Node::ResultIdx == ArithmeticNode::ResultIdx), \
245 #NODE_NAME_ \
246 "Node does not match expected arithmetic node format."); \
247 case glow::Kinded::Kind::NODE_NAME_##NodeKind:
248
249 switch (getKind()) {
250 ARITHMETIC_NODE_CASE(Add)
251 ARITHMETIC_NODE_CASE(Mul)
252 ARITHMETIC_NODE_CASE(Sub)
253 ARITHMETIC_NODE_CASE(Div)
254 ARITHMETIC_NODE_CASE(FloorDiv)
255 ARITHMETIC_NODE_CASE(Max)
256 ARITHMETIC_NODE_CASE(Min)
257 ARITHMETIC_NODE_CASE(CmpLTE)
258 ARITHMETIC_NODE_CASE(CmpLT)
259 ARITHMETIC_NODE_CASE(CmpEQ)
260 ARITHMETIC_NODE_CASE(Pow)
261 ARITHMETIC_NODE_CASE(Fmod)
262 return true;
263 default:
264 return false;
265 }
266#undef ARITHMETIC_NODE_CASE
267}
268
269bool Node::isOverwrittenNthInput(unsigned idx) const {
270 switch (getKind()) {
271#define DEF_NODE(CLASS, NAME) \
272 case glow::Kinded::Kind::CLASS##Kind: \
273 return static_cast<const CLASS *>(this)->isOverwrittenNthInput(idx);
274#include "glow/AutoGenNodes.def"
275 default:
276 llvm_unreachable("Unhandled node");
277 }
278}
279
280std::string Node::getDebugDesc() const {
281 switch (getKind()) {
282#define DEF_NODE(CLASS, NAME) \
283 case glow::Kinded::Kind::CLASS##Kind: \
284 return static_cast<const CLASS *>(this)->getDebugDesc();
285#include "glow/AutoGenNodes.def"
286 default:
287 llvm_unreachable("Unhandled node");
288 }
289}
290
291void Node::dump(llvm::raw_ostream &out) const { out << this->getDebugDesc(); }
292
293void Node::dump() const { dump(llvm::outs()); }
294
295std::string Node::toString() const { return this->getDebugDesc(); }
296
297size_t Node::getTotMemSize() const {
298 size_t totMemSize = 0;
299 for (unsigned idx = 0, e = getNumInputs(); idx < e; idx++) {
300 totMemSize += getNthInput(idx).getType()->getSizeInBytes();
301 }
302 for (unsigned idx = 0, e = getNumResults(); idx < e; idx++) {
303 totMemSize += getNthResult(idx).getType()->getSizeInBytes();
304 }
305 return totMemSize;
306}
307
308Node *Node::clone() const {
309 switch (getKind()) {
310#define DEF_NODE(CLASS, NAME) \
311 case glow::Kinded::Kind::CLASS##Kind: \
312 return static_cast<const CLASS *>(this)->clone();
313#include "glow/AutoGenNodes.def"
314 default:
315 llvm_unreachable("Unhandled node");
316 }
317}
318
319void Node::destroyNode(Node *N) {
320 switch (N->getKind()) {
321#define DEF_NODE(CLASS, NAME) \
322 case glow::Kinded::Kind::CLASS##Kind: { \
323 delete static_cast<CLASS *>(N); \
324 break; \
325 }
326#include "glow/AutoGenNodes.def"
327 default:
328 llvm_unreachable("Unhandled node");
329 }
330}
331
332namespace glow {
333
334llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node &node) {
335 node.dump(os);
336 return os;
337}
338
339llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node *node) {
340 assert(node != nullptr && "Null Pointer.");
341 node->dump(os);
342 return os;
343}
344} // namespace glow
345
346//===----------------------------------------------------------------------===//
347// Nodes verification
348//===----------------------------------------------------------------------===//
349
350bool Node::verify() const {
351 // Verify the shared members of the node.
352 bool isValid = true;
353
354 // Verify the predicate field.
355 if (hasPredicate()) {
356 auto pred = getPredicate();
357 if (!expectCompareTrue("Invalid predicate", bool(pred.getNode()), true,
358 this)) {
359 // The following code assumes pred is valid.
360 return false;
361 }
362 auto Ty = pred.getType();
363 isValid &= expectCompareTrue("Predicate must be a vector",
364 Ty->dims().size(), size_t(1), this);
365 }
366
367 if (getParent()) {
368 isValid &=
369 expectCompareTrue("Node not present in its parent",
370 std::find(getParent()->getNodes().begin(),
371 getParent()->getNodes().end(),
372 *this) != getParent()->getNodes().end(),
373 true, this);
374 }
375
376 // Verify node-specific properties:
377 switch (getKind()) {
378#define DEF_NODE(CLASS, NAME) \
379 case glow::Kinded::Kind::CLASS##Kind: \
380 isValid &= static_cast<const CLASS *>(this)->verify(); \
381 break;
382#include "glow/AutoGenNodes.def"
383 default:
384 llvm_unreachable("Unhandled node");
385 }
386 return isValid;
387}
388
389//===----------------------------------------------------------------------===//
390// ilist_traits<glow::Node> Implementation
391//===----------------------------------------------------------------------===//
392
393// The trait object is embedded into a Function. Use dirty hacks to
394// reconstruct the Function from the 'self' pointer of the trait.
395Function *llvm::ilist_traits<Node>::getContainingFunction() {
396 size_t Offset(size_t(&((Function *)nullptr->*Function::getNodesMemberPtr())));
397 iplist<Node> *Anchor(static_cast<iplist<Node> *>(this));
398 return reinterpret_cast<Function *>(reinterpret_cast<char *>(Anchor) -
399 Offset);
400}
401
402void llvm::ilist_traits<Node>::addNodeToList(Node *node) {
403 assert(node->getParent() == nullptr && "Already in a list!");
404 node->setParent(getContainingFunction());
405}
406
407void llvm::ilist_traits<Node>::removeNodeFromList(Node *node) {
408 // When an instruction is removed from a function, clear the parent pointer.
409 assert(node->getParent() && "Not in a list!");
410 node->setParent(nullptr);
411}
412
413void llvm::ilist_traits<Node>::transferNodesFromList(
414 llvm::ilist_traits<Node> &L2, node_iterator first, node_iterator last) {
415 // If transferring nodes within the same Function, no reason to
416 // update their parent pointers.
417 Function *ThisParent = getContainingFunction();
418 if (ThisParent == L2.getContainingFunction())
419 return;
420
421 // Update the parent fields in the nodes.
422 for (; first != last; ++first)
423 first->setParent(ThisParent);
424}
425