1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Runtime/DeferredWeightLoader.h"
18#include "BackendTestUtils.h"
19#include "glow/ExecutionContext/ExecutionContext.h"
20#include "glow/ExecutionEngine/ExecutionEngine.h"
21#include "glow/Runtime/Provisioner/Provisioner.h"
22
23#include "gtest/gtest.h"
24
25using namespace glow;
26using namespace glow::runtime;
27
28class TestDeferredWeightLoader : public DeferredWeightLoader {
29public:
30 Error loadNextWeight() override {
31 position_++;
32 if (position_ < names_.size() && names_[position_] == "fail") {
33 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
34 "Fail to load weight.");
35 }
36 return Error::success();
37 }
38 Error setSrc(void *loaderObject) override { return Error::success(); }
39 void addWeight(Tensor *weight) { weights_.push_back(weight); }
40 void addName(std::string name) { names_.push_back(name); }
41 void setTypeInfo(std::map<std::string, Type> info) override {}
42
43 std::string getName() override {
44 for (auto na : names_) {
45 }
46 if (position_ >= int(names_.size())) {
47 return "";
48 }
49 return names_[position_];
50 }
51
52 Tensor *getTensor() override {
53 if (position_ >= int(weights_.size())) {
54 return nullptr;
55 }
56 return weights_[position_];
57 }
58
59private:
60 std::vector<Tensor *> weights_{};
61 std::vector<std::string> names_{};
62 int position_{-1};
63};
64
65class DeferredWeightLoaderTest : public ::testing::TestWithParam<std::string> {
66};
67
68std::unique_ptr<HostManager>
69createHostManager(llvm::StringRef backendName,
70 HostConfig hostConfig = HostConfig()) {
71 std::vector<std::unique_ptr<DeviceConfig>> configs;
72 auto deviceConfig = glow::make_unique<DeviceConfig>(backendName);
73 configs.push_back(std::move(deviceConfig));
74 std::unique_ptr<HostManager> hostManager =
75 glow::make_unique<HostManager>(std::move(configs), hostConfig);
76 return hostManager;
77}
78
79TEST_P(DeferredWeightLoaderTest, cleanupFailedDeferred) {
80 // We want this provisioning to fail after loading a deferred weight, then
81 // verify that the network is cleaned up properly.
82 CHECK_IF_ENABLED();
83 std::unique_ptr<Module> module = glow::make_unique<Module>();
84 auto F = module->createFunction("main");
85 auto *X = module->createPlaceholder(ElemKind::FloatTy, {1}, "X", false);
86
87 auto *Y = module->createPlaceholder(ElemKind::FloatTy, {1}, "Y", false);
88 auto *Z = module->createPlaceholder(ElemKind::FloatTy, {1}, "Z", false);
89 auto *output =
90 module->createPlaceholder(ElemKind::FloatTy, {1}, "output", false);
91 // Set X and Y as static.
92 X->setStatic(true);
93 Y->setStatic(true);
94 auto pow1 = F->createPow("pow", X, Y);
95 auto pow2 = F->createPow("pow2", Z, pow1);
96 F->createSave("save", pow2, output);
97 std::vector<Tensor> staticInputs;
98 auto xTensor = Tensor(X->getType());
99 auto yTensor = Tensor(Y->getType());
100 auto zTensor = Tensor(Z->getType());
101 xTensor.getHandle().clear(2.0);
102 yTensor.getHandle().clear(3.0);
103 zTensor.getHandle().clear(2.0);
104
105 TestDeferredWeightLoader loader;
106 loader.addWeight(&xTensor);
107 loader.addWeight(&yTensor);
108 loader.addName("fail");
109 loader.addName("fail");
110 DeferredLoader()->registerLoader(&loader);
111
112 CompilationContext cctx;
113 cctx.deferredWeightLoader = &loader;
114 cctx.optimizationOpts.foldStaticPlaceholderConversions = true;
115
116 DeviceConfig config(GetParam());
117 std::unique_ptr<DeviceManager> device(
118 DeviceManager::createDeviceManager(config));
119 EXPECT_FALSE(ERR_TO_BOOL(device->init()));
120
121 DeviceManagerMapTy devices;
122 devices.emplace(0, std::move(device));
123
124 DAGListTy partitions;
125
126 DAGNodePtrVec nodes;
127 auto rootNode = glow::make_unique<DAGNode>();
128 auto firstNode = glow::make_unique<DAGNode>();
129 rootNode->name = "root";
130 rootNode->children.push_back(firstNode.get());
131 firstNode->name = "main";
132 firstNode->logicalDevices = {0};
133 firstNode->backendName = GetParam();
134 nodes.push_back(std::move(firstNode));
135 partitions.push_back({std::move(rootNode), std::move(nodes)});
136
137 Provisioner provisioner(devices);
138 auto err = provisioner.provision(partitions, *module.get(), cctx);
139 // Expect that there was an Error when provisioning
140 EXPECT_TRUE(ERR_TO_BOOL(std::move(err)));
141
142 // Setup a new loader with correct info.
143 TestDeferredWeightLoader loaderNew;
144 loaderNew.addWeight(&xTensor);
145 loaderNew.addWeight(&yTensor);
146 loaderNew.addName("X");
147 loaderNew.addName("Y");
148 DeferredLoader()->registerLoader(&loaderNew);
149 cctx.deferredWeightLoader = &loaderNew;
150 auto err2 = provisioner.provision(partitions, *module.get(), cctx);
151 // Verify provisioning completes correctly.
152 EXPECT_FALSE(ERR_TO_BOOL(std::move(err2)));
153}
154
155TEST_P(DeferredWeightLoaderTest, staticPlaceholderInference) {
156 CHECK_IF_ENABLED();
157 auto hostmanager = createHostManager(GetParam());
158 ExecutionEngine EE{GetParam()};
159 auto &module = EE.getModule();
160 auto F = module.createFunction("main");
161 auto *X = module.createPlaceholder(ElemKind::FloatTy, {1}, "X", false);
162
163 auto *Y = module.createPlaceholder(ElemKind::FloatTy, {1}, "Y", false);
164 auto *Z = module.createPlaceholder(ElemKind::FloatTy, {1}, "Z", false);
165 auto *output =
166 module.createPlaceholder(ElemKind::FloatTy, {1}, "output", false);
167 // Set X and Y as static.
168 X->setStatic(true);
169 Y->setStatic(true);
170 auto pow1 = F->createPow("pow", X, Y);
171 auto pow2 = F->createPow("pow2", Z, pow1);
172 F->createSave("save", pow2, output);
173 std::vector<Tensor> staticInputs;
174 auto xTensor = Tensor(X->getType());
175 auto yTensor = Tensor(Y->getType());
176 auto zTensor = Tensor(Z->getType());
177 xTensor.getHandle().clear(2.0);
178 yTensor.getHandle().clear(3.0);
179 zTensor.getHandle().clear(2.0);
180
181 TestDeferredWeightLoader loader;
182 loader.addWeight(&xTensor);
183 loader.addWeight(&yTensor);
184 loader.addName("X");
185 loader.addName("Y");
186 DeferredLoader()->registerLoader(&loader);
187
188 CompilationContext cctx;
189 cctx.deferredWeightLoader = &loader;
190 cctx.optimizationOpts.foldStaticPlaceholderConversions = true;
191 EE.compile(cctx);
192 PlaceholderBindings pBindings;
193 pBindings.allocate(Z);
194 pBindings.allocate(output);
195 updateInputPlaceholders(pBindings, {Z}, {&zTensor});
196 EE.run(pBindings);
197 auto resHandle = pBindings.get(output)->getHandle();
198 EXPECT_NEAR(resHandle.at({0}), 256.0, 1E-5);
199}
200
201TEST_P(DeferredWeightLoaderTest, FP16StaticPlaceholderInference) {
202 CHECK_IF_ENABLED();
203 auto hostmanager = createHostManager(GetParam());
204 ExecutionEngine EE{GetParam()};
205 auto &module = EE.getModule();
206 auto F = module.createFunction("main");
207 auto *X = module.createPlaceholder(ElemKind::FloatTy, {1}, "X", false);
208
209 auto *Y = module.createPlaceholder(ElemKind::FloatTy, {1}, "Y", false);
210 auto *Z = module.createPlaceholder(ElemKind::FloatTy, {1}, "Z", false);
211 auto *output =
212 module.createPlaceholder(ElemKind::FloatTy, {1}, "output", false);
213 // Set X and Y as static.
214 X->setStatic(true);
215 Y->setStatic(true);
216 auto mul1 = F->createMul("mul", X, Y);
217 auto mul2 = F->createMul("mul2", Z, mul1);
218 F->createSave("save", mul2, output);
219 std::vector<Tensor> staticInputs;
220 auto xTensor = Tensor(X->getType());
221 auto yTensor = Tensor(Y->getType());
222 auto zTensor = Tensor(Z->getType());
223 xTensor.getHandle().clear(2.0);
224 yTensor.getHandle().clear(3.0);
225 zTensor.getHandle().clear(2.0);
226
227 TestDeferredWeightLoader loader;
228 loader.addWeight(&xTensor);
229 loader.addWeight(&yTensor);
230 loader.addName("X");
231 loader.addName("Y");
232 DeferredLoader()->registerLoader(&loader);
233
234 PlaceholderBindings pBindings;
235
236 CompilationContext cctx;
237 cctx.deferredWeightLoader = &loader;
238 cctx.optimizationOpts.foldStaticPlaceholderConversions = true;
239 cctx.precisionConfig.convertToFP16 = true;
240
241 EE.compile(cctx);
242
243 pBindings.allocate(Z);
244 pBindings.allocate(output);
245 updateInputPlaceholders(pBindings, {Z}, {&zTensor});
246 EE.run(pBindings);
247 auto resHandle = pBindings.get(output)->getHandle();
248 EXPECT_NEAR(resHandle.at({0}), 12.0, 1E-5);
249}
250
251INSTANTIATE_BACKEND_TEST(DeferredWeightLoaderTest);
252