1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "GraphScheduler.h"
18
19#include "glow/Graph/Utils.h"
20#include "glow/Support/Debug.h"
21
22#include "llvm/Support/Casting.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/raw_ostream.h"
25
26#define DEBUG_TYPE "graph-scheduler"
27
28using llvm::cast;
29using llvm::dyn_cast;
30using llvm::isa;
31
32namespace glow {
33/// \returns true if a node \p N is scheduled already.
34bool ChildMemSizeBasedScheduler::isScheduled(const Node *N) const {
35 return std::find(scheduled_.begin(), scheduled_.end(), N) != scheduled_.end();
36}
37
38/// Computes the amount of memory required to keep the result
39/// of each node.
40void ChildMemSizeBasedScheduler::computeNodeResultsMemorySize() {
41 for (auto &N : G_.getNodes()) {
42 int64_t resultSize = 0;
43 for (size_t idx = 0, e = N.getNumResults(); idx < e; ++idx) {
44 resultSize += N.getType(idx)->getSizeInBytes();
45 }
46 resultMemSize_[&N] = resultSize;
47 DEBUG_GLOW(llvm::dbgs()
48 << "ResultSize of " << N.getName() << ":" << resultSize << "\n");
49 }
50}
51
52/// Computes the max amount of memory required during the computation
53/// of children for each node.
54void ChildMemSizeBasedScheduler::computeNodeComputationMaxMemorySize() {
55 // Traverse nodes in such a way, that dependnecies are processed
56 // before the node using them.
57 GraphPostOrderVisitor visitor(G_);
58 for (auto *N : visitor.getPostOrder()) {
59 int64_t maxSize = (N->getNumInputs() > 0)
60 ? std::max(resultMemSize_[N->getNthInput(0)],
61 maxMemSize_[N->getNthInput(0)])
62 : 0;
63 for (size_t idx = 1, e = N->getNumInputs(); idx < e; ++idx) {
64 const auto &input = N->getNthInput(idx);
65 // Skip operands that do not require memory allocations for storing
66 // their results.
67 if (isa<Storage>(input))
68 continue;
69 assert(resultMemSize_.count(input) > 0);
70 assert(maxMemSize_.count(input) > 0);
71 maxSize += resultMemSize_[input];
72 if (maxSize < maxMemSize_[input])
73 maxSize = maxMemSize_[input];
74 }
75 maxMemSize_[N] = maxSize;
76 DEBUG_GLOW(llvm::dbgs()
77 << "MaxSize of " << N->getName() << ":" << maxSize << "\n");
78 }
79}
80
81/// Order children by (maxSize - resultSize). It gives more
82/// priority to the nodes that free more memory after
83/// their computation.
84void ChildMemSizeBasedScheduler::orderChildNodesAndSchedule(Node *node) {
85 // Use `worklist` as a stack-like container to hold nodes for processing. Each
86 // node is going to appear in the container at least twice:
87 // 1) To put its children (and sometimes parents) in the worklist
88 // 2) To schedule the node itself (i.e. append it to `scheduled_` list)
89 // At the end of stage #1, the node is added in `readyNodes` set. By checking
90 // this set the algorithm figures out if it should perform stage #1 or #2 with
91 // the node it pulls out of the worklist.
92 // Since the algorithm does some special handling of node's parents, a node
93 // can get into the worklist more than twice. To ensure that stage #2 happens
94 // only after all children of the node are scheduled, `readyNodes` keeps
95 // node's worklist index instead of just node's pointer.
96 std::vector<Node *> worklist;
97 std::unordered_set<size_t> readyNodes;
98 worklist.push_back(node);
99
100 while (worklist.size()) {
101 auto *N = worklist.back();
102 worklist.pop_back();
103 size_t idx = worklist.size();
104
105 // Do not explicitly schedule storage nodes.
106 if (isa<Storage>(N)) {
107 continue;
108 }
109 // Each child should be scheduled just once.
110 if (isScheduled(N)) {
111 readyNodes.erase(idx);
112 continue;
113 }
114
115 // If node is marked as ready, all its children have been scheduled, so it
116 // can be scheduled now too.
117 if (readyNodes.count(idx)) {
118 readyNodes.erase(idx);
119 scheduled_.push_back(N);
120 continue;
121 }
122
123 // If this node has a user which does not have any users and which does not
124 // require any additional memory, schedule it here, because we don't want to
125 // extend the lifetime of this value for no reason. We want to execute and
126 // get rid of this node as soon as possible to reduce the memory pressure.
127 for (NodeUse &use : N->getUsers()) {
128 Node *user = use.getUser();
129 // Users may be scattered across different functions.
130 // Only accounts for the ones in that function.
131 if (&G_ != user->getParent()) {
132 continue;
133 }
134 // Bail if a nodes has users, because nodes that have users can't be
135 // scheduled safely without violating dependencies.
136 if (user->getNumUsers()) {
137 continue;
138 }
139 // Schedule a node if it does not require any additional memory.
140 if (resultMemSize_[user] == 0) {
141 worklist.push_back(user);
142 }
143 }
144
145 // Push the node again in the stack. By the time it's going to pop up again,
146 // all children are going to be scheduled, so mark it as ready.
147 readyNodes.insert(worklist.size());
148 worklist.push_back(N);
149
150 // Take care about node's children.
151 size_t childrenStartIdx = worklist.size();
152 for (int i = 0, e = N->getNumInputs(); i < e; ++i) {
153 worklist.push_back(N->getNthInput(i));
154 }
155
156 if (N->hasPredicate()) {
157 worklist.push_back(N->getPredicate());
158 }
159
160 // We don't model memory dependencies, but we still need to honor them.
161 // Make sure the a node mutating any of its inputs happens after the last
162 // non-mutating use of the operand being mutated. Some examples of such
163 // nodes would be SaveNode and QuantizationProfileNode.
164 for (unsigned i = 0, e = N->getNumInputs(); i < e; ++i) {
165 // We don't care about inputs that are not mutated by the node.
166 if (!N->isOverwrittenNthInput(i)) {
167 continue;
168 }
169 auto mutatedInput = N->getNthInput(i);
170 auto *destination = mutatedInput.getNode();
171 for (NodeUse &use : destination->getUsers()) {
172 Node *user = use.getUser();
173 if (user == N) {
174 continue;
175 }
176 // Nodes may have users scattered across different functions.
177 // Only accounts for the ones in that function.
178 if (&G_ != user->getParent()) {
179 continue;
180 }
181 worklist.push_back(user);
182 }
183 }
184
185 // Order children by (maxSize - resultSize). It gives more
186 // priority to the nodes that free more memory after
187 // their computation.
188 for (size_t j = childrenStartIdx, e = worklist.size(); j < e; ++j) {
189 for (size_t i = j; i > childrenStartIdx; --i) {
190 auto &currentChild = worklist[i];
191 auto &prevChild = worklist[i - 1];
192 if (maxMemSize_[currentChild] - resultMemSize_[currentChild] <=
193 maxMemSize_[prevChild] - resultMemSize_[prevChild]) {
194 std::swap(currentChild, prevChild);
195 }
196 }
197 }
198
199 DEBUG_GLOW(llvm::dbgs() << "\nAbout to schedule children of "
200 << N->getName() << "\n";
201 llvm::dbgs() << "Children are:\n");
202 DEBUG_GLOW(
203 for (size_t i = childrenStartIdx, e = worklist.size(); i < e; ++i) {
204 auto &child = worklist[i];
205 llvm::dbgs() << "Child " << child->getName() << ": "
206 << maxMemSize_[child] - resultMemSize_[child] << "\n";
207 });
208 }
209}
210
211void ChildMemSizeBasedScheduler::scheduleNodes() {
212 /// Try to schedule all root nodes.
213 for (auto &N : G_.getNodes()) {
214 if (N.getNumUsers() == 0)
215 orderChildNodesAndSchedule(&N);
216 }
217}
218
219void ChildMemSizeBasedScheduler::schedule() {
220 computeNodeResultsMemorySize();
221 computeNodeComputationMaxMemorySize();
222 scheduleNodes();
223}
224} // namespace glow
225