1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Base/Image.h"
17#include "glow/ExecutionEngine/ExecutionEngine.h"
18#include "glow/Graph/Graph.h"
19#include "glow/Importer/Caffe2ModelLoader.h"
20#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
21#include "glow/Optimizer/GraphOptimizer/TrainingPreparation.h"
22#include "glow/Support/Support.h"
23
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/Timer.h"
26#include "llvm/Support/raw_ostream.h"
27
28#include <glog/logging.h>
29
30#include <fstream>
31
32using namespace glow;
33
34const size_t mnistNumImages = 50000;
35
36namespace {
37llvm::cl::OptionCategory mnistCat("MNIST Options");
38llvm::cl::opt<std::string> executionBackend(
39 "backend",
40 llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:"),
41 llvm::cl::Optional, llvm::cl::init("Interpreter"), llvm::cl::cat(mnistCat));
42} // namespace
43
44unsigned loadMNIST(Tensor &imageInputs, Tensor &labelInputs) {
45 /// Load the MNIST database into 4D tensor of images and 2D tensor of labels.
46 LOG(INFO) << "Loading the mnist database.";
47
48 imageInputs.reset(ElemKind::FloatTy, {50000u, 28, 28, 1});
49 labelInputs.reset(ElemKind::Int64ITy, {50000u, 1});
50
51 std::ifstream imgInput("mnist_images.bin", std::ios::binary);
52 std::ifstream labInput("mnist_labels.bin", std::ios::binary);
53
54 CHECK(imgInput.is_open()) << "Error loading mnist_images.bin";
55 CHECK(labInput.is_open()) << "Error loading mnist_labels.bin";
56
57 std::vector<char> images((std::istreambuf_iterator<char>(imgInput)),
58 (std::istreambuf_iterator<char>()));
59 std::vector<char> labels((std::istreambuf_iterator<char>(labInput)),
60 (std::istreambuf_iterator<char>()));
61 float *imagesAsFloatPtr = reinterpret_cast<float *>(&images[0]);
62
63 CHECK_EQ(labels.size() * 28 * 28 * sizeof(float), images.size())
64 << "The size of the image buffer does not match the labels vector";
65
66 size_t idx = 0;
67
68 auto LIH = labelInputs.getHandle<int64_t>();
69 auto IIH = imageInputs.getHandle<>();
70
71 for (unsigned w = 0; w < mnistNumImages; w++) {
72 LIH.at({w, 0}) = labels[w];
73 for (unsigned x = 0; x < 28; x++) {
74 for (unsigned y = 0; y < 28; y++) {
75 IIH.at({w, x, y, 0}) = imagesAsFloatPtr[idx++];
76 }
77 }
78 }
79 size_t numImages = labels.size();
80 CHECK_GT(numImages, 0) << "No images were found.";
81 LOG(INFO) << "Loaded " << numImages << " images.";
82 return numImages;
83}
84
85void createModel(ExecutionEngine &EE, Function *F,
86 PlaceholderBindings &bindings, unsigned minibatchSize,
87 Placeholder *&inputPH, Placeholder *&outputPH,
88 Placeholder *&selectedPH) {
89 auto &mod = EE.getModule();
90
91 inputPH = mod.createPlaceholder(ElemKind::FloatTy, {minibatchSize, 28, 28, 1},
92 "input", false);
93
94 auto *CV0 = F->createConv(bindings, "conv", inputPH, 16, 5, 1, 2, 1);
95 auto *RL0 = F->createRELU("relu", CV0);
96 auto *MP0 = F->createMaxPool("pool", RL0, 3, 3, 0);
97
98 auto *CV1 = F->createConv(bindings, "conv", MP0->getResult(), 16, 5, 1, 2, 1);
99 auto *RL1 = F->createRELU("relu", CV1);
100 auto *MP1 = F->createMaxPool("pool", RL1, 3, 3, 0);
101
102 auto *FCL1 = F->createFullyConnected(bindings, "fc", MP1->getResult(), 10);
103 selectedPH = mod.createPlaceholder(ElemKind::Int64ITy, {minibatchSize, 1},
104 "selected", false);
105 auto *SM = F->createSoftMax("sm", FCL1, selectedPH);
106 SaveNode *result = F->createSave("return", SM);
107 outputPH = result->getPlaceholder();
108}
109
110void trainModel(ExecutionEngine &EE, PlaceholderBindings &bindings, Function *F,
111 unsigned minibatchSize, unsigned numIterations,
112 Tensor &imageInputs, Tensor &labelInputs, Placeholder *inputPH,
113 Placeholder *selectedPH) {
114 llvm::Timer timer("Training", "Training");
115 /// The training configuration.
116 TrainingConfig TC;
117
118 // Construct the network:
119 TC.learningRate = 0.001;
120 TC.momentum = 0.9;
121 TC.L2Decay = 0.001;
122 TC.batchSize = minibatchSize;
123
124 Function *TF = glow::differentiate(F, TC);
125
126 EE.compile(CompilationMode::Train);
127 bindings.allocate(EE.getModule().getPlaceholders());
128
129 LOG(INFO) << "Training.";
130
131 // This variable records the number of the next sample to be used for
132 // training.
133 size_t sampleCounter = 0;
134 auto tfName = TF->getName();
135
136 for (int epoch = 0; epoch < 60; epoch++) {
137 LOG(INFO) << "Training - epoch #" << epoch;
138
139 timer.startTimer();
140
141 // On each training iteration take a slice of imageInputs and labelInputs
142 // and put them into variables A and B, then run forward and backward passes
143 // and update weights.
144 runBatch(EE, bindings, numIterations, sampleCounter, {inputPH, selectedPH},
145 {&imageInputs, &labelInputs}, tfName);
146
147 timer.stopTimer();
148 }
149}
150
151void validateModel(ExecutionEngine &EE, PlaceholderBindings &bindings,
152 Function *F, unsigned minibatchSize, unsigned numIterations,
153 Tensor &imageInputs, Tensor &labelInputs,
154 Placeholder *inputPH, Placeholder *outputPH,
155 bool transpose) {
156 LOG(INFO) << "Validating.";
157
158 ::glow::convertPlaceholdersToConstants(F, bindings, {inputPH, outputPH});
159 EE.compile(CompilationMode::Infer);
160
161 dim_t rightAnswer = 0;
162 dim_t offset = numIterations * minibatchSize;
163 size_t sampleCounter = offset;
164 size_t iterations = 10;
165 std::vector<Tensor> estimates;
166 evalBatch(EE, bindings, iterations, sampleCounter, inputPH, outputPH,
167 imageInputs, labelInputs, F->getName(),
168 [&](const Tensor &sampleIn, const Tensor &sampleOut,
169 const Tensor &label, size_t sampleIndex) {
170 auto correct = label.getHandle<sdim_t>().at({0});
171 auto guess = sampleOut.getHandle().minMaxArg().second;
172 rightAnswer += (guess == correct);
173 if (sampleIndex < offset + minibatchSize) {
174 llvm::outs() << "MNIST Input";
175 if (transpose) {
176 Tensor IT;
177 // Transpose back to the ASCII printable format.
178 // CHW -> HWC.
179 sampleIn.transpose(&IT, {1, 2, 0});
180 IT.getHandle().dumpAscii();
181 } else {
182 sampleIn.getHandle().dumpAscii();
183 }
184 llvm::outs() << " Expected: " << correct
185 << " Guessed: " << guess << "\n";
186 sampleOut.getHandle<>().dump();
187 llvm::outs() << "\n-------------\n";
188 }
189 });
190
191 llvm::outs() << "Results: guessed/total:" << rightAnswer << "/"
192 << minibatchSize * 10 << "\n";
193 CHECK_GE(rightAnswer, 74) << "Did not classify as many digits as expected";
194}
195
196/// This test classifies digits from the MNIST labeled dataset.
197void testMNIST() {
198 Tensor imageInputs;
199 Tensor labelInputs;
200 loadMNIST(imageInputs, labelInputs);
201
202 unsigned minibatchSize = 8;
203 const int numIterations = 30;
204
205 PlaceholderBindings trainingBindings, inferBindings;
206 Placeholder *A, *E, *selected;
207
208 ExecutionEngine EEI_(executionBackend);
209 auto &inferMod = EEI_.getModule();
210 Function *F = inferMod.createFunction("mnist");
211 createModel(EEI_, F, inferBindings, minibatchSize, A, E, selected);
212 inferBindings.allocate(inferMod.getPlaceholders());
213
214 ExecutionEngine EET_(executionBackend);
215 auto &trainMod = EET_.getModule();
216 Function *TF = trainMod.createFunction("mnist");
217 createModel(EET_, TF, trainingBindings, minibatchSize, A, E, selected);
218
219 trainModel(EET_, trainingBindings, TF, minibatchSize, numIterations,
220 imageInputs, labelInputs, A, selected);
221
222 trainingBindings.copyTrainableWeightsTo(inferBindings);
223 A = inferBindings.getPlaceholderByNameSlow("input");
224 E = inferBindings.getPlaceholderByNameSlow("return");
225
226 validateModel(EEI_, inferBindings, F, minibatchSize, numIterations,
227 imageInputs, labelInputs, A, E, false /*transpose*/);
228}
229
230/// This test loads LENET-MNIST model, transferred it into the trainable form,
231/// trained, and run prediction.
232void testMNISTLoadAndTraining() {
233 Tensor imageInputs;
234 Tensor labelInputs;
235 Tensor imageInputsTransposed;
236 loadMNIST(imageInputsTransposed, labelInputs);
237 imageInputsTransposed.transpose(&imageInputs, NHWC2NCHW);
238
239 PlaceholderBindings trainingBindings, inferBindings;
240 ExecutionEngine EEI_(executionBackend);
241 auto &inferMod = EEI_.getModule();
242 auto *F = inferMod.createFunction("lenet_mnist");
243 unsigned minibatchSize = 8;
244
245 auto *inputType =
246 inferMod.uniqueType(glow::ElemKind::FloatTy, {minibatchSize, 1, 28, 28});
247 const char *inputName = "data";
248
249 Error errPtr = Error::empty();
250 // Load and compile LeNet MNIST model.
251 glow::Caffe2ModelLoader loader("lenet_mnist/predict_net.pb",
252 "lenet_mnist/init_net.pb", {inputName},
253 {inputType}, *F, &errPtr);
254
255 LOG(INFO) << "Loaded graph topology.";
256
257 if (errPtr) {
258 LOG(ERROR) << "Loader failed to load lenet_mnist model.";
259 return;
260 }
261
262 Placeholder *selectedI{nullptr};
263 if ((errPtr =
264 glow::prepareFunctionForTraining(F, inferBindings, selectedI))) {
265 return;
266 }
267
268 inferBindings.allocate(inferMod.getPlaceholders());
269
270 // Load the model a second time for training.
271 // TODO: remove once EE2 is able to compile in different modes.
272 ExecutionEngine EET_(executionBackend);
273 auto &trainMod = EET_.getModule();
274 auto *TF = trainMod.createFunction("lenet_mnist_train");
275 glow::Caffe2ModelLoader trainingLoader("lenet_mnist/predict_net.pb",
276 "lenet_mnist/init_net.pb", {inputName},
277 {inputType}, *TF, &errPtr);
278
279 if (errPtr) {
280 LOG(ERROR) << "Loader failed to load lenet_mnist model for training.";
281 return;
282 }
283
284 Placeholder *selected{nullptr};
285 if ((errPtr =
286 glow::prepareFunctionForTraining(TF, trainingBindings, selected))) {
287 return;
288 }
289
290 const int numIterations = 30;
291 // Get input placeholder.
292 auto *A = llvm::cast<glow::Placeholder>(
293 EXIT_ON_ERR(trainingLoader.getNodeValueByName(inputName)));
294
295 trainModel(EET_, trainingBindings, TF, minibatchSize, numIterations,
296 imageInputs, labelInputs, A, selected);
297
298 // Get input and output placeholders.
299 A = llvm::cast<glow::Placeholder>(
300 EXIT_ON_ERR(loader.getNodeValueByName(inputName)));
301 auto *E = EXIT_ON_ERR(loader.getSingleOutput());
302 trainingBindings.copyTrainableWeightsTo(inferBindings);
303
304 validateModel(EEI_, inferBindings, F, minibatchSize, numIterations,
305 imageInputs, labelInputs, A, E, true /*transpose*/);
306}
307
308int main(int argc, char **argv) {
309 llvm::cl::ParseCommandLineOptions(argc, argv, " The MNIST test\n\n");
310 testMNIST();
311 testMNISTLoadAndTraining();
312
313 return 0;
314}
315