1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
17 | #include "glow/Exporter/ONNXModelWriter.h" |
18 | #include "glow/Graph/Graph.h" |
19 | #include "glow/Graph/Nodes.h" |
20 | #include "glow/Graph/PlaceholderBindings.h" |
21 | #include "glow/Importer/ONNXModelLoader.h" |
22 | #include "gtest/gtest.h" |
23 | |
24 | #include "llvm/Support/FileSystem.h" |
25 | |
26 | #ifndef GLOW_DATA_PATH |
27 | #define GLOW_DATA_PATH |
28 | #endif |
29 | |
30 | using namespace glow; |
31 | |
32 | namespace { |
33 | /// Given a Function \p F and input names \p inputTensorNames and input types \p |
34 | /// inputTensorTypes, writes the function to file and reads it back using the |
35 | /// ONNXModelWriter and ONNXModelReader respectively then \returns the |
36 | /// reloaded function. \p useGlowCustomOps is used for determining the format |
37 | /// for ONNXModelWriter to write with. \p useString is used to use strings |
38 | /// rather than files for reading and writing functions. |
39 | Expected<Function *> saveAndReloadFunction( |
40 | Module &reloadMod, Function *F, |
41 | llvm::ArrayRef<const char *> inputTensorNames, |
42 | llvm::ArrayRef<TypeRef> inputTensorTypes, size_t irVer = 5, |
43 | size_t opsetVer = 10, bool zipMode = false, bool useGlowCustomOps = false, |
44 | bool useString = false, bool includeConstantData = true, |
45 | ConstantFoldingRecordMap *constFoldRecord = nullptr, |
46 | CompilationContext *reloadCctx = nullptr, |
47 | const BackendSpecificNodeInfo &backendSpecificNodeInfo = {}, |
48 | const OriginNameToTQPMap &originNameToTQPMap = {}) { |
49 | std::string outputString; |
50 | std::string outputFilename = zipMode ? "output.zip" : "output.onnxtxt" ; |
51 | |
52 | if (!useString) { |
53 | llvm::SmallString<64> path; |
54 | |
55 | auto tempFileRes = |
56 | llvm::sys::fs::createTemporaryFile("exporter" , outputFilename, path); |
57 | |
58 | RETURN_ERR_IF_NOT(tempFileRes.value() == 0, |
59 | "Failed to create temp file to write into." ); |
60 | |
61 | outputFilename = path.c_str(); |
62 | } |
63 | ScopeGuard cleanup([&]() { llvm::sys::fs::remove(outputFilename); }); |
64 | if (useString) { |
65 | cleanup.dismiss(); |
66 | } |
67 | // Write model to file or string. |
68 | { |
69 | Error err = Error::empty(); |
70 | llvm::StringMap<std::string> ; |
71 | RETURN_IF_ERR(ONNXModelWriter::insertLoaderNameUniqueOffsetMetadata( |
72 | extraMetadataProps, originNameToTQPMap)); |
73 | ONNXModelWriter onnxWR( |
74 | outputFilename, *F, irVer, opsetVer, &err, !zipMode, zipMode, |
75 | useGlowCustomOps, includeConstantData, extraMetadataProps, |
76 | constFoldRecord ? *constFoldRecord : ConstantFoldingRecordMap(), |
77 | backendSpecificNodeInfo, (useString) ? &outputString : nullptr); |
78 | if (err) { |
79 | llvm::errs() << "Failed to write model\n" ; |
80 | } |
81 | RETURN_IF_ERR(std::move(err)); |
82 | } |
83 | |
84 | Function *R = nullptr; |
85 | Module &origMod = *F->getParent(); |
86 | if (!includeConstantData) { |
87 | R = reloadMod.getFunction(F->getName()); |
88 | RETURN_ERR_IF_NOT(R, "Did not find Function to reload into." ); |
89 | R->clear(); |
90 | if (constFoldRecord) { |
91 | // Additionally remove the original Constants that we first folded, so |
92 | // that when we reload below we can recreate them. |
93 | std::unordered_set<Function *> funsToDelete; |
94 | for (auto &pair : *constFoldRecord) { |
95 | Function *origF = pair.second->getParent(); |
96 | funsToDelete.insert(origF); |
97 | Constant *C = reloadMod.getConstantByName(pair.first->getName()); |
98 | RETURN_ERR_IF_NOT(C, "Did not find constant that was initially folded" ); |
99 | reloadMod.eraseConstant(C); |
100 | } |
101 | for (Function *origF : funsToDelete) { |
102 | Function *reloadConstFoldF = reloadMod.getFunction(origF->getName()); |
103 | RETURN_ERR_IF_NOT(reloadConstFoldF, |
104 | "Did not find const folding function reloaded" ); |
105 | reloadMod.eraseFunction(reloadConstFoldF); |
106 | origMod.eraseFunction(origF); |
107 | } |
108 | } |
109 | } else { |
110 | R = reloadMod.createFunction("R" ); |
111 | } |
112 | |
113 | // Load model from file. |
114 | { |
115 | Error err = Error::empty(); |
116 | ONNXModelLoader onnxLD( |
117 | outputFilename, inputTensorNames, inputTensorTypes, *R, &err, zipMode, |
118 | reloadCctx ? &reloadCctx->backendOpts.backendSpecificNodeInfo : nullptr, |
119 | /* disableConstFoldInLoader */ true, |
120 | /* loadIntoExistingModule */ !includeConstantData, |
121 | /* Backend */ nullptr, |
122 | /* inputStringPtr */ (useString) ? &outputString : nullptr); |
123 | |
124 | if (err) { |
125 | llvm::errs() << "ONNXModelLoader failed to reload model: " |
126 | << outputFilename << "\n" ; |
127 | } |
128 | |
129 | RETURN_IF_ERR(std::move(err)); |
130 | } |
131 | |
132 | // Verify reloaded function is valid. |
133 | RETURN_ERR_IF_NOT(R->verify(), "Reloaded function is not valid." ); |
134 | |
135 | // Verify that the Constants from the original Module have the same data as |
136 | // those in the reloaded module. |
137 | if (constFoldRecord) { |
138 | deleteUnusedConstants(reloadMod); |
139 | deleteUnusedConstants(origMod); |
140 | for (Constant *newC : reloadMod.getConstants()) { |
141 | Constant *origC = R->getParent()->getConstantByName(newC->getName()); |
142 | RETURN_ERR_IF_NOT(origC, |
143 | strFormat("Expected original Constant by name %s" , |
144 | newC->getName().data())); |
145 | RETURN_ERR_IF_NOT(newC->getPayload().isBitwiseEqual(origC->getPayload()), |
146 | strFormat("Mismatch on Constants of name %s" , |
147 | newC->getName().data())); |
148 | } |
149 | } |
150 | |
151 | return R; |
152 | } |
153 | |
154 | /// Loads model from ONNX format file \p name into glow Function. |
155 | /// On success exports glow graph to the output file in "extended" ONNX format, |
156 | /// i.e. some glow operators don't have presentation in vanilla ONNX standard. |
157 | void testLoadAndSaveONNXModel(const std::string &name, bool zipMode, |
158 | bool useString) { |
159 | ExecutionEngine EE{}; |
160 | auto &mod = EE.getModule(); |
161 | Function *F = mod.createFunction("main" ); |
162 | llvm::errs() << "loading model " << name << "\n" ; |
163 | |
164 | size_t irVer = 0, opsetVer = 0; |
165 | |
166 | bool useGlowCustomOps = false; |
167 | |
168 | // Load model from file. |
169 | { |
170 | Error err = Error::empty(); |
171 | ONNXModelLoader onnxLD(name, {}, {}, *F, &err); |
172 | irVer = onnxLD.getIrVersion(); |
173 | opsetVer = onnxLD.getOpSetVersion(); |
174 | useGlowCustomOps = onnxLD.usingGlowCustomOps(); |
175 | |
176 | if (err) { |
177 | llvm::errs() << "ONNXModelLoader failed to load model: " << name << "\n" ; |
178 | } |
179 | FAIL_TEST_IF_ERR(std::move(err)); |
180 | } |
181 | |
182 | Module reloadMod; |
183 | FAIL_TEST_IF_ERR(saveAndReloadFunction(reloadMod, F, {}, {}, irVer, opsetVer, |
184 | zipMode, useGlowCustomOps, useString) |
185 | .takeError()); |
186 | } |
187 | |
188 | bool endsWith(const std::string &full, const std::string &ending) { |
189 | if (full.length() >= ending.length()) { |
190 | return (0 == full.compare(full.length() - ending.length(), ending.length(), |
191 | ending)); |
192 | } else { |
193 | return false; |
194 | } |
195 | } |
196 | |
197 | /// Given a Function \p F, \returns a list of nodes with the Kind \p kind. |
198 | std::vector<Node *> getNodesByType(Function *F, Kinded::Kind kind) { |
199 | std::vector<Node *> found; |
200 | for (auto &N : F->getNodes()) { |
201 | if (N.getKind() == kind) { |
202 | found.push_back(&N); |
203 | } |
204 | } |
205 | return found; |
206 | } |
207 | |
208 | /// Given a function \p F and a Kind \p type, returns a casted pointer to the |
209 | /// single node in F with that kind or an Error if one occurs. |
210 | template <typename T> |
211 | Expected<T *> getSingleNodeWithKind(Function *F, Kinded::Kind type) { |
212 | auto nodesWithKind = getNodesByType(F, type); |
213 | |
214 | RETURN_ERR_IF_NOT(nodesWithKind.size() == 1, |
215 | strFormat("Expected one node with kind %s but found %lu" , |
216 | Kinded::getKindName(type), nodesWithKind.size())); |
217 | |
218 | T *node = llvm::dyn_cast<T>(nodesWithKind[0]); |
219 | |
220 | RETURN_ERR_IF_NOT(node != nullptr, "Node is not of expected types" ); |
221 | |
222 | return node; |
223 | } |
224 | |
225 | /// Helper that \returns whether two StringMaps \p LHS and \p RHS are equal. |
226 | template <typename T> |
227 | static bool isStrMapEqual(const llvm::StringMap<T> &LHS, |
228 | const llvm::StringMap<T> &RHS) { |
229 | if (LHS.size() != RHS.size()) { |
230 | return false; |
231 | } |
232 | for (const auto &keyValue : LHS) { |
233 | auto findInRHS = RHS.find(keyValue.getKey()); |
234 | if (findInRHS == RHS.end()) { |
235 | return false; |
236 | } |
237 | if (keyValue.getValue() != findInRHS->getValue()) { |
238 | return false; |
239 | } |
240 | } |
241 | return true; |
242 | } |
243 | |
244 | } // namespace |
245 | |
246 | /// Use to test constant folding exporting and reloading tests, where some |
247 | /// constant folding is recorded and serialized in the custom Glow ONNX model. |
248 | class ConstFoldReloadTest : public ::testing::Test { |
249 | public: |
250 | ConstFoldReloadTest() : EE_("Interpreter" ), mod_(EE_.getModule()) { |
251 | F_ = mod_.createFunction("main" ); |
252 | } |
253 | |
254 | protected: |
255 | ExecutionEngine EE_; |
256 | Module &mod_; |
257 | Function *F_; |
258 | PlaceholderBindings bindings_; |
259 | CompilationContext cctx_; |
260 | |
261 | /// Constant folds \ref F_ and then serializes it. Then deserializes it and |
262 | /// runs it and makes sure that running the original and the reloaded Function |
263 | /// are bitwise equal. Verifies that \p numExpectedConstsFolded constant |
264 | /// folding records are created during constant folding/recording. Any Nodes |
265 | /// listed in \p nodesToPar will be Model parallelized in two. |
266 | void serializeAndReloadAndCompareResults( |
267 | unsigned numExpectedConstsFolded, |
268 | const std::unordered_set<Node *> &nodesToPar = {}) { |
269 | bindings_.allocate(mod_.getPlaceholders()); |
270 | |
271 | // Perform constant folding, recording what occurs so we can serialize it. |
272 | ConstantFoldingRecordMap record = constantFoldAndRecord(F_, cctx_); |
273 | EXPECT_EQ(record.size(), numExpectedConstsFolded); |
274 | runDCEPass(F_, cctx_); |
275 | |
276 | if (nodesToPar.size()) { |
277 | llvm::DenseMap<Node *, size_t> numChunks; |
278 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
279 | for (Node *N : nodesToPar) { |
280 | numChunks[N] = 2; |
281 | parOpts[N] = ParallelTransformKind::Model; |
282 | } |
283 | |
284 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
285 | ASSIGN_VALUE_OR_FAIL_TEST(replacedMap, |
286 | ::glow::parallelizeOps(F_, numChunks, parOpts)); |
287 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
288 | |
289 | ConstantFoldingRecordMap parRecord = constantFoldAndRecord(F_, cctx_); |
290 | record.insert(parRecord.begin(), parRecord.end()); |
291 | runDCEPass(F_, cctx_); |
292 | } |
293 | |
294 | // Clone the original module into a new module used for reloading the model. |
295 | ExecutionEngine reloadEE(EE_.getBackendName()); |
296 | Module &reloadMod = reloadEE.getModule(); |
297 | mod_.clone(&reloadMod); |
298 | |
299 | // Save and reload F. |
300 | Function *reloadF; |
301 | CompilationContext reloadCctx; |
302 | ASSIGN_VALUE_OR_FAIL_TEST( |
303 | reloadF, saveAndReloadFunction( |
304 | reloadMod, F_, {}, {}, 7, 9, |
305 | /* zipMode */ false, |
306 | /* useGlowCustomOps */ true, |
307 | /* useString */ false, |
308 | /* includeConstantData */ false, &record, &reloadCctx, |
309 | cctx_.backendOpts.backendSpecificNodeInfo)); |
310 | |
311 | // Verify that the Function and its Module are the same before and after. |
312 | EXPECT_EQ(reloadF->toString(), F_->toString()); |
313 | EXPECT_EQ(reloadMod.toString(), mod_.toString()); |
314 | |
315 | PlaceholderBindings reloadBindings = |
316 | bindings_.clone(reloadMod.getPlaceholders()); |
317 | |
318 | // Now run both to check they have bitwise equal results. |
319 | EE_.compile(cctx_); |
320 | EE_.run(bindings_); |
321 | |
322 | reloadEE.compile(reloadCctx); |
323 | reloadEE.run(reloadBindings); |
324 | |
325 | EXPECT_TRUE( |
326 | PlaceholderBindings::compare(&bindings_, &reloadBindings, 0.0f)); |
327 | |
328 | // Verify that backend-specific node info was serialized correctly. |
329 | EXPECT_EQ(cctx_.backendOpts.backendSpecificNodeInfo.count(F_), |
330 | reloadCctx.backendOpts.backendSpecificNodeInfo.count(reloadF)); |
331 | |
332 | if (cctx_.backendOpts.backendSpecificNodeInfo.count(F_)) { |
333 | auto &origNodeMap = cctx_.backendOpts.backendSpecificNodeInfo[F_]; |
334 | auto &reloadNodeMap = |
335 | reloadCctx.backendOpts.backendSpecificNodeInfo[reloadF]; |
336 | |
337 | for (const Node &origN : F_->getNodes()) { |
338 | auto reloadNodeIt = std::find_if( |
339 | reloadF->getNodes().begin(), reloadF->getNodes().end(), |
340 | [&](const Node &N) { return N.getName() == origN.getName(); }); |
341 | ASSERT_NE(reloadNodeIt, reloadF->getNodes().end()); |
342 | EXPECT_TRUE( |
343 | isStrMapEqual(reloadNodeMap[&*reloadNodeIt], origNodeMap[&origN])); |
344 | } |
345 | } |
346 | } |
347 | }; |
348 | |
349 | TEST(exporter, onnxModels) { |
350 | std::string inputDirectory(GLOW_DATA_PATH "tests/models/onnxModels" ); |
351 | std::cout << "inputDirectory: " << inputDirectory << std::endl; |
352 | std::error_code code; |
353 | for (llvm::sys::fs::directory_iterator dirIt(inputDirectory, code); |
354 | !code && dirIt != llvm::sys::fs::directory_iterator(); |
355 | dirIt.increment(code)) { |
356 | auto name = dirIt->path(); |
357 | if (!endsWith(name, ".onnxtxt" )) { |
358 | llvm::outs() << "Ignore non-onnxtxt input: " << name << "\n" ; |
359 | continue; |
360 | } |
361 | if (name.find("getInputsOnnxDefineSample.onnxtxt" ) != std::string::npos || |
362 | name.find("preluInvalidBroadcastSlope.onnxtxt" ) != std::string::npos || |
363 | name.find("padReflect.onnxtxt" ) != std::string::npos || |
364 | name.find("powMultiBroadcastOp7.onnxtxt" ) != std::string::npos || |
365 | name.find("gatherConstantFolding.onnxtxt" ) != std::string::npos || |
366 | name.find("averagePool3D.onnxtxt" ) != std::string::npos || |
367 | name.find("sparseLengthsSum.onnxtxt" ) != std::string::npos || |
368 | name.find("constantOfShapeInt32Fail.onnxtxt" ) != std::string::npos || |
369 | name.find("padEdge.onnxtxt" ) != std::string::npos || |
370 | name.find("castToFloat.onnxtxt" ) != std::string::npos || |
371 | name.find("castToFloat16.onnxtxt" ) != std::string::npos || |
372 | name.find("castToInt64.onnxtxt" ) != std::string::npos || |
373 | name.find("castToInt32.onnxtxt" ) != std::string::npos || |
374 | name.find("simpleConvBiasFail.onnxtxt" ) != std::string::npos || |
375 | name.find("Where.onnxtxt" ) != std::string::npos || |
376 | name.find("constantOfShapeInt64Fail.onnxtxt" ) != std::string::npos || |
377 | name.find("ArgMaxDefault.onnxtxt" ) != std::string::npos || |
378 | name.find("ArgMaxKeepDim.onnxtxt" ) != std::string::npos || |
379 | name.find("ArgMaxNoKeepDim.onnxtxt" ) != std::string::npos || |
380 | name.find("ArgMinDefault.onnxtxt" ) != std::string::npos || |
381 | name.find("ArgMinKeepDim.onnxtxt" ) != std::string::npos || |
382 | name.find("ArgMinNoKeepDim.onnxtxt" ) != std::string::npos || |
383 | name.find("upsampleOpset7.onnxtxt" ) != std::string::npos || |
384 | name.find("upsampleOpset9.onnxtxt" ) != std::string::npos || |
385 | name.find("resizeNearestV11compat.onnxtxt" ) != std::string::npos || |
386 | name.find("resizeNearestV11compat_sizes.onnxtxt" ) != |
387 | std::string::npos || |
388 | name.find("resizeBilinearV11compat.onnxtxt" ) != std::string::npos || |
389 | name.find("resizeBilinearV11compat_sizes.onnxtxt" ) != |
390 | std::string::npos || |
391 | name.find("upsampleOpset9.onnxtxt" ) != std::string::npos || |
392 | name.find("NonMaxSuppressionSSD_ONNX.onnxtxt" ) != std::string::npos || |
393 | name.find("NonMaxSuppression.onnxtxt" ) != std::string::npos || |
394 | name.find("NonMaxSuppressionOptionalParams.onnxtxt" ) != |
395 | std::string::npos || |
396 | name.find("NonMaxSuppressionSSD.onnxtxt" ) != std::string::npos || |
397 | name.find("ROIAlign_onnx.onnxtxt" ) != std::string::npos || |
398 | name.find("MatMul4D.onnxtxt" ) != std::string::npos || |
399 | name.find("Less.onnxtxt" ) != std::string::npos || |
400 | name.find("Erf.onnxtxt" ) != std::string::npos || |
401 | name.find("Asin.onnxtxt" ) != std::string::npos || |
402 | name.find("Acos.onnxtxt" ) != std::string::npos || |
403 | name.find("Atan.onnxtxt" ) != std::string::npos || |
404 | name.find("Sin.onnxtxt" ) != std::string::npos || |
405 | name.find("Cos.onnxtxt" ) != std::string::npos || |
406 | name.find("abs.onnxtxt" ) != std::string::npos || |
407 | name.find("log.onnxtxt" ) != std::string::npos || |
408 | name.find("RangeInt32.onnxtxt" ) != std::string::npos || |
409 | name.find("RangeFloat.onnxtxt" ) != std::string::npos || |
410 | name.find("scatterND.onnxtxt" ) != std::string::npos || |
411 | name.find("mscatterND.onnxtxt" ) != std::string::npos || |
412 | name.find("loop_cond.onnxtxt" ) != std::string::npos || |
413 | name.find("loop_empty_tripcount.onnxtxt" ) != std::string::npos || |
414 | name.find("loop_emptycond.onnxtxt" ) != std::string::npos || |
415 | name.find("loop_no_iteration.onnxtxt" ) != std::string::npos || |
416 | name.find("loop_tripcount.onnxtxt" ) != std::string::npos || |
417 | name.find("loop_withoutN.onnxtxt" ) != std::string::npos || |
418 | name.find("sign.onnxtxt" ) != std::string::npos || |
419 | name.find("gatherND.onnxtxt" ) != std::string::npos || |
420 | name.find("softmax13.onnxtxt" ) != std::string::npos || |
421 | name.find("logsoftmax.onnxtxt" ) != std::string::npos || |
422 | name.find("hardsigmoid.onnxtxt" ) != std::string::npos || |
423 | name.find("simpleConvTranspose.onnxtxt" ) != std::string::npos || |
424 | name.find("simpleConvTransposeOutShape.onnxtxt" ) != std::string::npos || |
425 | name.find("simpleConvTransposeOutShapeDilation.onnxtxt" ) != |
426 | std::string::npos || |
427 | name.find("simpleConvTransposeOutShapeSameLower.onnxtxt" ) != |
428 | std::string::npos || |
429 | name.find("simpleConvTransposeOutShapeSameUpper.onnxtxt" ) != |
430 | std::string::npos || |
431 | name.find("simpleConvTransposeAutoPadSameLower.onnxtxt" ) != |
432 | std::string::npos || |
433 | name.find("simpleConvTransposeAutoPadSameUpper.onnxtxt" ) != |
434 | std::string::npos || |
435 | name.find("convTransposeAsymmetric.onnxtxt" ) != std::string::npos || |
436 | name.find("Mean.onnxtxt" ) != std::string::npos || |
437 | name.find("Mean_broadcast.onnxtxt" ) != std::string::npos || |
438 | name.find("NonZero.onnxtxt" ) != std::string::npos || |
439 | name.find("logicalAnd.onnxtxt" ) != std::string::npos || |
440 | name.find("logicalAndBcast.onnxtxt" ) != std::string::npos || |
441 | name.find("logicalOrBcast.onnxtxt" ) != std::string::npos || |
442 | name.find("logicalOr.onnxtxt" ) != std::string::npos || |
443 | name.find("logicalXorBcast.onnxtxt" ) != std::string::npos || |
444 | name.find("logicalXor.onnxtxt" ) != std::string::npos || |
445 | name.find("logicalNot.onnxtxt" ) != std::string::npos || |
446 | |
447 | name.find("simpleConvTransposePads.onnxtxt" ) != std::string::npos || |
448 | name.find("simpleConvTransposeAutoPadValid.onnxtxt" ) != |
449 | std::string::npos || |
450 | name.find("simpleConvTransposeOutShapeSameUpper.onnxtxt" ) != |
451 | std::string::npos || |
452 | name.find("simpleConvTransposeAutoPadSameLower.onnxtxt" ) != |
453 | std::string::npos || |
454 | name.find("convTransposeGroup.onnxtxt" ) != std::string::npos || |
455 | name.find("pow_element_wise.onnxtxt" ) != std::string::npos || |
456 | name.find("pow_array_broadcast.onnxtxt" ) != std::string::npos || |
457 | name.find("pow_scalar_broadcast.onnxtxt" ) != std::string::npos || |
458 | name.find("simpleConvTransposeAutoPadSameUpper.onnxtxt" ) != |
459 | std::string::npos || |
460 | name.find("sliceInvalidAxes.onnxtxt" ) != std::string::npos || |
461 | name.find("sliceWithUnsupportedStep.onnxtxt" ) != std::string::npos || |
462 | name.find("simpleConv3DNonSquareDilation.onnxtxt" ) != |
463 | std::string::npos) { |
464 | // Ignore invalid ONNX files and graphs without nodes. |
465 | llvm::outs() << "Ignore invalid input files: " << name << "\n" ; |
466 | continue; |
467 | } |
468 | if (name.find("constant.onnxtxt" ) != std::string::npos || |
469 | name.find("shape.onnxtxt" ) != std::string::npos || |
470 | name.find("bool_from_int.onnxtxt" ) != std::string::npos || |
471 | name.find("sum1.onnxtxt" ) != std::string::npos) { |
472 | // Ignore invalid ONNX files and graphs without nodes. |
473 | llvm::outs() << "Ignore empty graph file: " << name << "\n" ; |
474 | continue; |
475 | } |
476 | if (name.find(".output.onnxtxt" ) != std::string::npos) { |
477 | // Ignore output files - debugging mode only. |
478 | llvm::outs() << "Ignore output file: " << name << "\n" ; |
479 | continue; |
480 | } |
481 | // TODO: Debug why these RNN models don`t work! |
482 | if (name.find("rnn" ) != std::string::npos) { |
483 | // Ignore RNN files. |
484 | llvm::outs() << "Ignore RNN model file: " << name << "\n" ; |
485 | continue; |
486 | } |
487 | if (name.find("gru" ) != std::string::npos) { |
488 | // Ignore GRU files. |
489 | llvm::outs() << "Ignore GRU model file: " << name << "\n" ; |
490 | continue; |
491 | } |
492 | if (name.find("lstm" ) != std::string::npos) { |
493 | // Ignore LSTM files. |
494 | llvm::outs() << "Ignore LSTM model file: " << name << "\n" ; |
495 | continue; |
496 | } |
497 | const bool customOnnxDefineSymbol = |
498 | name.find("dimParam.onnxtxt" ) != std::string::npos; |
499 | if (customOnnxDefineSymbol) { |
500 | setOnnxDefineSymbol({"ONNXUndefinedSymbol,1" }); |
501 | } |
502 | |
503 | // Disable constant folding for these tests. |
504 | setConstantFoldLoaderOpsFlag(false); |
505 | |
506 | testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ true, |
507 | /* useString */ false); |
508 | testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ false, |
509 | /* useString */ false); |
510 | testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ false, |
511 | /* useString */ true); |
512 | |
513 | // Reset the custom symbol used. |
514 | if (customOnnxDefineSymbol) { |
515 | setOnnxDefineSymbol({}); |
516 | } |
517 | } |
518 | } |
519 | |
520 | TEST(exporter, ChannelwiseQuantizedConvolution) { |
521 | ExecutionEngine EE{}; |
522 | auto &mod = EE.getModule(); |
523 | auto *F = mod.createFunction("F" ); |
524 | |
525 | unsigned_t inChannels = 8; |
526 | unsigned_t inSide = 6; |
527 | unsigned_t batchSize = 8; |
528 | unsigned_t outChannels = 12; |
529 | unsigned_t filterSide = 3; |
530 | unsigned_t groups = 4; |
531 | unsigned_t dilation = 1; |
532 | |
533 | Placeholder *input = mod.createPlaceholder( |
534 | ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3, |
535 | "input" , /* isTrainable */ false); |
536 | |
537 | Constant *biasConstant = |
538 | mod.createConstant(ElemKind::FloatTy, {outChannels}, "bias" ); |
539 | biasConstant->getPayloadMutable().getHandle<float>().randomize(-0.1, 0.1, |
540 | mod.getPRNG()); |
541 | |
542 | Constant *filterScalesConstant = |
543 | mod.createConstant(ElemKind::FloatTy, {outChannels}, "filter_scales" ); |
544 | |
545 | Constant *filterOffsetsConstant = |
546 | mod.createConstant(ElemKind::Int32ITy, {outChannels}, "filter_offsets" ); |
547 | |
548 | Constant *weightsConstant = mod.createConstant( |
549 | ElemKind::Int8QTy, |
550 | {outChannels, filterSide, filterSide, inChannels / groups}, 2.5, 1, |
551 | "offsets" ); |
552 | |
553 | std::vector<unsigned_t> kernels = {filterSide, filterSide}; |
554 | std::vector<unsigned_t> strides = {1, 1}; |
555 | std::vector<unsigned_t> pads = {1, 1, 1, 1}; |
556 | |
557 | auto outSize = |
558 | calculateConvPoolOutputDims(inSide, inSide, kernels, strides, pads); |
559 | auto *outTy = mod.uniqueType( |
560 | ElemKind::Int8QTy, |
561 | {batchSize, outSize.first, outSize.second, outChannels}, 3.8, 4); |
562 | |
563 | auto *cqConv = F->createChannelwiseQuantizedConv( |
564 | "cqconv" , input, weightsConstant, biasConstant, filterScalesConstant, |
565 | filterOffsetsConstant, /* biasScales */ nullptr, |
566 | /* biasOffsets */ nullptr, outTy, kernels, strides, pads, groups, |
567 | {dilation, dilation}); |
568 | |
569 | auto *save = F->createSave("save_out" , cqConv); |
570 | |
571 | Placeholder *output = save->getPlaceholder(); |
572 | |
573 | ASSERT_TRUE(F->verify()); |
574 | |
575 | PlaceholderBindings bindings; |
576 | bindings.allocate({input, output}); |
577 | |
578 | // Save and reload F. |
579 | Function *R; |
580 | Module reloadMod; |
581 | ASSIGN_VALUE_OR_FAIL_TEST( |
582 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {input->getType()})); |
583 | |
584 | ChannelwiseQuantizedConvolutionNode *cqConvReloaded; |
585 | ASSIGN_VALUE_OR_FAIL_TEST( |
586 | cqConvReloaded, |
587 | getSingleNodeWithKind<ChannelwiseQuantizedConvolutionNode>( |
588 | R, Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind)); |
589 | |
590 | EXPECT_TRUE(cqConvReloaded->getInput().getType()->isEqual( |
591 | *cqConv->getInput().getType())); |
592 | EXPECT_TRUE(cqConvReloaded->getResult().getType()->isEqual( |
593 | *cqConv->getResult().getType())); |
594 | |
595 | EXPECT_TRUE(cqConvReloaded->getFilter().getType()->isEqual( |
596 | *cqConv->getFilter().getType())); |
597 | EXPECT_TRUE(cqConvReloaded->getBias().getType()->isEqual( |
598 | *cqConv->getBias().getType())); |
599 | EXPECT_TRUE(cqConvReloaded->getFilterScales().getType()->isEqual( |
600 | *cqConv->getFilterScales().getType())); |
601 | EXPECT_TRUE(cqConvReloaded->getFilterOffsets().getType()->isEqual( |
602 | *cqConv->getFilterOffsets().getType())); |
603 | EXPECT_TRUE(cqConvReloaded->getBiasScales().getType()->isEqual( |
604 | *cqConv->getBiasScales().getType())); |
605 | EXPECT_TRUE(cqConvReloaded->getBiasOffsets().getType()->isEqual( |
606 | *cqConv->getBiasOffsets().getType())); |
607 | |
608 | EXPECT_EQ(cqConvReloaded->getKernels(), cqConv->getKernels()); |
609 | EXPECT_EQ(cqConvReloaded->getStrides(), cqConv->getStrides()); |
610 | EXPECT_EQ(cqConvReloaded->getPads(), cqConv->getPads()); |
611 | EXPECT_EQ(cqConvReloaded->getGroup(), cqConv->getGroup()); |
612 | EXPECT_EQ(cqConvReloaded->getDilation(), cqConv->getDilation()); |
613 | } |
614 | |
615 | TEST(exporter, QuantizedConvolution) { |
616 | ExecutionEngine EE{}; |
617 | auto &mod = EE.getModule(); |
618 | auto *F = mod.createFunction("F" ); |
619 | |
620 | unsigned_t inChannels = 8; |
621 | unsigned_t inSide = 6; |
622 | unsigned_t batchSize = 8; |
623 | unsigned_t outChannels = 12; |
624 | unsigned_t filterSide = 3; |
625 | unsigned_t groups = 4; |
626 | unsigned_t dilation = 1; |
627 | |
628 | Placeholder *input = mod.createPlaceholder( |
629 | ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3, |
630 | "input" , /* isTrainable */ false); |
631 | |
632 | Placeholder *weights = mod.createPlaceholder( |
633 | ElemKind::Int8QTy, |
634 | {outChannels, filterSide, filterSide, inChannels / groups}, 2.5, 1, |
635 | "weights" , |
636 | /* isTrainable */ false); |
637 | |
638 | Placeholder *bias = |
639 | mod.createPlaceholder(ElemKind::Int32QTy, {outChannels}, 0.25, 2, "bias" , |
640 | /* isTrainable */ false); |
641 | |
642 | std::vector<unsigned_t> kernels = {filterSide, filterSide}; |
643 | std::vector<unsigned_t> strides = {1, 1}; |
644 | std::vector<unsigned_t> pads = {1, 1, 1, 1}; |
645 | |
646 | auto outSize = |
647 | calculateConvPoolOutputDims(inSide, inSide, kernels, strides, pads); |
648 | auto *outTy = mod.uniqueType( |
649 | ElemKind::Int8QTy, |
650 | {batchSize, outSize.first, outSize.second, outChannels}, 3.8, 4); |
651 | |
652 | auto *qConv = F->createConv("qconv" , input, weights, bias, outTy, kernels, |
653 | strides, pads, groups, {dilation, dilation}); |
654 | |
655 | auto *save = F->createSave("save_out" , qConv); |
656 | |
657 | Placeholder *output = save->getPlaceholder(); |
658 | |
659 | ASSERT_TRUE(F->verify()); |
660 | |
661 | PlaceholderBindings bindings; |
662 | bindings.allocate({weights, bias, input, output}); |
663 | convertPlaceholdersToConstants(F, bindings, {input, output}); |
664 | |
665 | // Save and reload F. |
666 | Function *R; |
667 | Module reloadMod; |
668 | ASSIGN_VALUE_OR_FAIL_TEST( |
669 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {input->getType()})); |
670 | |
671 | // Verify reloaded function matches the original. |
672 | ConvolutionNode *qConvReloaded; |
673 | ASSIGN_VALUE_OR_FAIL_TEST(qConvReloaded, |
674 | getSingleNodeWithKind<ConvolutionNode>( |
675 | R, Kinded::Kind::ConvolutionNodeKind)); |
676 | |
677 | EXPECT_TRUE(qConvReloaded->getInput().getType()->isEqual( |
678 | *qConv->getInput().getType())); |
679 | EXPECT_TRUE(qConvReloaded->getResult().getType()->isEqual( |
680 | *qConv->getResult().getType())); |
681 | |
682 | EXPECT_TRUE(qConvReloaded->getFilter().getType()->isEqual( |
683 | *qConv->getFilter().getType())); |
684 | EXPECT_TRUE( |
685 | qConvReloaded->getBias().getType()->isEqual(*qConv->getBias().getType())); |
686 | |
687 | EXPECT_EQ(qConvReloaded->getKernels(), qConv->getKernels()); |
688 | EXPECT_EQ(qConvReloaded->getStrides(), qConv->getStrides()); |
689 | EXPECT_EQ(qConvReloaded->getPads(), qConv->getPads()); |
690 | EXPECT_EQ(qConvReloaded->getGroup(), qConv->getGroup()); |
691 | EXPECT_EQ(qConvReloaded->getDilation(), qConv->getDilation()); |
692 | } |
693 | |
694 | TEST(exporter, QuantizedMaxPool) { |
695 | ExecutionEngine EE{}; |
696 | auto &mod = EE.getModule(); |
697 | auto *F = mod.createFunction("F" ); |
698 | |
699 | unsigned_t inChannels = 8; |
700 | unsigned_t inSide = 6; |
701 | unsigned_t batchSize = 8; |
702 | unsigned_t filterSide = 3; |
703 | |
704 | Placeholder *input = mod.createPlaceholder( |
705 | ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3, |
706 | "input" , /* isTrainable */ false); |
707 | |
708 | std::vector<unsigned_t> kernels = {filterSide, filterSide}; |
709 | std::vector<unsigned_t> strides = {1, 1}; |
710 | std::vector<unsigned_t> pads = {1, 1, 1, 1}; |
711 | |
712 | auto *maxPool = F->createMaxPool("maxpool" , input, kernels, strides, pads); |
713 | |
714 | auto *save = F->createSave("save_out" , maxPool->getNthResult(0)); |
715 | |
716 | Placeholder *output = save->getPlaceholder(); |
717 | |
718 | ASSERT_TRUE(F->verify()); |
719 | |
720 | PlaceholderBindings bindings; |
721 | bindings.allocate({input, output}); |
722 | convertPlaceholdersToConstants(F, bindings, {input, output}); |
723 | |
724 | // Save and reload F. |
725 | Function *R; |
726 | Module reloadMod; |
727 | ASSIGN_VALUE_OR_FAIL_TEST( |
728 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {input->getType()})); |
729 | |
730 | // Verify reloaded function matches the original. |
731 | MaxPoolNode *maxPoolReloaded; |
732 | ASSIGN_VALUE_OR_FAIL_TEST( |
733 | maxPoolReloaded, |
734 | getSingleNodeWithKind<MaxPoolNode>(R, Kinded::Kind::MaxPoolNodeKind)); |
735 | |
736 | EXPECT_TRUE(maxPoolReloaded->getInput().getType()->isEqual( |
737 | *maxPool->getInput().getType())); |
738 | EXPECT_TRUE(maxPoolReloaded->getResult().getType()->isEqual( |
739 | *maxPool->getResult().getType())); |
740 | |
741 | EXPECT_EQ(maxPoolReloaded->getKernels(), maxPool->getKernels()); |
742 | EXPECT_EQ(maxPoolReloaded->getStrides(), maxPool->getStrides()); |
743 | EXPECT_EQ(maxPoolReloaded->getPads(), maxPool->getPads()); |
744 | } |
745 | |
746 | TEST(exporter, QuantizedAvgPool) { |
747 | ExecutionEngine EE{}; |
748 | auto &mod = EE.getModule(); |
749 | auto *F = mod.createFunction("F" ); |
750 | |
751 | unsigned_t inChannels = 8; |
752 | unsigned_t inSide = 6; |
753 | unsigned_t batchSize = 8; |
754 | unsigned_t filterSide = 3; |
755 | |
756 | Placeholder *input = mod.createPlaceholder( |
757 | ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3, |
758 | "input" , /* isTrainable */ false); |
759 | |
760 | std::vector<unsigned_t> kernels = {filterSide, filterSide}; |
761 | std::vector<unsigned_t> strides = {1, 1}; |
762 | std::vector<unsigned_t> pads = {1, 1, 1, 1}; |
763 | |
764 | auto *avgPool = F->createAvgPool("avgpool" , input, kernels, strides, pads); |
765 | |
766 | auto *save = F->createSave("save_out" , avgPool->getNthResult(0)); |
767 | |
768 | Placeholder *output = save->getPlaceholder(); |
769 | |
770 | ASSERT_TRUE(F->verify()); |
771 | |
772 | PlaceholderBindings bindings; |
773 | bindings.allocate({input, output}); |
774 | convertPlaceholdersToConstants(F, bindings, {input, output}); |
775 | |
776 | // Save and reload F. |
777 | Function *R; |
778 | Module reloadMod; |
779 | ASSIGN_VALUE_OR_FAIL_TEST( |
780 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {input->getType()})); |
781 | |
782 | // Verify reloaded function matches the original. |
783 | AvgPoolNode *avgPoolReloaded; |
784 | ASSIGN_VALUE_OR_FAIL_TEST( |
785 | avgPoolReloaded, |
786 | getSingleNodeWithKind<AvgPoolNode>(R, Kinded::Kind::AvgPoolNodeKind)); |
787 | |
788 | EXPECT_TRUE(avgPoolReloaded->getInput().getType()->isEqual( |
789 | *avgPool->getInput().getType())); |
790 | EXPECT_TRUE(avgPoolReloaded->getResult().getType()->isEqual( |
791 | *avgPool->getResult().getType())); |
792 | |
793 | EXPECT_EQ(avgPoolReloaded->getKernels(), avgPool->getKernels()); |
794 | EXPECT_EQ(avgPoolReloaded->getStrides(), avgPool->getStrides()); |
795 | EXPECT_EQ(avgPoolReloaded->getPads(), avgPool->getPads()); |
796 | EXPECT_EQ(avgPoolReloaded->getCountIncludePads(), |
797 | avgPool->getCountIncludePads()); |
798 | } |
799 | |
800 | TEST(exporter, QuantizedAdaptiveAvgPool) { |
801 | ExecutionEngine EE{}; |
802 | auto &mod = EE.getModule(); |
803 | auto *F = mod.createFunction("F" ); |
804 | |
805 | unsigned_t inChannels = 8; |
806 | unsigned_t inSide = 6; |
807 | unsigned_t batchSize = 8; |
808 | |
809 | Placeholder *input = mod.createPlaceholder( |
810 | ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3, |
811 | "input" , /* isTrainable */ false); |
812 | |
813 | auto *outTy = mod.uniqueTypeWithNewShape(input->getType(), |
814 | {batchSize, 3, 3, inChannels}); |
815 | |
816 | auto *adaptiveAvgPool = |
817 | F->createAdaptiveAvgPool("adaptive_avgpool" , input, outTy); |
818 | |
819 | auto *save = F->createSave("save_out" , adaptiveAvgPool->getNthResult(0)); |
820 | |
821 | Placeholder *output = save->getPlaceholder(); |
822 | |
823 | ASSERT_TRUE(F->verify()); |
824 | |
825 | PlaceholderBindings bindings; |
826 | bindings.allocate({input, output}); |
827 | convertPlaceholdersToConstants(F, bindings, {input, output}); |
828 | |
829 | // Save and reload F. |
830 | Function *R; |
831 | Module reloadMod; |
832 | ASSIGN_VALUE_OR_FAIL_TEST( |
833 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {input->getType()})); |
834 | |
835 | // Verify reloaded function matches the original. |
836 | AdaptiveAvgPoolNode *adaptiveAvgPoolReloaded; |
837 | ASSIGN_VALUE_OR_FAIL_TEST(adaptiveAvgPoolReloaded, |
838 | getSingleNodeWithKind<AdaptiveAvgPoolNode>( |
839 | R, Kinded::Kind::AdaptiveAvgPoolNodeKind)); |
840 | |
841 | EXPECT_TRUE(adaptiveAvgPoolReloaded->getInput().getType()->isEqual( |
842 | *adaptiveAvgPool->getInput().getType())); |
843 | EXPECT_TRUE(adaptiveAvgPoolReloaded->getResult().getType()->isEqual( |
844 | *adaptiveAvgPool->getResult().getType())); |
845 | } |
846 | |
847 | TEST(exporter, RowwiseQuantizedFullyConnected) { |
848 | ExecutionEngine EE{}; |
849 | auto &mod = EE.getModule(); |
850 | auto *F = mod.createFunction("F" ); |
851 | |
852 | Placeholder *input = mod.createPlaceholder( |
853 | ElemKind::Int8QTy, {2, 100}, 1.2, 3, "input" , /* isTrainable */ false); |
854 | |
855 | Constant *weightsConstant = |
856 | mod.createConstant(ElemKind::Int8QTy, {10, 100}, 1.0, 0, "weights" ); |
857 | |
858 | Constant *biasConstant = |
859 | mod.createConstant(ElemKind::Int32QTy, {10}, 1.0, 0, "bias" ); |
860 | |
861 | Constant *scalesConstant = |
862 | mod.createConstant(ElemKind::FloatTy, {10}, "scales" ); |
863 | |
864 | Constant *offsetsConstant = |
865 | mod.createConstant(ElemKind::Int32ITy, {10}, "offsets" ); |
866 | |
867 | auto *outTy = mod.uniqueType(ElemKind::Int8QTy, {2, 10}, 3.8, 4); |
868 | |
869 | auto *rwqFC = F->createRowwiseQuantizedFullyConnected( |
870 | "rwqFC" , input, weightsConstant, scalesConstant, offsetsConstant, |
871 | biasConstant, outTy); |
872 | |
873 | auto *save = F->createSave("save_out" , rwqFC); |
874 | |
875 | Placeholder *output = save->getPlaceholder(); |
876 | |
877 | ASSERT_TRUE(F->verify()); |
878 | |
879 | PlaceholderBindings bindings; |
880 | bindings.allocate({input, output}); |
881 | |
882 | // Save and reload F. |
883 | Function *R; |
884 | Module reloadMod; |
885 | ASSIGN_VALUE_OR_FAIL_TEST( |
886 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {input->getType()})); |
887 | |
888 | RowwiseQuantizedFullyConnectedNode *rwqFCReloaded; |
889 | ASSIGN_VALUE_OR_FAIL_TEST( |
890 | rwqFCReloaded, |
891 | getSingleNodeWithKind<RowwiseQuantizedFullyConnectedNode>( |
892 | R, Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind)); |
893 | |
894 | EXPECT_TRUE(rwqFCReloaded->getInput().getType()->isEqual( |
895 | *rwqFC->getInput().getType())); |
896 | EXPECT_TRUE(rwqFCReloaded->getResult().getType()->isEqual( |
897 | *rwqFC->getResult().getType())); |
898 | |
899 | EXPECT_TRUE(rwqFCReloaded->getWeights().getType()->isEqual( |
900 | *rwqFC->getWeights().getType())); |
901 | EXPECT_TRUE( |
902 | rwqFCReloaded->getBias().getType()->isEqual(*rwqFC->getBias().getType())); |
903 | EXPECT_TRUE(rwqFCReloaded->getScales().getType()->isEqual( |
904 | *rwqFC->getScales().getType())); |
905 | EXPECT_TRUE(rwqFCReloaded->getOffsets().getType()->isEqual( |
906 | *rwqFC->getOffsets().getType())); |
907 | } |
908 | |
909 | TEST_F(ConstFoldReloadTest, exportGraphWithOneConstFoldingRecord) { |
910 | Placeholder *I = |
911 | mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input" , |
912 | /* isTrainable */ false); |
913 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight" ); |
914 | ClipNode *clipW = F_->createClip("clip" , W, -5.f, 5.f); |
915 | ConvertToNode *convertW = |
916 | F_->createConvertTo("conv" , clipW, ElemKind::Float16Ty); |
917 | TransposeNode *transposeW = |
918 | F_->createTranspose("transpose" , convertW, {1, 0}); |
919 | MatMulNode *MM = F_->createMatMul("matmul" , I, transposeW); |
920 | F_->createSave("save" , MM); |
921 | |
922 | bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10, |
923 | mod_.getPRNG()); |
924 | W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG()); |
925 | |
926 | serializeAndReloadAndCompareResults(1); |
927 | } |
928 | |
929 | TEST_F(ConstFoldReloadTest, exportGraphWithTwoConstFoldingRecords) { |
930 | Placeholder *I = |
931 | mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input" , |
932 | /* isTrainable */ false); |
933 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight" ); |
934 | ClipNode *clipW = F_->createClip("clip" , W, -5.f, 5.f); |
935 | ConvertToNode *convertW = |
936 | F_->createConvertTo("conv" , clipW, ElemKind::Float16Ty); |
937 | TransposeNode *transposeW = |
938 | F_->createTranspose("transpose" , convertW, {1, 0}); |
939 | MatMulNode *MM = F_->createMatMul("matmul" , I, transposeW); |
940 | F_->createSave("save_mm" , MM); |
941 | |
942 | Constant *W2 = mod_.createConstant(ElemKind::Float16Ty, {2, 100}, "weight2" ); |
943 | TanhNode *tanhW = F_->createTanh("tanh" , W2); |
944 | AddNode *add = F_->createAdd("add" , tanhW, I); |
945 | F_->createSave("save_add" , add); |
946 | |
947 | bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10, |
948 | mod_.getPRNG()); |
949 | W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG()); |
950 | W2->getPayloadMutable().getHandle<float16_t>().randomize(-10, 10, |
951 | mod_.getPRNG()); |
952 | |
953 | serializeAndReloadAndCompareResults(2); |
954 | } |
955 | |
956 | TEST_F(ConstFoldReloadTest, exportGraphWithTwoConstFoldingMultiOutputRecord) { |
957 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {100}, "weight" ); |
958 | SigmoidNode *sigmoidW = F_->createSigmoid("sig" , W); |
959 | ConvertToNode *convertW = |
960 | F_->createConvertTo("conv" , sigmoidW, ElemKind::Float16Ty); |
961 | TopKNode *TK = F_->createTopK("topk" , convertW, 5); |
962 | F_->createSave("save_indices" , TK->getIndices()); |
963 | |
964 | Placeholder *I = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "input" , |
965 | /* isTrainable */ false); |
966 | AddNode *add = F_->createAdd("add" , I, TK->getValues()); |
967 | F_->createSave("save_add" , add); |
968 | |
969 | bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10, |
970 | mod_.getPRNG()); |
971 | W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG()); |
972 | |
973 | serializeAndReloadAndCompareResults(2); |
974 | } |
975 | |
976 | /// Verify that exporting and reloading with placement hints retains the hints. |
977 | TEST_F(ConstFoldReloadTest, exportWithPlacementHints) { |
978 | auto *input1 = |
979 | mod_.createPlaceholder(ElemKind::Float16Ty, {16, 32}, "input1" , false); |
980 | auto *input2 = |
981 | mod_.createPlaceholder(ElemKind::Float16Ty, {16, 32}, "input2" , false); |
982 | auto *weights = |
983 | F_->getParent()->createConstant(ElemKind::Float16Ty, {16, 16}, "weights" ); |
984 | auto *bias = |
985 | F_->getParent()->createConstant(ElemKind::Float16Ty, {16}, "bias" ); |
986 | weights->getPayloadMutable().getHandle<float16_t>().randomize(-1.0, 1.0, |
987 | mod_.getPRNG()); |
988 | bias->getPayloadMutable().getHandle<float16_t>().randomize(-1.0, 1.0, |
989 | mod_.getPRNG()); |
990 | |
991 | auto *CI = F_->createConcat("concat" , {input1, input2}, 1); |
992 | auto *TN = F_->createTranspose("transpose" , CI, {1, 0}); |
993 | auto *FC = F_->createFullyConnected("fc" , TN, weights, bias); |
994 | auto *THN = F_->createTanh("tanh" , FC); |
995 | auto *SN = F_->createSigmoid("sigmoid" , THN); |
996 | F_->createSave("ret" , SN); |
997 | |
998 | auto *AN = F_->createAdd("add" , input1, input2); |
999 | F_->createSave("add_save" , AN); |
1000 | |
1001 | bindings_.allocate(input1)->getHandle<float16_t>().randomize(-1.0, 1.0, |
1002 | mod_.getPRNG()); |
1003 | bindings_.allocate(input2)->getHandle<float16_t>().randomize(-1.0, 1.0, |
1004 | mod_.getPRNG()); |
1005 | |
1006 | auto &nodeInfo = cctx_.backendOpts.backendSpecificNodeInfo[F_]; |
1007 | |
1008 | nodeInfo[AN]["Interpreter_Hint1" ].push_back(CI->getName().str()); |
1009 | nodeInfo[AN]["Interpreter_Hint2" ].push_back("@1" ); |
1010 | nodeInfo[CI]["Interpreter_Hint1" ].push_back(TN->getName().str()); |
1011 | nodeInfo[CI]["Interpreter_Hint3" ].push_back("@1" ); |
1012 | nodeInfo[CI]["Interpreter_Hint1" ].push_back(FC->getName().str()); |
1013 | nodeInfo[CI]["Interpreter_Hint3" ].push_back("@1" ); |
1014 | nodeInfo[AN]["Interpreter_Hint1" ].push_back(CI->getName().str()); |
1015 | nodeInfo[AN]["Interpreter_Hint2" ].push_back("@1" ); |
1016 | nodeInfo[TN]["Interpreter_Hint1" ].push_back(FC->getName().str()); |
1017 | nodeInfo[TN]["Interpreter_Hint1" ].push_back(SN->getName().str()); |
1018 | nodeInfo[FC]["Interpreter_Hint1" ].push_back(THN->getName().str()); |
1019 | |
1020 | nodeInfo[TN]["Interpreter_Hint4" ].push_back("3" ); |
1021 | nodeInfo[FC]["Interpreter_Hint4" ].push_back("2" ); |
1022 | nodeInfo[CI]["Interpreter_Hint4" ].push_back("1" ); |
1023 | nodeInfo[CI]["Interpreter_Hint5" ].push_back("@0" ); |
1024 | nodeInfo[CI]["Interpreter_Hint4" ].push_back("3" ); |
1025 | nodeInfo[CI]["Interpreter_Hint5" ].push_back("@1" ); |
1026 | |
1027 | serializeAndReloadAndCompareResults(0); |
1028 | } |
1029 | |
1030 | TEST_F(ConstFoldReloadTest, exportParallelizedGraphWithTwoConstFoldingRecords) { |
1031 | Placeholder *I = |
1032 | mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input" , |
1033 | /* isTrainable */ false); |
1034 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weights" ); |
1035 | Constant *B = mod_.createConstant(ElemKind::Float16Ty, {10}, "bias" ); |
1036 | |
1037 | ClipNode *clipW = F_->createClip("clip" , W, -5.f, 5.f); |
1038 | ConvertToNode *convertW = |
1039 | F_->createConvertTo("conv" , clipW, ElemKind::Float16Ty); |
1040 | TransposeNode *transposeW = |
1041 | F_->createTranspose("transpose" , convertW, {1, 0}); |
1042 | FullyConnectedNode *FC = F_->createFullyConnected("fc" , I, transposeW, B); |
1043 | F_->createSave("save" , FC); |
1044 | |
1045 | Constant *W2 = mod_.createConstant(ElemKind::Float16Ty, {2, 100}, "weight2" ); |
1046 | TanhNode *tanhW = F_->createTanh("tanh" , W2); |
1047 | AddNode *add = F_->createAdd("add" , tanhW, I); |
1048 | F_->createSave("save_add" , add); |
1049 | |
1050 | bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10, |
1051 | mod_.getPRNG()); |
1052 | W->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1053 | B->getHandle<float16_t>().randomize(0.0, 0.5, mod_.getPRNG()); |
1054 | W2->getPayloadMutable().getHandle<float16_t>().randomize(-10, 10, |
1055 | mod_.getPRNG()); |
1056 | |
1057 | serializeAndReloadAndCompareResults(2, {FC}); |
1058 | } |
1059 | |
1060 | TEST(exporter, VeryLongChain) { |
1061 | ExecutionEngine EE{}; |
1062 | auto &mod = EE.getModule(); |
1063 | auto *F = mod.createFunction("F" ); |
1064 | |
1065 | Placeholder *input = |
1066 | mod.createPlaceholder(ElemKind::Float16Ty, {1, 6}, "input" , false); |
1067 | |
1068 | Node *cur = input; |
1069 | for (dim_t iter = 0; iter < 3000; iter++) { |
1070 | auto *mul = F->createMul("mul" , cur, cur); |
1071 | auto *clip = F->createClip("clip" , mul, 0.0, 128.0); |
1072 | if (iter == 0) { |
1073 | F->createSave("save_out0" , clip); |
1074 | } |
1075 | cur = (Node *)clip; |
1076 | } |
1077 | auto *save = F->createSave("save_out" , cur); |
1078 | |
1079 | Placeholder *output = save->getPlaceholder(); |
1080 | |
1081 | ASSERT_TRUE(F->verify()); |
1082 | |
1083 | PlaceholderBindings bindings; |
1084 | bindings.allocate({input, output}); |
1085 | |
1086 | // Save and reload F. |
1087 | Function *R; |
1088 | Module reloadMod; |
1089 | ASSIGN_VALUE_OR_FAIL_TEST( |
1090 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {input->getType()})); |
1091 | (void)R; |
1092 | } |
1093 | |
1094 | /// Tests that we can serialize and then reload a model with OriginNameToTQPMap |
1095 | /// added to the model. Note that we don't do anything with the reloaded map. |
1096 | TEST(exporter, TestUniqueOffsetMapSerialization) { |
1097 | ExecutionEngine EE{}; |
1098 | auto &mod = EE.getModule(); |
1099 | auto *F = mod.createFunction("F" ); |
1100 | |
1101 | Placeholder *I = |
1102 | mod.createPlaceholder(ElemKind::Float16Ty, {5, 3}, "input" , false); |
1103 | Constant *W = |
1104 | mod.createConstant(ElemKind::Int8QTy, {3, 4}, 0.f, 0, "weights" ); |
1105 | Constant *B = mod.createConstant(ElemKind::Int8QTy, {4}, 0.f, 1, "bias" ); |
1106 | QuantizeNode *QI = F->createQuantize("quant" , I, ElemKind::Int8QTy, 0.f, 2); |
1107 | FullyConnectedNode *FC = F->createFullyConnected("fc" , QI, W, B); |
1108 | SaveNode *save = F->createSave("save_out" , FC); |
1109 | |
1110 | Placeholder *output = save->getPlaceholder(); |
1111 | |
1112 | ASSERT_TRUE(F->verify()); |
1113 | |
1114 | PlaceholderBindings bindings; |
1115 | bindings.allocate({I, output}); |
1116 | |
1117 | #define GET_TQP(T_) TensorQuantizationParams{T_->getScale(), T_->getOffset()} |
1118 | OriginNameToTQPMap originNameToTQPMap; |
1119 | originNameToTQPMap.emplace(W->getName(), GET_TQP(W->getOutput().getType())); |
1120 | originNameToTQPMap.emplace(B->getName(), GET_TQP(B->getOutput().getType())); |
1121 | originNameToTQPMap.emplace(QI->getName(), GET_TQP(QI->getResult().getType())); |
1122 | #undef GET_TQP |
1123 | |
1124 | // Save and reload F. |
1125 | Function *R; |
1126 | Module reloadMod; |
1127 | ASSIGN_VALUE_OR_FAIL_TEST( |
1128 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {I->getType()}, 7, 9, |
1129 | /* zipMode */ false, |
1130 | /* useGlowCustomOps */ true, |
1131 | /* useString */ false, |
1132 | /* includeConstantData */ true, |
1133 | /* record */ nullptr, /* reloadCctx */ nullptr, |
1134 | /* backendSpecificNodeInfo */ {}, |
1135 | originNameToTQPMap)); |
1136 | (void)R; |
1137 | } |
1138 | |
1139 | /// Test that we can serialize tensor strides. |
1140 | TEST(exporter, TestStridesSerialization) { |
1141 | ExecutionEngine EE{}; |
1142 | auto &mod = EE.getModule(); |
1143 | auto *F = mod.createFunction("F" ); |
1144 | |
1145 | auto ty = mod.uniqueType(ElemKind::Float16Ty, {5, 3}); |
1146 | // Create a type with non-standard strides. |
1147 | ty = mod.uniqueTypeWithNewStrides(ty, ty->dims(), {332, 1}); |
1148 | Placeholder *I = mod.createPlaceholder(ty, "input" , false); |
1149 | SaveNode *save = F->createSave("save_out" , I); |
1150 | |
1151 | Placeholder *output = save->getPlaceholder(); |
1152 | |
1153 | ASSERT_TRUE(F->verify()); |
1154 | |
1155 | PlaceholderBindings bindings; |
1156 | bindings.allocate({I, output}); |
1157 | |
1158 | // Save and reload F in text mode. |
1159 | { |
1160 | Function *R; |
1161 | Module reloadMod; |
1162 | ASSIGN_VALUE_OR_FAIL_TEST( |
1163 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {I->getType()}, 7, 9, |
1164 | /* zipMode */ false, |
1165 | /* useGlowCustomOps */ true, |
1166 | /* useString */ false, |
1167 | /* includeConstantData */ true, |
1168 | /* record */ nullptr, /* reloadCctx */ nullptr, |
1169 | /* backendSpecificNodeInfo */ {})); |
1170 | (void)R; |
1171 | } |
1172 | // Save and reload F in zip mode. |
1173 | { |
1174 | Function *R; |
1175 | Module reloadMod; |
1176 | ASSIGN_VALUE_OR_FAIL_TEST( |
1177 | R, saveAndReloadFunction(reloadMod, F, {"input" }, {I->getType()}, 7, 9, |
1178 | /* zipMode */ true, |
1179 | /* useGlowCustomOps */ true, |
1180 | /* useString */ false, |
1181 | /* includeConstantData */ true, |
1182 | /* record */ nullptr, /* reloadCctx */ nullptr, |
1183 | /* backendSpecificNodeInfo */ {})); |
1184 | (void)R; |
1185 | } |
1186 | } |
1187 | |