1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/ExecutionEngine/ExecutionEngine.h"
17#include "glow/Graph/Graph.h"
18#include "glow/Support/Support.h"
19
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/Timer.h"
22#include "llvm/Support/raw_ostream.h"
23
24#include <glog/logging.h>
25
26#include <fstream>
27#include <iostream>
28
29using namespace glow;
30
31enum class ModelKind {
32 MODEL_SIMPLE,
33 MODEL_VGG,
34};
35
36namespace {
37llvm::cl::OptionCategory cifarCat("CIFAR10 Options");
38llvm::cl::opt<std::string> executionBackend(
39 "backend",
40 llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:"),
41 llvm::cl::Optional, llvm::cl::init("Interpreter"), llvm::cl::cat(cifarCat));
42llvm::cl::opt<ModelKind>
43 model(llvm::cl::desc("Model to use:"), llvm::cl::Optional,
44 llvm::cl::values(clEnumValN(ModelKind::MODEL_SIMPLE, "model-simple",
45 "simple default model"),
46 clEnumValN(ModelKind::MODEL_VGG, "model-vgg",
47 "model similar to vgg11")),
48 llvm::cl::init(ModelKind::MODEL_SIMPLE), llvm::cl::cat(cifarCat));
49} // namespace
50
51/// The CIFAR file format is structured as one byte label in the range 0..9.
52/// The label is followed by an image: 32 x 32 pixels, in RGB format. Each
53/// color is 1 byte. The first 1024 red bytes are followed by 1024 of green
54/// and blue. Each 1024 byte color slice is organized in row-major format.
55/// The database contains 10000 images.
56/// Size: (1 + (32 * 32 * 3)) * 10000 = 30730000.
57const size_t cifarImageSize = 1 + (32 * 32 * 3);
58const dim_t cifarNumImages = 10000;
59const unsigned numLabels = 10;
60
61static Placeholder *createDefaultModel(PlaceholderBindings &bindings,
62 Function *F, NodeValue input,
63 NodeValue expected) {
64 auto *CV0 = F->createConv(bindings, "conv", input, 16, 5, 1, 2, 1);
65 auto *RL0 = F->createRELU("relu", CV0);
66 auto *MP0 = F->createMaxPool("pool", RL0, 2, 2, 0);
67
68 auto *CV1 = F->createConv(bindings, "conv", MP0->getResult(), 20, 5, 1, 2, 1);
69 auto *RL1 = F->createRELU("relu", CV1);
70 auto *MP1 = F->createMaxPool("pool", RL1, 2, 2, 0);
71
72 auto *CV2 = F->createConv(bindings, "conv", MP1->getResult(), 20, 5, 1, 2, 1);
73 auto *RL2 = F->createRELU("relu", CV2);
74 auto *MP2 = F->createMaxPool("pool", RL2, 2, 2, 0);
75
76 auto *FCL1 =
77 F->createFullyConnected(bindings, "fc", MP2->getResult(), numLabels);
78 auto *SM = F->createSoftMax("softmax", FCL1, expected);
79 auto *save = F->createSave("ret", SM);
80 return save->getPlaceholder();
81}
82
83/// Creates a VGG Model. Inspired by pytorch/torchvision vgg.py/vgg11:
84/// https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
85static Placeholder *createVGGModel(PlaceholderBindings &bindings, Function *F,
86 NodeValue input, NodeValue expected) {
87 NodeValue v = input;
88
89 // Create feature detection part.
90 unsigned cfg[] = {64, 0, 128, 0, 256, 256, 0, 512, 512, 0, 512, 512, 0};
91 for (unsigned c : cfg) {
92 if (c == 0) {
93 v = F->createMaxPool("pool", v, 2, 2, 0);
94 } else {
95 auto *conv = F->createConv(bindings, "conv", v, c, 3, 1, 1, 1);
96 auto *relu = F->createRELU("relu", conv);
97 v = relu;
98 }
99 }
100
101 // Create classifier part.
102 for (unsigned i = 0; i < 2; ++i) {
103 auto *fc0 = F->createFullyConnected(bindings, "fc", v, 4096);
104 auto *relu0 = F->createRELU("relu", fc0);
105 // TODO: There is not builtin dropout node in glow yet
106 // Dropout
107 v = relu0;
108 }
109 v = F->createFullyConnected(bindings, "fc", v, numLabels);
110 auto *softmax = F->createSoftMax("softmax", v, expected);
111 auto *save = F->createSave("ret", softmax);
112 return save->getPlaceholder();
113}
114
115/// This test classifies digits from the CIFAR labeled dataset.
116/// Details: http://www.cs.toronto.edu/~kriz/cifar.html
117/// Dataset: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
118void testCIFAR10() {
119 (void)cifarImageSize;
120 const char *textualLabels[] = {"airplane", "automobile", "bird", "cat",
121 "deer", "dog", "frog", "horse",
122 "ship", "truck"};
123
124 std::ifstream dbInput("cifar-10-batches-bin/data_batch_1.bin",
125 std::ios::binary);
126
127 CHECK(dbInput.is_open())
128 << "Failed to open cifar10 data file, probably missing. Run 'python "
129 "../glow/utils/download_datasets_and_models.py -d cifar10'";
130
131 LOG(INFO) << "Loading the CIFAR-10 database.\n";
132
133 /// Load the CIFAR database into a 4d tensor.
134 Tensor images(ElemKind::FloatTy, {cifarNumImages, 32, 32, 3});
135 Tensor labels(ElemKind::Int64ITy, {cifarNumImages, 1});
136 size_t idx = 0;
137
138 auto labelsH = labels.getHandle<int64_t>();
139 auto imagesH = images.getHandle<>();
140 for (unsigned w = 0; w < cifarNumImages; w++) {
141 labelsH.at({w, 0}) = static_cast<uint8_t>(dbInput.get());
142 idx++;
143
144 for (unsigned z = 0; z < 3; z++) {
145 for (unsigned y = 0; y < 32; y++) {
146 for (unsigned x = 0; x < 32; x++) {
147 imagesH.at({w, x, y, z}) =
148 static_cast<float>(static_cast<uint8_t>(dbInput.get())) / 255.0;
149 idx++;
150 }
151 }
152 }
153 }
154 CHECK_EQ(idx, cifarImageSize * cifarNumImages) << "Invalid input file";
155
156 unsigned minibatchSize = 8;
157
158 // Construct the network:
159 TrainingConfig TC;
160
161 ExecutionEngine EE(executionBackend);
162 PlaceholderBindings bindings;
163
164 TC.learningRate = 0.001;
165 TC.momentum = 0.9;
166 TC.L2Decay = 0.0001;
167 TC.batchSize = minibatchSize;
168
169 auto &mod = EE.getModule();
170 Function *F = mod.createFunction("main");
171
172 // Create the input layer:
173 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {minibatchSize, 32, 32, 3},
174 "input", false);
175 bindings.allocate(A);
176 auto *E = mod.createPlaceholder(ElemKind::Int64ITy, {minibatchSize, 1},
177 "expected", false);
178 bindings.allocate(E);
179
180 auto createModel =
181 model == ModelKind::MODEL_SIMPLE ? createDefaultModel : createVGGModel;
182 auto *resultPH = createModel(bindings, F, A, E);
183 auto *result = bindings.allocate(resultPH);
184
185 Function *TF = glow::differentiate(F, TC);
186 auto tfName = TF->getName();
187 EE.compile(CompilationMode::Train);
188 bindings.allocate(mod.getPlaceholders());
189
190 // Report progress every this number of training iterations.
191 // Report less often for fast models.
192 int reportRate = model == ModelKind::MODEL_SIMPLE ? 256 : 64;
193
194 LOG(INFO) << "Training.";
195
196 // This variable records the number of the next sample to be used for
197 // training.
198 size_t sampleCounter = 0;
199
200 for (int iter = 0; iter < 100000; iter++) {
201 unsigned epoch = (iter * reportRate) / labels.getType().sizes_[0];
202 LOG(INFO) << "Training - iteration #" << iter << " (epoch #" << epoch
203 << ")";
204
205 llvm::Timer timer("Training", "Training");
206 timer.startTimer();
207
208 // Bind the images tensor to the input array A, and the labels tensor
209 // to the softmax node SM.
210 runBatch(EE, bindings, reportRate, sampleCounter, {A, E},
211 {&images, &labels}, tfName);
212
213 unsigned score = 0;
214
215 for (unsigned int i = 0; i < 100 / minibatchSize; i++) {
216 Tensor sample(ElemKind::FloatTy, {minibatchSize, 32, 32, 3});
217 sample.copyConsecutiveSlices(&images, minibatchSize * i);
218 updateInputPlaceholders(bindings, {A}, {&sample});
219 EE.run(bindings);
220
221 for (unsigned int iter = 0; iter < minibatchSize; iter++) {
222 auto T = result->getHandle<>().extractSlice(iter);
223 size_t guess = T.getHandle<>().minMaxArg().second;
224 size_t correct = labelsH.at({minibatchSize * i + iter, 0});
225 score += guess == correct;
226
227 if ((iter < numLabels) && i == 0) {
228 LOG(INFO) << iter << ") Expected: " << textualLabels[correct]
229 << " Got: " << textualLabels[guess];
230 }
231 }
232 }
233
234 timer.stopTimer();
235
236 LOG(INFO) << "Iteration #" << iter << " score: " << score << "%";
237 }
238}
239
240int main(int argc, char **argv) {
241 llvm::cl::ParseCommandLineOptions(argc, argv, " The CIFAR10 test\n\n");
242 testCIFAR10();
243
244 return 0;
245}
246