1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Graph/NodeValue.h" |
18 | #include "glow/Graph/Graph.h" |
19 | #include "glow/Graph/Node.h" |
20 | |
21 | using namespace glow; |
22 | |
23 | NodeValue::NodeValue(Node *N) { |
24 | assert((!N || (N->getNumResults() == 1)) && |
25 | "Constructing a value for a multi-res node" ); |
26 | node_ = N; |
27 | resNo_ = 0; |
28 | } |
29 | |
30 | NodeValue::NodeValue(Node *N, unsigned resNo) { |
31 | assert(resNo < N->getNumResults() && "Invalid result number" ); |
32 | node_ = N; |
33 | resNo_ = resNo; |
34 | } |
35 | |
36 | void NodeValue::replaceAllUsesOfWith(NodeValue v, const Function *F, |
37 | Node *skipReplacement) const { |
38 | if (v.getNode() && getType() != v.getType()) { |
39 | // Fresh replacement nodes without users are usually created by the graph |
40 | // optimizer, where the type of the replacement value should always match |
41 | // the type of the original node value. Check if the only difference is |
42 | // related to strides and adjust accordingly. |
43 | assert(v.getNumUsers() == 0 && "Cannot update type if there are users" ); |
44 | assert(getType()->isEqual(*v.getType(), /* allowDifferentShape */ false, |
45 | /* allowDifferentStrides */ true) && |
46 | "Replacing value with the wrong type" ); |
47 | v.setType(getType()); |
48 | } |
49 | typeUnsafeReplaceAllUsesOfWith(v, F, skipReplacement); |
50 | } |
51 | |
52 | void NodeValue::typeUnsafeReplaceAllUsesOfWith(NodeValue v, const Function *F, |
53 | Node *skipReplacement) const { |
54 | // Copy the list of users in a temporary vector since that list (and the |
55 | // underlying iterators) are going to be invalidated by the next loop. |
56 | auto nodeValueUsers = getUsers(); |
57 | llvm::SmallVector<NodeUse, 4> usersVec(nodeValueUsers.begin(), |
58 | nodeValueUsers.end()); |
59 | for (auto &U : usersVec) { |
60 | NodeHandle *site = U.get(); |
61 | auto *userF = U.getUser()->getParent(); |
62 | // If the user is not in function F, don't touch it. |
63 | if (F && userF != F) { |
64 | continue; |
65 | } |
66 | assert(site->getNode() == node_ && "Invalid user" ); |
67 | assert(site->getResNo() == getResNo() && "Invalid list of uses" ); |
68 | |
69 | if (U.getUser() == skipReplacement) { |
70 | continue; |
71 | } |
72 | |
73 | // Log the change of node input(operand). |
74 | if (Function *F = getNode()->getParent()) { |
75 | F->getLogContext()->logNodeInputChange(*(U.getUser()), *this, v); |
76 | } |
77 | // Constant or Placeholder has no associated Function, we need to log the |
78 | // input changes inside its user's Function. |
79 | else if (getNode()->getKind() == Kinded::Kind::ConstantKind || |
80 | getNode()->getKind() == Kinded::Kind::PlaceholderKind) { |
81 | userF->getLogContext()->logNodeInputChange(*(U.getUser()), *this, v); |
82 | } |
83 | |
84 | site->setOperand(v.getNode(), v.getResNo()); |
85 | } |
86 | } |
87 | |
88 | unsigned NodeValue::getNumUsers() const { |
89 | auto range = getUsers(); |
90 | return std::distance(range.begin(), range.end()); |
91 | } |
92 | |
93 | llvm::iterator_range<NodeValueIterator> NodeValue::getUsers() { |
94 | auto &unfilteredUsers = getNode()->getUsers(); |
95 | return llvm::make_range(NodeValueIterator(*this, unfilteredUsers.begin()), |
96 | NodeValueIterator(*this, unfilteredUsers.end())); |
97 | } |
98 | |
99 | llvm::iterator_range<NodeValueConstIterator> NodeValue::getUsers() const { |
100 | const auto &unfilteredUsers = getNode()->getUsers(); |
101 | return llvm::make_range( |
102 | NodeValueConstIterator(*this, unfilteredUsers.begin()), |
103 | NodeValueConstIterator(*this, unfilteredUsers.end())); |
104 | } |
105 | |
106 | TypeRef NodeValue::getType() const { return node_->getType(resNo_); } |
107 | void NodeValue::setType(TypeRef ty) { node_->setType(resNo_, ty); } |
108 | void NodeValue::setTypeUnsafe(TypeRef ty) { node_->setTypeUnsafe(resNo_, ty); } |
109 | |
110 | ElemKind NodeValue::getElementType() const { |
111 | return getType()->getElementType(); |
112 | } |
113 | |
114 | float NodeValue::getScale() const { return getType()->getScale(); } |
115 | |
116 | int32_t NodeValue::getOffset() const { return getType()->getOffset(); } |
117 | |
118 | llvm::ArrayRef<dim_t> NodeValue::dims() const { return getType()->dims(); } |
119 | |
120 | std::string |
121 | NodeValue::generateNodeOutputName(bool stripResNoFor0thInput) const { |
122 | return generateNodeOutputName(node_->getName().str(), resNo_, |
123 | stripResNoFor0thInput); |
124 | } |
125 | |
126 | NodeHandle::NodeHandle(Node *parent, Node *N) : NodeValue(N), parent_(parent) { |
127 | setOperand(N, 0); |
128 | } |
129 | |
130 | NodeHandle::NodeHandle(Node *parent, Node *N, unsigned resNo) |
131 | : NodeValue(N, resNo), parent_(parent) { |
132 | setOperand(N, resNo); |
133 | } |
134 | |
135 | void NodeHandle::setOperand(Node *v, unsigned resNo) { |
136 | if (node_ == v && resNo == resNo_) { |
137 | return; |
138 | } |
139 | |
140 | if (node_) { |
141 | node_->removeUse(NodeUse(this)); |
142 | node_ = nullptr; |
143 | resNo_ = 0; |
144 | } |
145 | |
146 | if (v) { |
147 | node_ = v; |
148 | resNo_ = resNo; |
149 | v->addUse(NodeUse(this)); |
150 | } |
151 | } |
152 | |
153 | void NodeUse::setOperand(NodeHandle &other) { |
154 | if (other && site_->getNode()) { |
155 | assert(site_->getType() == other.getType() && |
156 | "Setting operand to a node with a different type" ); |
157 | } |
158 | site_->setOperand(other.getNode(), other.getResNo()); |
159 | } |
160 | |