1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Runtime/Provisioner/Provisioner.h"
17#include "../../lib/Backends/CPU/CPUDeviceManager.h"
18#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
19
20#include "gtest/gtest.h"
21
22using namespace glow;
23using namespace glow::runtime;
24
25class ProvisionerTest : public ::testing::Test {};
26
27std::unique_ptr<Module> setupModule(unsigned functionCount) {
28 auto mod = glow::make_unique<Module>();
29 for (unsigned int i = 0; i < functionCount; i++) {
30 auto *F = mod->createFunction("function" + std::to_string(i));
31 auto *X = mod->createPlaceholder(ElemKind::FloatTy, {16, 1024}, "X", false);
32 auto *W = mod->createConstant(ElemKind::FloatTy, {1024, 1024}, "W");
33 auto *B = mod->createConstant(ElemKind::FloatTy, {1024}, "B");
34 auto *FC = F->createFullyConnected("FC", X, W, B);
35 F->createSave("save", FC);
36 CompilationContext cctx;
37 lower(F, cctx);
38 }
39 return mod;
40}
41
42DAGListTy setupDAG(unsigned rootCount, unsigned childCount,
43 unsigned replicationCount = 1) {
44 DAGListTy partitions;
45 unsigned currentFunction = 0;
46 for (unsigned int root = 0; root < rootCount; root++) {
47 DAGNodePtrVec nodes;
48 auto rootNode = glow::make_unique<DAGNode>();
49 auto firstNode = glow::make_unique<DAGNode>();
50 rootNode->name = "root" + std::to_string(root);
51 rootNode->children.push_back(firstNode.get());
52 firstNode->name = "function" + std::to_string(currentFunction);
53 firstNode->logicalDevices = {0, 1};
54 firstNode->backendName = "CPU";
55 firstNode->replicationCount = replicationCount;
56 currentFunction++;
57 for (unsigned int child = 0; child < childCount; child++) {
58 auto newChild = glow::make_unique<DAGNode>();
59 newChild->name = "function" + std::to_string(currentFunction);
60 newChild->logicalDevices = {0};
61 newChild->backendName = "CPU";
62 newChild->replicationCount = replicationCount;
63 currentFunction++;
64 firstNode->children.push_back(newChild.get());
65 nodes.push_back(std::move(newChild));
66 }
67 nodes.push_back(std::move(firstNode));
68 partitions.push_back({std::move(rootNode), std::move(nodes)});
69 }
70 return partitions;
71}
72
73TEST_F(ProvisionerTest, provisionDag) {
74 auto mod = setupModule(6);
75 auto networks = setupDAG(2, 0);
76
77 DeviceManagerMapTy devices;
78 for (int i = 0; i < 6; i++) {
79 std::unique_ptr<DeviceManager> device(
80 new CPUDeviceManager(DeviceConfig("CPU")));
81 devices.emplace(i, std::move(device));
82 }
83
84 CompilationContext cctx;
85 Provisioner provisioner(devices);
86 auto err = provisioner.provision(networks, *mod.get(), cctx);
87 // Expect that there was no Error when provisioning
88 EXPECT_FALSE(ERR_TO_BOOL(std::move(err)));
89}
90
91TEST_F(ProvisionerTest, provisionDagFail) {
92 auto mod = setupModule(6);
93 auto networks = setupDAG(2, 0);
94
95 DeviceManagerMapTy devices;
96 for (int i = 0; i < 6; i++) {
97 auto config = DeviceConfig("CPU");
98 config.setDeviceMemory(1000);
99 std::unique_ptr<DeviceManager> device(new CPUDeviceManager(config));
100 devices.emplace(i, std::move(device));
101 }
102
103 CompilationContext cctx;
104 Provisioner provisioner(devices);
105 auto err = provisioner.provision(networks, *mod.get(), cctx);
106 // Expect that there was an Error when provisioning
107 EXPECT_TRUE(ERR_TO_BOOL(std::move(err)));
108}
109
110TEST_F(ProvisionerTest, provisionFailCleanup) {
111 // We want this provisioning to fail after adding the first partition
112 // successfully. This is to test that cleanup properly evicts networks.
113 DeviceConfig configBig("CPU");
114 DeviceConfig configSmall("CPU");
115 configSmall.setDeviceMemory(1);
116 std::unique_ptr<DeviceManager> deviceBig(new CPUDeviceManager(configBig));
117 std::unique_ptr<DeviceManager> deviceSmall(new CPUDeviceManager(configSmall));
118 DeviceManagerMapTy devices;
119 devices.emplace(0, std::move(deviceBig));
120 devices.emplace(1, std::move(deviceSmall));
121
122 auto mod = setupModule(2);
123 auto networks = setupDAG(2, 0);
124 CompilationContext cctx;
125 Provisioner provisioner(devices);
126 auto err = provisioner.provision(networks, *mod.get(), cctx);
127 // Expect that there was an Error when provisioning
128 EXPECT_TRUE(ERR_TO_BOOL(std::move(err)));
129}
130