1 | /* |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "InlineOnnxifi.h" |
18 | |
19 | #include "glow/Flags/Flags.h" |
20 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
21 | #include "glow/Quantization/Quantization.h" |
22 | #include "glow/Quantization/Serialization.h" |
23 | #include "glow/Support/Support.h" |
24 | |
25 | #include "llvm/Support/MD5.h" |
26 | |
27 | namespace glow { |
28 | namespace onnxifi { |
29 | |
30 | namespace { |
31 | std::string getProfileFile(llvm::StringRef hash) { |
32 | return strFormat("/tmp/glow-profile-%s.yaml" , hash.str().c_str()); |
33 | } |
34 | |
35 | void computeModelHash(const void *onnxModel, size_t onnxModelSize, |
36 | llvm::SmallString<32> &str) { |
37 | llvm::MD5::MD5Result res; |
38 | llvm::MD5 MD5; |
39 | MD5.update(llvm::makeArrayRef((uint8_t *)onnxModel, onnxModelSize)); |
40 | MD5.final(res); |
41 | llvm::MD5::stringifyResult(res, str); |
42 | } |
43 | } // namespace |
44 | |
45 | onnxStatus InlineGraph::initGraph( |
46 | const void *onnxModel, size_t onnxModelSize, uint32_t weightCount, |
47 | const onnxTensorDescriptorV1 *weightDescriptors, uint32_t maxSeqLength, |
48 | void * /*unused */, bool /*unused*/) { |
49 | Module &mod = executionEngine_.getModule(); |
50 | // Note: Pass in a nullptr for PPC here because we do not currently support |
51 | // pre-partitioned models here. |
52 | std::unique_ptr<ONNXIFIModelLoader> loader; |
53 | PlaceholderBindings dummyBindings; |
54 | CompilationContext cctx{&dummyBindings, &loweredMap_}; |
55 | auto loaderOrErr = ONNXIFIModelLoader::parse( |
56 | onnxModel, onnxModelSize, weightCount, weightDescriptors, mod, "function" , |
57 | cctx, /* staticPlaceholderTypes */ nullptr, |
58 | true /*loadInputsAsPlaceholdersForOnnx*/, backendPtr_->getUseOnnx()); |
59 | if (loaderOrErr) { |
60 | loader = std::move(*loaderOrErr); |
61 | } else { |
62 | LOG(ERROR) << "Error when loading model: " |
63 | << ERR_TO_STRING(loaderOrErr.takeError()); |
64 | return ONNXIFI_STATUS_INVALID_MODEL; |
65 | } |
66 | |
67 | CHECK_EQ(mod.getFunctions().size(), 1) << "Should have exactly one Function." ; |
68 | function_ = *mod.getFunctions().begin(); |
69 | |
70 | bindPlaceholders(*loader); |
71 | if (flags::SaveModel) { |
72 | saveOnnxifiModel(function_); |
73 | } |
74 | |
75 | setZeroLengthSequence(maxSeqLength); |
76 | computeModelHash(onnxModel, onnxModelSize, modelHash_); |
77 | optimize(function_, CompilationMode::Infer); |
78 | |
79 | PrecisionConfiguration &precConfig = cctx.precisionConfig; |
80 | precConfig.quantMode = quantizationMode_; |
81 | |
82 | // If quantizing, load quantization infos and setup the schema. |
83 | if (quantizationMode_ == QuantizationMode::Quantize) { |
84 | auto fileExists = deserializeProfilingInfosFromYaml( |
85 | getProfileFile(modelHash_), precConfig.quantConfig.graphPreLowerHash, |
86 | precConfig.quantConfig.infos); |
87 | if (!fileExists) { |
88 | return ONNXIFI_STATUS_UNIDENTIFIED_NAME; |
89 | } |
90 | precConfig.quantConfig.schema = quantization::Schema::Symmetric; |
91 | } |
92 | |
93 | executionEngine_.compile(CompilationMode::Infer); |
94 | |
95 | return ONNXIFI_STATUS_SUCCESS; |
96 | } |
97 | |
98 | onnxStatus InlineGraph::run(std::unique_ptr<ExecutionContext> ctx, |
99 | EventPtr outputEvent, |
100 | onnxTraceEventList *traceEvents) { |
101 | executionEngine_.run(*ctx); |
102 | |
103 | // Dump profile if requested. |
104 | if (quantizationMode_ == QuantizationMode::Profile) { |
105 | auto PI = quantization::generateNodeProfilingInfos( |
106 | *(ctx->getPlaceholderBindings()), function_, loweredMap_); |
107 | serializeProfilingInfosToYaml(getProfileFile(modelHash_), |
108 | /* graphPreLowerHash */ 0, PI); |
109 | } |
110 | |
111 | if (auto *traceContext = ctx->getTraceContext()) { |
112 | setTraceEvents(traceEvents, traceContext); |
113 | } |
114 | |
115 | outputEvent->signal(ONNXIFI_STATUS_SUCCESS); |
116 | return ONNXIFI_STATUS_SUCCESS; |
117 | } |
118 | |
119 | } // namespace onnxifi |
120 | } // namespace glow |
121 | |