1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | // This file contains a set of tests related to a toy neural network |
18 | // that can hyphenate words. This isn't meant to represent great strides |
19 | // in machine learning research, but rather to exercise the CPU JIT |
20 | // compiler with a small end-to-end example. The toy network is small |
21 | // enough that it can be trained as part of the unit test suite. |
22 | |
23 | #include "BackendTestUtils.h" |
24 | |
25 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
26 | #include "glow/Graph/Graph.h" |
27 | #include "glow/IR/IR.h" |
28 | #include "glow/IR/IRBuilder.h" |
29 | #include "glow/IR/Instrs.h" |
30 | #include "glow/Support/Random.h" |
31 | |
32 | #include "gtest/gtest.h" |
33 | |
34 | #include <cctype> |
35 | #include <string> |
36 | |
37 | using namespace glow; |
38 | using llvm::cast; |
39 | using std::string; |
40 | using std::vector; |
41 | |
42 | // Network architecture |
43 | // ==================== |
44 | // |
45 | // The network is a simple multi-layer perceptron with 27 x 6 input |
46 | // nodes, 10 inner nodes, and a 2-way soft-max output node. |
47 | // |
48 | // The input nodes represent 6 letters of a candidate word, and the |
49 | // output node indicates the probability that the word can be hyphenated |
50 | // between the 3rd and 4th letters. As a word slides past the 6-letter |
51 | // window, the network classifies each possible hyphen position. |
52 | // |
53 | // Example: "hyphenate" |
54 | // |
55 | // "..hyph" -> 0, h-yphenate is wrong. |
56 | // ".hyphe" -> 1, hy-phenate is right. |
57 | // "hyphen" -> 0, hyp-henate is wrong. |
58 | // "yphena" -> 0, hyph-enate is wrong. |
59 | // "phenat" -> 0, hyphe-nate is wrong. |
60 | // "henate" -> 1, hyphen-ate is right. |
61 | // "enate." -> 0, hyphena-te is wrong. |
62 | // "nate.." -> 0, hyphenat-e is wrong. |
63 | |
64 | /// Parse an already hyphenated word into word windows and hyphen labels. |
65 | /// |
66 | /// Given a word with embedded hyphens, generate a sequence of sliding |
67 | /// 6-character windows and associated boolean labels, like the table above. |
68 | /// |
69 | static void dehyphenate(const char *hword, vector<string> &words, |
70 | vector<bool> &hyphens) { |
71 | EXPECT_EQ(words.size(), hyphens.size()); |
72 | |
73 | // The first character can't be a hyphen, and the word can't be null. |
74 | EXPECT_TRUE(std::islower(*hword)); |
75 | string word = ".." ; |
76 | word.push_back(*hword++); |
77 | |
78 | // Parse `hword` and add all the letters to `word` and hyphen/no-hyphen |
79 | // entries to `hyphens`. |
80 | for (; *hword; hword++) { |
81 | bool hyph = (*hword == '-'); |
82 | hyphens.push_back(hyph); |
83 | if (hyph) { |
84 | hword++; |
85 | } |
86 | // There can't be multiple adjacent hyphens, and the word can't |
87 | // end with a hyphen. |
88 | EXPECT_TRUE(std::islower(*hword)); |
89 | word.push_back(*hword); |
90 | } |
91 | word += ".." ; |
92 | |
93 | // Now `word` contains the letters of `hword` surrounded by '..' on both |
94 | // sides. Generate all 6-character windows and append them to `words`. |
95 | for (size_t i = 0, e = word.size(); i + 5 < e; i++) { |
96 | words.push_back(word.substr(i, 6)); |
97 | } |
98 | EXPECT_EQ(words.size(), hyphens.size()); |
99 | } |
100 | |
101 | TEST(HyphenTest, dehyphenate) { |
102 | vector<string> words; |
103 | vector<bool> hyphens; |
104 | |
105 | dehyphenate("x" , words, hyphens); |
106 | EXPECT_EQ(words.size(), 0); |
107 | EXPECT_EQ(hyphens.size(), 0); |
108 | |
109 | dehyphenate("xy" , words, hyphens); |
110 | EXPECT_EQ(words, (vector<string>{"..xy.." })); |
111 | EXPECT_EQ(hyphens, (vector<bool>{0})); |
112 | |
113 | dehyphenate("y-z" , words, hyphens); |
114 | EXPECT_EQ(words, (vector<string>{"..xy.." , "..yz.." })); |
115 | EXPECT_EQ(hyphens, (vector<bool>{0, 1})); |
116 | |
117 | words.clear(); |
118 | hyphens.clear(); |
119 | dehyphenate("hy-phen-ate" , words, hyphens); |
120 | EXPECT_EQ(words, (vector<string>{"..hyph" , ".hyphe" , "hyphen" , "yphena" , |
121 | "phenat" , "henate" , "enate." , "nate.." })); |
122 | EXPECT_EQ(hyphens, (vector<bool>{0, 1, 0, 0, 0, 1, 0, 0})); |
123 | } |
124 | |
125 | /// Map a lower-case letter to an input index in the range 0-26. |
126 | /// Use 0 to represent any characters outside the a-z range. |
127 | static size_t mapLetter(char l) { |
128 | unsigned d = l - unsigned('a'); |
129 | return d < 26 ? d + 1 : 0; |
130 | } |
131 | |
132 | TEST(HyphenTest, mapLetter) { |
133 | EXPECT_EQ(mapLetter('a'), 1); |
134 | EXPECT_EQ(mapLetter('d'), 4); |
135 | EXPECT_EQ(mapLetter('z'), 26); |
136 | EXPECT_EQ(mapLetter('.'), 0); |
137 | } |
138 | |
139 | /// Map a 6-letter window of a word to an input tensor using a one-hot encoding. |
140 | /// |
141 | /// The tensor must be N x 6 x 27: batch x position x letter. |
142 | static void mapLetterWindow(const string &window, dim_t idx, |
143 | Handle<float> tensor) { |
144 | EXPECT_EQ(window.size(), 6); |
145 | for (dim_t row = 0; row < 6; row++) { |
146 | dim_t col = mapLetter(window[row]); |
147 | tensor.at({idx, row, col}) = 1; |
148 | } |
149 | } |
150 | |
151 | // Training data consisting of pre-hyphenated common words. |
152 | const vector<const char *> TrainingData{ |
153 | "ad-mi-ni-stra-tion" , |
154 | "ad-mit" , |
155 | "al-low" , |
156 | "al-though" , |
157 | "an-i-mal" , |
158 | "any-one" , |
159 | "ar-rive" , |
160 | "art" , |
161 | "at-tor-ney" , |
162 | "be-cause" , |
163 | "be-fore" , |
164 | "be-ha-vior" , |
165 | "can-cer" , |
166 | "cer-tain-ly" , |
167 | "con-gress" , |
168 | "coun-try" , |
169 | "cul-tural" , |
170 | "cul-ture" , |
171 | "de-cide" , |
172 | "de-fense" , |
173 | "de-gree" , |
174 | "de-sign" , |
175 | "de-spite" , |
176 | "de-velop" , |
177 | "di-rec-tion" , |
178 | "di-rec-tor" , |
179 | "dis-cus-sion" , |
180 | "eco-nomy" , |
181 | "elec-tion" , |
182 | "en-vi-ron-men-tal" , |
183 | "es-tab-lish" , |
184 | "ev-ery-one" , |
185 | "ex-actly" , |
186 | "ex-ec-u-tive" , |
187 | "ex-ist" , |
188 | "ex-pe-ri-ence" , |
189 | "ex-plain" , |
190 | "fi-nally" , |
191 | "for-get" , |
192 | "hun-dred" , |
193 | "in-crease" , |
194 | "in-di-vid-ual" , |
195 | "it-self" , |
196 | "lan-guage" , |
197 | "le-gal" , |
198 | "lit-tle" , |
199 | "lo-cal" , |
200 | "ma-jo-ri-ty" , |
201 | "ma-te-rial" , |
202 | "may-be" , |
203 | "me-di-cal" , |
204 | "meet-ing" , |
205 | "men-tion" , |
206 | "mid-dle" , |
207 | "na-tion" , |
208 | "na-tional" , |
209 | "oc-cur" , |
210 | "of-fi-cer" , |
211 | "par-tic-u-lar-ly" , |
212 | "pat-tern" , |
213 | "pe-riod" , |
214 | "phy-si-cal" , |
215 | "po-si-tion" , |
216 | "pol-icy" , |
217 | "pos-si-ble" , |
218 | "pre-vent" , |
219 | "pres-sure" , |
220 | "pro-per-ty" , |
221 | "pur-pose" , |
222 | "re-cog-nize" , |
223 | "re-gion" , |
224 | "re-la-tion-ship" , |
225 | "re-main" , |
226 | "re-sponse" , |
227 | "re-sult" , |
228 | "rea-son" , |
229 | "sea-son" , |
230 | "sex-ual" , |
231 | "si-mi-lar" , |
232 | "sig-ni-fi-cant" , |
233 | "sim-ple" , |
234 | "sud-den-ly" , |
235 | "sum-mer" , |
236 | "thou-sand" , |
237 | "to-day" , |
238 | "train-ing" , |
239 | "treat-ment" , |
240 | "va-ri-ous" , |
241 | "value" , |
242 | "vi-o-lence" , |
243 | }; |
244 | |
245 | namespace { |
246 | struct HyphenNetwork { |
247 | /// The execution context. |
248 | PlaceholderBindings bindings_; |
249 | |
250 | /// The input variable is N x 6 x 27 as encoded by mapLetterWindow(). |
251 | Placeholder *input_; |
252 | |
253 | /// The expected output index when training: 0 = no hyphen, 1 = hyphen. |
254 | Placeholder *expected_; |
255 | |
256 | /// The forward inference function. |
257 | Function *infer_; |
258 | |
259 | /// The result of the forward inference. N x 1 float with a probability. |
260 | SaveNode *result_; |
261 | |
262 | /// The corresponding gradient function for training. |
263 | Function *train_; |
264 | |
265 | HyphenNetwork(Module &mod, TrainingConfig &conf) |
266 | : input_(mod.createPlaceholder(ElemKind::FloatTy, {conf.batchSize, 6, 27}, |
267 | "input" , false)), |
268 | expected_(mod.createPlaceholder(ElemKind::Int64ITy, {conf.batchSize, 1}, |
269 | "expected" , false)), |
270 | infer_(mod.createFunction("infer" )), result_(nullptr), train_(nullptr) { |
271 | bindings_.allocate(input_); |
272 | bindings_.allocate(expected_); |
273 | Node *n; |
274 | |
275 | n = infer_->createFullyConnected(bindings_, "hidden_fc" , input_, 10); |
276 | n = infer_->createRELU("hidden" , n); |
277 | n = infer_->createFullyConnected(bindings_, "output_fc" , n, 2); |
278 | n = infer_->createSoftMax("output" , n, expected_); |
279 | result_ = infer_->createSave("result" , n); |
280 | bindings_.allocate(result_->getPlaceholder()); |
281 | train_ = glow::differentiate(infer_, conf); |
282 | } |
283 | |
284 | // Run `inputs` through the inference function and check the results against |
285 | // `hyphens`. Return the number of errors. |
286 | unsigned inferenceErrors(ExecutionEngine &EE, llvm::StringRef fName, |
287 | Tensor &inputs, const vector<bool> &hyphens, |
288 | TrainingConfig &TC) { |
289 | dim_t batchSize = TC.batchSize; |
290 | dim_t numSamples = inputs.dims()[0]; |
291 | EXPECT_LE(batchSize, numSamples); |
292 | auto resultHandle = |
293 | bindings_.get(bindings_.getPlaceholderByNameSlow("result" )) |
294 | ->getHandle<>(); |
295 | unsigned errors = 0; |
296 | |
297 | for (dim_t bi = 0; bi < numSamples; bi += batchSize) { |
298 | // Get a batch-sized slice of inputs and run them through the inference |
299 | // function. Do a bit of overlapping if the batch size doesn't divide the |
300 | // number of samples. |
301 | if (bi + batchSize > numSamples) { |
302 | bi = numSamples - batchSize; |
303 | } |
304 | auto batchInputs = inputs.getUnowned({batchSize, 6, 27}, {bi, 0, 0}); |
305 | updateInputPlaceholders(bindings_, {input_}, {&batchInputs}); |
306 | EE.run(bindings_, fName); |
307 | |
308 | // Check each output in the batch. |
309 | for (dim_t i = 0; i != batchSize; i++) { |
310 | // Note that the two softmax outputs always sum to 1, so we only look at |
311 | // one. |
312 | float value = resultHandle.at({i, 1}); |
313 | if ((value > 0.5) != hyphens[bi + i]) { |
314 | errors++; |
315 | } |
316 | } |
317 | } |
318 | return errors; |
319 | } |
320 | }; |
321 | } // namespace |
322 | |
323 | TEST(HyphenTest, network) { |
324 | ExecutionEngine EE("CPU" ); |
325 | |
326 | // Convert the training data to word windows and labels. |
327 | vector<string> words; |
328 | vector<bool> hyphens; |
329 | for (auto *hword : TrainingData) { |
330 | dehyphenate(hword, words, hyphens); |
331 | } |
332 | |
333 | // This depends on the training data, of course. |
334 | const dim_t numSamples = 566; |
335 | ASSERT_EQ(hyphens.size(), numSamples); |
336 | ASSERT_EQ(words.size(), numSamples); |
337 | |
338 | // Randomly shuffle the training data. |
339 | // This is required for stochastic gradient descent training. |
340 | auto &PRNG = EE.getModule().getPRNG(); |
341 | for (size_t i = numSamples - 1; i > 0; i--) { |
342 | size_t j = PRNG.nextRandInt(0, i); |
343 | std::swap(words[i], words[j]); |
344 | std::swap(hyphens[i], hyphens[j]); |
345 | } |
346 | |
347 | // Convert words and hyphens to a tensor representation. |
348 | Tensor inputs(ElemKind::FloatTy, {numSamples, 6, 27}); |
349 | Tensor expected(ElemKind::Int64ITy, {numSamples, 1}); |
350 | inputs.zero(); |
351 | auto inputHandle = inputs.getHandle<float>(); |
352 | auto expectedHandle = expected.getHandle<int64_t>(); |
353 | for (dim_t i = 0; i != numSamples; i++) { |
354 | mapLetterWindow(words[i], i, inputHandle); |
355 | expectedHandle.at({i, 0}) = hyphens[i]; |
356 | } |
357 | |
358 | // Now build the network. |
359 | TrainingConfig TC; |
360 | TC.learningRate = 0.8; |
361 | TC.batchSize = 50; |
362 | HyphenNetwork net(EE.getModule(), TC); |
363 | auto fName = net.infer_->getName(); |
364 | auto tfName = net.train_->getName(); |
365 | |
366 | // This variable records the number of the next sample to be used for |
367 | // training. |
368 | size_t sampleCounter = 0; |
369 | |
370 | // Train using mini-batch SGD. |
371 | EE.compile(CompilationMode::Train); |
372 | runBatch(EE, net.bindings_, 1000, sampleCounter, {net.input_, net.expected_}, |
373 | {&inputs, &expected}, tfName); |
374 | |
375 | // Now test inference on the trained network. |
376 | // Note that we have probably overfitted the data, so we expect 100% accuracy. |
377 | EXPECT_EQ(net.inferenceErrors(EE, fName, inputs, hyphens, TC), 0); |
378 | |
379 | // See of the interpreter gets the same result. |
380 | |
381 | ExecutionEngine EE2("CPU" ); |
382 | HyphenNetwork netInterpreter(EE2.getModule(), TC); |
383 | EE2.compile(CompilationMode::Train); |
384 | // Copy the trained weights from the CPU run. |
385 | net.bindings_.copyToTarget("bias" , netInterpreter.bindings_); |
386 | net.bindings_.copyToTarget("bias__1" , netInterpreter.bindings_); |
387 | net.bindings_.copyToTarget("weights" , netInterpreter.bindings_); |
388 | net.bindings_.copyToTarget("weights__1" , netInterpreter.bindings_); |
389 | |
390 | EXPECT_EQ(netInterpreter.inferenceErrors(EE2, fName, inputs, hyphens, TC), 0); |
391 | } |
392 | |