1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_PARTITIONER_PARTITIONERTYPES_H
17#define GLOW_PARTITIONER_PARTITIONERTYPES_H
18
19#include "glow/Graph/Graph.h"
20#include "glow/Runtime/RuntimeTypes.h"
21
22namespace glow {
23
24using namespace runtime;
25
26using NodesSet = std::set<Node *>;
27
28/// The memory usage of a subgraph (i.e. a list of nodes of a function).
29struct GraphMemInfo {
30 // The memory usage of all input nodes (whose predecessors are not included in
31 // this subgraph) of this subgraph.
32 uint64_t inMemSize;
33 // The memory usage of all output nodes (whose successors are not included in
34 // this subgraph) of this subgraph.
35 uint64_t outMemSize;
36 // The memory usage of all constants used in this subgraph.
37 uint64_t constMemSize;
38 // The number of contexts reserved on the device, this affecting input/out
39 // memory useage.
40 unsigned contextCount;
41 // Count of inputs to the graph.
42 unsigned inputCount{0};
43 // Count of inputs to the graph that are coming from a peer graph, i.e. are
44 // the output of another graph, and not inputs to the original input model.
45 unsigned inputFromPeerCount{0};
46 // The memory usage of only deferred constants used in this subgraph.
47 uint64_t deferredConstMemSize{0};
48
49 GraphMemInfo()
50 : inMemSize(0), outMemSize(0), constMemSize(0), contextCount(1){};
51 GraphMemInfo(uint64_t inMem, uint64_t outMem, uint64_t constMem,
52 unsigned count = 1)
53 : inMemSize(inMem), outMemSize(outMem), constMemSize(constMem),
54 contextCount(count){};
55
56 /// Get the total memory size of each partition.
57 uint64_t getTotalMemSize() const {
58 return ((inMemSize + outMemSize) * contextCount) + constMemSize;
59 }
60
61 bool equals(const GraphMemInfo &other) const {
62 return inMemSize == other.inMemSize && outMemSize == other.outMemSize &&
63 constMemSize == other.constMemSize;
64 }
65};
66
67inline bool operator==(const GraphMemInfo &LHS, const GraphMemInfo &RHS) {
68 return LHS.equals(RHS);
69}
70
71/// A list of <nodelist> with BFS order.
72using BFSLevel = std::vector<std::vector<Node *>>;
73
74/// Data structure that contains the info for each type of backend used for
75/// partitioning.
76struct BackendInfo {
77 /// Num of the devices which has the same type of backend.
78 size_t num = 0;
79 /// The memory constraints for this backend.
80 uint64_t memSize;
81 /// Maximum amount of input resources defaults to 0 if there is no limit.
82 uint64_t inputCountMax{0};
83 /// The following peakCompute, peakDramBw, peakSramBw, peakPCIeBw are from
84 /// DeviceInfo_. Available SRAM capacity in bytes.
85 uint64_t sramCapacity;
86 /// Peak compute on device in ops/second. Assumes all ops are in int8.
87 float peakCompute;
88 /// Peak memory bandwidth from DRAM on device in bytes/second.
89 float peakDramBw;
90 /// Peak memory bandwidth from SRAM on device in bytes/second.
91 float peakSramBw;
92 /// Peak ingress/egress PCI-E bandwidth from device in bytes/second.
93 float peakPCIeBw;
94 /// Backend pointer.
95 Backend *backend = nullptr;
96 /// The non-supported nodes kind.
97 std::set<Kinded::Kind> nonSupportedNodesKinds;
98 /// The supported nodes kind.
99 std::set<Kinded::Kind> supportedNodesKinds;
100};
101
102struct SLSTableInfo {
103 Node *node;
104 std::unordered_set<Node *> neighbors;
105 std::unordered_set<NodeValue> frontier;
106 uint64_t numBytesInTable;
107 unsigned int deviceId;
108 NodeValue slsResult;
109 uint64_t cost;
110};
111
112struct SLSDeviceInfo {
113 unsigned int deviceId;
114 uint64_t memAvailableInBytes;
115 size_t currentCost;
116};
117
118/// A mapping of newly-created functions along with a set of nodes sets. The
119/// overloaded compare function to make sure the map is sorted by the key's
120/// name(i.e. the function's name) which makes the optimization sequence is
121/// consistent for each run.
122struct FunctionNameComparator {
123 bool operator()(const Function *lhs, const Function *rhs) const {
124 return strcmp(lhs->getName().data(), rhs->getName().data()) < 0;
125 }
126};
127using FunctionToNodesMap =
128 std::map<Function *, NodesSet, FunctionNameComparator>;
129
130using FunctionToBackendNameMap =
131 std::map<Function *, std::string, FunctionNameComparator>;
132
133class NodeToFunctionMap {
134 /// Helper structure for building a partition. Records mapping of nodes in
135 /// the original function to destination partitions, along with a list of the
136 /// newly-created functions;
137 using Map = llvm::DenseMap<Node *, Function *>;
138
139 using PartitionCostMap = llvm::DenseMap<Function *, GraphMemInfo>;
140
141 using BackendHintsMap = llvm::DenseMap<Function *, BackendHints>;
142
143 using BackendSpecificOptsMap =
144 llvm::DenseMap<Function *, BackendSpecificOptions>;
145
146 /// Newly-created partitions.
147 FunctionList functions_;
148
149 /// Map of nodes in the original function to their target partition.
150 Map nodeToFunction_;
151
152 /// Map of the partitions to the backend which will be used for compiling
153 /// this partition.
154 FunctionToBackendNameMap functionToBackendName_;
155
156 /// Map of sub-functions to their memory consumption.
157 PartitionCostMap partitionCost_;
158
159 /// BackendHints for this sub-function
160 BackendHintsMap backendHints_;
161
162 /// BackendSpecificOpts for this sub-function
163 BackendSpecificOptsMap backendSpecificOpts_;
164
165 /// Map of partitions and the logicalDeviceID. The partitions with the same
166 /// logcialDeviceID will be assigned into the same physical device.
167 std::map<Function *, std::vector<DeviceIDTy>> logicalDeviceIDMap_;
168
169 /// Map of partitions and replication count.
170 std::map<Function *, unsigned> replicationCountMap_;
171
172public:
173 /// Create a new partition \p F, and map it with \p backendName.
174 void createPartition(Function *F, llvm::StringRef backendName) {
175 functions_.emplace_back(F);
176 functionToBackendName_[F] = backendName.str();
177 }
178
179 std::string getPartitionBackendName(Function *F) const {
180 DCHECK(functionToBackendName_.find(F) != functionToBackendName_.end())
181 << "Unknown partition in Function: " << F->getName().str();
182 return functionToBackendName_.find(F)->second;
183 }
184
185 /// Add a new Node->Function mapping.
186 void add(Node *N, Function *F) { nodeToFunction_[N] = F; }
187
188 /// Get list of functions contained in this map.
189 const FunctionList &getPartitions() const { return functions_; }
190
191 /// Get the list of logical device ID related to this function \p F.
192 const std::vector<DeviceIDTy> getLogicalDeviceIDList(Function *F) const {
193 if (logicalDeviceIDMap_.find(F) == logicalDeviceIDMap_.end()) {
194 return {};
195 }
196 return logicalDeviceIDMap_.at(F);
197 }
198
199 void clearLogicalDeviceID() { logicalDeviceIDMap_.clear(); }
200
201 void appendLogicalDeviceID(Function *F, DeviceIDTy id) {
202 if (logicalDeviceIDMap_.find(F) == logicalDeviceIDMap_.end()) {
203 logicalDeviceIDMap_.emplace(
204 std::make_pair(F, std::vector<DeviceIDTy>{id}));
205 } else {
206 logicalDeviceIDMap_[F].push_back(id);
207 }
208 }
209
210 void addReplicationCount(Function *F, unsigned count) {
211 replicationCountMap_[F] = count;
212 }
213
214 unsigned getReplicationCount(Function *F) {
215 auto it = replicationCountMap_.find(F);
216 if (it == replicationCountMap_.end()) {
217 return 1;
218 } else {
219 return it->second;
220 }
221 }
222
223 /// attach \p map to current mapping.
224 void insert(NodeToFunctionMap &map) {
225 FunctionList flist = map.getPartitions();
226 for (auto it = flist.begin(); it != flist.end(); ++it) {
227 Function *func = *it;
228 auto backendName = map.getPartitionBackendName(func);
229 createPartition(func, backendName);
230 GraphMemInfo cost = map.getGraphMemInfo(func);
231 setGraphMemInfo(func, cost);
232 }
233 for (auto it = map.begin(); it != map.end(); ++it) {
234 Node *n = it->first;
235 Function *f = it->second;
236 add(n, f);
237 }
238 }
239
240 /// Map API.
241 Map::iterator find(Node *N) { return nodeToFunction_.find(N); }
242 Map::iterator begin() { return nodeToFunction_.begin(); }
243 Map::iterator end() { return nodeToFunction_.end(); }
244 Function *operator[](Node *n) { return nodeToFunction_[n]; }
245
246 void deletePartition(Function *func) {
247 functions_.remove(func);
248 functionToBackendName_.erase(func);
249 partitionCost_.erase(func);
250 backendHints_.erase(func);
251 backendSpecificOpts_.erase(func);
252 }
253
254 /// Set the memory consumption \p cost for a partition \p func.
255 void setGraphMemInfo(Function *func, GraphMemInfo cost) {
256 partitionCost_[func] = cost;
257 }
258
259 /// Get the memory consumption for a partition \p func.
260 GraphMemInfo getGraphMemInfo(Function *func) const {
261 if (partitionCost_.find(func) == partitionCost_.end()) {
262 return GraphMemInfo{};
263 }
264 return partitionCost_.find(func)->second;
265 }
266
267 /// Set the backend hints for a partition \p func.
268 void setBackendHints(Function *func, BackendHints hints) {
269 backendHints_[func] = hints;
270 }
271
272 /// Get the backend hints for a partition \p func.
273 BackendHints getBackendHints(Function *func) const {
274 if (backendHints_.find(func) == backendHints_.end()) {
275 return BackendHints{};
276 }
277 return backendHints_.find(func)->second;
278 }
279
280 /// Set the backend specific opts \p opts for a partition \p func.
281 void setBackendSpecificOpts(Function *func,
282 const BackendSpecificOptions &opts) {
283 backendSpecificOpts_[func] = opts;
284 }
285
286 /// Get the backend hints for a partition \p func.
287 BackendSpecificOptions getBackendSpecificOpts(Function *func) const {
288 if (backendSpecificOpts_.find(func) == backendSpecificOpts_.end()) {
289 return BackendSpecificOptions{};
290 }
291 return backendSpecificOpts_.find(func)->second;
292 }
293};
294
295} // namespace glow
296#endif // GLOW_RUNTIME_PARTITIONERTYPES_H
297