1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Runtime/DeferredWeightLoader.h" |
18 | #include "BackendTestUtils.h" |
19 | #include "glow/ExecutionContext/ExecutionContext.h" |
20 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
21 | #include "glow/Runtime/Provisioner/Provisioner.h" |
22 | |
23 | #include "gtest/gtest.h" |
24 | |
25 | using namespace glow; |
26 | using namespace glow::runtime; |
27 | |
28 | class TestDeferredWeightLoader : public DeferredWeightLoader { |
29 | public: |
30 | Error loadNextWeight() override { |
31 | position_++; |
32 | if (position_ < names_.size() && names_[position_] == "fail" ) { |
33 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR, |
34 | "Fail to load weight." ); |
35 | } |
36 | return Error::success(); |
37 | } |
38 | Error setSrc(void *loaderObject) override { return Error::success(); } |
39 | void addWeight(Tensor *weight) { weights_.push_back(weight); } |
40 | void addName(std::string name) { names_.push_back(name); } |
41 | void setTypeInfo(std::map<std::string, Type> info) override {} |
42 | |
43 | std::string getName() override { |
44 | for (auto na : names_) { |
45 | } |
46 | if (position_ >= int(names_.size())) { |
47 | return "" ; |
48 | } |
49 | return names_[position_]; |
50 | } |
51 | |
52 | Tensor *getTensor() override { |
53 | if (position_ >= int(weights_.size())) { |
54 | return nullptr; |
55 | } |
56 | return weights_[position_]; |
57 | } |
58 | |
59 | private: |
60 | std::vector<Tensor *> weights_{}; |
61 | std::vector<std::string> names_{}; |
62 | int position_{-1}; |
63 | }; |
64 | |
65 | class DeferredWeightLoaderTest : public ::testing::TestWithParam<std::string> { |
66 | }; |
67 | |
68 | std::unique_ptr<HostManager> |
69 | createHostManager(llvm::StringRef backendName, |
70 | HostConfig hostConfig = HostConfig()) { |
71 | std::vector<std::unique_ptr<DeviceConfig>> configs; |
72 | auto deviceConfig = glow::make_unique<DeviceConfig>(backendName); |
73 | configs.push_back(std::move(deviceConfig)); |
74 | std::unique_ptr<HostManager> hostManager = |
75 | glow::make_unique<HostManager>(std::move(configs), hostConfig); |
76 | return hostManager; |
77 | } |
78 | |
79 | TEST_P(DeferredWeightLoaderTest, cleanupFailedDeferred) { |
80 | // We want this provisioning to fail after loading a deferred weight, then |
81 | // verify that the network is cleaned up properly. |
82 | CHECK_IF_ENABLED(); |
83 | std::unique_ptr<Module> module = glow::make_unique<Module>(); |
84 | auto F = module->createFunction("main" ); |
85 | auto *X = module->createPlaceholder(ElemKind::FloatTy, {1}, "X" , false); |
86 | |
87 | auto *Y = module->createPlaceholder(ElemKind::FloatTy, {1}, "Y" , false); |
88 | auto *Z = module->createPlaceholder(ElemKind::FloatTy, {1}, "Z" , false); |
89 | auto *output = |
90 | module->createPlaceholder(ElemKind::FloatTy, {1}, "output" , false); |
91 | // Set X and Y as static. |
92 | X->setStatic(true); |
93 | Y->setStatic(true); |
94 | auto pow1 = F->createPow("pow" , X, Y); |
95 | auto pow2 = F->createPow("pow2" , Z, pow1); |
96 | F->createSave("save" , pow2, output); |
97 | std::vector<Tensor> staticInputs; |
98 | auto xTensor = Tensor(X->getType()); |
99 | auto yTensor = Tensor(Y->getType()); |
100 | auto zTensor = Tensor(Z->getType()); |
101 | xTensor.getHandle().clear(2.0); |
102 | yTensor.getHandle().clear(3.0); |
103 | zTensor.getHandle().clear(2.0); |
104 | |
105 | TestDeferredWeightLoader loader; |
106 | loader.addWeight(&xTensor); |
107 | loader.addWeight(&yTensor); |
108 | loader.addName("fail" ); |
109 | loader.addName("fail" ); |
110 | DeferredLoader()->registerLoader(&loader); |
111 | |
112 | CompilationContext cctx; |
113 | cctx.deferredWeightLoader = &loader; |
114 | cctx.optimizationOpts.foldStaticPlaceholderConversions = true; |
115 | |
116 | DeviceConfig config(GetParam()); |
117 | std::unique_ptr<DeviceManager> device( |
118 | DeviceManager::createDeviceManager(config)); |
119 | EXPECT_FALSE(ERR_TO_BOOL(device->init())); |
120 | |
121 | DeviceManagerMapTy devices; |
122 | devices.emplace(0, std::move(device)); |
123 | |
124 | DAGListTy partitions; |
125 | |
126 | DAGNodePtrVec nodes; |
127 | auto rootNode = glow::make_unique<DAGNode>(); |
128 | auto firstNode = glow::make_unique<DAGNode>(); |
129 | rootNode->name = "root" ; |
130 | rootNode->children.push_back(firstNode.get()); |
131 | firstNode->name = "main" ; |
132 | firstNode->logicalDevices = {0}; |
133 | firstNode->backendName = GetParam(); |
134 | nodes.push_back(std::move(firstNode)); |
135 | partitions.push_back({std::move(rootNode), std::move(nodes)}); |
136 | |
137 | Provisioner provisioner(devices); |
138 | auto err = provisioner.provision(partitions, *module.get(), cctx); |
139 | // Expect that there was an Error when provisioning |
140 | EXPECT_TRUE(ERR_TO_BOOL(std::move(err))); |
141 | |
142 | // Setup a new loader with correct info. |
143 | TestDeferredWeightLoader loaderNew; |
144 | loaderNew.addWeight(&xTensor); |
145 | loaderNew.addWeight(&yTensor); |
146 | loaderNew.addName("X" ); |
147 | loaderNew.addName("Y" ); |
148 | DeferredLoader()->registerLoader(&loaderNew); |
149 | cctx.deferredWeightLoader = &loaderNew; |
150 | auto err2 = provisioner.provision(partitions, *module.get(), cctx); |
151 | // Verify provisioning completes correctly. |
152 | EXPECT_FALSE(ERR_TO_BOOL(std::move(err2))); |
153 | } |
154 | |
155 | TEST_P(DeferredWeightLoaderTest, staticPlaceholderInference) { |
156 | CHECK_IF_ENABLED(); |
157 | auto hostmanager = createHostManager(GetParam()); |
158 | ExecutionEngine EE{GetParam()}; |
159 | auto &module = EE.getModule(); |
160 | auto F = module.createFunction("main" ); |
161 | auto *X = module.createPlaceholder(ElemKind::FloatTy, {1}, "X" , false); |
162 | |
163 | auto *Y = module.createPlaceholder(ElemKind::FloatTy, {1}, "Y" , false); |
164 | auto *Z = module.createPlaceholder(ElemKind::FloatTy, {1}, "Z" , false); |
165 | auto *output = |
166 | module.createPlaceholder(ElemKind::FloatTy, {1}, "output" , false); |
167 | // Set X and Y as static. |
168 | X->setStatic(true); |
169 | Y->setStatic(true); |
170 | auto pow1 = F->createPow("pow" , X, Y); |
171 | auto pow2 = F->createPow("pow2" , Z, pow1); |
172 | F->createSave("save" , pow2, output); |
173 | std::vector<Tensor> staticInputs; |
174 | auto xTensor = Tensor(X->getType()); |
175 | auto yTensor = Tensor(Y->getType()); |
176 | auto zTensor = Tensor(Z->getType()); |
177 | xTensor.getHandle().clear(2.0); |
178 | yTensor.getHandle().clear(3.0); |
179 | zTensor.getHandle().clear(2.0); |
180 | |
181 | TestDeferredWeightLoader loader; |
182 | loader.addWeight(&xTensor); |
183 | loader.addWeight(&yTensor); |
184 | loader.addName("X" ); |
185 | loader.addName("Y" ); |
186 | DeferredLoader()->registerLoader(&loader); |
187 | |
188 | CompilationContext cctx; |
189 | cctx.deferredWeightLoader = &loader; |
190 | cctx.optimizationOpts.foldStaticPlaceholderConversions = true; |
191 | EE.compile(cctx); |
192 | PlaceholderBindings pBindings; |
193 | pBindings.allocate(Z); |
194 | pBindings.allocate(output); |
195 | updateInputPlaceholders(pBindings, {Z}, {&zTensor}); |
196 | EE.run(pBindings); |
197 | auto resHandle = pBindings.get(output)->getHandle(); |
198 | EXPECT_NEAR(resHandle.at({0}), 256.0, 1E-5); |
199 | } |
200 | |
201 | TEST_P(DeferredWeightLoaderTest, FP16StaticPlaceholderInference) { |
202 | CHECK_IF_ENABLED(); |
203 | auto hostmanager = createHostManager(GetParam()); |
204 | ExecutionEngine EE{GetParam()}; |
205 | auto &module = EE.getModule(); |
206 | auto F = module.createFunction("main" ); |
207 | auto *X = module.createPlaceholder(ElemKind::FloatTy, {1}, "X" , false); |
208 | |
209 | auto *Y = module.createPlaceholder(ElemKind::FloatTy, {1}, "Y" , false); |
210 | auto *Z = module.createPlaceholder(ElemKind::FloatTy, {1}, "Z" , false); |
211 | auto *output = |
212 | module.createPlaceholder(ElemKind::FloatTy, {1}, "output" , false); |
213 | // Set X and Y as static. |
214 | X->setStatic(true); |
215 | Y->setStatic(true); |
216 | auto mul1 = F->createMul("mul" , X, Y); |
217 | auto mul2 = F->createMul("mul2" , Z, mul1); |
218 | F->createSave("save" , mul2, output); |
219 | std::vector<Tensor> staticInputs; |
220 | auto xTensor = Tensor(X->getType()); |
221 | auto yTensor = Tensor(Y->getType()); |
222 | auto zTensor = Tensor(Z->getType()); |
223 | xTensor.getHandle().clear(2.0); |
224 | yTensor.getHandle().clear(3.0); |
225 | zTensor.getHandle().clear(2.0); |
226 | |
227 | TestDeferredWeightLoader loader; |
228 | loader.addWeight(&xTensor); |
229 | loader.addWeight(&yTensor); |
230 | loader.addName("X" ); |
231 | loader.addName("Y" ); |
232 | DeferredLoader()->registerLoader(&loader); |
233 | |
234 | PlaceholderBindings pBindings; |
235 | |
236 | CompilationContext cctx; |
237 | cctx.deferredWeightLoader = &loader; |
238 | cctx.optimizationOpts.foldStaticPlaceholderConversions = true; |
239 | cctx.precisionConfig.convertToFP16 = true; |
240 | |
241 | EE.compile(cctx); |
242 | |
243 | pBindings.allocate(Z); |
244 | pBindings.allocate(output); |
245 | updateInputPlaceholders(pBindings, {Z}, {&zTensor}); |
246 | EE.run(pBindings); |
247 | auto resHandle = pBindings.get(output)->getHandle(); |
248 | EXPECT_NEAR(resHandle.at({0}), 12.0, 1E-5); |
249 | } |
250 | |
251 | INSTANTIATE_BACKEND_TEST(DeferredWeightLoaderTest); |
252 | |