1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Runtime/Provisioner/Provisioner.h" |
17 | #include "../../lib/Backends/CPU/CPUDeviceManager.h" |
18 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
19 | |
20 | #include "gtest/gtest.h" |
21 | |
22 | using namespace glow; |
23 | using namespace glow::runtime; |
24 | |
25 | class ProvisionerTest : public ::testing::Test {}; |
26 | |
27 | std::unique_ptr<Module> setupModule(unsigned functionCount) { |
28 | auto mod = glow::make_unique<Module>(); |
29 | for (unsigned int i = 0; i < functionCount; i++) { |
30 | auto *F = mod->createFunction("function" + std::to_string(i)); |
31 | auto *X = mod->createPlaceholder(ElemKind::FloatTy, {16, 1024}, "X" , false); |
32 | auto *W = mod->createConstant(ElemKind::FloatTy, {1024, 1024}, "W" ); |
33 | auto *B = mod->createConstant(ElemKind::FloatTy, {1024}, "B" ); |
34 | auto *FC = F->createFullyConnected("FC" , X, W, B); |
35 | F->createSave("save" , FC); |
36 | CompilationContext cctx; |
37 | lower(F, cctx); |
38 | } |
39 | return mod; |
40 | } |
41 | |
42 | DAGListTy setupDAG(unsigned rootCount, unsigned childCount, |
43 | unsigned replicationCount = 1) { |
44 | DAGListTy partitions; |
45 | unsigned currentFunction = 0; |
46 | for (unsigned int root = 0; root < rootCount; root++) { |
47 | DAGNodePtrVec nodes; |
48 | auto rootNode = glow::make_unique<DAGNode>(); |
49 | auto firstNode = glow::make_unique<DAGNode>(); |
50 | rootNode->name = "root" + std::to_string(root); |
51 | rootNode->children.push_back(firstNode.get()); |
52 | firstNode->name = "function" + std::to_string(currentFunction); |
53 | firstNode->logicalDevices = {0, 1}; |
54 | firstNode->backendName = "CPU" ; |
55 | firstNode->replicationCount = replicationCount; |
56 | currentFunction++; |
57 | for (unsigned int child = 0; child < childCount; child++) { |
58 | auto newChild = glow::make_unique<DAGNode>(); |
59 | newChild->name = "function" + std::to_string(currentFunction); |
60 | newChild->logicalDevices = {0}; |
61 | newChild->backendName = "CPU" ; |
62 | newChild->replicationCount = replicationCount; |
63 | currentFunction++; |
64 | firstNode->children.push_back(newChild.get()); |
65 | nodes.push_back(std::move(newChild)); |
66 | } |
67 | nodes.push_back(std::move(firstNode)); |
68 | partitions.push_back({std::move(rootNode), std::move(nodes)}); |
69 | } |
70 | return partitions; |
71 | } |
72 | |
73 | TEST_F(ProvisionerTest, provisionDag) { |
74 | auto mod = setupModule(6); |
75 | auto networks = setupDAG(2, 0); |
76 | |
77 | DeviceManagerMapTy devices; |
78 | for (int i = 0; i < 6; i++) { |
79 | std::unique_ptr<DeviceManager> device( |
80 | new CPUDeviceManager(DeviceConfig("CPU" ))); |
81 | devices.emplace(i, std::move(device)); |
82 | } |
83 | |
84 | CompilationContext cctx; |
85 | Provisioner provisioner(devices); |
86 | auto err = provisioner.provision(networks, *mod.get(), cctx); |
87 | // Expect that there was no Error when provisioning |
88 | EXPECT_FALSE(ERR_TO_BOOL(std::move(err))); |
89 | } |
90 | |
91 | TEST_F(ProvisionerTest, provisionDagFail) { |
92 | auto mod = setupModule(6); |
93 | auto networks = setupDAG(2, 0); |
94 | |
95 | DeviceManagerMapTy devices; |
96 | for (int i = 0; i < 6; i++) { |
97 | auto config = DeviceConfig("CPU" ); |
98 | config.setDeviceMemory(1000); |
99 | std::unique_ptr<DeviceManager> device(new CPUDeviceManager(config)); |
100 | devices.emplace(i, std::move(device)); |
101 | } |
102 | |
103 | CompilationContext cctx; |
104 | Provisioner provisioner(devices); |
105 | auto err = provisioner.provision(networks, *mod.get(), cctx); |
106 | // Expect that there was an Error when provisioning |
107 | EXPECT_TRUE(ERR_TO_BOOL(std::move(err))); |
108 | } |
109 | |
110 | TEST_F(ProvisionerTest, provisionFailCleanup) { |
111 | // We want this provisioning to fail after adding the first partition |
112 | // successfully. This is to test that cleanup properly evicts networks. |
113 | DeviceConfig configBig("CPU" ); |
114 | DeviceConfig configSmall("CPU" ); |
115 | configSmall.setDeviceMemory(1); |
116 | std::unique_ptr<DeviceManager> deviceBig(new CPUDeviceManager(configBig)); |
117 | std::unique_ptr<DeviceManager> deviceSmall(new CPUDeviceManager(configSmall)); |
118 | DeviceManagerMapTy devices; |
119 | devices.emplace(0, std::move(deviceBig)); |
120 | devices.emplace(1, std::move(deviceSmall)); |
121 | |
122 | auto mod = setupModule(2); |
123 | auto networks = setupDAG(2, 0); |
124 | CompilationContext cctx; |
125 | Provisioner provisioner(devices); |
126 | auto err = provisioner.provision(networks, *mod.get(), cctx); |
127 | // Expect that there was an Error when provisioning |
128 | EXPECT_TRUE(ERR_TO_BOOL(std::move(err))); |
129 | } |
130 | |