1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Partitioner/PartitionerBase.h" |
18 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
19 | #include "llvm/Support/FileSystem.h" |
20 | #include "llvm/Support/raw_ostream.h" |
21 | #include <fstream> |
22 | |
23 | using namespace glow; |
24 | using llvm::isa; |
25 | |
26 | /// Creates and \returns a new DAGNode from \p F given \p mapping. |
27 | static std::unique_ptr<DAGNode> |
28 | createDAGNodeFromFun(Function *F, NodeToFunctionMap &mapping) { |
29 | std::unique_ptr<DAGNode> DN = glow::make_unique<DAGNode>(); |
30 | DN->name = F->getName().str(); |
31 | DN->logicalDevices = mapping.getLogicalDeviceIDList(F); |
32 | DN->backendName = mapping.getPartitionBackendName(F); |
33 | DN->size = mapping.getGraphMemInfo(F).getTotalMemSize(); |
34 | DN->backendHints = mapping.getBackendHints(F); |
35 | DN->backendSpecificOpts = mapping.getBackendSpecificOpts(F); |
36 | DN->replicationCount = mapping.getReplicationCount(F); |
37 | return DN; |
38 | } |
39 | |
40 | // Current only partition the representative function. |
41 | DAGListTy PartitionerBase::doPartitioning( |
42 | llvm::StringRef funcName, std::vector<Function *> funcs, Module *module, |
43 | NodeToFunctionMap &mapping, bool saveDAG, BackendSpecificNodeInfo &nodeInfo, |
44 | bool skipCloning) { |
45 | DAGListTy partitions; |
46 | // Add a dummy node to make sure that a DAG has a single entrance. |
47 | DAGNodePtr DAGRoot = glow::make_unique<DAGNode>(); |
48 | DAGNodePtrVec nodes; |
49 | DAGRoot->logicalDevices = {0}; |
50 | DAGRoot->name = funcName.str(); |
51 | DAGRoot->module = module; |
52 | DAGNode *root = DAGRoot.get(); |
53 | |
54 | llvm::DenseMap<Node *, Node *> currToNew; |
55 | |
56 | if (!skipCloning) { |
57 | // Clone nodes into target partition. Update nodeInfo as necessary. |
58 | for (size_t i = 0, e = funcs.size(); i < e; i++) { |
59 | for (auto &N : funcs[i]->getNodes()) { |
60 | auto *clone = N.clone(); |
61 | currToNew[&N] = clone; |
62 | mapping[&N]->addNode(clone); |
63 | |
64 | // If needed, update NodeInfo to point old Node's info to clone. |
65 | auto itF = nodeInfo.find(funcs[i]); |
66 | if (itF == nodeInfo.end()) { |
67 | continue; |
68 | } |
69 | auto &currNodeInfo = itF->second; |
70 | auto itN = currNodeInfo.find(&N); |
71 | if (itN != currNodeInfo.end()) { |
72 | currNodeInfo[clone] = std::move(itN->second); |
73 | // Erase old NodeInfo; current Nodes will be eliminated later when |
74 | // input funcs will be erased. |
75 | currNodeInfo.erase(itN); |
76 | } |
77 | } |
78 | } |
79 | } |
80 | |
81 | // For any dependency that crosses a partition, add a placeholder and save |
82 | // node. Record the dependence in the function graph. |
83 | std::unordered_map<NodeValue, Placeholder *> placeholders; |
84 | llvm::DenseMap<Function *, DAGNode *> funcDAG; |
85 | for (auto *subF : mapping.getPartitions()) { |
86 | if (funcDAG.find(subF) == funcDAG.end()) { |
87 | std::unique_ptr<DAGNode> subDAG = createDAGNodeFromFun(subF, mapping); |
88 | funcDAG[subF] = subDAG.get(); |
89 | nodes.push_back(std::move(subDAG)); |
90 | } |
91 | |
92 | // Link subF to its parents. |
93 | std::set<Function *> parents; |
94 | for (auto &N : subF->getNodes()) { |
95 | for (int inp = 0, e = N.getNumInputs(); inp < e; inp++) { |
96 | auto input = N.getNthInput(inp); |
97 | // No need to check Constant since it won't be the result of another |
98 | // function. |
99 | if (isa<Constant>(input.getNode())) { |
100 | continue; |
101 | } |
102 | |
103 | Function *inputF = nullptr; |
104 | // It is possible that one input is the output of anther function. |
105 | if (Placeholder *ph = llvm::dyn_cast<Placeholder>(input.getNode())) { |
106 | for (auto &user : ph->getUsers()) { |
107 | if (auto *save = llvm::dyn_cast<SaveNode>(user.getUser())) { |
108 | placeholders[input] = save->getPlaceholder(); |
109 | inputF = mapping[user.getUser()]; |
110 | break; |
111 | } |
112 | } |
113 | if (!inputF) { |
114 | continue; |
115 | } |
116 | } |
117 | |
118 | if (!inputF) { |
119 | inputF = mapping[input.getNode()]; |
120 | } |
121 | if (subF == inputF) { |
122 | continue; |
123 | } |
124 | // Check if a DAGNode for subF's parent is created or not. If not, |
125 | // create one. |
126 | if (funcDAG.find(inputF) == funcDAG.end()) { |
127 | std::unique_ptr<DAGNode> subDAG = |
128 | createDAGNodeFromFun(inputF, mapping); |
129 | funcDAG[inputF] = subDAG.get(); |
130 | nodes.push_back(std::move(subDAG)); |
131 | } |
132 | |
133 | // subF is a child of inputF, inputF is a parent of subF. |
134 | if (parents.find(inputF) == parents.end()) { |
135 | funcDAG[inputF]->children.push_back(funcDAG[subF]); |
136 | funcDAG[subF]->parents.push_back(funcDAG[inputF]); |
137 | parents.insert(inputF); |
138 | } |
139 | // If we've already created a placeholder for this dependence, use it. |
140 | auto it = placeholders.find(input); |
141 | if (it != placeholders.end()) { |
142 | N.setNthInput(inp, it->second); |
143 | continue; |
144 | } |
145 | |
146 | // Create a new placeholder to represent this dependence. |
147 | auto *save = inputF->createSave("tmp" , input); |
148 | auto *tmp = save->getPlaceholder(); |
149 | placeholders[input] = tmp; |
150 | N.setNthInput(inp, tmp); |
151 | } |
152 | } |
153 | } |
154 | |
155 | if (saveDAG) { |
156 | DAG dag; |
157 | dag.root = std::move(DAGRoot); |
158 | dag.nodes = std::move(nodes); |
159 | partitions.push_back(std::move(dag)); |
160 | } |
161 | |
162 | if (!skipCloning) { |
163 | // Update links between nodes in the cloned functions. Add placeholders (and |
164 | // save nodes) where a link crosses a partition boundary. |
165 | for (auto *subF : mapping.getPartitions()) { |
166 | for (auto &N : subF->getNodes()) { |
167 | for (int inp = 0, e = N.getNumInputs(); inp < e; inp++) { |
168 | auto input = N.getNthInput(inp); |
169 | if (isa<Storage>(input.getNode())) { |
170 | continue; |
171 | } |
172 | // Link this node to the clone of its input. |
173 | auto *clone = currToNew[input.getNode()]; |
174 | N.setNthInput(inp, NodeValue(clone, input.getResNo())); |
175 | } |
176 | } |
177 | } |
178 | } |
179 | |
180 | // For all DAGNode without parents, link them to the root DAG. |
181 | for (auto *subF : mapping.getPartitions()) { |
182 | if (funcDAG[subF]->parents.size() == 0) { |
183 | funcDAG[subF]->parents.push_back(root); |
184 | root->children.push_back(funcDAG[subF]); |
185 | } |
186 | } |
187 | return partitions; |
188 | } |
189 | |
190 | void PartitionerBase::dumpDAG(llvm::StringRef dotFilename, |
191 | const DAGListTy &partitions) const { |
192 | if (partitions.size() == 0) { |
193 | return; |
194 | } |
195 | auto *root = partitions[0].root.get(); |
196 | LOG(INFO) << "Writing dotty graph for DAG after graph partitioning: " |
197 | << dotFilename.str(); |
198 | std::ofstream myfile; |
199 | myfile.open(dotFilename.str()); |
200 | myfile << "digraph DAG {\n\trankdir=TB;\n" ; |
201 | // Dump DAGNodes |
202 | std::vector<DAGNode *> nodes; |
203 | llvm::SmallSet<DAGNode *, 10> used; |
204 | nodes.push_back(root); |
205 | int cur = 0; |
206 | int num = 1; |
207 | while (cur < num) { |
208 | auto *node = nodes[cur]; |
209 | for (size_t i = 0; i < node->children.size(); i++) { |
210 | auto child = node->children[i]; |
211 | DescriptionBuilder db(child->name.c_str()); |
212 | const std::string &backendName = child->backendName; |
213 | db.addParam("BackendName" , backendName); |
214 | myfile << "\"" << escapeDottyString(child->name) << "\"" |
215 | << " [ label = \"" << escapeDottyString(db) << "\"" ; |
216 | myfile << "\tshape = \"record\"\n" ; |
217 | myfile << "\tstyle=\"filled,rounded\"\n" ; |
218 | auto colorIdx = llvm::hash_value(backendName); |
219 | myfile << "\tfillcolor=" << getDotFileNodeColor(colorIdx) << "\n" ; |
220 | myfile << "penwidth = 2];\n" ; |
221 | if (used.count(child) == 0) { |
222 | nodes.push_back(child); |
223 | used.insert(child); |
224 | num++; |
225 | } |
226 | } |
227 | cur++; |
228 | } |
229 | |
230 | // Dump edges. |
231 | for (size_t i = 0; i < nodes.size(); i++) { |
232 | auto *node = nodes[i]; |
233 | for (size_t j = 0; j < node->children.size(); j++) { |
234 | auto child = node->children[j]; |
235 | if (node->name.compare(child->name) == 0) { |
236 | // If a network is too small to be partitioned, the dummy node's name |
237 | // and its child (i.e. the original network) share the same name. The |
238 | // edge will create loop. So in this case, this edge just be ignored. |
239 | continue; |
240 | } |
241 | myfile << "\"" << escapeDottyString(node->name) << "\"" |
242 | << " -> " |
243 | << "\"" << escapeDottyString(child->name) << "\"" |
244 | << ";" ; |
245 | } |
246 | } |
247 | myfile << "}" ; |
248 | |
249 | myfile.close(); |
250 | return; |
251 | } |
252 | |