1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Base/Image.h"
17#include "glow/ExecutionEngine/ExecutionEngine.h"
18#include "glow/Importer/Caffe2ModelLoader.h"
19#include "glow/Support/Error.h"
20
21using namespace glow;
22
23/// A stripped-down example of how to load a Caffe2 protobuf and perform
24/// inference.
25int main() {
26 glow::PlaceholderBindings bindings;
27 glow::ExecutionEngine EE;
28 auto &mod = EE.getModule();
29 auto *F = mod.createFunction("lenet_mnist");
30 auto *inputType = mod.uniqueType(glow::ElemKind::FloatTy, {1, 1, 28, 28});
31 const char *inputName = "data";
32
33 // Load and compile LeNet MNIST model.
34 glow::Caffe2ModelLoader loader("lenet_mnist/predict_net.pb",
35 "lenet_mnist/init_net.pb", {inputName},
36 {inputType}, *F);
37 EE.compile(glow::CompilationMode::Infer);
38
39 // Get input and output placeholders.
40 auto *input = llvm::cast<glow::Placeholder>(
41 EXIT_ON_ERR(loader.getNodeValueByName(inputName)));
42 auto *output = EXIT_ON_ERR(loader.getSingleOutput());
43
44 // Read an example PNG and add it to an input batch.
45 auto image = glow::readPngPpmImageAndPreprocess(
46 "tests/images/mnist/5_1087.png", glow::ImageNormalizationMode::k0to1,
47 glow::ImageChannelOrder::BGR, glow::ImageLayout::NCHW);
48 glow::Tensor batch(inputType);
49 batch.getHandle<>().insertSlice(image, 0);
50
51 // Allocate memory for input and bind it to the placeholders.
52 bindings.allocate(mod.getPlaceholders());
53 glow::updateInputPlaceholders(bindings, {input}, {&batch});
54
55 // Perform inference.
56 EE.run(bindings);
57
58 // Read output and find argmax.
59 auto out = bindings.get(output)->getHandle<float>();
60 printf("digit: %zu\n", (size_t)out.minMaxArg().second);
61 return 0;
62}
63