1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Loader.h"
18
19#include "glow/Graph/Nodes.h"
20#include "glow/Importer/Caffe2ModelLoader.h"
21#include "glow/Importer/ONNXModelLoader.h"
22
23#include "llvm/Support/raw_ostream.h"
24
25#include <memory>
26
27using namespace glow;
28
29int main(int argc, char **argv) {
30 PlaceholderBindings bindings;
31 // Verify/initialize command line parameters, and then loader initializes
32 // the ExecutionEngine and Function.
33 parseCommandLine(argc, argv);
34 Loader loader;
35
36 // Create the model based on the input net, and get SaveNode for the output.
37 std::unique_ptr<ProtobufLoader> LD;
38 if (!loader.getCaffe2NetDescFilename().empty()) {
39 LD.reset(new Caffe2ModelLoader(loader.getCaffe2NetDescFilename().str(),
40 loader.getCaffe2NetWeightFilename().str(),
41 {}, {}, *loader.getFunction()));
42 } else {
43 LD.reset(new ONNXModelLoader(loader.getOnnxModelFilename().str(), {}, {},
44 *loader.getFunction()));
45 }
46 Placeholder *output = EXIT_ON_ERR(LD->getSingleOutput());
47 auto *outputT = bindings.allocate(output);
48
49 CHECK_EQ(0, std::distance(LD->getInputVarsMapping().keys().begin(),
50 LD->getInputVarsMapping().keys().end()))
51 << "ModelRunner only supports models with no external inputs.";
52
53 std::string modelName = loader.getFunction()->getName().str();
54
55 // Compile the model, and perform quantization/emit a bundle/dump debug info
56 // if requested from command line.
57 CompilationContext cctx = loader.getCompilationContext();
58 cctx.bindings = &bindings;
59 // Disable constant folding, as the model runner is designed for models with
60 // all Constant inputs.
61 cctx.optimizationOpts.enableConstantFolding = false;
62 loader.compile(cctx);
63
64 // If in bundle mode, do not run inference.
65 if (!emittingBundle()) {
66 loader.runInference(bindings);
67
68 llvm::outs() << "Model: " << modelName << "\n";
69
70 // Print out the result of output operator.
71 switch (outputT->getElementType()) {
72 case ElemKind::FloatTy:
73 outputT->getHandle<float>().dump();
74 break;
75 case ElemKind::Int8QTy:
76 outputT->getHandle<int8_t>().dump();
77 break;
78 default:
79 LOG(FATAL) << "Unexpected output type";
80 }
81
82 // If profiling, generate and serialize the profiling infos now that we
83 // have run inference to gather the profile.
84 if (profilingGraph()) {
85 loader.generateAndSerializeProfilingInfos(bindings);
86 }
87 }
88
89 return 0;
90}
91