1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/ExecutionEngine/ExecutionEngine.h"
17#include "glow/Exporter/ONNXModelWriter.h"
18#include "glow/Graph/Graph.h"
19#include "glow/Graph/Nodes.h"
20#include "glow/Graph/PlaceholderBindings.h"
21#include "glow/Importer/ONNXModelLoader.h"
22#include "gtest/gtest.h"
23
24#include "llvm/Support/FileSystem.h"
25
26#ifndef GLOW_DATA_PATH
27#define GLOW_DATA_PATH
28#endif
29
30using namespace glow;
31
32namespace {
33/// Given a Function \p F and input names \p inputTensorNames and input types \p
34/// inputTensorTypes, writes the function to file and reads it back using the
35/// ONNXModelWriter and ONNXModelReader respectively then \returns the
36/// reloaded function. \p useGlowCustomOps is used for determining the format
37/// for ONNXModelWriter to write with. \p useString is used to use strings
38/// rather than files for reading and writing functions.
39Expected<Function *> saveAndReloadFunction(
40 Module &reloadMod, Function *F,
41 llvm::ArrayRef<const char *> inputTensorNames,
42 llvm::ArrayRef<TypeRef> inputTensorTypes, size_t irVer = 5,
43 size_t opsetVer = 10, bool zipMode = false, bool useGlowCustomOps = false,
44 bool useString = false, bool includeConstantData = true,
45 ConstantFoldingRecordMap *constFoldRecord = nullptr,
46 CompilationContext *reloadCctx = nullptr,
47 const BackendSpecificNodeInfo &backendSpecificNodeInfo = {},
48 const OriginNameToTQPMap &originNameToTQPMap = {}) {
49 std::string outputString;
50 std::string outputFilename = zipMode ? "output.zip" : "output.onnxtxt";
51
52 if (!useString) {
53 llvm::SmallString<64> path;
54
55 auto tempFileRes =
56 llvm::sys::fs::createTemporaryFile("exporter", outputFilename, path);
57
58 RETURN_ERR_IF_NOT(tempFileRes.value() == 0,
59 "Failed to create temp file to write into.");
60
61 outputFilename = path.c_str();
62 }
63 ScopeGuard cleanup([&]() { llvm::sys::fs::remove(outputFilename); });
64 if (useString) {
65 cleanup.dismiss();
66 }
67 // Write model to file or string.
68 {
69 Error err = Error::empty();
70 llvm::StringMap<std::string> extraMetadataProps;
71 RETURN_IF_ERR(ONNXModelWriter::insertLoaderNameUniqueOffsetMetadata(
72 extraMetadataProps, originNameToTQPMap));
73 ONNXModelWriter onnxWR(
74 outputFilename, *F, irVer, opsetVer, &err, !zipMode, zipMode,
75 useGlowCustomOps, includeConstantData, extraMetadataProps,
76 constFoldRecord ? *constFoldRecord : ConstantFoldingRecordMap(),
77 backendSpecificNodeInfo, (useString) ? &outputString : nullptr);
78 if (err) {
79 llvm::errs() << "Failed to write model\n";
80 }
81 RETURN_IF_ERR(std::move(err));
82 }
83
84 Function *R = nullptr;
85 Module &origMod = *F->getParent();
86 if (!includeConstantData) {
87 R = reloadMod.getFunction(F->getName());
88 RETURN_ERR_IF_NOT(R, "Did not find Function to reload into.");
89 R->clear();
90 if (constFoldRecord) {
91 // Additionally remove the original Constants that we first folded, so
92 // that when we reload below we can recreate them.
93 std::unordered_set<Function *> funsToDelete;
94 for (auto &pair : *constFoldRecord) {
95 Function *origF = pair.second->getParent();
96 funsToDelete.insert(origF);
97 Constant *C = reloadMod.getConstantByName(pair.first->getName());
98 RETURN_ERR_IF_NOT(C, "Did not find constant that was initially folded");
99 reloadMod.eraseConstant(C);
100 }
101 for (Function *origF : funsToDelete) {
102 Function *reloadConstFoldF = reloadMod.getFunction(origF->getName());
103 RETURN_ERR_IF_NOT(reloadConstFoldF,
104 "Did not find const folding function reloaded");
105 reloadMod.eraseFunction(reloadConstFoldF);
106 origMod.eraseFunction(origF);
107 }
108 }
109 } else {
110 R = reloadMod.createFunction("R");
111 }
112
113 // Load model from file.
114 {
115 Error err = Error::empty();
116 ONNXModelLoader onnxLD(
117 outputFilename, inputTensorNames, inputTensorTypes, *R, &err, zipMode,
118 reloadCctx ? &reloadCctx->backendOpts.backendSpecificNodeInfo : nullptr,
119 /* disableConstFoldInLoader */ true,
120 /* loadIntoExistingModule */ !includeConstantData,
121 /* Backend */ nullptr,
122 /* inputStringPtr */ (useString) ? &outputString : nullptr);
123
124 if (err) {
125 llvm::errs() << "ONNXModelLoader failed to reload model: "
126 << outputFilename << "\n";
127 }
128
129 RETURN_IF_ERR(std::move(err));
130 }
131
132 // Verify reloaded function is valid.
133 RETURN_ERR_IF_NOT(R->verify(), "Reloaded function is not valid.");
134
135 // Verify that the Constants from the original Module have the same data as
136 // those in the reloaded module.
137 if (constFoldRecord) {
138 deleteUnusedConstants(reloadMod);
139 deleteUnusedConstants(origMod);
140 for (Constant *newC : reloadMod.getConstants()) {
141 Constant *origC = R->getParent()->getConstantByName(newC->getName());
142 RETURN_ERR_IF_NOT(origC,
143 strFormat("Expected original Constant by name %s",
144 newC->getName().data()));
145 RETURN_ERR_IF_NOT(newC->getPayload().isBitwiseEqual(origC->getPayload()),
146 strFormat("Mismatch on Constants of name %s",
147 newC->getName().data()));
148 }
149 }
150
151 return R;
152}
153
154/// Loads model from ONNX format file \p name into glow Function.
155/// On success exports glow graph to the output file in "extended" ONNX format,
156/// i.e. some glow operators don't have presentation in vanilla ONNX standard.
157void testLoadAndSaveONNXModel(const std::string &name, bool zipMode,
158 bool useString) {
159 ExecutionEngine EE{};
160 auto &mod = EE.getModule();
161 Function *F = mod.createFunction("main");
162 llvm::errs() << "loading model " << name << "\n";
163
164 size_t irVer = 0, opsetVer = 0;
165
166 bool useGlowCustomOps = false;
167
168 // Load model from file.
169 {
170 Error err = Error::empty();
171 ONNXModelLoader onnxLD(name, {}, {}, *F, &err);
172 irVer = onnxLD.getIrVersion();
173 opsetVer = onnxLD.getOpSetVersion();
174 useGlowCustomOps = onnxLD.usingGlowCustomOps();
175
176 if (err) {
177 llvm::errs() << "ONNXModelLoader failed to load model: " << name << "\n";
178 }
179 FAIL_TEST_IF_ERR(std::move(err));
180 }
181
182 Module reloadMod;
183 FAIL_TEST_IF_ERR(saveAndReloadFunction(reloadMod, F, {}, {}, irVer, opsetVer,
184 zipMode, useGlowCustomOps, useString)
185 .takeError());
186}
187
188bool endsWith(const std::string &full, const std::string &ending) {
189 if (full.length() >= ending.length()) {
190 return (0 == full.compare(full.length() - ending.length(), ending.length(),
191 ending));
192 } else {
193 return false;
194 }
195}
196
197/// Given a Function \p F, \returns a list of nodes with the Kind \p kind.
198std::vector<Node *> getNodesByType(Function *F, Kinded::Kind kind) {
199 std::vector<Node *> found;
200 for (auto &N : F->getNodes()) {
201 if (N.getKind() == kind) {
202 found.push_back(&N);
203 }
204 }
205 return found;
206}
207
208/// Given a function \p F and a Kind \p type, returns a casted pointer to the
209/// single node in F with that kind or an Error if one occurs.
210template <typename T>
211Expected<T *> getSingleNodeWithKind(Function *F, Kinded::Kind type) {
212 auto nodesWithKind = getNodesByType(F, type);
213
214 RETURN_ERR_IF_NOT(nodesWithKind.size() == 1,
215 strFormat("Expected one node with kind %s but found %lu",
216 Kinded::getKindName(type), nodesWithKind.size()));
217
218 T *node = llvm::dyn_cast<T>(nodesWithKind[0]);
219
220 RETURN_ERR_IF_NOT(node != nullptr, "Node is not of expected types");
221
222 return node;
223}
224
225/// Helper that \returns whether two StringMaps \p LHS and \p RHS are equal.
226template <typename T>
227static bool isStrMapEqual(const llvm::StringMap<T> &LHS,
228 const llvm::StringMap<T> &RHS) {
229 if (LHS.size() != RHS.size()) {
230 return false;
231 }
232 for (const auto &keyValue : LHS) {
233 auto findInRHS = RHS.find(keyValue.getKey());
234 if (findInRHS == RHS.end()) {
235 return false;
236 }
237 if (keyValue.getValue() != findInRHS->getValue()) {
238 return false;
239 }
240 }
241 return true;
242}
243
244} // namespace
245
246/// Use to test constant folding exporting and reloading tests, where some
247/// constant folding is recorded and serialized in the custom Glow ONNX model.
248class ConstFoldReloadTest : public ::testing::Test {
249public:
250 ConstFoldReloadTest() : EE_("Interpreter"), mod_(EE_.getModule()) {
251 F_ = mod_.createFunction("main");
252 }
253
254protected:
255 ExecutionEngine EE_;
256 Module &mod_;
257 Function *F_;
258 PlaceholderBindings bindings_;
259 CompilationContext cctx_;
260
261 /// Constant folds \ref F_ and then serializes it. Then deserializes it and
262 /// runs it and makes sure that running the original and the reloaded Function
263 /// are bitwise equal. Verifies that \p numExpectedConstsFolded constant
264 /// folding records are created during constant folding/recording. Any Nodes
265 /// listed in \p nodesToPar will be Model parallelized in two.
266 void serializeAndReloadAndCompareResults(
267 unsigned numExpectedConstsFolded,
268 const std::unordered_set<Node *> &nodesToPar = {}) {
269 bindings_.allocate(mod_.getPlaceholders());
270
271 // Perform constant folding, recording what occurs so we can serialize it.
272 ConstantFoldingRecordMap record = constantFoldAndRecord(F_, cctx_);
273 EXPECT_EQ(record.size(), numExpectedConstsFolded);
274 runDCEPass(F_, cctx_);
275
276 if (nodesToPar.size()) {
277 llvm::DenseMap<Node *, size_t> numChunks;
278 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
279 for (Node *N : nodesToPar) {
280 numChunks[N] = 2;
281 parOpts[N] = ParallelTransformKind::Model;
282 }
283
284 std::unordered_map<Node *, ConcatNode *> replacedMap;
285 ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
286 ::glow::parallelizeOps(F_, numChunks, parOpts));
287 EXPECT_EQ(replacedMap.size(), parOpts.size());
288
289 ConstantFoldingRecordMap parRecord = constantFoldAndRecord(F_, cctx_);
290 record.insert(parRecord.begin(), parRecord.end());
291 runDCEPass(F_, cctx_);
292 }
293
294 // Clone the original module into a new module used for reloading the model.
295 ExecutionEngine reloadEE(EE_.getBackendName());
296 Module &reloadMod = reloadEE.getModule();
297 mod_.clone(&reloadMod);
298
299 // Save and reload F.
300 Function *reloadF;
301 CompilationContext reloadCctx;
302 ASSIGN_VALUE_OR_FAIL_TEST(
303 reloadF, saveAndReloadFunction(
304 reloadMod, F_, {}, {}, 7, 9,
305 /* zipMode */ false,
306 /* useGlowCustomOps */ true,
307 /* useString */ false,
308 /* includeConstantData */ false, &record, &reloadCctx,
309 cctx_.backendOpts.backendSpecificNodeInfo));
310
311 // Verify that the Function and its Module are the same before and after.
312 EXPECT_EQ(reloadF->toString(), F_->toString());
313 EXPECT_EQ(reloadMod.toString(), mod_.toString());
314
315 PlaceholderBindings reloadBindings =
316 bindings_.clone(reloadMod.getPlaceholders());
317
318 // Now run both to check they have bitwise equal results.
319 EE_.compile(cctx_);
320 EE_.run(bindings_);
321
322 reloadEE.compile(reloadCctx);
323 reloadEE.run(reloadBindings);
324
325 EXPECT_TRUE(
326 PlaceholderBindings::compare(&bindings_, &reloadBindings, 0.0f));
327
328 // Verify that backend-specific node info was serialized correctly.
329 EXPECT_EQ(cctx_.backendOpts.backendSpecificNodeInfo.count(F_),
330 reloadCctx.backendOpts.backendSpecificNodeInfo.count(reloadF));
331
332 if (cctx_.backendOpts.backendSpecificNodeInfo.count(F_)) {
333 auto &origNodeMap = cctx_.backendOpts.backendSpecificNodeInfo[F_];
334 auto &reloadNodeMap =
335 reloadCctx.backendOpts.backendSpecificNodeInfo[reloadF];
336
337 for (const Node &origN : F_->getNodes()) {
338 auto reloadNodeIt = std::find_if(
339 reloadF->getNodes().begin(), reloadF->getNodes().end(),
340 [&](const Node &N) { return N.getName() == origN.getName(); });
341 ASSERT_NE(reloadNodeIt, reloadF->getNodes().end());
342 EXPECT_TRUE(
343 isStrMapEqual(reloadNodeMap[&*reloadNodeIt], origNodeMap[&origN]));
344 }
345 }
346 }
347};
348
349TEST(exporter, onnxModels) {
350 std::string inputDirectory(GLOW_DATA_PATH "tests/models/onnxModels");
351 std::cout << "inputDirectory: " << inputDirectory << std::endl;
352 std::error_code code;
353 for (llvm::sys::fs::directory_iterator dirIt(inputDirectory, code);
354 !code && dirIt != llvm::sys::fs::directory_iterator();
355 dirIt.increment(code)) {
356 auto name = dirIt->path();
357 if (!endsWith(name, ".onnxtxt")) {
358 llvm::outs() << "Ignore non-onnxtxt input: " << name << "\n";
359 continue;
360 }
361 if (name.find("getInputsOnnxDefineSample.onnxtxt") != std::string::npos ||
362 name.find("preluInvalidBroadcastSlope.onnxtxt") != std::string::npos ||
363 name.find("padReflect.onnxtxt") != std::string::npos ||
364 name.find("powMultiBroadcastOp7.onnxtxt") != std::string::npos ||
365 name.find("gatherConstantFolding.onnxtxt") != std::string::npos ||
366 name.find("averagePool3D.onnxtxt") != std::string::npos ||
367 name.find("sparseLengthsSum.onnxtxt") != std::string::npos ||
368 name.find("constantOfShapeInt32Fail.onnxtxt") != std::string::npos ||
369 name.find("padEdge.onnxtxt") != std::string::npos ||
370 name.find("castToFloat.onnxtxt") != std::string::npos ||
371 name.find("castToFloat16.onnxtxt") != std::string::npos ||
372 name.find("castToInt64.onnxtxt") != std::string::npos ||
373 name.find("castToInt32.onnxtxt") != std::string::npos ||
374 name.find("simpleConvBiasFail.onnxtxt") != std::string::npos ||
375 name.find("Where.onnxtxt") != std::string::npos ||
376 name.find("constantOfShapeInt64Fail.onnxtxt") != std::string::npos ||
377 name.find("ArgMaxDefault.onnxtxt") != std::string::npos ||
378 name.find("ArgMaxKeepDim.onnxtxt") != std::string::npos ||
379 name.find("ArgMaxNoKeepDim.onnxtxt") != std::string::npos ||
380 name.find("ArgMinDefault.onnxtxt") != std::string::npos ||
381 name.find("ArgMinKeepDim.onnxtxt") != std::string::npos ||
382 name.find("ArgMinNoKeepDim.onnxtxt") != std::string::npos ||
383 name.find("upsampleOpset7.onnxtxt") != std::string::npos ||
384 name.find("upsampleOpset9.onnxtxt") != std::string::npos ||
385 name.find("resizeNearestV11compat.onnxtxt") != std::string::npos ||
386 name.find("resizeNearestV11compat_sizes.onnxtxt") !=
387 std::string::npos ||
388 name.find("resizeBilinearV11compat.onnxtxt") != std::string::npos ||
389 name.find("resizeBilinearV11compat_sizes.onnxtxt") !=
390 std::string::npos ||
391 name.find("upsampleOpset9.onnxtxt") != std::string::npos ||
392 name.find("NonMaxSuppressionSSD_ONNX.onnxtxt") != std::string::npos ||
393 name.find("NonMaxSuppression.onnxtxt") != std::string::npos ||
394 name.find("NonMaxSuppressionOptionalParams.onnxtxt") !=
395 std::string::npos ||
396 name.find("NonMaxSuppressionSSD.onnxtxt") != std::string::npos ||
397 name.find("ROIAlign_onnx.onnxtxt") != std::string::npos ||
398 name.find("MatMul4D.onnxtxt") != std::string::npos ||
399 name.find("Less.onnxtxt") != std::string::npos ||
400 name.find("Erf.onnxtxt") != std::string::npos ||
401 name.find("Asin.onnxtxt") != std::string::npos ||
402 name.find("Acos.onnxtxt") != std::string::npos ||
403 name.find("Atan.onnxtxt") != std::string::npos ||
404 name.find("Sin.onnxtxt") != std::string::npos ||
405 name.find("Cos.onnxtxt") != std::string::npos ||
406 name.find("abs.onnxtxt") != std::string::npos ||
407 name.find("log.onnxtxt") != std::string::npos ||
408 name.find("RangeInt32.onnxtxt") != std::string::npos ||
409 name.find("RangeFloat.onnxtxt") != std::string::npos ||
410 name.find("scatterND.onnxtxt") != std::string::npos ||
411 name.find("mscatterND.onnxtxt") != std::string::npos ||
412 name.find("loop_cond.onnxtxt") != std::string::npos ||
413 name.find("loop_empty_tripcount.onnxtxt") != std::string::npos ||
414 name.find("loop_emptycond.onnxtxt") != std::string::npos ||
415 name.find("loop_no_iteration.onnxtxt") != std::string::npos ||
416 name.find("loop_tripcount.onnxtxt") != std::string::npos ||
417 name.find("loop_withoutN.onnxtxt") != std::string::npos ||
418 name.find("sign.onnxtxt") != std::string::npos ||
419 name.find("gatherND.onnxtxt") != std::string::npos ||
420 name.find("softmax13.onnxtxt") != std::string::npos ||
421 name.find("logsoftmax.onnxtxt") != std::string::npos ||
422 name.find("hardsigmoid.onnxtxt") != std::string::npos ||
423 name.find("simpleConvTranspose.onnxtxt") != std::string::npos ||
424 name.find("simpleConvTransposeOutShape.onnxtxt") != std::string::npos ||
425 name.find("simpleConvTransposeOutShapeDilation.onnxtxt") !=
426 std::string::npos ||
427 name.find("simpleConvTransposeOutShapeSameLower.onnxtxt") !=
428 std::string::npos ||
429 name.find("simpleConvTransposeOutShapeSameUpper.onnxtxt") !=
430 std::string::npos ||
431 name.find("simpleConvTransposeAutoPadSameLower.onnxtxt") !=
432 std::string::npos ||
433 name.find("simpleConvTransposeAutoPadSameUpper.onnxtxt") !=
434 std::string::npos ||
435 name.find("convTransposeAsymmetric.onnxtxt") != std::string::npos ||
436 name.find("Mean.onnxtxt") != std::string::npos ||
437 name.find("Mean_broadcast.onnxtxt") != std::string::npos ||
438 name.find("NonZero.onnxtxt") != std::string::npos ||
439 name.find("logicalAnd.onnxtxt") != std::string::npos ||
440 name.find("logicalAndBcast.onnxtxt") != std::string::npos ||
441 name.find("logicalOrBcast.onnxtxt") != std::string::npos ||
442 name.find("logicalOr.onnxtxt") != std::string::npos ||
443 name.find("logicalXorBcast.onnxtxt") != std::string::npos ||
444 name.find("logicalXor.onnxtxt") != std::string::npos ||
445 name.find("logicalNot.onnxtxt") != std::string::npos ||
446
447 name.find("simpleConvTransposePads.onnxtxt") != std::string::npos ||
448 name.find("simpleConvTransposeAutoPadValid.onnxtxt") !=
449 std::string::npos ||
450 name.find("simpleConvTransposeOutShapeSameUpper.onnxtxt") !=
451 std::string::npos ||
452 name.find("simpleConvTransposeAutoPadSameLower.onnxtxt") !=
453 std::string::npos ||
454 name.find("convTransposeGroup.onnxtxt") != std::string::npos ||
455 name.find("pow_element_wise.onnxtxt") != std::string::npos ||
456 name.find("pow_array_broadcast.onnxtxt") != std::string::npos ||
457 name.find("pow_scalar_broadcast.onnxtxt") != std::string::npos ||
458 name.find("simpleConvTransposeAutoPadSameUpper.onnxtxt") !=
459 std::string::npos ||
460 name.find("sliceInvalidAxes.onnxtxt") != std::string::npos ||
461 name.find("sliceWithUnsupportedStep.onnxtxt") != std::string::npos ||
462 name.find("simpleConv3DNonSquareDilation.onnxtxt") !=
463 std::string::npos) {
464 // Ignore invalid ONNX files and graphs without nodes.
465 llvm::outs() << "Ignore invalid input files: " << name << "\n";
466 continue;
467 }
468 if (name.find("constant.onnxtxt") != std::string::npos ||
469 name.find("shape.onnxtxt") != std::string::npos ||
470 name.find("bool_from_int.onnxtxt") != std::string::npos ||
471 name.find("sum1.onnxtxt") != std::string::npos) {
472 // Ignore invalid ONNX files and graphs without nodes.
473 llvm::outs() << "Ignore empty graph file: " << name << "\n";
474 continue;
475 }
476 if (name.find(".output.onnxtxt") != std::string::npos) {
477 // Ignore output files - debugging mode only.
478 llvm::outs() << "Ignore output file: " << name << "\n";
479 continue;
480 }
481 // TODO: Debug why these RNN models don`t work!
482 if (name.find("rnn") != std::string::npos) {
483 // Ignore RNN files.
484 llvm::outs() << "Ignore RNN model file: " << name << "\n";
485 continue;
486 }
487 if (name.find("gru") != std::string::npos) {
488 // Ignore GRU files.
489 llvm::outs() << "Ignore GRU model file: " << name << "\n";
490 continue;
491 }
492 if (name.find("lstm") != std::string::npos) {
493 // Ignore LSTM files.
494 llvm::outs() << "Ignore LSTM model file: " << name << "\n";
495 continue;
496 }
497 const bool customOnnxDefineSymbol =
498 name.find("dimParam.onnxtxt") != std::string::npos;
499 if (customOnnxDefineSymbol) {
500 setOnnxDefineSymbol({"ONNXUndefinedSymbol,1"});
501 }
502
503 // Disable constant folding for these tests.
504 setConstantFoldLoaderOpsFlag(false);
505
506 testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ true,
507 /* useString */ false);
508 testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ false,
509 /* useString */ false);
510 testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ false,
511 /* useString */ true);
512
513 // Reset the custom symbol used.
514 if (customOnnxDefineSymbol) {
515 setOnnxDefineSymbol({});
516 }
517 }
518}
519
520TEST(exporter, ChannelwiseQuantizedConvolution) {
521 ExecutionEngine EE{};
522 auto &mod = EE.getModule();
523 auto *F = mod.createFunction("F");
524
525 unsigned_t inChannels = 8;
526 unsigned_t inSide = 6;
527 unsigned_t batchSize = 8;
528 unsigned_t outChannels = 12;
529 unsigned_t filterSide = 3;
530 unsigned_t groups = 4;
531 unsigned_t dilation = 1;
532
533 Placeholder *input = mod.createPlaceholder(
534 ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3,
535 "input", /* isTrainable */ false);
536
537 Constant *biasConstant =
538 mod.createConstant(ElemKind::FloatTy, {outChannels}, "bias");
539 biasConstant->getPayloadMutable().getHandle<float>().randomize(-0.1, 0.1,
540 mod.getPRNG());
541
542 Constant *filterScalesConstant =
543 mod.createConstant(ElemKind::FloatTy, {outChannels}, "filter_scales");
544
545 Constant *filterOffsetsConstant =
546 mod.createConstant(ElemKind::Int32ITy, {outChannels}, "filter_offsets");
547
548 Constant *weightsConstant = mod.createConstant(
549 ElemKind::Int8QTy,
550 {outChannels, filterSide, filterSide, inChannels / groups}, 2.5, 1,
551 "offsets");
552
553 std::vector<unsigned_t> kernels = {filterSide, filterSide};
554 std::vector<unsigned_t> strides = {1, 1};
555 std::vector<unsigned_t> pads = {1, 1, 1, 1};
556
557 auto outSize =
558 calculateConvPoolOutputDims(inSide, inSide, kernels, strides, pads);
559 auto *outTy = mod.uniqueType(
560 ElemKind::Int8QTy,
561 {batchSize, outSize.first, outSize.second, outChannels}, 3.8, 4);
562
563 auto *cqConv = F->createChannelwiseQuantizedConv(
564 "cqconv", input, weightsConstant, biasConstant, filterScalesConstant,
565 filterOffsetsConstant, /* biasScales */ nullptr,
566 /* biasOffsets */ nullptr, outTy, kernels, strides, pads, groups,
567 {dilation, dilation});
568
569 auto *save = F->createSave("save_out", cqConv);
570
571 Placeholder *output = save->getPlaceholder();
572
573 ASSERT_TRUE(F->verify());
574
575 PlaceholderBindings bindings;
576 bindings.allocate({input, output});
577
578 // Save and reload F.
579 Function *R;
580 Module reloadMod;
581 ASSIGN_VALUE_OR_FAIL_TEST(
582 R, saveAndReloadFunction(reloadMod, F, {"input"}, {input->getType()}));
583
584 ChannelwiseQuantizedConvolutionNode *cqConvReloaded;
585 ASSIGN_VALUE_OR_FAIL_TEST(
586 cqConvReloaded,
587 getSingleNodeWithKind<ChannelwiseQuantizedConvolutionNode>(
588 R, Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind));
589
590 EXPECT_TRUE(cqConvReloaded->getInput().getType()->isEqual(
591 *cqConv->getInput().getType()));
592 EXPECT_TRUE(cqConvReloaded->getResult().getType()->isEqual(
593 *cqConv->getResult().getType()));
594
595 EXPECT_TRUE(cqConvReloaded->getFilter().getType()->isEqual(
596 *cqConv->getFilter().getType()));
597 EXPECT_TRUE(cqConvReloaded->getBias().getType()->isEqual(
598 *cqConv->getBias().getType()));
599 EXPECT_TRUE(cqConvReloaded->getFilterScales().getType()->isEqual(
600 *cqConv->getFilterScales().getType()));
601 EXPECT_TRUE(cqConvReloaded->getFilterOffsets().getType()->isEqual(
602 *cqConv->getFilterOffsets().getType()));
603 EXPECT_TRUE(cqConvReloaded->getBiasScales().getType()->isEqual(
604 *cqConv->getBiasScales().getType()));
605 EXPECT_TRUE(cqConvReloaded->getBiasOffsets().getType()->isEqual(
606 *cqConv->getBiasOffsets().getType()));
607
608 EXPECT_EQ(cqConvReloaded->getKernels(), cqConv->getKernels());
609 EXPECT_EQ(cqConvReloaded->getStrides(), cqConv->getStrides());
610 EXPECT_EQ(cqConvReloaded->getPads(), cqConv->getPads());
611 EXPECT_EQ(cqConvReloaded->getGroup(), cqConv->getGroup());
612 EXPECT_EQ(cqConvReloaded->getDilation(), cqConv->getDilation());
613}
614
615TEST(exporter, QuantizedConvolution) {
616 ExecutionEngine EE{};
617 auto &mod = EE.getModule();
618 auto *F = mod.createFunction("F");
619
620 unsigned_t inChannels = 8;
621 unsigned_t inSide = 6;
622 unsigned_t batchSize = 8;
623 unsigned_t outChannels = 12;
624 unsigned_t filterSide = 3;
625 unsigned_t groups = 4;
626 unsigned_t dilation = 1;
627
628 Placeholder *input = mod.createPlaceholder(
629 ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3,
630 "input", /* isTrainable */ false);
631
632 Placeholder *weights = mod.createPlaceholder(
633 ElemKind::Int8QTy,
634 {outChannels, filterSide, filterSide, inChannels / groups}, 2.5, 1,
635 "weights",
636 /* isTrainable */ false);
637
638 Placeholder *bias =
639 mod.createPlaceholder(ElemKind::Int32QTy, {outChannels}, 0.25, 2, "bias",
640 /* isTrainable */ false);
641
642 std::vector<unsigned_t> kernels = {filterSide, filterSide};
643 std::vector<unsigned_t> strides = {1, 1};
644 std::vector<unsigned_t> pads = {1, 1, 1, 1};
645
646 auto outSize =
647 calculateConvPoolOutputDims(inSide, inSide, kernels, strides, pads);
648 auto *outTy = mod.uniqueType(
649 ElemKind::Int8QTy,
650 {batchSize, outSize.first, outSize.second, outChannels}, 3.8, 4);
651
652 auto *qConv = F->createConv("qconv", input, weights, bias, outTy, kernels,
653 strides, pads, groups, {dilation, dilation});
654
655 auto *save = F->createSave("save_out", qConv);
656
657 Placeholder *output = save->getPlaceholder();
658
659 ASSERT_TRUE(F->verify());
660
661 PlaceholderBindings bindings;
662 bindings.allocate({weights, bias, input, output});
663 convertPlaceholdersToConstants(F, bindings, {input, output});
664
665 // Save and reload F.
666 Function *R;
667 Module reloadMod;
668 ASSIGN_VALUE_OR_FAIL_TEST(
669 R, saveAndReloadFunction(reloadMod, F, {"input"}, {input->getType()}));
670
671 // Verify reloaded function matches the original.
672 ConvolutionNode *qConvReloaded;
673 ASSIGN_VALUE_OR_FAIL_TEST(qConvReloaded,
674 getSingleNodeWithKind<ConvolutionNode>(
675 R, Kinded::Kind::ConvolutionNodeKind));
676
677 EXPECT_TRUE(qConvReloaded->getInput().getType()->isEqual(
678 *qConv->getInput().getType()));
679 EXPECT_TRUE(qConvReloaded->getResult().getType()->isEqual(
680 *qConv->getResult().getType()));
681
682 EXPECT_TRUE(qConvReloaded->getFilter().getType()->isEqual(
683 *qConv->getFilter().getType()));
684 EXPECT_TRUE(
685 qConvReloaded->getBias().getType()->isEqual(*qConv->getBias().getType()));
686
687 EXPECT_EQ(qConvReloaded->getKernels(), qConv->getKernels());
688 EXPECT_EQ(qConvReloaded->getStrides(), qConv->getStrides());
689 EXPECT_EQ(qConvReloaded->getPads(), qConv->getPads());
690 EXPECT_EQ(qConvReloaded->getGroup(), qConv->getGroup());
691 EXPECT_EQ(qConvReloaded->getDilation(), qConv->getDilation());
692}
693
694TEST(exporter, QuantizedMaxPool) {
695 ExecutionEngine EE{};
696 auto &mod = EE.getModule();
697 auto *F = mod.createFunction("F");
698
699 unsigned_t inChannels = 8;
700 unsigned_t inSide = 6;
701 unsigned_t batchSize = 8;
702 unsigned_t filterSide = 3;
703
704 Placeholder *input = mod.createPlaceholder(
705 ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3,
706 "input", /* isTrainable */ false);
707
708 std::vector<unsigned_t> kernels = {filterSide, filterSide};
709 std::vector<unsigned_t> strides = {1, 1};
710 std::vector<unsigned_t> pads = {1, 1, 1, 1};
711
712 auto *maxPool = F->createMaxPool("maxpool", input, kernels, strides, pads);
713
714 auto *save = F->createSave("save_out", maxPool->getNthResult(0));
715
716 Placeholder *output = save->getPlaceholder();
717
718 ASSERT_TRUE(F->verify());
719
720 PlaceholderBindings bindings;
721 bindings.allocate({input, output});
722 convertPlaceholdersToConstants(F, bindings, {input, output});
723
724 // Save and reload F.
725 Function *R;
726 Module reloadMod;
727 ASSIGN_VALUE_OR_FAIL_TEST(
728 R, saveAndReloadFunction(reloadMod, F, {"input"}, {input->getType()}));
729
730 // Verify reloaded function matches the original.
731 MaxPoolNode *maxPoolReloaded;
732 ASSIGN_VALUE_OR_FAIL_TEST(
733 maxPoolReloaded,
734 getSingleNodeWithKind<MaxPoolNode>(R, Kinded::Kind::MaxPoolNodeKind));
735
736 EXPECT_TRUE(maxPoolReloaded->getInput().getType()->isEqual(
737 *maxPool->getInput().getType()));
738 EXPECT_TRUE(maxPoolReloaded->getResult().getType()->isEqual(
739 *maxPool->getResult().getType()));
740
741 EXPECT_EQ(maxPoolReloaded->getKernels(), maxPool->getKernels());
742 EXPECT_EQ(maxPoolReloaded->getStrides(), maxPool->getStrides());
743 EXPECT_EQ(maxPoolReloaded->getPads(), maxPool->getPads());
744}
745
746TEST(exporter, QuantizedAvgPool) {
747 ExecutionEngine EE{};
748 auto &mod = EE.getModule();
749 auto *F = mod.createFunction("F");
750
751 unsigned_t inChannels = 8;
752 unsigned_t inSide = 6;
753 unsigned_t batchSize = 8;
754 unsigned_t filterSide = 3;
755
756 Placeholder *input = mod.createPlaceholder(
757 ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3,
758 "input", /* isTrainable */ false);
759
760 std::vector<unsigned_t> kernels = {filterSide, filterSide};
761 std::vector<unsigned_t> strides = {1, 1};
762 std::vector<unsigned_t> pads = {1, 1, 1, 1};
763
764 auto *avgPool = F->createAvgPool("avgpool", input, kernels, strides, pads);
765
766 auto *save = F->createSave("save_out", avgPool->getNthResult(0));
767
768 Placeholder *output = save->getPlaceholder();
769
770 ASSERT_TRUE(F->verify());
771
772 PlaceholderBindings bindings;
773 bindings.allocate({input, output});
774 convertPlaceholdersToConstants(F, bindings, {input, output});
775
776 // Save and reload F.
777 Function *R;
778 Module reloadMod;
779 ASSIGN_VALUE_OR_FAIL_TEST(
780 R, saveAndReloadFunction(reloadMod, F, {"input"}, {input->getType()}));
781
782 // Verify reloaded function matches the original.
783 AvgPoolNode *avgPoolReloaded;
784 ASSIGN_VALUE_OR_FAIL_TEST(
785 avgPoolReloaded,
786 getSingleNodeWithKind<AvgPoolNode>(R, Kinded::Kind::AvgPoolNodeKind));
787
788 EXPECT_TRUE(avgPoolReloaded->getInput().getType()->isEqual(
789 *avgPool->getInput().getType()));
790 EXPECT_TRUE(avgPoolReloaded->getResult().getType()->isEqual(
791 *avgPool->getResult().getType()));
792
793 EXPECT_EQ(avgPoolReloaded->getKernels(), avgPool->getKernels());
794 EXPECT_EQ(avgPoolReloaded->getStrides(), avgPool->getStrides());
795 EXPECT_EQ(avgPoolReloaded->getPads(), avgPool->getPads());
796 EXPECT_EQ(avgPoolReloaded->getCountIncludePads(),
797 avgPool->getCountIncludePads());
798}
799
800TEST(exporter, QuantizedAdaptiveAvgPool) {
801 ExecutionEngine EE{};
802 auto &mod = EE.getModule();
803 auto *F = mod.createFunction("F");
804
805 unsigned_t inChannels = 8;
806 unsigned_t inSide = 6;
807 unsigned_t batchSize = 8;
808
809 Placeholder *input = mod.createPlaceholder(
810 ElemKind::Int8QTy, {batchSize, inSide, inSide, inChannels}, 1.2, 3,
811 "input", /* isTrainable */ false);
812
813 auto *outTy = mod.uniqueTypeWithNewShape(input->getType(),
814 {batchSize, 3, 3, inChannels});
815
816 auto *adaptiveAvgPool =
817 F->createAdaptiveAvgPool("adaptive_avgpool", input, outTy);
818
819 auto *save = F->createSave("save_out", adaptiveAvgPool->getNthResult(0));
820
821 Placeholder *output = save->getPlaceholder();
822
823 ASSERT_TRUE(F->verify());
824
825 PlaceholderBindings bindings;
826 bindings.allocate({input, output});
827 convertPlaceholdersToConstants(F, bindings, {input, output});
828
829 // Save and reload F.
830 Function *R;
831 Module reloadMod;
832 ASSIGN_VALUE_OR_FAIL_TEST(
833 R, saveAndReloadFunction(reloadMod, F, {"input"}, {input->getType()}));
834
835 // Verify reloaded function matches the original.
836 AdaptiveAvgPoolNode *adaptiveAvgPoolReloaded;
837 ASSIGN_VALUE_OR_FAIL_TEST(adaptiveAvgPoolReloaded,
838 getSingleNodeWithKind<AdaptiveAvgPoolNode>(
839 R, Kinded::Kind::AdaptiveAvgPoolNodeKind));
840
841 EXPECT_TRUE(adaptiveAvgPoolReloaded->getInput().getType()->isEqual(
842 *adaptiveAvgPool->getInput().getType()));
843 EXPECT_TRUE(adaptiveAvgPoolReloaded->getResult().getType()->isEqual(
844 *adaptiveAvgPool->getResult().getType()));
845}
846
847TEST(exporter, RowwiseQuantizedFullyConnected) {
848 ExecutionEngine EE{};
849 auto &mod = EE.getModule();
850 auto *F = mod.createFunction("F");
851
852 Placeholder *input = mod.createPlaceholder(
853 ElemKind::Int8QTy, {2, 100}, 1.2, 3, "input", /* isTrainable */ false);
854
855 Constant *weightsConstant =
856 mod.createConstant(ElemKind::Int8QTy, {10, 100}, 1.0, 0, "weights");
857
858 Constant *biasConstant =
859 mod.createConstant(ElemKind::Int32QTy, {10}, 1.0, 0, "bias");
860
861 Constant *scalesConstant =
862 mod.createConstant(ElemKind::FloatTy, {10}, "scales");
863
864 Constant *offsetsConstant =
865 mod.createConstant(ElemKind::Int32ITy, {10}, "offsets");
866
867 auto *outTy = mod.uniqueType(ElemKind::Int8QTy, {2, 10}, 3.8, 4);
868
869 auto *rwqFC = F->createRowwiseQuantizedFullyConnected(
870 "rwqFC", input, weightsConstant, scalesConstant, offsetsConstant,
871 biasConstant, outTy);
872
873 auto *save = F->createSave("save_out", rwqFC);
874
875 Placeholder *output = save->getPlaceholder();
876
877 ASSERT_TRUE(F->verify());
878
879 PlaceholderBindings bindings;
880 bindings.allocate({input, output});
881
882 // Save and reload F.
883 Function *R;
884 Module reloadMod;
885 ASSIGN_VALUE_OR_FAIL_TEST(
886 R, saveAndReloadFunction(reloadMod, F, {"input"}, {input->getType()}));
887
888 RowwiseQuantizedFullyConnectedNode *rwqFCReloaded;
889 ASSIGN_VALUE_OR_FAIL_TEST(
890 rwqFCReloaded,
891 getSingleNodeWithKind<RowwiseQuantizedFullyConnectedNode>(
892 R, Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind));
893
894 EXPECT_TRUE(rwqFCReloaded->getInput().getType()->isEqual(
895 *rwqFC->getInput().getType()));
896 EXPECT_TRUE(rwqFCReloaded->getResult().getType()->isEqual(
897 *rwqFC->getResult().getType()));
898
899 EXPECT_TRUE(rwqFCReloaded->getWeights().getType()->isEqual(
900 *rwqFC->getWeights().getType()));
901 EXPECT_TRUE(
902 rwqFCReloaded->getBias().getType()->isEqual(*rwqFC->getBias().getType()));
903 EXPECT_TRUE(rwqFCReloaded->getScales().getType()->isEqual(
904 *rwqFC->getScales().getType()));
905 EXPECT_TRUE(rwqFCReloaded->getOffsets().getType()->isEqual(
906 *rwqFC->getOffsets().getType()));
907}
908
909TEST_F(ConstFoldReloadTest, exportGraphWithOneConstFoldingRecord) {
910 Placeholder *I =
911 mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
912 /* isTrainable */ false);
913 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
914 ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
915 ConvertToNode *convertW =
916 F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
917 TransposeNode *transposeW =
918 F_->createTranspose("transpose", convertW, {1, 0});
919 MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
920 F_->createSave("save", MM);
921
922 bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10,
923 mod_.getPRNG());
924 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
925
926 serializeAndReloadAndCompareResults(1);
927}
928
929TEST_F(ConstFoldReloadTest, exportGraphWithTwoConstFoldingRecords) {
930 Placeholder *I =
931 mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
932 /* isTrainable */ false);
933 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
934 ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
935 ConvertToNode *convertW =
936 F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
937 TransposeNode *transposeW =
938 F_->createTranspose("transpose", convertW, {1, 0});
939 MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
940 F_->createSave("save_mm", MM);
941
942 Constant *W2 = mod_.createConstant(ElemKind::Float16Ty, {2, 100}, "weight2");
943 TanhNode *tanhW = F_->createTanh("tanh", W2);
944 AddNode *add = F_->createAdd("add", tanhW, I);
945 F_->createSave("save_add", add);
946
947 bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10,
948 mod_.getPRNG());
949 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
950 W2->getPayloadMutable().getHandle<float16_t>().randomize(-10, 10,
951 mod_.getPRNG());
952
953 serializeAndReloadAndCompareResults(2);
954}
955
956TEST_F(ConstFoldReloadTest, exportGraphWithTwoConstFoldingMultiOutputRecord) {
957 Constant *W = mod_.createConstant(ElemKind::FloatTy, {100}, "weight");
958 SigmoidNode *sigmoidW = F_->createSigmoid("sig", W);
959 ConvertToNode *convertW =
960 F_->createConvertTo("conv", sigmoidW, ElemKind::Float16Ty);
961 TopKNode *TK = F_->createTopK("topk", convertW, 5);
962 F_->createSave("save_indices", TK->getIndices());
963
964 Placeholder *I = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "input",
965 /* isTrainable */ false);
966 AddNode *add = F_->createAdd("add", I, TK->getValues());
967 F_->createSave("save_add", add);
968
969 bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10,
970 mod_.getPRNG());
971 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
972
973 serializeAndReloadAndCompareResults(2);
974}
975
976/// Verify that exporting and reloading with placement hints retains the hints.
977TEST_F(ConstFoldReloadTest, exportWithPlacementHints) {
978 auto *input1 =
979 mod_.createPlaceholder(ElemKind::Float16Ty, {16, 32}, "input1", false);
980 auto *input2 =
981 mod_.createPlaceholder(ElemKind::Float16Ty, {16, 32}, "input2", false);
982 auto *weights =
983 F_->getParent()->createConstant(ElemKind::Float16Ty, {16, 16}, "weights");
984 auto *bias =
985 F_->getParent()->createConstant(ElemKind::Float16Ty, {16}, "bias");
986 weights->getPayloadMutable().getHandle<float16_t>().randomize(-1.0, 1.0,
987 mod_.getPRNG());
988 bias->getPayloadMutable().getHandle<float16_t>().randomize(-1.0, 1.0,
989 mod_.getPRNG());
990
991 auto *CI = F_->createConcat("concat", {input1, input2}, 1);
992 auto *TN = F_->createTranspose("transpose", CI, {1, 0});
993 auto *FC = F_->createFullyConnected("fc", TN, weights, bias);
994 auto *THN = F_->createTanh("tanh", FC);
995 auto *SN = F_->createSigmoid("sigmoid", THN);
996 F_->createSave("ret", SN);
997
998 auto *AN = F_->createAdd("add", input1, input2);
999 F_->createSave("add_save", AN);
1000
1001 bindings_.allocate(input1)->getHandle<float16_t>().randomize(-1.0, 1.0,
1002 mod_.getPRNG());
1003 bindings_.allocate(input2)->getHandle<float16_t>().randomize(-1.0, 1.0,
1004 mod_.getPRNG());
1005
1006 auto &nodeInfo = cctx_.backendOpts.backendSpecificNodeInfo[F_];
1007
1008 nodeInfo[AN]["Interpreter_Hint1"].push_back(CI->getName().str());
1009 nodeInfo[AN]["Interpreter_Hint2"].push_back("@1");
1010 nodeInfo[CI]["Interpreter_Hint1"].push_back(TN->getName().str());
1011 nodeInfo[CI]["Interpreter_Hint3"].push_back("@1");
1012 nodeInfo[CI]["Interpreter_Hint1"].push_back(FC->getName().str());
1013 nodeInfo[CI]["Interpreter_Hint3"].push_back("@1");
1014 nodeInfo[AN]["Interpreter_Hint1"].push_back(CI->getName().str());
1015 nodeInfo[AN]["Interpreter_Hint2"].push_back("@1");
1016 nodeInfo[TN]["Interpreter_Hint1"].push_back(FC->getName().str());
1017 nodeInfo[TN]["Interpreter_Hint1"].push_back(SN->getName().str());
1018 nodeInfo[FC]["Interpreter_Hint1"].push_back(THN->getName().str());
1019
1020 nodeInfo[TN]["Interpreter_Hint4"].push_back("3");
1021 nodeInfo[FC]["Interpreter_Hint4"].push_back("2");
1022 nodeInfo[CI]["Interpreter_Hint4"].push_back("1");
1023 nodeInfo[CI]["Interpreter_Hint5"].push_back("@0");
1024 nodeInfo[CI]["Interpreter_Hint4"].push_back("3");
1025 nodeInfo[CI]["Interpreter_Hint5"].push_back("@1");
1026
1027 serializeAndReloadAndCompareResults(0);
1028}
1029
1030TEST_F(ConstFoldReloadTest, exportParallelizedGraphWithTwoConstFoldingRecords) {
1031 Placeholder *I =
1032 mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
1033 /* isTrainable */ false);
1034 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weights");
1035 Constant *B = mod_.createConstant(ElemKind::Float16Ty, {10}, "bias");
1036
1037 ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
1038 ConvertToNode *convertW =
1039 F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
1040 TransposeNode *transposeW =
1041 F_->createTranspose("transpose", convertW, {1, 0});
1042 FullyConnectedNode *FC = F_->createFullyConnected("fc", I, transposeW, B);
1043 F_->createSave("save", FC);
1044
1045 Constant *W2 = mod_.createConstant(ElemKind::Float16Ty, {2, 100}, "weight2");
1046 TanhNode *tanhW = F_->createTanh("tanh", W2);
1047 AddNode *add = F_->createAdd("add", tanhW, I);
1048 F_->createSave("save_add", add);
1049
1050 bindings_.allocate(I)->getHandle<float16_t>().randomize(-10, 10,
1051 mod_.getPRNG());
1052 W->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1053 B->getHandle<float16_t>().randomize(0.0, 0.5, mod_.getPRNG());
1054 W2->getPayloadMutable().getHandle<float16_t>().randomize(-10, 10,
1055 mod_.getPRNG());
1056
1057 serializeAndReloadAndCompareResults(2, {FC});
1058}
1059
1060TEST(exporter, VeryLongChain) {
1061 ExecutionEngine EE{};
1062 auto &mod = EE.getModule();
1063 auto *F = mod.createFunction("F");
1064
1065 Placeholder *input =
1066 mod.createPlaceholder(ElemKind::Float16Ty, {1, 6}, "input", false);
1067
1068 Node *cur = input;
1069 for (dim_t iter = 0; iter < 3000; iter++) {
1070 auto *mul = F->createMul("mul", cur, cur);
1071 auto *clip = F->createClip("clip", mul, 0.0, 128.0);
1072 if (iter == 0) {
1073 F->createSave("save_out0", clip);
1074 }
1075 cur = (Node *)clip;
1076 }
1077 auto *save = F->createSave("save_out", cur);
1078
1079 Placeholder *output = save->getPlaceholder();
1080
1081 ASSERT_TRUE(F->verify());
1082
1083 PlaceholderBindings bindings;
1084 bindings.allocate({input, output});
1085
1086 // Save and reload F.
1087 Function *R;
1088 Module reloadMod;
1089 ASSIGN_VALUE_OR_FAIL_TEST(
1090 R, saveAndReloadFunction(reloadMod, F, {"input"}, {input->getType()}));
1091 (void)R;
1092}
1093
1094/// Tests that we can serialize and then reload a model with OriginNameToTQPMap
1095/// added to the model. Note that we don't do anything with the reloaded map.
1096TEST(exporter, TestUniqueOffsetMapSerialization) {
1097 ExecutionEngine EE{};
1098 auto &mod = EE.getModule();
1099 auto *F = mod.createFunction("F");
1100
1101 Placeholder *I =
1102 mod.createPlaceholder(ElemKind::Float16Ty, {5, 3}, "input", false);
1103 Constant *W =
1104 mod.createConstant(ElemKind::Int8QTy, {3, 4}, 0.f, 0, "weights");
1105 Constant *B = mod.createConstant(ElemKind::Int8QTy, {4}, 0.f, 1, "bias");
1106 QuantizeNode *QI = F->createQuantize("quant", I, ElemKind::Int8QTy, 0.f, 2);
1107 FullyConnectedNode *FC = F->createFullyConnected("fc", QI, W, B);
1108 SaveNode *save = F->createSave("save_out", FC);
1109
1110 Placeholder *output = save->getPlaceholder();
1111
1112 ASSERT_TRUE(F->verify());
1113
1114 PlaceholderBindings bindings;
1115 bindings.allocate({I, output});
1116
1117#define GET_TQP(T_) TensorQuantizationParams{T_->getScale(), T_->getOffset()}
1118 OriginNameToTQPMap originNameToTQPMap;
1119 originNameToTQPMap.emplace(W->getName(), GET_TQP(W->getOutput().getType()));
1120 originNameToTQPMap.emplace(B->getName(), GET_TQP(B->getOutput().getType()));
1121 originNameToTQPMap.emplace(QI->getName(), GET_TQP(QI->getResult().getType()));
1122#undef GET_TQP
1123
1124 // Save and reload F.
1125 Function *R;
1126 Module reloadMod;
1127 ASSIGN_VALUE_OR_FAIL_TEST(
1128 R, saveAndReloadFunction(reloadMod, F, {"input"}, {I->getType()}, 7, 9,
1129 /* zipMode */ false,
1130 /* useGlowCustomOps */ true,
1131 /* useString */ false,
1132 /* includeConstantData */ true,
1133 /* record */ nullptr, /* reloadCctx */ nullptr,
1134 /* backendSpecificNodeInfo */ {},
1135 originNameToTQPMap));
1136 (void)R;
1137}
1138
1139/// Test that we can serialize tensor strides.
1140TEST(exporter, TestStridesSerialization) {
1141 ExecutionEngine EE{};
1142 auto &mod = EE.getModule();
1143 auto *F = mod.createFunction("F");
1144
1145 auto ty = mod.uniqueType(ElemKind::Float16Ty, {5, 3});
1146 // Create a type with non-standard strides.
1147 ty = mod.uniqueTypeWithNewStrides(ty, ty->dims(), {332, 1});
1148 Placeholder *I = mod.createPlaceholder(ty, "input", false);
1149 SaveNode *save = F->createSave("save_out", I);
1150
1151 Placeholder *output = save->getPlaceholder();
1152
1153 ASSERT_TRUE(F->verify());
1154
1155 PlaceholderBindings bindings;
1156 bindings.allocate({I, output});
1157
1158 // Save and reload F in text mode.
1159 {
1160 Function *R;
1161 Module reloadMod;
1162 ASSIGN_VALUE_OR_FAIL_TEST(
1163 R, saveAndReloadFunction(reloadMod, F, {"input"}, {I->getType()}, 7, 9,
1164 /* zipMode */ false,
1165 /* useGlowCustomOps */ true,
1166 /* useString */ false,
1167 /* includeConstantData */ true,
1168 /* record */ nullptr, /* reloadCctx */ nullptr,
1169 /* backendSpecificNodeInfo */ {}));
1170 (void)R;
1171 }
1172 // Save and reload F in zip mode.
1173 {
1174 Function *R;
1175 Module reloadMod;
1176 ASSIGN_VALUE_OR_FAIL_TEST(
1177 R, saveAndReloadFunction(reloadMod, F, {"input"}, {I->getType()}, 7, 9,
1178 /* zipMode */ true,
1179 /* useGlowCustomOps */ true,
1180 /* useString */ false,
1181 /* includeConstantData */ true,
1182 /* record */ nullptr, /* reloadCctx */ nullptr,
1183 /* backendSpecificNodeInfo */ {}));
1184 (void)R;
1185 }
1186}
1187