1/*
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "InlineOnnxifi.h"
18
19#include "glow/Flags/Flags.h"
20#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
21#include "glow/Quantization/Quantization.h"
22#include "glow/Quantization/Serialization.h"
23#include "glow/Support/Support.h"
24
25#include "llvm/Support/MD5.h"
26
27namespace glow {
28namespace onnxifi {
29
30namespace {
31std::string getProfileFile(llvm::StringRef hash) {
32 return strFormat("/tmp/glow-profile-%s.yaml", hash.str().c_str());
33}
34
35void computeModelHash(const void *onnxModel, size_t onnxModelSize,
36 llvm::SmallString<32> &str) {
37 llvm::MD5::MD5Result res;
38 llvm::MD5 MD5;
39 MD5.update(llvm::makeArrayRef((uint8_t *)onnxModel, onnxModelSize));
40 MD5.final(res);
41 llvm::MD5::stringifyResult(res, str);
42}
43} // namespace
44
45onnxStatus InlineGraph::initGraph(
46 const void *onnxModel, size_t onnxModelSize, uint32_t weightCount,
47 const onnxTensorDescriptorV1 *weightDescriptors, uint32_t maxSeqLength,
48 void * /*unused */, bool /*unused*/) {
49 Module &mod = executionEngine_.getModule();
50 // Note: Pass in a nullptr for PPC here because we do not currently support
51 // pre-partitioned models here.
52 std::unique_ptr<ONNXIFIModelLoader> loader;
53 PlaceholderBindings dummyBindings;
54 CompilationContext cctx{&dummyBindings, &loweredMap_};
55 auto loaderOrErr = ONNXIFIModelLoader::parse(
56 onnxModel, onnxModelSize, weightCount, weightDescriptors, mod, "function",
57 cctx, /* staticPlaceholderTypes */ nullptr,
58 true /*loadInputsAsPlaceholdersForOnnx*/, backendPtr_->getUseOnnx());
59 if (loaderOrErr) {
60 loader = std::move(*loaderOrErr);
61 } else {
62 LOG(ERROR) << "Error when loading model: "
63 << ERR_TO_STRING(loaderOrErr.takeError());
64 return ONNXIFI_STATUS_INVALID_MODEL;
65 }
66
67 CHECK_EQ(mod.getFunctions().size(), 1) << "Should have exactly one Function.";
68 function_ = *mod.getFunctions().begin();
69
70 bindPlaceholders(*loader);
71 if (flags::SaveModel) {
72 saveOnnxifiModel(function_);
73 }
74
75 setZeroLengthSequence(maxSeqLength);
76 computeModelHash(onnxModel, onnxModelSize, modelHash_);
77 optimize(function_, CompilationMode::Infer);
78
79 PrecisionConfiguration &precConfig = cctx.precisionConfig;
80 precConfig.quantMode = quantizationMode_;
81
82 // If quantizing, load quantization infos and setup the schema.
83 if (quantizationMode_ == QuantizationMode::Quantize) {
84 auto fileExists = deserializeProfilingInfosFromYaml(
85 getProfileFile(modelHash_), precConfig.quantConfig.graphPreLowerHash,
86 precConfig.quantConfig.infos);
87 if (!fileExists) {
88 return ONNXIFI_STATUS_UNIDENTIFIED_NAME;
89 }
90 precConfig.quantConfig.schema = quantization::Schema::Symmetric;
91 }
92
93 executionEngine_.compile(CompilationMode::Infer);
94
95 return ONNXIFI_STATUS_SUCCESS;
96}
97
98onnxStatus InlineGraph::run(std::unique_ptr<ExecutionContext> ctx,
99 EventPtr outputEvent,
100 onnxTraceEventList *traceEvents) {
101 executionEngine_.run(*ctx);
102
103 // Dump profile if requested.
104 if (quantizationMode_ == QuantizationMode::Profile) {
105 auto PI = quantization::generateNodeProfilingInfos(
106 *(ctx->getPlaceholderBindings()), function_, loweredMap_);
107 serializeProfilingInfosToYaml(getProfileFile(modelHash_),
108 /* graphPreLowerHash */ 0, PI);
109 }
110
111 if (auto *traceContext = ctx->getTraceContext()) {
112 setTraceEvents(traceEvents, traceContext);
113 }
114
115 outputEvent->signal(ONNXIFI_STATUS_SUCCESS);
116 return ONNXIFI_STATUS_SUCCESS;
117}
118
119} // namespace onnxifi
120} // namespace glow
121