1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "Loader.h" |
18 | |
19 | #include "glow/Graph/Nodes.h" |
20 | #include "glow/Importer/Caffe2ModelLoader.h" |
21 | #include "glow/Importer/ONNXModelLoader.h" |
22 | |
23 | #include "llvm/Support/raw_ostream.h" |
24 | |
25 | #include <memory> |
26 | |
27 | using namespace glow; |
28 | |
29 | int main(int argc, char **argv) { |
30 | PlaceholderBindings bindings; |
31 | // Verify/initialize command line parameters, and then loader initializes |
32 | // the ExecutionEngine and Function. |
33 | parseCommandLine(argc, argv); |
34 | Loader loader; |
35 | |
36 | // Create the model based on the input net, and get SaveNode for the output. |
37 | std::unique_ptr<ProtobufLoader> LD; |
38 | if (!loader.getCaffe2NetDescFilename().empty()) { |
39 | LD.reset(new Caffe2ModelLoader(loader.getCaffe2NetDescFilename().str(), |
40 | loader.getCaffe2NetWeightFilename().str(), |
41 | {}, {}, *loader.getFunction())); |
42 | } else { |
43 | LD.reset(new ONNXModelLoader(loader.getOnnxModelFilename().str(), {}, {}, |
44 | *loader.getFunction())); |
45 | } |
46 | Placeholder *output = EXIT_ON_ERR(LD->getSingleOutput()); |
47 | auto *outputT = bindings.allocate(output); |
48 | |
49 | CHECK_EQ(0, std::distance(LD->getInputVarsMapping().keys().begin(), |
50 | LD->getInputVarsMapping().keys().end())) |
51 | << "ModelRunner only supports models with no external inputs." ; |
52 | |
53 | std::string modelName = loader.getFunction()->getName().str(); |
54 | |
55 | // Compile the model, and perform quantization/emit a bundle/dump debug info |
56 | // if requested from command line. |
57 | CompilationContext cctx = loader.getCompilationContext(); |
58 | cctx.bindings = &bindings; |
59 | // Disable constant folding, as the model runner is designed for models with |
60 | // all Constant inputs. |
61 | cctx.optimizationOpts.enableConstantFolding = false; |
62 | loader.compile(cctx); |
63 | |
64 | // If in bundle mode, do not run inference. |
65 | if (!emittingBundle()) { |
66 | loader.runInference(bindings); |
67 | |
68 | llvm::outs() << "Model: " << modelName << "\n" ; |
69 | |
70 | // Print out the result of output operator. |
71 | switch (outputT->getElementType()) { |
72 | case ElemKind::FloatTy: |
73 | outputT->getHandle<float>().dump(); |
74 | break; |
75 | case ElemKind::Int8QTy: |
76 | outputT->getHandle<int8_t>().dump(); |
77 | break; |
78 | default: |
79 | LOG(FATAL) << "Unexpected output type" ; |
80 | } |
81 | |
82 | // If profiling, generate and serialize the profiling infos now that we |
83 | // have run inference to gather the profile. |
84 | if (profilingGraph()) { |
85 | loader.generateAndSerializeProfilingInfos(bindings); |
86 | } |
87 | } |
88 | |
89 | return 0; |
90 | } |
91 | |