1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Graph/NodeValue.h"
18#include "glow/Graph/Graph.h"
19#include "glow/Graph/Node.h"
20
21using namespace glow;
22
23NodeValue::NodeValue(Node *N) {
24 assert((!N || (N->getNumResults() == 1)) &&
25 "Constructing a value for a multi-res node");
26 node_ = N;
27 resNo_ = 0;
28}
29
30NodeValue::NodeValue(Node *N, unsigned resNo) {
31 assert(resNo < N->getNumResults() && "Invalid result number");
32 node_ = N;
33 resNo_ = resNo;
34}
35
36void NodeValue::replaceAllUsesOfWith(NodeValue v, const Function *F,
37 Node *skipReplacement) const {
38 if (v.getNode() && getType() != v.getType()) {
39 // Fresh replacement nodes without users are usually created by the graph
40 // optimizer, where the type of the replacement value should always match
41 // the type of the original node value. Check if the only difference is
42 // related to strides and adjust accordingly.
43 assert(v.getNumUsers() == 0 && "Cannot update type if there are users");
44 assert(getType()->isEqual(*v.getType(), /* allowDifferentShape */ false,
45 /* allowDifferentStrides */ true) &&
46 "Replacing value with the wrong type");
47 v.setType(getType());
48 }
49 typeUnsafeReplaceAllUsesOfWith(v, F, skipReplacement);
50}
51
52void NodeValue::typeUnsafeReplaceAllUsesOfWith(NodeValue v, const Function *F,
53 Node *skipReplacement) const {
54 // Copy the list of users in a temporary vector since that list (and the
55 // underlying iterators) are going to be invalidated by the next loop.
56 auto nodeValueUsers = getUsers();
57 llvm::SmallVector<NodeUse, 4> usersVec(nodeValueUsers.begin(),
58 nodeValueUsers.end());
59 for (auto &U : usersVec) {
60 NodeHandle *site = U.get();
61 auto *userF = U.getUser()->getParent();
62 // If the user is not in function F, don't touch it.
63 if (F && userF != F) {
64 continue;
65 }
66 assert(site->getNode() == node_ && "Invalid user");
67 assert(site->getResNo() == getResNo() && "Invalid list of uses");
68
69 if (U.getUser() == skipReplacement) {
70 continue;
71 }
72
73 // Log the change of node input(operand).
74 if (Function *F = getNode()->getParent()) {
75 F->getLogContext()->logNodeInputChange(*(U.getUser()), *this, v);
76 }
77 // Constant or Placeholder has no associated Function, we need to log the
78 // input changes inside its user's Function.
79 else if (getNode()->getKind() == Kinded::Kind::ConstantKind ||
80 getNode()->getKind() == Kinded::Kind::PlaceholderKind) {
81 userF->getLogContext()->logNodeInputChange(*(U.getUser()), *this, v);
82 }
83
84 site->setOperand(v.getNode(), v.getResNo());
85 }
86}
87
88unsigned NodeValue::getNumUsers() const {
89 auto range = getUsers();
90 return std::distance(range.begin(), range.end());
91}
92
93llvm::iterator_range<NodeValueIterator> NodeValue::getUsers() {
94 auto &unfilteredUsers = getNode()->getUsers();
95 return llvm::make_range(NodeValueIterator(*this, unfilteredUsers.begin()),
96 NodeValueIterator(*this, unfilteredUsers.end()));
97}
98
99llvm::iterator_range<NodeValueConstIterator> NodeValue::getUsers() const {
100 const auto &unfilteredUsers = getNode()->getUsers();
101 return llvm::make_range(
102 NodeValueConstIterator(*this, unfilteredUsers.begin()),
103 NodeValueConstIterator(*this, unfilteredUsers.end()));
104}
105
106TypeRef NodeValue::getType() const { return node_->getType(resNo_); }
107void NodeValue::setType(TypeRef ty) { node_->setType(resNo_, ty); }
108void NodeValue::setTypeUnsafe(TypeRef ty) { node_->setTypeUnsafe(resNo_, ty); }
109
110ElemKind NodeValue::getElementType() const {
111 return getType()->getElementType();
112}
113
114float NodeValue::getScale() const { return getType()->getScale(); }
115
116int32_t NodeValue::getOffset() const { return getType()->getOffset(); }
117
118llvm::ArrayRef<dim_t> NodeValue::dims() const { return getType()->dims(); }
119
120std::string
121NodeValue::generateNodeOutputName(bool stripResNoFor0thInput) const {
122 return generateNodeOutputName(node_->getName().str(), resNo_,
123 stripResNoFor0thInput);
124}
125
126NodeHandle::NodeHandle(Node *parent, Node *N) : NodeValue(N), parent_(parent) {
127 setOperand(N, 0);
128}
129
130NodeHandle::NodeHandle(Node *parent, Node *N, unsigned resNo)
131 : NodeValue(N, resNo), parent_(parent) {
132 setOperand(N, resNo);
133}
134
135void NodeHandle::setOperand(Node *v, unsigned resNo) {
136 if (node_ == v && resNo == resNo_) {
137 return;
138 }
139
140 if (node_) {
141 node_->removeUse(NodeUse(this));
142 node_ = nullptr;
143 resNo_ = 0;
144 }
145
146 if (v) {
147 node_ = v;
148 resNo_ = resNo;
149 v->addUse(NodeUse(this));
150 }
151}
152
153void NodeUse::setOperand(NodeHandle &other) {
154 if (other && site_->getNode()) {
155 assert(site_->getType() == other.getType() &&
156 "Setting operand to a node with a different type");
157 }
158 site_->setOperand(other.getNode(), other.getResNo());
159}
160