1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Partitioner/PartitionerOptimizer.h"
18#include "glow/Partitioner/PartitionerUtils.h"
19#include <unordered_set>
20
21namespace glow {
22using llvm::isa;
23
24// Sorted the std::pair<DAGNode *, uint64_t> based on the second from min to
25// max.
26bool sortMinMemory(const std::pair<Function *, uint64_t> &a,
27 const std::pair<Function *, uint64_t> &b) {
28 return a.second < b.second;
29}
30
31void optimizeCommunicationCost(NodeToFunctionMap &partitions,
32 FunctionToNodesMap &nodesSet, Module *mod,
33 uint64_t availableMemory) {
34 // Move/Exchange nodes between any two connected partitions, until no gain is
35 // get.
36 // Step1 Move: Assume Partition1 -> Partition2, try to move nodes from
37 // Partition2 to Partition1 if those nodes only use the nodes in
38 // Partition1(recursively) and the move won't make Partition1's memory exceeds
39 // the memory constraint, and the communication cost is minimized.
40 bool gain = true;
41 while (gain) {
42 // gain is initialized as false, it will be set to be true if there is at
43 // least one node can be moved from one set to another set.
44 gain = false;
45 for (FunctionToNodesMap::iterator it = nodesSet.begin();
46 it != nodesSet.end(); ++it) {
47 NodesSet &curSet = (*it).second;
48 std::vector<Node *> outUsers = getOutUsersWithOnePredecessor(curSet);
49 if (outUsers.empty()) {
50 continue;
51 }
52 Function *cur = (*it).first;
53 GraphMemInfo curCost = partitions.getGraphMemInfo(cur);
54 auto contextCount = curCost.contextCount;
55 // Check if a node can be moved to current node set (i.e curSet).
56 for (int i = 0, e = outUsers.size(); i < e; i++) {
57 // Get the new cost if outUsers[i] is added.
58 GraphMemInfo newCurCost =
59 updateGraphMemInfoByAddingNode(curSet, curCost, outUsers[i]);
60
61 // Rule 1: this move won't break memory constraint.
62 if (newCurCost.getTotalMemSize() > availableMemory) {
63 continue;
64 }
65 // Rule 2: this move won't cause constant duplication.
66 bool cont = false;
67 for (int j = 0, e1 = outUsers[i]->getNumInputs(); j < e1; j++) {
68 auto in = outUsers[i]->getNthInput(j);
69 if (isa<Storage>(in.getNode()) && !in.hasOneUse()) {
70 cont = true;
71 break;
72 }
73 }
74 if (cont) {
75 continue;
76 }
77 // Rule 3: this move won't increase communication cost. Even if this
78 // move won't change communication cost, according to rule 1 and rule 2,
79 // the memory consumption of the partition where this node (i.e
80 // outUsers[i]) belongs can be reduced. Therefore, it may trigger later
81 // node movement or partitionsCombine.
82 Function *suc = partitions[outUsers[i]];
83 uint64_t outMem = getOutMemPerNode(nodesSet[suc], outUsers[i]);
84 if (newCurCost.outMemSize - outMem <= curCost.outMemSize) {
85 // Move this node to current node set.
86 curSet.insert(outUsers[i]);
87 nodesSet[suc].erase(outUsers[i]);
88 curCost = newCurCost;
89 // Update the partitions.
90 partitions.add(outUsers[i], cur);
91 partitions.setGraphMemInfo(cur, newCurCost);
92 if (nodesSet[suc].empty()) {
93 // It is possible that after moving a node from Partition2 to
94 // Partition1, Partition2 become empty. Remove the empty partition.
95 partitions.deletePartition(suc);
96 nodesSet.erase(suc);
97 mod->eraseFunction(suc);
98 } else {
99 GraphMemInfo newCost = getGraphMemInfo(nodesSet[suc], contextCount);
100 partitions.setGraphMemInfo(suc, newCost);
101 }
102 gain = true;
103 }
104 }
105 }
106 }
107}
108
109void partitionsCombine(NodeToFunctionMap &partitions,
110 FunctionToNodesMap &nodesSet, Module *mod,
111 uint64_t availableMemory) {
112
113 size_t origPartitions = 0;
114
115 // Do the combination until the size of partitions is stable.
116 while (partitions.getPartitions().size() != origPartitions) {
117 origPartitions = partitions.getPartitions().size();
118 // Rule 1:
119 for (FunctionToNodesMap::iterator it = nodesSet.begin();
120 it != nodesSet.end(); ++it) {
121 std::vector<Node *> outUsers = getOutUsers((*it).second);
122 if (outUsers.empty()) {
123 continue;
124 }
125
126 bool flag = true;
127 for (int i = 1, e = outUsers.size(); i < e; i++) {
128 if (partitions[outUsers[i]] != partitions[outUsers[i - 1]]) {
129 flag = false;
130 break;
131 }
132 }
133 if (flag) {
134 // This partition only has one successor.
135 Function *cur = (*it).first;
136 Function *suc = partitions[outUsers[0]];
137 NodesSet tmp = (nodesSet.find(suc))->second;
138 GraphMemInfo cost1 = partitions.getGraphMemInfo(cur);
139 GraphMemInfo cost2 = partitions.getGraphMemInfo(suc);
140 if (cost1.getTotalMemSize() + cost2.getTotalMemSize() -
141 cost1.outMemSize <
142 availableMemory) {
143 // We can combine the two partitions to fit one device.
144 for (NodesSet::iterator it2 = tmp.begin(); it2 != tmp.end(); ++it2) {
145 partitions.add(*it2, cur);
146 }
147 GraphMemInfo newCost;
148 newCost.constMemSize = cost1.constMemSize + cost2.constMemSize;
149 newCost.inMemSize =
150 cost1.inMemSize + cost2.inMemSize - cost1.outMemSize;
151 newCost.outMemSize = cost2.outMemSize;
152 partitions.setGraphMemInfo((*it).first, newCost);
153 (*it).second.insert(tmp.begin(), tmp.end());
154 partitions.deletePartition(suc);
155 nodesSet.erase(suc);
156 mod->eraseFunction(suc);
157 }
158 }
159 }
160 }
161}
162
163DeviceIDTy
164assignLogicalDeviceID(NodeToFunctionMap &mapping,
165 const std::map<std::string, BackendInfo> &backendMap) {
166 DeviceIDTy logicalDeviceID = 0;
167
168 std::map<std::string, std::vector<Function *>> backendFuncMap;
169 for (auto &func : mapping.getPartitions()) {
170 // Traverse the partitions, and get list of partitions with each
171 // backendName.
172 auto backendName = mapping.getPartitionBackendName(func);
173 if (backendFuncMap.find(backendName) == backendFuncMap.end()) {
174 backendFuncMap.emplace(backendName, std::vector<Function *>{func});
175 } else {
176 backendFuncMap[backendName].push_back(func);
177 }
178 }
179
180 // For each type of the backend, assign the logicalDevice ID.
181 for (const auto &p : backendFuncMap) {
182 if (mapping.getPartitions().size() <= backendMap.at(p.first).num) {
183 // There is enough device with this backendName, no need to adjust the
184 // logical ID.
185 for (auto &func : p.second) {
186 mapping.appendLogicalDeviceID(func, logicalDeviceID++);
187 }
188 continue;
189 }
190 // Get the list of functions with current BackendName, and sort it based on
191 // used memory from min to max.
192 std::vector<std::pair<Function *, uint64_t>> nodeSize;
193 for (size_t i = 0, e = p.second.size(); i < e; i++) {
194 Function *function = p.second[i];
195 uint64_t totalMem = mapping.getGraphMemInfo(function).getTotalMemSize();
196 nodeSize.push_back(std::make_pair(p.second[i], totalMem));
197 }
198 std::sort(nodeSize.begin(), nodeSize.end(), sortMinMemory);
199
200 // Assume we have n devices(NOTE: here the n devices have the same available
201 // memory size, and the following algorithm can find the accurate result. If
202 // the memory size are differnt, this assignment issue will be a NP problem
203 // -- multiple knapsack problem, and the following algorithm becomes greedy
204 // and the result may not be optimal), and m partitions, where m > n. If
205 // these m partitions can be assigned to n devices, there must be 1 device
206 // have at least (m - 1)/n + 1 partitions(Pigeonhole principle). Based on
207 // this theory, the algorithm is: Given N devices, and M partitions:
208 // Step 1 : sort the partitions from min to max based on their memory usage.
209 std::sort(nodeSize.begin(), nodeSize.end(), sortMinMemory);
210 // Step 2 : let n = N, m = M.
211 size_t m = p.second.size();
212 size_t n = backendMap.at(p.first).num;
213 while (m > 0) {
214 // Step 3 : find the first k partitions whose total memory usage still
215 // under the memory limitation (k should be max).
216 uint64_t usedMem = 0;
217 size_t numOfPartitionsWithSameID = (m - 1) / n + 1;
218 size_t start = p.second.size() - m;
219 size_t i;
220 for (i = start; i < p.second.size(); i++) {
221 if (usedMem + nodeSize[i].second > backendMap.at(p.first).memSize) {
222 break;
223 }
224 usedMem += nodeSize[i].second;
225 }
226 // Step 4 : if k = start - i found in step 3 is smaller than (m - 1) / n +
227 // 1, this means we can't find a proper assignment to fit the number of
228 // devices. Assign each partition with a unique logicalDevice ID and
229 // return.
230 if (i - start < numOfPartitionsWithSameID) {
231 // Can't find a proper assignment. Assign each partition a unique
232 // logicalDevice ID and return;
233 logicalDeviceID = 0;
234 for (auto &func : mapping.getPartitions()) {
235 mapping.appendLogicalDeviceID(func, logicalDeviceID++);
236 }
237 return logicalDeviceID;
238 }
239
240 // Step 5 : Assign these partitions which are assigned to one device with
241 // the same logical ID.
242 for (size_t j = start; j < i; j++) {
243 mapping.appendLogicalDeviceID(nodeSize[j].first, logicalDeviceID);
244 }
245 logicalDeviceID++;
246 // Step 6 : Update the left number of devices and partitions.
247 n--;
248 m = m - (i - start);
249 }
250 }
251 return logicalDeviceID;
252}
253} // namespace glow
254