1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Base/Image.h" |
17 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
18 | #include "glow/Importer/Caffe2ModelLoader.h" |
19 | #include "glow/Support/Error.h" |
20 | |
21 | using namespace glow; |
22 | |
23 | /// A stripped-down example of how to load a Caffe2 protobuf and perform |
24 | /// inference. |
25 | int main() { |
26 | glow::PlaceholderBindings bindings; |
27 | glow::ExecutionEngine EE; |
28 | auto &mod = EE.getModule(); |
29 | auto *F = mod.createFunction("lenet_mnist" ); |
30 | auto *inputType = mod.uniqueType(glow::ElemKind::FloatTy, {1, 1, 28, 28}); |
31 | const char *inputName = "data" ; |
32 | |
33 | // Load and compile LeNet MNIST model. |
34 | glow::Caffe2ModelLoader loader("lenet_mnist/predict_net.pb" , |
35 | "lenet_mnist/init_net.pb" , {inputName}, |
36 | {inputType}, *F); |
37 | EE.compile(glow::CompilationMode::Infer); |
38 | |
39 | // Get input and output placeholders. |
40 | auto *input = llvm::cast<glow::Placeholder>( |
41 | EXIT_ON_ERR(loader.getNodeValueByName(inputName))); |
42 | auto *output = EXIT_ON_ERR(loader.getSingleOutput()); |
43 | |
44 | // Read an example PNG and add it to an input batch. |
45 | auto image = glow::readPngPpmImageAndPreprocess( |
46 | "tests/images/mnist/5_1087.png" , glow::ImageNormalizationMode::k0to1, |
47 | glow::ImageChannelOrder::BGR, glow::ImageLayout::NCHW); |
48 | glow::Tensor batch(inputType); |
49 | batch.getHandle<>().insertSlice(image, 0); |
50 | |
51 | // Allocate memory for input and bind it to the placeholders. |
52 | bindings.allocate(mod.getPlaceholders()); |
53 | glow::updateInputPlaceholders(bindings, {input}, {&batch}); |
54 | |
55 | // Perform inference. |
56 | EE.run(bindings); |
57 | |
58 | // Read output and find argmax. |
59 | auto out = bindings.get(output)->getHandle<float>(); |
60 | printf("digit: %zu\n" , (size_t)out.minMaxArg().second); |
61 | return 0; |
62 | } |
63 | |