1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Partitioner/PartitionerBase.h"
18#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
19#include "llvm/Support/FileSystem.h"
20#include "llvm/Support/raw_ostream.h"
21#include <fstream>
22
23using namespace glow;
24using llvm::isa;
25
26/// Creates and \returns a new DAGNode from \p F given \p mapping.
27static std::unique_ptr<DAGNode>
28createDAGNodeFromFun(Function *F, NodeToFunctionMap &mapping) {
29 std::unique_ptr<DAGNode> DN = glow::make_unique<DAGNode>();
30 DN->name = F->getName().str();
31 DN->logicalDevices = mapping.getLogicalDeviceIDList(F);
32 DN->backendName = mapping.getPartitionBackendName(F);
33 DN->size = mapping.getGraphMemInfo(F).getTotalMemSize();
34 DN->backendHints = mapping.getBackendHints(F);
35 DN->backendSpecificOpts = mapping.getBackendSpecificOpts(F);
36 DN->replicationCount = mapping.getReplicationCount(F);
37 return DN;
38}
39
40// Current only partition the representative function.
41DAGListTy PartitionerBase::doPartitioning(
42 llvm::StringRef funcName, std::vector<Function *> funcs, Module *module,
43 NodeToFunctionMap &mapping, bool saveDAG, BackendSpecificNodeInfo &nodeInfo,
44 bool skipCloning) {
45 DAGListTy partitions;
46 // Add a dummy node to make sure that a DAG has a single entrance.
47 DAGNodePtr DAGRoot = glow::make_unique<DAGNode>();
48 DAGNodePtrVec nodes;
49 DAGRoot->logicalDevices = {0};
50 DAGRoot->name = funcName.str();
51 DAGRoot->module = module;
52 DAGNode *root = DAGRoot.get();
53
54 llvm::DenseMap<Node *, Node *> currToNew;
55
56 if (!skipCloning) {
57 // Clone nodes into target partition. Update nodeInfo as necessary.
58 for (size_t i = 0, e = funcs.size(); i < e; i++) {
59 for (auto &N : funcs[i]->getNodes()) {
60 auto *clone = N.clone();
61 currToNew[&N] = clone;
62 mapping[&N]->addNode(clone);
63
64 // If needed, update NodeInfo to point old Node's info to clone.
65 auto itF = nodeInfo.find(funcs[i]);
66 if (itF == nodeInfo.end()) {
67 continue;
68 }
69 auto &currNodeInfo = itF->second;
70 auto itN = currNodeInfo.find(&N);
71 if (itN != currNodeInfo.end()) {
72 currNodeInfo[clone] = std::move(itN->second);
73 // Erase old NodeInfo; current Nodes will be eliminated later when
74 // input funcs will be erased.
75 currNodeInfo.erase(itN);
76 }
77 }
78 }
79 }
80
81 // For any dependency that crosses a partition, add a placeholder and save
82 // node. Record the dependence in the function graph.
83 std::unordered_map<NodeValue, Placeholder *> placeholders;
84 llvm::DenseMap<Function *, DAGNode *> funcDAG;
85 for (auto *subF : mapping.getPartitions()) {
86 if (funcDAG.find(subF) == funcDAG.end()) {
87 std::unique_ptr<DAGNode> subDAG = createDAGNodeFromFun(subF, mapping);
88 funcDAG[subF] = subDAG.get();
89 nodes.push_back(std::move(subDAG));
90 }
91
92 // Link subF to its parents.
93 std::set<Function *> parents;
94 for (auto &N : subF->getNodes()) {
95 for (int inp = 0, e = N.getNumInputs(); inp < e; inp++) {
96 auto input = N.getNthInput(inp);
97 // No need to check Constant since it won't be the result of another
98 // function.
99 if (isa<Constant>(input.getNode())) {
100 continue;
101 }
102
103 Function *inputF = nullptr;
104 // It is possible that one input is the output of anther function.
105 if (Placeholder *ph = llvm::dyn_cast<Placeholder>(input.getNode())) {
106 for (auto &user : ph->getUsers()) {
107 if (auto *save = llvm::dyn_cast<SaveNode>(user.getUser())) {
108 placeholders[input] = save->getPlaceholder();
109 inputF = mapping[user.getUser()];
110 break;
111 }
112 }
113 if (!inputF) {
114 continue;
115 }
116 }
117
118 if (!inputF) {
119 inputF = mapping[input.getNode()];
120 }
121 if (subF == inputF) {
122 continue;
123 }
124 // Check if a DAGNode for subF's parent is created or not. If not,
125 // create one.
126 if (funcDAG.find(inputF) == funcDAG.end()) {
127 std::unique_ptr<DAGNode> subDAG =
128 createDAGNodeFromFun(inputF, mapping);
129 funcDAG[inputF] = subDAG.get();
130 nodes.push_back(std::move(subDAG));
131 }
132
133 // subF is a child of inputF, inputF is a parent of subF.
134 if (parents.find(inputF) == parents.end()) {
135 funcDAG[inputF]->children.push_back(funcDAG[subF]);
136 funcDAG[subF]->parents.push_back(funcDAG[inputF]);
137 parents.insert(inputF);
138 }
139 // If we've already created a placeholder for this dependence, use it.
140 auto it = placeholders.find(input);
141 if (it != placeholders.end()) {
142 N.setNthInput(inp, it->second);
143 continue;
144 }
145
146 // Create a new placeholder to represent this dependence.
147 auto *save = inputF->createSave("tmp", input);
148 auto *tmp = save->getPlaceholder();
149 placeholders[input] = tmp;
150 N.setNthInput(inp, tmp);
151 }
152 }
153 }
154
155 if (saveDAG) {
156 DAG dag;
157 dag.root = std::move(DAGRoot);
158 dag.nodes = std::move(nodes);
159 partitions.push_back(std::move(dag));
160 }
161
162 if (!skipCloning) {
163 // Update links between nodes in the cloned functions. Add placeholders (and
164 // save nodes) where a link crosses a partition boundary.
165 for (auto *subF : mapping.getPartitions()) {
166 for (auto &N : subF->getNodes()) {
167 for (int inp = 0, e = N.getNumInputs(); inp < e; inp++) {
168 auto input = N.getNthInput(inp);
169 if (isa<Storage>(input.getNode())) {
170 continue;
171 }
172 // Link this node to the clone of its input.
173 auto *clone = currToNew[input.getNode()];
174 N.setNthInput(inp, NodeValue(clone, input.getResNo()));
175 }
176 }
177 }
178 }
179
180 // For all DAGNode without parents, link them to the root DAG.
181 for (auto *subF : mapping.getPartitions()) {
182 if (funcDAG[subF]->parents.size() == 0) {
183 funcDAG[subF]->parents.push_back(root);
184 root->children.push_back(funcDAG[subF]);
185 }
186 }
187 return partitions;
188}
189
190void PartitionerBase::dumpDAG(llvm::StringRef dotFilename,
191 const DAGListTy &partitions) const {
192 if (partitions.size() == 0) {
193 return;
194 }
195 auto *root = partitions[0].root.get();
196 LOG(INFO) << "Writing dotty graph for DAG after graph partitioning: "
197 << dotFilename.str();
198 std::ofstream myfile;
199 myfile.open(dotFilename.str());
200 myfile << "digraph DAG {\n\trankdir=TB;\n";
201 // Dump DAGNodes
202 std::vector<DAGNode *> nodes;
203 llvm::SmallSet<DAGNode *, 10> used;
204 nodes.push_back(root);
205 int cur = 0;
206 int num = 1;
207 while (cur < num) {
208 auto *node = nodes[cur];
209 for (size_t i = 0; i < node->children.size(); i++) {
210 auto child = node->children[i];
211 DescriptionBuilder db(child->name.c_str());
212 const std::string &backendName = child->backendName;
213 db.addParam("BackendName", backendName);
214 myfile << "\"" << escapeDottyString(child->name) << "\""
215 << " [ label = \"" << escapeDottyString(db) << "\"";
216 myfile << "\tshape = \"record\"\n";
217 myfile << "\tstyle=\"filled,rounded\"\n";
218 auto colorIdx = llvm::hash_value(backendName);
219 myfile << "\tfillcolor=" << getDotFileNodeColor(colorIdx) << "\n";
220 myfile << "penwidth = 2];\n";
221 if (used.count(child) == 0) {
222 nodes.push_back(child);
223 used.insert(child);
224 num++;
225 }
226 }
227 cur++;
228 }
229
230 // Dump edges.
231 for (size_t i = 0; i < nodes.size(); i++) {
232 auto *node = nodes[i];
233 for (size_t j = 0; j < node->children.size(); j++) {
234 auto child = node->children[j];
235 if (node->name.compare(child->name) == 0) {
236 // If a network is too small to be partitioned, the dummy node's name
237 // and its child (i.e. the original network) share the same name. The
238 // edge will create loop. So in this case, this edge just be ignored.
239 continue;
240 }
241 myfile << "\"" << escapeDottyString(node->name) << "\""
242 << " -> "
243 << "\"" << escapeDottyString(child->name) << "\""
244 << ";";
245 }
246 }
247 myfile << "}";
248
249 myfile.close();
250 return;
251}
252