1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Base/Image.h"
17#include "glow/ExecutionEngine/ExecutionEngine.h"
18#include "glow/Graph/Hook.h"
19#include "glow/Importer/Caffe2ModelLoader.h"
20
21#include "llvm/Support/Casting.h"
22#include "llvm/Support/Format.h"
23
24using namespace glow;
25
26const char inputName[] = "gpu_0/data";
27
28class Tester {
29 PlaceholderBindings bindings, inferBindings;
30 ExecutionEngine EEI;
31 std::unique_ptr<Module> mod;
32 Function *F;
33 TypeRef inputType;
34 Placeholder *input;
35 Placeholder *output;
36
37public:
38 explicit Tester(llvm::StringRef backendName)
39 : EEI(backendName), mod(new Module), F(mod->createFunction("resnet50")),
40 inputType(mod->uniqueType(ElemKind::FloatTy, {1, 3, 224, 224})) {
41 // Load and compile ResNet-50.
42 Caffe2ModelLoader loader("resnet50/predict_net.pb", "resnet50/init_net.pb",
43 {inputName}, {inputType}, *F);
44 input = llvm::cast<Placeholder>(
45 EXIT_ON_ERR(loader.getNodeValueByName(inputName)));
46 output = EXIT_ON_ERR(loader.getSingleOutput());
47 }
48
49 void bindInput(Tensor *batch) {
50 // Allocate memory for input and bind it to the placeholders.
51 bindings.allocate(mod->getPlaceholders());
52 updateInputPlaceholders(bindings, {input}, {batch});
53 }
54
55 TypeRef getInputType() const { return inputType; }
56
57 Function *getFunction() const { return F; }
58
59 std::list<Tensor *> hookAndRun(llvm::StringRef name) {
60 EEI.setBackendName(EEI.getBackendName());
61 inferBindings.clear();
62 auto modI = &EEI.getModule();
63 auto *FI = modI->createFunction("resnet50");
64 Caffe2ModelLoader loader(
65 "resnet50/predict_net.pb", "resnet50/init_net.pb", {inputName},
66 {mod->uniqueType(ElemKind::FloatTy, {1, 3, 224, 224})}, *FI);
67 auto hook = hookNode(FI, name);
68 inferBindings.allocate(modI->getPlaceholders());
69 for (const auto &PH : bindings.pairs()) {
70 auto iPH = inferBindings.getPlaceholderByNameSlow(PH.first->getName());
71 inferBindings.get(iPH)->assign(&PH.second);
72 }
73
74 std::list<Tensor *> outs;
75 for (const auto &P : hook.outputs) {
76 outs.emplace_back(inferBindings.get(P));
77 }
78
79 auto fName = hook.function->getName();
80 EEI.compile(CompilationMode::Infer);
81 EEI.run(inferBindings, fName);
82 return outs;
83 }
84};
85
86/// Compare layer-by-layer execution of ResNet on two backends.
87int main() {
88 Tester interp{"Interpreter"};
89 Tester cpu{"CPU"};
90
91 // Read an example PNG and add it to an input batch.
92 auto image = readPngPpmImageAndPreprocess(
93 "tests/images/imagenet/cat_285.png", ImageNormalizationMode::k0to1,
94 ImageChannelOrder::BGR, ImageLayout::NCHW, imagenetNormMean,
95 imagenetNormStd);
96 Tensor batch(interp.getInputType());
97 batch.getHandle<float>().insertSlice(image, 0);
98
99 interp.bindInput(&batch);
100 cpu.bindInput(&batch);
101
102 for (auto const &node : interp.getFunction()->getNodes()) {
103 if (llvm::isa<SaveNode>(&node)) {
104 continue;
105 }
106 llvm::errs() << "Verifying layer: " << node.getName() << "\n";
107 auto interpOuts = interp.hookAndRun(node.getName());
108 auto cpuOuts = cpu.hookAndRun(node.getName());
109
110 if (interpOuts.size() == cpuOuts.size()) {
111 auto interpOutIt = interpOuts.begin(), interpOutEnd = interpOuts.end();
112 auto cpuOutIt = cpuOuts.begin(), cpuOutEnd = cpuOuts.end();
113
114 while (interpOutIt != interpOutEnd && cpuOutIt != cpuOutEnd) {
115 auto *interpOut = *interpOutIt;
116 auto *cpuOut = *cpuOutIt;
117
118 if (!interpOut->isEqual(*cpuOut)) {
119 llvm::errs() << "Results differ\n";
120 dumpImpl(interpOut);
121 dumpImpl(cpuOut);
122 }
123
124 ++interpOutIt;
125 ++cpuOutIt;
126 }
127 } else {
128 llvm::errs()
129 << "Backends produced different number of results using hook at "
130 << node.getName() << "\n";
131 }
132 }
133
134 return 0;
135}
136