1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Base/Image.h" |
17 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
18 | #include "glow/Graph/Graph.h" |
19 | #include "glow/Importer/Caffe2ModelLoader.h" |
20 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
21 | #include "glow/Optimizer/GraphOptimizer/TrainingPreparation.h" |
22 | #include "glow/Support/Support.h" |
23 | |
24 | #include "llvm/Support/CommandLine.h" |
25 | #include "llvm/Support/Timer.h" |
26 | #include "llvm/Support/raw_ostream.h" |
27 | |
28 | #include <glog/logging.h> |
29 | |
30 | #include <fstream> |
31 | |
32 | using namespace glow; |
33 | |
34 | const size_t mnistNumImages = 50000; |
35 | |
36 | namespace { |
37 | llvm::cl::OptionCategory mnistCat("MNIST Options" ); |
38 | llvm::cl::opt<std::string> executionBackend( |
39 | "backend" , |
40 | llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:" ), |
41 | llvm::cl::Optional, llvm::cl::init("Interpreter" ), llvm::cl::cat(mnistCat)); |
42 | } // namespace |
43 | |
44 | unsigned loadMNIST(Tensor &imageInputs, Tensor &labelInputs) { |
45 | /// Load the MNIST database into 4D tensor of images and 2D tensor of labels. |
46 | LOG(INFO) << "Loading the mnist database." ; |
47 | |
48 | imageInputs.reset(ElemKind::FloatTy, {50000u, 28, 28, 1}); |
49 | labelInputs.reset(ElemKind::Int64ITy, {50000u, 1}); |
50 | |
51 | std::ifstream imgInput("mnist_images.bin" , std::ios::binary); |
52 | std::ifstream labInput("mnist_labels.bin" , std::ios::binary); |
53 | |
54 | CHECK(imgInput.is_open()) << "Error loading mnist_images.bin" ; |
55 | CHECK(labInput.is_open()) << "Error loading mnist_labels.bin" ; |
56 | |
57 | std::vector<char> images((std::istreambuf_iterator<char>(imgInput)), |
58 | (std::istreambuf_iterator<char>())); |
59 | std::vector<char> labels((std::istreambuf_iterator<char>(labInput)), |
60 | (std::istreambuf_iterator<char>())); |
61 | float *imagesAsFloatPtr = reinterpret_cast<float *>(&images[0]); |
62 | |
63 | CHECK_EQ(labels.size() * 28 * 28 * sizeof(float), images.size()) |
64 | << "The size of the image buffer does not match the labels vector" ; |
65 | |
66 | size_t idx = 0; |
67 | |
68 | auto LIH = labelInputs.getHandle<int64_t>(); |
69 | auto IIH = imageInputs.getHandle<>(); |
70 | |
71 | for (unsigned w = 0; w < mnistNumImages; w++) { |
72 | LIH.at({w, 0}) = labels[w]; |
73 | for (unsigned x = 0; x < 28; x++) { |
74 | for (unsigned y = 0; y < 28; y++) { |
75 | IIH.at({w, x, y, 0}) = imagesAsFloatPtr[idx++]; |
76 | } |
77 | } |
78 | } |
79 | size_t numImages = labels.size(); |
80 | CHECK_GT(numImages, 0) << "No images were found." ; |
81 | LOG(INFO) << "Loaded " << numImages << " images." ; |
82 | return numImages; |
83 | } |
84 | |
85 | void createModel(ExecutionEngine &EE, Function *F, |
86 | PlaceholderBindings &bindings, unsigned minibatchSize, |
87 | Placeholder *&inputPH, Placeholder *&outputPH, |
88 | Placeholder *&selectedPH) { |
89 | auto &mod = EE.getModule(); |
90 | |
91 | inputPH = mod.createPlaceholder(ElemKind::FloatTy, {minibatchSize, 28, 28, 1}, |
92 | "input" , false); |
93 | |
94 | auto *CV0 = F->createConv(bindings, "conv" , inputPH, 16, 5, 1, 2, 1); |
95 | auto *RL0 = F->createRELU("relu" , CV0); |
96 | auto *MP0 = F->createMaxPool("pool" , RL0, 3, 3, 0); |
97 | |
98 | auto *CV1 = F->createConv(bindings, "conv" , MP0->getResult(), 16, 5, 1, 2, 1); |
99 | auto *RL1 = F->createRELU("relu" , CV1); |
100 | auto *MP1 = F->createMaxPool("pool" , RL1, 3, 3, 0); |
101 | |
102 | auto *FCL1 = F->createFullyConnected(bindings, "fc" , MP1->getResult(), 10); |
103 | selectedPH = mod.createPlaceholder(ElemKind::Int64ITy, {minibatchSize, 1}, |
104 | "selected" , false); |
105 | auto *SM = F->createSoftMax("sm" , FCL1, selectedPH); |
106 | SaveNode *result = F->createSave("return" , SM); |
107 | outputPH = result->getPlaceholder(); |
108 | } |
109 | |
110 | void trainModel(ExecutionEngine &EE, PlaceholderBindings &bindings, Function *F, |
111 | unsigned minibatchSize, unsigned numIterations, |
112 | Tensor &imageInputs, Tensor &labelInputs, Placeholder *inputPH, |
113 | Placeholder *selectedPH) { |
114 | llvm::Timer timer("Training" , "Training" ); |
115 | /// The training configuration. |
116 | TrainingConfig TC; |
117 | |
118 | // Construct the network: |
119 | TC.learningRate = 0.001; |
120 | TC.momentum = 0.9; |
121 | TC.L2Decay = 0.001; |
122 | TC.batchSize = minibatchSize; |
123 | |
124 | Function *TF = glow::differentiate(F, TC); |
125 | |
126 | EE.compile(CompilationMode::Train); |
127 | bindings.allocate(EE.getModule().getPlaceholders()); |
128 | |
129 | LOG(INFO) << "Training." ; |
130 | |
131 | // This variable records the number of the next sample to be used for |
132 | // training. |
133 | size_t sampleCounter = 0; |
134 | auto tfName = TF->getName(); |
135 | |
136 | for (int epoch = 0; epoch < 60; epoch++) { |
137 | LOG(INFO) << "Training - epoch #" << epoch; |
138 | |
139 | timer.startTimer(); |
140 | |
141 | // On each training iteration take a slice of imageInputs and labelInputs |
142 | // and put them into variables A and B, then run forward and backward passes |
143 | // and update weights. |
144 | runBatch(EE, bindings, numIterations, sampleCounter, {inputPH, selectedPH}, |
145 | {&imageInputs, &labelInputs}, tfName); |
146 | |
147 | timer.stopTimer(); |
148 | } |
149 | } |
150 | |
151 | void validateModel(ExecutionEngine &EE, PlaceholderBindings &bindings, |
152 | Function *F, unsigned minibatchSize, unsigned numIterations, |
153 | Tensor &imageInputs, Tensor &labelInputs, |
154 | Placeholder *inputPH, Placeholder *outputPH, |
155 | bool transpose) { |
156 | LOG(INFO) << "Validating." ; |
157 | |
158 | ::glow::convertPlaceholdersToConstants(F, bindings, {inputPH, outputPH}); |
159 | EE.compile(CompilationMode::Infer); |
160 | |
161 | dim_t rightAnswer = 0; |
162 | dim_t offset = numIterations * minibatchSize; |
163 | size_t sampleCounter = offset; |
164 | size_t iterations = 10; |
165 | std::vector<Tensor> estimates; |
166 | evalBatch(EE, bindings, iterations, sampleCounter, inputPH, outputPH, |
167 | imageInputs, labelInputs, F->getName(), |
168 | [&](const Tensor &sampleIn, const Tensor &sampleOut, |
169 | const Tensor &label, size_t sampleIndex) { |
170 | auto correct = label.getHandle<sdim_t>().at({0}); |
171 | auto guess = sampleOut.getHandle().minMaxArg().second; |
172 | rightAnswer += (guess == correct); |
173 | if (sampleIndex < offset + minibatchSize) { |
174 | llvm::outs() << "MNIST Input" ; |
175 | if (transpose) { |
176 | Tensor IT; |
177 | // Transpose back to the ASCII printable format. |
178 | // CHW -> HWC. |
179 | sampleIn.transpose(&IT, {1, 2, 0}); |
180 | IT.getHandle().dumpAscii(); |
181 | } else { |
182 | sampleIn.getHandle().dumpAscii(); |
183 | } |
184 | llvm::outs() << " Expected: " << correct |
185 | << " Guessed: " << guess << "\n" ; |
186 | sampleOut.getHandle<>().dump(); |
187 | llvm::outs() << "\n-------------\n" ; |
188 | } |
189 | }); |
190 | |
191 | llvm::outs() << "Results: guessed/total:" << rightAnswer << "/" |
192 | << minibatchSize * 10 << "\n" ; |
193 | CHECK_GE(rightAnswer, 74) << "Did not classify as many digits as expected" ; |
194 | } |
195 | |
196 | /// This test classifies digits from the MNIST labeled dataset. |
197 | void testMNIST() { |
198 | Tensor imageInputs; |
199 | Tensor labelInputs; |
200 | loadMNIST(imageInputs, labelInputs); |
201 | |
202 | unsigned minibatchSize = 8; |
203 | const int numIterations = 30; |
204 | |
205 | PlaceholderBindings trainingBindings, inferBindings; |
206 | Placeholder *A, *E, *selected; |
207 | |
208 | ExecutionEngine EEI_(executionBackend); |
209 | auto &inferMod = EEI_.getModule(); |
210 | Function *F = inferMod.createFunction("mnist" ); |
211 | createModel(EEI_, F, inferBindings, minibatchSize, A, E, selected); |
212 | inferBindings.allocate(inferMod.getPlaceholders()); |
213 | |
214 | ExecutionEngine EET_(executionBackend); |
215 | auto &trainMod = EET_.getModule(); |
216 | Function *TF = trainMod.createFunction("mnist" ); |
217 | createModel(EET_, TF, trainingBindings, minibatchSize, A, E, selected); |
218 | |
219 | trainModel(EET_, trainingBindings, TF, minibatchSize, numIterations, |
220 | imageInputs, labelInputs, A, selected); |
221 | |
222 | trainingBindings.copyTrainableWeightsTo(inferBindings); |
223 | A = inferBindings.getPlaceholderByNameSlow("input" ); |
224 | E = inferBindings.getPlaceholderByNameSlow("return" ); |
225 | |
226 | validateModel(EEI_, inferBindings, F, minibatchSize, numIterations, |
227 | imageInputs, labelInputs, A, E, false /*transpose*/); |
228 | } |
229 | |
230 | /// This test loads LENET-MNIST model, transferred it into the trainable form, |
231 | /// trained, and run prediction. |
232 | void testMNISTLoadAndTraining() { |
233 | Tensor imageInputs; |
234 | Tensor labelInputs; |
235 | Tensor imageInputsTransposed; |
236 | loadMNIST(imageInputsTransposed, labelInputs); |
237 | imageInputsTransposed.transpose(&imageInputs, NHWC2NCHW); |
238 | |
239 | PlaceholderBindings trainingBindings, inferBindings; |
240 | ExecutionEngine EEI_(executionBackend); |
241 | auto &inferMod = EEI_.getModule(); |
242 | auto *F = inferMod.createFunction("lenet_mnist" ); |
243 | unsigned minibatchSize = 8; |
244 | |
245 | auto *inputType = |
246 | inferMod.uniqueType(glow::ElemKind::FloatTy, {minibatchSize, 1, 28, 28}); |
247 | const char *inputName = "data" ; |
248 | |
249 | Error errPtr = Error::empty(); |
250 | // Load and compile LeNet MNIST model. |
251 | glow::Caffe2ModelLoader loader("lenet_mnist/predict_net.pb" , |
252 | "lenet_mnist/init_net.pb" , {inputName}, |
253 | {inputType}, *F, &errPtr); |
254 | |
255 | LOG(INFO) << "Loaded graph topology." ; |
256 | |
257 | if (errPtr) { |
258 | LOG(ERROR) << "Loader failed to load lenet_mnist model." ; |
259 | return; |
260 | } |
261 | |
262 | Placeholder *selectedI{nullptr}; |
263 | if ((errPtr = |
264 | glow::prepareFunctionForTraining(F, inferBindings, selectedI))) { |
265 | return; |
266 | } |
267 | |
268 | inferBindings.allocate(inferMod.getPlaceholders()); |
269 | |
270 | // Load the model a second time for training. |
271 | // TODO: remove once EE2 is able to compile in different modes. |
272 | ExecutionEngine EET_(executionBackend); |
273 | auto &trainMod = EET_.getModule(); |
274 | auto *TF = trainMod.createFunction("lenet_mnist_train" ); |
275 | glow::Caffe2ModelLoader trainingLoader("lenet_mnist/predict_net.pb" , |
276 | "lenet_mnist/init_net.pb" , {inputName}, |
277 | {inputType}, *TF, &errPtr); |
278 | |
279 | if (errPtr) { |
280 | LOG(ERROR) << "Loader failed to load lenet_mnist model for training." ; |
281 | return; |
282 | } |
283 | |
284 | Placeholder *selected{nullptr}; |
285 | if ((errPtr = |
286 | glow::prepareFunctionForTraining(TF, trainingBindings, selected))) { |
287 | return; |
288 | } |
289 | |
290 | const int numIterations = 30; |
291 | // Get input placeholder. |
292 | auto *A = llvm::cast<glow::Placeholder>( |
293 | EXIT_ON_ERR(trainingLoader.getNodeValueByName(inputName))); |
294 | |
295 | trainModel(EET_, trainingBindings, TF, minibatchSize, numIterations, |
296 | imageInputs, labelInputs, A, selected); |
297 | |
298 | // Get input and output placeholders. |
299 | A = llvm::cast<glow::Placeholder>( |
300 | EXIT_ON_ERR(loader.getNodeValueByName(inputName))); |
301 | auto *E = EXIT_ON_ERR(loader.getSingleOutput()); |
302 | trainingBindings.copyTrainableWeightsTo(inferBindings); |
303 | |
304 | validateModel(EEI_, inferBindings, F, minibatchSize, numIterations, |
305 | imageInputs, labelInputs, A, E, true /*transpose*/); |
306 | } |
307 | |
308 | int main(int argc, char **argv) { |
309 | llvm::cl::ParseCommandLineOptions(argc, argv, " The MNIST test\n\n" ); |
310 | testMNIST(); |
311 | testMNISTLoadAndTraining(); |
312 | |
313 | return 0; |
314 | } |
315 | |