1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_PARTITIONER_PARTITIONUTILS_H
17#define GLOW_PARTITIONER_PARTITIONUTILS_H
18
19#include "glow/Graph/Graph.h"
20#include "glow/Partitioner/PartitionerTypes.h"
21#include "llvm/ADT/DenseMap.h"
22
23namespace glow {
24/// Visit nodes if Function \p F in BFS order and return the nodes by levels
25/// (the longest distance between one node and the root).
26BFSLevel getBFSLevel(Function *F);
27
28/// Given \p nodes, return a list of nodes who use any node in this set.
29std::vector<Node *> getOutUsers(const NodesSet &nodes);
30
31/// Given \p nodes, return a list of nodes who use only the nodes in this set or
32/// constant.
33std::vector<Node *> getOutUsersWithOnePredecessor(const NodesSet &nodes);
34
35/// \returns the memory usage of the output caused by \p node who has users not
36/// in the set \p nodes.
37uint64_t getOutMemPerNode(const NodesSet &nodes, const Node *node);
38
39/// Given a node, \returns the NodeSet of inputs of this node.
40NodesSet getInputs(const Node *node);
41
42/// Return the estimated op computation time based on \p backendInfo.
43float getNodeComputeTime(const Node *node, const BackendInfo &backendInfo);
44
45/// Given a node, \returns the memory usage of its inputs (i.e. Storage input).
46uint64_t getNodeMemUsage(const Node *node);
47
48/// Given nodes set \p currNodes and its memory usage info \p info, \returns the
49/// new memory usage if \p newNode is added into \p currNodes.
50GraphMemInfo updateGraphMemInfoByAddingNode(const NodesSet &currNodes,
51 const GraphMemInfo &info,
52 Node *newNode);
53
54/// Return the memory usage of a given nodes set with a given \p contextCount.
55GraphMemInfo getGraphMemInfo(const NodesSet &nodes, unsigned contextCount);
56
57/// Return the memory usage of \p func function.
58GraphMemInfo getFunctionMemory(Function *func);
59
60/// Parse a node name string (e.g. "Div,Add") \p names, \returns a set of
61/// NodeKinds corresponding to the names in the string.
62std::set<Kinded::Kind> generateNodeKindsSet(llvm::StringRef names);
63
64/// Log the info of current partition \p partitions.
65void logPartitionInfo(const NodeToFunctionMap &partitions);
66
67/// Print numBytesInTable, deviceID, cost and cost/numBytesInTable.
68/// Print item from [start to end), with start inclusively and end
69/// exclusively. If verbose_only is true, we use VLOG(1), otherwise we use
70/// LOG(INFO).
71void printSlsTableInfo(std::vector<SLSTableInfo>::iterator start,
72 std::vector<SLSTableInfo>::iterator end,
73 bool verbose_only = true);
74void printSlsTableInfo(std::vector<SLSTableInfo> &slsTables,
75 bool verbose_only = true);
76
77/// Print deviceId, used_memory, free_memory, cost, node_size, cost/used_memory.
78/// Used memeory is calculated using \p nodesets and \p contextCount. If
79/// verbose_only is true, we use VLOG(1), otherwise we use LOG(INFO).
80void printSlsDeviceInfo(const std::vector<SLSDeviceInfo> &slsDevices,
81 const std::vector<NodesSet> &nodesets,
82 const unsigned contextCount, bool verbose_only);
83
84// Returns whether \p node is an SLS node
85bool isSLSNode(const Node *node);
86
87// Returns whether all inputs to \p node are of the kind \p kind
88bool checkNodeInputsAllKind(const Node *node, glow::Kinded::Kind kind);
89
90/// Loop through slsDevices, assign \p table to first available \p slsDevices
91/// that can fit \p table.
92/// \returns Error if we could not find one.
93Error assignSlsTableToFirstAvailableDevice(
94 SLSTableInfo &table, std::vector<SLSDeviceInfo> &slsDevices,
95 std::vector<NodesSet> &nodesets,
96 std::vector<std::unordered_set<NodeValue>> &frontierValues,
97 const unsigned contextCount);
98
99/// Assign \p slsTables to \p slsDevices by:
100/// 1. Sort \p slsTables by size decreasing.
101/// 2. Split \p slsTables into two parts: large tables, and small tables where
102/// large tables have numBytesInTable >
103/// glow::runtime::flags::BigTableThresholdBytes.
104/// 3. For large tables, we sort tables by size, and then for each table we
105/// assign it to the device with lowest size.
106/// 4. For small tables, we sort tables by cost, and then for each table we
107/// assign it to the device with lowest cost.
108///
109/// \returns Error if we could not find a feasible partitioning plan to fit all
110/// slsTables into slsDevices.
111/// In case of error, all the inputs will be restored to original values.
112Error assignSlsTablesToDevices(
113 std::vector<SLSTableInfo> &slsTables,
114 std::vector<SLSDeviceInfo> &slsDevices,
115 std::vector<std::unordered_set<NodeValue>> &frontierValues,
116 const unsigned contextCount);
117
118/// Assign \p slsTables to \p slsDevices by:
119/// Sort \p slsTables by size, then for each sls table, assign to slsDevice with
120/// lowest cost.
121///
122/// \returns Error if we could not find a feasible allocation plan to fit all
123/// slsTables into slsDevices.
124/// In case of error, all the inputs will be restored to original values.
125Error assignSlsTablesToDevicesGreedy(
126 std::vector<SLSTableInfo> &slsTables,
127 std::vector<SLSDeviceInfo> &slsDevices,
128 std::vector<std::unordered_set<NodeValue>> &frontierValues,
129 const unsigned contextCount);
130} // namespace glow
131#endif // GLOW_PARTITIONER_PARTITIONUTILS_H
132