1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Importer/ONNXIFIModelLoader.h" |
18 | #include "glow/Importer/Caffe2ModelLoader.h" |
19 | |
20 | #include "caffe2/proto/caffe2.pb.h" |
21 | #include "onnx/onnx_pb.h" |
22 | |
23 | namespace glow { |
24 | |
25 | Expected<std::unique_ptr<ONNXIFIModelLoader>> ONNXIFIModelLoader::parse( |
26 | const void *model, uint32_t modelSize, uint32_t weightsCount, |
27 | const onnxTensorDescriptorV1 *weightDescriptors, Module &mod, |
28 | llvm::StringRef netName, CompilationContext &cctx, |
29 | std::map<std::string, Type> *staticPlaceholderTypes, |
30 | bool loadInputsAsPlaceholdersForOnnx, bool use_onnx, |
31 | bool constFoldInLoader) { |
32 | |
33 | std::unique_ptr<ONNXIFIModelLoader> loader(new ONNXIFIModelLoader()); |
34 | Error loaderConstructionErr = Error::empty(); |
35 | |
36 | if (use_onnx) { |
37 | // If we're loading an ONNX model then we will always be replacing dummy |
38 | // TQPs if they're found. |
39 | cctx.precisionConfig.replaceDummyTQPs = true; |
40 | std::unique_ptr<ONNXModelLoader> onnxLoader(new ONNXModelLoader( |
41 | model, modelSize, weightsCount, weightDescriptors, mod, netName, |
42 | cctx.prepartitionedConfig, loadInputsAsPlaceholdersForOnnx, |
43 | &loaderConstructionErr, constFoldInLoader, |
44 | &cctx.backendOpts.backendSpecificNodeInfo, staticPlaceholderTypes, |
45 | cctx.precisionConfig.replaceDummyTQPs, |
46 | cctx.precisionConfig.clipQuantRangeToFP16)); |
47 | if (loaderConstructionErr) { |
48 | return std::move(loaderConstructionErr); |
49 | } |
50 | // Keep hold of the context |
51 | loader->core_ = std::move(onnxLoader); |
52 | } else { |
53 | // Use Caffe2 Model loader |
54 | std::unique_ptr<Caffe2ModelLoader> c2Loader(new Caffe2ModelLoader( |
55 | model, modelSize, weightsCount, weightDescriptors, mod, netName, |
56 | cctx.prepartitionedConfig, &loaderConstructionErr, constFoldInLoader, |
57 | cctx.precisionConfig.originNameToTQPMap, |
58 | cctx.precisionConfig.loadUniquedDummyQParams, |
59 | cctx.precisionConfig.zeroScaleFP16Clip, |
60 | cctx.precisionConfig.clipQuantRangeToFP16)); |
61 | if (loaderConstructionErr) { |
62 | return std::move(loaderConstructionErr); |
63 | } |
64 | // Keep hold of the context |
65 | loader->core_ = std::move(c2Loader); |
66 | } |
67 | |
68 | return Expected<std::unique_ptr<ONNXIFIModelLoader>>(std::move(loader)); |
69 | } |
70 | } // namespace glow |
71 | |