1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
17 | #include "glow/Graph/Graph.h" |
18 | #include "glow/Support/Support.h" |
19 | |
20 | #include "llvm/Support/CommandLine.h" |
21 | #include "llvm/Support/Timer.h" |
22 | #include "llvm/Support/raw_ostream.h" |
23 | |
24 | #include <glog/logging.h> |
25 | |
26 | #include <fstream> |
27 | #include <iostream> |
28 | |
29 | using namespace glow; |
30 | |
31 | enum class ModelKind { |
32 | MODEL_SIMPLE, |
33 | MODEL_VGG, |
34 | }; |
35 | |
36 | namespace { |
37 | llvm::cl::OptionCategory cifarCat("CIFAR10 Options" ); |
38 | llvm::cl::opt<std::string> executionBackend( |
39 | "backend" , |
40 | llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:" ), |
41 | llvm::cl::Optional, llvm::cl::init("Interpreter" ), llvm::cl::cat(cifarCat)); |
42 | llvm::cl::opt<ModelKind> |
43 | model(llvm::cl::desc("Model to use:" ), llvm::cl::Optional, |
44 | llvm::cl::values(clEnumValN(ModelKind::MODEL_SIMPLE, "model-simple" , |
45 | "simple default model" ), |
46 | clEnumValN(ModelKind::MODEL_VGG, "model-vgg" , |
47 | "model similar to vgg11" )), |
48 | llvm::cl::init(ModelKind::MODEL_SIMPLE), llvm::cl::cat(cifarCat)); |
49 | } // namespace |
50 | |
51 | /// The CIFAR file format is structured as one byte label in the range 0..9. |
52 | /// The label is followed by an image: 32 x 32 pixels, in RGB format. Each |
53 | /// color is 1 byte. The first 1024 red bytes are followed by 1024 of green |
54 | /// and blue. Each 1024 byte color slice is organized in row-major format. |
55 | /// The database contains 10000 images. |
56 | /// Size: (1 + (32 * 32 * 3)) * 10000 = 30730000. |
57 | const size_t cifarImageSize = 1 + (32 * 32 * 3); |
58 | const dim_t cifarNumImages = 10000; |
59 | const unsigned numLabels = 10; |
60 | |
61 | static Placeholder *createDefaultModel(PlaceholderBindings &bindings, |
62 | Function *F, NodeValue input, |
63 | NodeValue expected) { |
64 | auto *CV0 = F->createConv(bindings, "conv" , input, 16, 5, 1, 2, 1); |
65 | auto *RL0 = F->createRELU("relu" , CV0); |
66 | auto *MP0 = F->createMaxPool("pool" , RL0, 2, 2, 0); |
67 | |
68 | auto *CV1 = F->createConv(bindings, "conv" , MP0->getResult(), 20, 5, 1, 2, 1); |
69 | auto *RL1 = F->createRELU("relu" , CV1); |
70 | auto *MP1 = F->createMaxPool("pool" , RL1, 2, 2, 0); |
71 | |
72 | auto *CV2 = F->createConv(bindings, "conv" , MP1->getResult(), 20, 5, 1, 2, 1); |
73 | auto *RL2 = F->createRELU("relu" , CV2); |
74 | auto *MP2 = F->createMaxPool("pool" , RL2, 2, 2, 0); |
75 | |
76 | auto *FCL1 = |
77 | F->createFullyConnected(bindings, "fc" , MP2->getResult(), numLabels); |
78 | auto *SM = F->createSoftMax("softmax" , FCL1, expected); |
79 | auto *save = F->createSave("ret" , SM); |
80 | return save->getPlaceholder(); |
81 | } |
82 | |
83 | /// Creates a VGG Model. Inspired by pytorch/torchvision vgg.py/vgg11: |
84 | /// https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py |
85 | static Placeholder *createVGGModel(PlaceholderBindings &bindings, Function *F, |
86 | NodeValue input, NodeValue expected) { |
87 | NodeValue v = input; |
88 | |
89 | // Create feature detection part. |
90 | unsigned cfg[] = {64, 0, 128, 0, 256, 256, 0, 512, 512, 0, 512, 512, 0}; |
91 | for (unsigned c : cfg) { |
92 | if (c == 0) { |
93 | v = F->createMaxPool("pool" , v, 2, 2, 0); |
94 | } else { |
95 | auto *conv = F->createConv(bindings, "conv" , v, c, 3, 1, 1, 1); |
96 | auto *relu = F->createRELU("relu" , conv); |
97 | v = relu; |
98 | } |
99 | } |
100 | |
101 | // Create classifier part. |
102 | for (unsigned i = 0; i < 2; ++i) { |
103 | auto *fc0 = F->createFullyConnected(bindings, "fc" , v, 4096); |
104 | auto *relu0 = F->createRELU("relu" , fc0); |
105 | // TODO: There is not builtin dropout node in glow yet |
106 | // Dropout |
107 | v = relu0; |
108 | } |
109 | v = F->createFullyConnected(bindings, "fc" , v, numLabels); |
110 | auto *softmax = F->createSoftMax("softmax" , v, expected); |
111 | auto *save = F->createSave("ret" , softmax); |
112 | return save->getPlaceholder(); |
113 | } |
114 | |
115 | /// This test classifies digits from the CIFAR labeled dataset. |
116 | /// Details: http://www.cs.toronto.edu/~kriz/cifar.html |
117 | /// Dataset: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz |
118 | void testCIFAR10() { |
119 | (void)cifarImageSize; |
120 | const char *textualLabels[] = {"airplane" , "automobile" , "bird" , "cat" , |
121 | "deer" , "dog" , "frog" , "horse" , |
122 | "ship" , "truck" }; |
123 | |
124 | std::ifstream dbInput("cifar-10-batches-bin/data_batch_1.bin" , |
125 | std::ios::binary); |
126 | |
127 | CHECK(dbInput.is_open()) |
128 | << "Failed to open cifar10 data file, probably missing. Run 'python " |
129 | "../glow/utils/download_datasets_and_models.py -d cifar10'" ; |
130 | |
131 | LOG(INFO) << "Loading the CIFAR-10 database.\n" ; |
132 | |
133 | /// Load the CIFAR database into a 4d tensor. |
134 | Tensor images(ElemKind::FloatTy, {cifarNumImages, 32, 32, 3}); |
135 | Tensor labels(ElemKind::Int64ITy, {cifarNumImages, 1}); |
136 | size_t idx = 0; |
137 | |
138 | auto labelsH = labels.getHandle<int64_t>(); |
139 | auto imagesH = images.getHandle<>(); |
140 | for (unsigned w = 0; w < cifarNumImages; w++) { |
141 | labelsH.at({w, 0}) = static_cast<uint8_t>(dbInput.get()); |
142 | idx++; |
143 | |
144 | for (unsigned z = 0; z < 3; z++) { |
145 | for (unsigned y = 0; y < 32; y++) { |
146 | for (unsigned x = 0; x < 32; x++) { |
147 | imagesH.at({w, x, y, z}) = |
148 | static_cast<float>(static_cast<uint8_t>(dbInput.get())) / 255.0; |
149 | idx++; |
150 | } |
151 | } |
152 | } |
153 | } |
154 | CHECK_EQ(idx, cifarImageSize * cifarNumImages) << "Invalid input file" ; |
155 | |
156 | unsigned minibatchSize = 8; |
157 | |
158 | // Construct the network: |
159 | TrainingConfig TC; |
160 | |
161 | ExecutionEngine EE(executionBackend); |
162 | PlaceholderBindings bindings; |
163 | |
164 | TC.learningRate = 0.001; |
165 | TC.momentum = 0.9; |
166 | TC.L2Decay = 0.0001; |
167 | TC.batchSize = minibatchSize; |
168 | |
169 | auto &mod = EE.getModule(); |
170 | Function *F = mod.createFunction("main" ); |
171 | |
172 | // Create the input layer: |
173 | auto *A = mod.createPlaceholder(ElemKind::FloatTy, {minibatchSize, 32, 32, 3}, |
174 | "input" , false); |
175 | bindings.allocate(A); |
176 | auto *E = mod.createPlaceholder(ElemKind::Int64ITy, {minibatchSize, 1}, |
177 | "expected" , false); |
178 | bindings.allocate(E); |
179 | |
180 | auto createModel = |
181 | model == ModelKind::MODEL_SIMPLE ? createDefaultModel : createVGGModel; |
182 | auto *resultPH = createModel(bindings, F, A, E); |
183 | auto *result = bindings.allocate(resultPH); |
184 | |
185 | Function *TF = glow::differentiate(F, TC); |
186 | auto tfName = TF->getName(); |
187 | EE.compile(CompilationMode::Train); |
188 | bindings.allocate(mod.getPlaceholders()); |
189 | |
190 | // Report progress every this number of training iterations. |
191 | // Report less often for fast models. |
192 | int reportRate = model == ModelKind::MODEL_SIMPLE ? 256 : 64; |
193 | |
194 | LOG(INFO) << "Training." ; |
195 | |
196 | // This variable records the number of the next sample to be used for |
197 | // training. |
198 | size_t sampleCounter = 0; |
199 | |
200 | for (int iter = 0; iter < 100000; iter++) { |
201 | unsigned epoch = (iter * reportRate) / labels.getType().sizes_[0]; |
202 | LOG(INFO) << "Training - iteration #" << iter << " (epoch #" << epoch |
203 | << ")" ; |
204 | |
205 | llvm::Timer timer("Training" , "Training" ); |
206 | timer.startTimer(); |
207 | |
208 | // Bind the images tensor to the input array A, and the labels tensor |
209 | // to the softmax node SM. |
210 | runBatch(EE, bindings, reportRate, sampleCounter, {A, E}, |
211 | {&images, &labels}, tfName); |
212 | |
213 | unsigned score = 0; |
214 | |
215 | for (unsigned int i = 0; i < 100 / minibatchSize; i++) { |
216 | Tensor sample(ElemKind::FloatTy, {minibatchSize, 32, 32, 3}); |
217 | sample.copyConsecutiveSlices(&images, minibatchSize * i); |
218 | updateInputPlaceholders(bindings, {A}, {&sample}); |
219 | EE.run(bindings); |
220 | |
221 | for (unsigned int iter = 0; iter < minibatchSize; iter++) { |
222 | auto T = result->getHandle<>().extractSlice(iter); |
223 | size_t guess = T.getHandle<>().minMaxArg().second; |
224 | size_t correct = labelsH.at({minibatchSize * i + iter, 0}); |
225 | score += guess == correct; |
226 | |
227 | if ((iter < numLabels) && i == 0) { |
228 | LOG(INFO) << iter << ") Expected: " << textualLabels[correct] |
229 | << " Got: " << textualLabels[guess]; |
230 | } |
231 | } |
232 | } |
233 | |
234 | timer.stopTimer(); |
235 | |
236 | LOG(INFO) << "Iteration #" << iter << " score: " << score << "%" ; |
237 | } |
238 | } |
239 | |
240 | int main(int argc, char **argv) { |
241 | llvm::cl::ParseCommandLineOptions(argc, argv, " The CIFAR10 test\n\n" ); |
242 | testCIFAR10(); |
243 | |
244 | return 0; |
245 | } |
246 | |