1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Base/Type.h" |
18 | #include "glow/Graph/Graph.h" |
19 | #include "glow/Graph/Nodes.h" |
20 | #include "glow/Graph/VerifierHelper.h" |
21 | #include "glow/Support/Support.h" |
22 | |
23 | using namespace glow; |
24 | |
25 | void Node::setPredicate(const NodeValue &P) { predicate_ = P; } |
26 | |
27 | bool Node::hasPredicate() const { return predicate_.getNode(); } |
28 | |
29 | TypeRef Node::getType(unsigned idx) const { |
30 | assert(idx < getNumResults() && "Result number does not exist." ); |
31 | return types_[idx]; |
32 | } |
33 | |
34 | void Node::setType(unsigned idx, TypeRef ty) { |
35 | assert(types_[idx]->dims() == ty->dims() && |
36 | "Better create a new node at this point" ); |
37 | setTypeUnsafe(idx, ty); |
38 | } |
39 | |
40 | void Node::setTypeUnsafe(unsigned idx, TypeRef ty) { |
41 | assert(idx < getNumResults() && "Result number does not exist." ); |
42 | types_[idx] = ty; |
43 | } |
44 | |
45 | ElemKind Node::getElementType(unsigned resNo) const { |
46 | TypeRef TR = getType(resNo); |
47 | return TR->getElementType(); |
48 | } |
49 | |
50 | llvm::ArrayRef<dim_t> Node::dims(unsigned resNo) const { |
51 | TypeRef TR = getType(resNo); |
52 | return TR->dims(); |
53 | } |
54 | |
55 | void Node::addResult(TypeRef T) { types_.push_back(T); } |
56 | |
57 | bool Node::isEqual(const Node &other) const { |
58 | if (this == &other) |
59 | return true; |
60 | |
61 | if (getKind() != other.getKind()) |
62 | return false; |
63 | |
64 | switch (getKind()) { |
65 | #define DEF_NODE(CLASS, NAME) \ |
66 | case glow::Kinded::Kind::CLASS##Kind: \ |
67 | return static_cast<const CLASS *>(this)->isEqual( \ |
68 | *static_cast<const CLASS *>(&other)); |
69 | #include "glow/AutoGenNodes.def" |
70 | |
71 | #define DEF_INSTR(CLASS, NAME) case glow::Kinded::Kind::CLASS##Kind: |
72 | #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) DEF_INSTR(CLASS, NAME) |
73 | #define DEF_VALUE(CLASS, NAME) DEF_INSTR(CLASS, NAME) |
74 | #include "glow/AutoGenInstr.def" |
75 | |
76 | llvm_unreachable( |
77 | "Not reachable, values and instructions are not handled here" ); |
78 | } |
79 | return false; |
80 | } |
81 | |
82 | const NodeValue Node::getPredicate() const { return predicate_; } |
83 | |
84 | namespace { |
85 | class HashNodeVisitor : public NodeVisitor<HashNodeVisitor, llvm::hash_code> { |
86 | using hash_code = llvm::hash_code; |
87 | using super = NodeVisitor; |
88 | |
89 | public: |
90 | #define DEF_NODE(CLASS, NAME) \ |
91 | hash_code visit##CLASS(const CLASS *N) const { return N->getHash(); } |
92 | #include "glow/AutoGenNodes.def" |
93 | |
94 | hash_code visit(const Node *N) const { |
95 | return const_cast<HashNodeVisitor *>(this)->super::visit( |
96 | const_cast<Node *>(N)); |
97 | } |
98 | }; |
99 | |
100 | } // namespace |
101 | |
102 | llvm::hash_code Node::getHash() const { return HashNodeVisitor().visit(this); } |
103 | |
104 | void Node::visit(Node *parent, NodeWalker *visitor) { |
105 | if (hasPredicate()) { |
106 | getPredicate().getNode()->visit(this, visitor); |
107 | } |
108 | |
109 | switch (getKind()) { |
110 | #define DEF_NODE(CLASS, NAME) \ |
111 | case glow::Kinded::Kind::CLASS##Kind: \ |
112 | return static_cast<CLASS *>(this)->visit(parent, visitor); |
113 | #include "glow/AutoGenNodes.def" |
114 | default: |
115 | llvm_unreachable("Unhandled node" ); |
116 | } |
117 | } |
118 | |
119 | //===----------------------------------------------------------------------===// |
120 | // Debug description methods |
121 | //===----------------------------------------------------------------------===// |
122 | |
123 | unsigned Node::getNumInputs() const { |
124 | switch (getKind()) { |
125 | #define DEF_NODE(CLASS, NAME) \ |
126 | case glow::Kinded::Kind::CLASS##Kind: \ |
127 | return static_cast<const CLASS *>(this)->getNumInputs(); |
128 | #include "glow/AutoGenNodes.def" |
129 | default: |
130 | llvm_unreachable("Unhandled node" ); |
131 | } |
132 | } |
133 | |
134 | std::string Node::getInputName(unsigned idx) const { |
135 | switch (getKind()) { |
136 | #define DEF_NODE(CLASS, NAME) \ |
137 | case glow::Kinded::Kind::CLASS##Kind: \ |
138 | return static_cast<const CLASS *>(this)->getInputName(idx); |
139 | #include "glow/AutoGenNodes.def" |
140 | default: |
141 | llvm_unreachable("Unhandled node" ); |
142 | } |
143 | } |
144 | |
145 | NodeValue Node::getNthInput(unsigned idx) { |
146 | switch (getKind()) { |
147 | #define DEF_NODE(CLASS, NAME) \ |
148 | case glow::Kinded::Kind::CLASS##Kind: \ |
149 | return static_cast<CLASS *>(this)->getNthInput(idx); |
150 | #include "glow/AutoGenNodes.def" |
151 | default: |
152 | llvm_unreachable("Unhandled node" ); |
153 | } |
154 | } |
155 | |
156 | const NodeValue Node::getNthInput(unsigned idx) const { |
157 | switch (getKind()) { |
158 | #define DEF_NODE(CLASS, NAME) \ |
159 | case glow::Kinded::Kind::CLASS##Kind: \ |
160 | return static_cast<CLASS *>(const_cast<Node *>(this))->getNthInput(idx); |
161 | #include "glow/AutoGenNodes.def" |
162 | default: |
163 | llvm_unreachable("Unhandled node" ); |
164 | } |
165 | } |
166 | |
167 | void Node::setNthInput(unsigned idx, NodeValue val) { |
168 | switch (getKind()) { |
169 | #define DEF_NODE(CLASS, NAME) \ |
170 | case glow::Kinded::Kind::CLASS##Kind: \ |
171 | if (getParent()) { \ |
172 | getParent()->getLogContext()->logNodeInputChange( \ |
173 | *this, this->getNthInput(idx), val); \ |
174 | } \ |
175 | return static_cast<CLASS *>(this)->setNthInput(idx, val); |
176 | #include "glow/AutoGenNodes.def" |
177 | default: |
178 | llvm_unreachable("Unhandled node" ); |
179 | } |
180 | } |
181 | |
182 | NodeValue Node::getNthResult(unsigned idx) { |
183 | assert(idx < getNumResults()); |
184 | return NodeValue(this, idx); |
185 | } |
186 | |
187 | const NodeValue Node::getNthResult(unsigned idx) const { |
188 | assert(idx < getNumResults()); |
189 | return NodeValue(const_cast<Node *>(this), idx); |
190 | } |
191 | |
192 | llvm::StringRef Node::getOutputName(unsigned idx) const { |
193 | switch (getKind()) { |
194 | #define DEF_NODE(CLASS, NAME) \ |
195 | case glow::Kinded::Kind::CLASS##Kind: \ |
196 | return static_cast<const CLASS *>(this)->getOutputName(idx); |
197 | #include "glow/AutoGenNodes.def" |
198 | default: |
199 | llvm_unreachable("Unhandled node" ); |
200 | } |
201 | } |
202 | |
203 | bool Node::hasSideEffects() const { |
204 | switch (getKind()) { |
205 | #define DEF_NODE(CLASS, NAME) \ |
206 | case glow::Kinded::Kind::CLASS##Kind: \ |
207 | return static_cast<const CLASS *>(this)->hasSideEffects(); |
208 | #include "glow/AutoGenNodes.def" |
209 | default: |
210 | llvm_unreachable("Unhandled node" ); |
211 | } |
212 | } |
213 | |
214 | bool Node::isCanonical() const { |
215 | switch (getKind()) { |
216 | #define DEF_NODE(CLASS, NAME) \ |
217 | case glow::Kinded::Kind::CLASS##Kind: \ |
218 | return static_cast<const CLASS *>(this)->isCanonical(); |
219 | #include "glow/AutoGenNodes.def" |
220 | default: |
221 | llvm_unreachable("Unhandled node" ); |
222 | } |
223 | } |
224 | |
225 | bool Node::isDataParallel() const { |
226 | switch (getKind()) { |
227 | #define DEF_NODE(CLASS, NAME) \ |
228 | case glow::Kinded::Kind::CLASS##Kind: \ |
229 | return static_cast<const CLASS *>(this)->isDataParallel(); |
230 | #include "glow/AutoGenNodes.def" |
231 | default: |
232 | llvm_unreachable("Unhandled node" ); |
233 | } |
234 | } |
235 | |
236 | // NOTE: This is used in conjunction with assuming the 1st input is LHS, and 2nd |
237 | // input is RHS, and 1st result is Result. |
238 | bool Node::isArithmetic() const { |
239 | // Each case includes a static assert that the generated nodes that we |
240 | // consider arithmetic have the expected format/order of LHS, RHS, Result. |
241 | #define ARITHMETIC_NODE_CASE(NODE_NAME_) \ |
242 | static_assert((NODE_NAME_##Node::LHSIdx == ArithmeticNode::LHSIdx && \ |
243 | NODE_NAME_##Node::RHSIdx == ArithmeticNode::RHSIdx && \ |
244 | NODE_NAME_##Node::ResultIdx == ArithmeticNode::ResultIdx), \ |
245 | #NODE_NAME_ \ |
246 | "Node does not match expected arithmetic node format."); \ |
247 | case glow::Kinded::Kind::NODE_NAME_##NodeKind: |
248 | |
249 | switch (getKind()) { |
250 | ARITHMETIC_NODE_CASE(Add) |
251 | ARITHMETIC_NODE_CASE(Mul) |
252 | ARITHMETIC_NODE_CASE(Sub) |
253 | ARITHMETIC_NODE_CASE(Div) |
254 | ARITHMETIC_NODE_CASE(FloorDiv) |
255 | ARITHMETIC_NODE_CASE(Max) |
256 | ARITHMETIC_NODE_CASE(Min) |
257 | ARITHMETIC_NODE_CASE(CmpLTE) |
258 | ARITHMETIC_NODE_CASE(CmpLT) |
259 | ARITHMETIC_NODE_CASE(CmpEQ) |
260 | ARITHMETIC_NODE_CASE(Pow) |
261 | ARITHMETIC_NODE_CASE(Fmod) |
262 | return true; |
263 | default: |
264 | return false; |
265 | } |
266 | #undef ARITHMETIC_NODE_CASE |
267 | } |
268 | |
269 | bool Node::isOverwrittenNthInput(unsigned idx) const { |
270 | switch (getKind()) { |
271 | #define DEF_NODE(CLASS, NAME) \ |
272 | case glow::Kinded::Kind::CLASS##Kind: \ |
273 | return static_cast<const CLASS *>(this)->isOverwrittenNthInput(idx); |
274 | #include "glow/AutoGenNodes.def" |
275 | default: |
276 | llvm_unreachable("Unhandled node" ); |
277 | } |
278 | } |
279 | |
280 | std::string Node::getDebugDesc() const { |
281 | switch (getKind()) { |
282 | #define DEF_NODE(CLASS, NAME) \ |
283 | case glow::Kinded::Kind::CLASS##Kind: \ |
284 | return static_cast<const CLASS *>(this)->getDebugDesc(); |
285 | #include "glow/AutoGenNodes.def" |
286 | default: |
287 | llvm_unreachable("Unhandled node" ); |
288 | } |
289 | } |
290 | |
291 | void Node::dump(llvm::raw_ostream &out) const { out << this->getDebugDesc(); } |
292 | |
293 | void Node::dump() const { dump(llvm::outs()); } |
294 | |
295 | std::string Node::toString() const { return this->getDebugDesc(); } |
296 | |
297 | size_t Node::getTotMemSize() const { |
298 | size_t totMemSize = 0; |
299 | for (unsigned idx = 0, e = getNumInputs(); idx < e; idx++) { |
300 | totMemSize += getNthInput(idx).getType()->getSizeInBytes(); |
301 | } |
302 | for (unsigned idx = 0, e = getNumResults(); idx < e; idx++) { |
303 | totMemSize += getNthResult(idx).getType()->getSizeInBytes(); |
304 | } |
305 | return totMemSize; |
306 | } |
307 | |
308 | Node *Node::clone() const { |
309 | switch (getKind()) { |
310 | #define DEF_NODE(CLASS, NAME) \ |
311 | case glow::Kinded::Kind::CLASS##Kind: \ |
312 | return static_cast<const CLASS *>(this)->clone(); |
313 | #include "glow/AutoGenNodes.def" |
314 | default: |
315 | llvm_unreachable("Unhandled node" ); |
316 | } |
317 | } |
318 | |
319 | void Node::destroyNode(Node *N) { |
320 | switch (N->getKind()) { |
321 | #define DEF_NODE(CLASS, NAME) \ |
322 | case glow::Kinded::Kind::CLASS##Kind: { \ |
323 | delete static_cast<CLASS *>(N); \ |
324 | break; \ |
325 | } |
326 | #include "glow/AutoGenNodes.def" |
327 | default: |
328 | llvm_unreachable("Unhandled node" ); |
329 | } |
330 | } |
331 | |
332 | namespace glow { |
333 | |
334 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node &node) { |
335 | node.dump(os); |
336 | return os; |
337 | } |
338 | |
339 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node *node) { |
340 | assert(node != nullptr && "Null Pointer." ); |
341 | node->dump(os); |
342 | return os; |
343 | } |
344 | } // namespace glow |
345 | |
346 | //===----------------------------------------------------------------------===// |
347 | // Nodes verification |
348 | //===----------------------------------------------------------------------===// |
349 | |
350 | bool Node::verify() const { |
351 | // Verify the shared members of the node. |
352 | bool isValid = true; |
353 | |
354 | // Verify the predicate field. |
355 | if (hasPredicate()) { |
356 | auto pred = getPredicate(); |
357 | if (!expectCompareTrue("Invalid predicate" , bool(pred.getNode()), true, |
358 | this)) { |
359 | // The following code assumes pred is valid. |
360 | return false; |
361 | } |
362 | auto Ty = pred.getType(); |
363 | isValid &= expectCompareTrue("Predicate must be a vector" , |
364 | Ty->dims().size(), size_t(1), this); |
365 | } |
366 | |
367 | if (getParent()) { |
368 | isValid &= |
369 | expectCompareTrue("Node not present in its parent" , |
370 | std::find(getParent()->getNodes().begin(), |
371 | getParent()->getNodes().end(), |
372 | *this) != getParent()->getNodes().end(), |
373 | true, this); |
374 | } |
375 | |
376 | // Verify node-specific properties: |
377 | switch (getKind()) { |
378 | #define DEF_NODE(CLASS, NAME) \ |
379 | case glow::Kinded::Kind::CLASS##Kind: \ |
380 | isValid &= static_cast<const CLASS *>(this)->verify(); \ |
381 | break; |
382 | #include "glow/AutoGenNodes.def" |
383 | default: |
384 | llvm_unreachable("Unhandled node" ); |
385 | } |
386 | return isValid; |
387 | } |
388 | |
389 | //===----------------------------------------------------------------------===// |
390 | // ilist_traits<glow::Node> Implementation |
391 | //===----------------------------------------------------------------------===// |
392 | |
393 | // The trait object is embedded into a Function. Use dirty hacks to |
394 | // reconstruct the Function from the 'self' pointer of the trait. |
395 | Function *llvm::ilist_traits<Node>::getContainingFunction() { |
396 | size_t Offset(size_t(&((Function *)nullptr->*Function::getNodesMemberPtr()))); |
397 | iplist<Node> *Anchor(static_cast<iplist<Node> *>(this)); |
398 | return reinterpret_cast<Function *>(reinterpret_cast<char *>(Anchor) - |
399 | Offset); |
400 | } |
401 | |
402 | void llvm::ilist_traits<Node>::addNodeToList(Node *node) { |
403 | assert(node->getParent() == nullptr && "Already in a list!" ); |
404 | node->setParent(getContainingFunction()); |
405 | } |
406 | |
407 | void llvm::ilist_traits<Node>::removeNodeFromList(Node *node) { |
408 | // When an instruction is removed from a function, clear the parent pointer. |
409 | assert(node->getParent() && "Not in a list!" ); |
410 | node->setParent(nullptr); |
411 | } |
412 | |
413 | void llvm::ilist_traits<Node>::transferNodesFromList( |
414 | llvm::ilist_traits<Node> &L2, node_iterator first, node_iterator last) { |
415 | // If transferring nodes within the same Function, no reason to |
416 | // update their parent pointers. |
417 | Function *ThisParent = getContainingFunction(); |
418 | if (ThisParent == L2.getContainingFunction()) |
419 | return; |
420 | |
421 | // Update the parent fields in the nodes. |
422 | for (; first != last; ++first) |
423 | first->setParent(ThisParent); |
424 | } |
425 | |