1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Partitioner/PartitionerValidation.h"
17#include "glow/Partitioner/PartitionerUtils.h"
18
19#include "llvm/Support/FormatVariadic.h"
20
21namespace glow {
22Error logicalDevicesValidation(
23 const NodeToFunctionMap &partitions,
24 const std::map<std::string, BackendInfo> &backendMap) {
25 std::map<std::string, std::set<DeviceIDTy>> partitionsNum;
26 for (auto &func : partitions.getPartitions()) {
27 auto backendName = partitions.getPartitionBackendName(func);
28 if (partitionsNum.find(backendName) == partitionsNum.end()) {
29 partitionsNum.emplace(backendName, std::set<DeviceIDTy>{});
30 }
31 auto logicalIDList = partitions.getLogicalDeviceIDList(func);
32 for (size_t i = 0, e = logicalIDList.size(); i < e; i++) {
33 partitionsNum[backendName].insert(logicalIDList[i]);
34 }
35 auto backendNum = backendMap.at(backendName).num;
36 if (partitionsNum[backendName].size() > backendNum) {
37 logPartitionInfo(partitions);
38 return MAKE_ERR(
39 llvm::formatv(
40 "Partition failed: the number of given({0}) devices({1}) "
41 "is fewer than the required minimal partitions({2}).",
42 backendName, backendNum, partitionsNum[backendName].size())
43 .str());
44 }
45 }
46 return Error::success();
47}
48
49Error resourceCountValidation(
50 const NodeToFunctionMap &partitions,
51 const std::map<std::string, BackendInfo> &backendMap) {
52 VLOG(1) << "Entering resource count validation";
53 std::map<DeviceIDTy, uint64_t> logicalIDInputResources;
54 // We are assuming the same resource limit for all devices.
55 // Since we do not support P2P for heterogenous devices we do not run this
56 // check for heterogeneous partitioning.
57 uint64_t availableInputResources = 0;
58 for (auto &func : partitions.getPartitions()) {
59 auto logicalDeviceList = partitions.getLogicalDeviceIDList(func);
60 auto backendName = partitions.getPartitionBackendName(func);
61 auto graphMem = partitions.getGraphMemInfo(func);
62 if (!availableInputResources) {
63 availableInputResources = backendMap.at(backendName).inputCountMax;
64 }
65 // For each logicalDevice add the input resources used.
66 for (auto dev : logicalDeviceList) {
67 if (!logicalIDInputResources.count(dev)) {
68 logicalIDInputResources[dev] = 0;
69 }
70 // Scale the number of inputs from peers by the context count, as we make
71 // that many copies of input peer resources.
72 logicalIDInputResources[dev] +=
73 graphMem.inputFromPeerCount * graphMem.contextCount;
74 }
75 }
76
77 // If availableInputResources is not 0 check that we are below the limit.
78 if (availableInputResources) {
79 for (auto &resourceCount : logicalIDInputResources) {
80 VLOG(1) << "Checking logical device " << resourceCount.first
81 << " resource count: " << resourceCount.second
82 << "is less than: " << availableInputResources << "\n";
83 if (resourceCount.second > availableInputResources) {
84 return MAKE_ERR(
85 llvm::formatv(
86 "Partition failed: the resource count usage ({0}) of one "
87 "partition exceeds "
88 "the available resource count ({1}).",
89 resourceCount.second, availableInputResources)
90 .str());
91 }
92 }
93 }
94 return Error::success();
95}
96
97Error memoryUsageValidation(
98 const NodeToFunctionMap &partitions,
99 const std::map<std::string, BackendInfo> &backendMap) {
100 VLOG(1) << "Entering mem validation";
101 for (auto &func : partitions.getPartitions()) {
102 auto backendName = partitions.getPartitionBackendName(func);
103 auto usedMemSize = partitions.getGraphMemInfo(func).getTotalMemSize();
104 auto availableMemSize = backendMap.at(backendName).memSize;
105 VLOG(1) << "Comparing " << usedMemSize << " " << availableMemSize << " for "
106 << backendName;
107 if (usedMemSize > availableMemSize) {
108 logPartitionInfo(partitions);
109 return MAKE_ERR(
110 llvm::formatv("Partition failed: the memory usage({0}) of one "
111 "partition exceeds "
112 "the available memory({1}) of given devices({2}).",
113 usedMemSize, availableMemSize, backendName)
114 .str());
115 }
116 }
117 return Error::success();
118}
119
120/// \returns true if \p node contains no cycles. \p path contains the nodes in a
121/// path, and \p visited contains the nodes checked before.
122static bool isDAG(DAGNode *node, llvm::SmallSet<DAGNode *, 10> &path,
123 llvm::SmallSet<DAGNode *, 10> &visited) {
124 if (!visited.count(node)) {
125 path.insert(node);
126 visited.insert(node);
127 for (size_t i = 0; i < node->children.size(); i++) {
128 auto child = node->children[i];
129 if (path.count(child)) {
130 // Cycle found.
131 return false;
132 }
133 if (!isDAG(child, path, visited)) {
134 return false;
135 }
136 }
137 }
138 if (path.count(node)) {
139 path.erase(node);
140 }
141 return true;
142}
143
144Error dagValidation(const DAG &dag) {
145 auto *root = dag.root.get();
146 llvm::SmallSet<DAGNode *, 10> path;
147 llvm::SmallSet<DAGNode *, 10> visited;
148 // For the first condition: root->children.size() > 0 -- When a dag is
149 // created, its root is a dummy node and other DAGNode without parents will be
150 // linked to this root. Therefore, root without any child means that each of
151 // the rest of DAGNodes has at least one parent. That is, a cycle exists.
152
153 RETURN_ERR_IF_NOT((root->children.size() > 0 && isDAG(root, path, visited)),
154 "Invalid partition: a cycle is detected.");
155
156 // There should not be isolated nodes in partitions.
157 RETURN_ERR_IF_NOT((visited.size() == dag.nodes.size() + 1),
158 "Invalid partition: isolated node is detected.");
159 return Error::success();
160}
161} // namespace glow
162