1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// This file contains a set of tests related to a toy neural network
18// that can hyphenate words. This isn't meant to represent great strides
19// in machine learning research, but rather to exercise the CPU JIT
20// compiler with a small end-to-end example. The toy network is small
21// enough that it can be trained as part of the unit test suite.
22
23#include "BackendTestUtils.h"
24
25#include "glow/ExecutionEngine/ExecutionEngine.h"
26#include "glow/Graph/Graph.h"
27#include "glow/IR/IR.h"
28#include "glow/IR/IRBuilder.h"
29#include "glow/IR/Instrs.h"
30#include "glow/Support/Random.h"
31
32#include "gtest/gtest.h"
33
34#include <cctype>
35#include <string>
36
37using namespace glow;
38using llvm::cast;
39using std::string;
40using std::vector;
41
42// Network architecture
43// ====================
44//
45// The network is a simple multi-layer perceptron with 27 x 6 input
46// nodes, 10 inner nodes, and a 2-way soft-max output node.
47//
48// The input nodes represent 6 letters of a candidate word, and the
49// output node indicates the probability that the word can be hyphenated
50// between the 3rd and 4th letters. As a word slides past the 6-letter
51// window, the network classifies each possible hyphen position.
52//
53// Example: "hyphenate"
54//
55// "..hyph" -> 0, h-yphenate is wrong.
56// ".hyphe" -> 1, hy-phenate is right.
57// "hyphen" -> 0, hyp-henate is wrong.
58// "yphena" -> 0, hyph-enate is wrong.
59// "phenat" -> 0, hyphe-nate is wrong.
60// "henate" -> 1, hyphen-ate is right.
61// "enate." -> 0, hyphena-te is wrong.
62// "nate.." -> 0, hyphenat-e is wrong.
63
64/// Parse an already hyphenated word into word windows and hyphen labels.
65///
66/// Given a word with embedded hyphens, generate a sequence of sliding
67/// 6-character windows and associated boolean labels, like the table above.
68///
69static void dehyphenate(const char *hword, vector<string> &words,
70 vector<bool> &hyphens) {
71 EXPECT_EQ(words.size(), hyphens.size());
72
73 // The first character can't be a hyphen, and the word can't be null.
74 EXPECT_TRUE(std::islower(*hword));
75 string word = "..";
76 word.push_back(*hword++);
77
78 // Parse `hword` and add all the letters to `word` and hyphen/no-hyphen
79 // entries to `hyphens`.
80 for (; *hword; hword++) {
81 bool hyph = (*hword == '-');
82 hyphens.push_back(hyph);
83 if (hyph) {
84 hword++;
85 }
86 // There can't be multiple adjacent hyphens, and the word can't
87 // end with a hyphen.
88 EXPECT_TRUE(std::islower(*hword));
89 word.push_back(*hword);
90 }
91 word += "..";
92
93 // Now `word` contains the letters of `hword` surrounded by '..' on both
94 // sides. Generate all 6-character windows and append them to `words`.
95 for (size_t i = 0, e = word.size(); i + 5 < e; i++) {
96 words.push_back(word.substr(i, 6));
97 }
98 EXPECT_EQ(words.size(), hyphens.size());
99}
100
101TEST(HyphenTest, dehyphenate) {
102 vector<string> words;
103 vector<bool> hyphens;
104
105 dehyphenate("x", words, hyphens);
106 EXPECT_EQ(words.size(), 0);
107 EXPECT_EQ(hyphens.size(), 0);
108
109 dehyphenate("xy", words, hyphens);
110 EXPECT_EQ(words, (vector<string>{"..xy.."}));
111 EXPECT_EQ(hyphens, (vector<bool>{0}));
112
113 dehyphenate("y-z", words, hyphens);
114 EXPECT_EQ(words, (vector<string>{"..xy..", "..yz.."}));
115 EXPECT_EQ(hyphens, (vector<bool>{0, 1}));
116
117 words.clear();
118 hyphens.clear();
119 dehyphenate("hy-phen-ate", words, hyphens);
120 EXPECT_EQ(words, (vector<string>{"..hyph", ".hyphe", "hyphen", "yphena",
121 "phenat", "henate", "enate.", "nate.."}));
122 EXPECT_EQ(hyphens, (vector<bool>{0, 1, 0, 0, 0, 1, 0, 0}));
123}
124
125/// Map a lower-case letter to an input index in the range 0-26.
126/// Use 0 to represent any characters outside the a-z range.
127static size_t mapLetter(char l) {
128 unsigned d = l - unsigned('a');
129 return d < 26 ? d + 1 : 0;
130}
131
132TEST(HyphenTest, mapLetter) {
133 EXPECT_EQ(mapLetter('a'), 1);
134 EXPECT_EQ(mapLetter('d'), 4);
135 EXPECT_EQ(mapLetter('z'), 26);
136 EXPECT_EQ(mapLetter('.'), 0);
137}
138
139/// Map a 6-letter window of a word to an input tensor using a one-hot encoding.
140///
141/// The tensor must be N x 6 x 27: batch x position x letter.
142static void mapLetterWindow(const string &window, dim_t idx,
143 Handle<float> tensor) {
144 EXPECT_EQ(window.size(), 6);
145 for (dim_t row = 0; row < 6; row++) {
146 dim_t col = mapLetter(window[row]);
147 tensor.at({idx, row, col}) = 1;
148 }
149}
150
151// Training data consisting of pre-hyphenated common words.
152const vector<const char *> TrainingData{
153 "ad-mi-ni-stra-tion",
154 "ad-mit",
155 "al-low",
156 "al-though",
157 "an-i-mal",
158 "any-one",
159 "ar-rive",
160 "art",
161 "at-tor-ney",
162 "be-cause",
163 "be-fore",
164 "be-ha-vior",
165 "can-cer",
166 "cer-tain-ly",
167 "con-gress",
168 "coun-try",
169 "cul-tural",
170 "cul-ture",
171 "de-cide",
172 "de-fense",
173 "de-gree",
174 "de-sign",
175 "de-spite",
176 "de-velop",
177 "di-rec-tion",
178 "di-rec-tor",
179 "dis-cus-sion",
180 "eco-nomy",
181 "elec-tion",
182 "en-vi-ron-men-tal",
183 "es-tab-lish",
184 "ev-ery-one",
185 "ex-actly",
186 "ex-ec-u-tive",
187 "ex-ist",
188 "ex-pe-ri-ence",
189 "ex-plain",
190 "fi-nally",
191 "for-get",
192 "hun-dred",
193 "in-crease",
194 "in-di-vid-ual",
195 "it-self",
196 "lan-guage",
197 "le-gal",
198 "lit-tle",
199 "lo-cal",
200 "ma-jo-ri-ty",
201 "ma-te-rial",
202 "may-be",
203 "me-di-cal",
204 "meet-ing",
205 "men-tion",
206 "mid-dle",
207 "na-tion",
208 "na-tional",
209 "oc-cur",
210 "of-fi-cer",
211 "par-tic-u-lar-ly",
212 "pat-tern",
213 "pe-riod",
214 "phy-si-cal",
215 "po-si-tion",
216 "pol-icy",
217 "pos-si-ble",
218 "pre-vent",
219 "pres-sure",
220 "pro-per-ty",
221 "pur-pose",
222 "re-cog-nize",
223 "re-gion",
224 "re-la-tion-ship",
225 "re-main",
226 "re-sponse",
227 "re-sult",
228 "rea-son",
229 "sea-son",
230 "sex-ual",
231 "si-mi-lar",
232 "sig-ni-fi-cant",
233 "sim-ple",
234 "sud-den-ly",
235 "sum-mer",
236 "thou-sand",
237 "to-day",
238 "train-ing",
239 "treat-ment",
240 "va-ri-ous",
241 "value",
242 "vi-o-lence",
243};
244
245namespace {
246struct HyphenNetwork {
247 /// The execution context.
248 PlaceholderBindings bindings_;
249
250 /// The input variable is N x 6 x 27 as encoded by mapLetterWindow().
251 Placeholder *input_;
252
253 /// The expected output index when training: 0 = no hyphen, 1 = hyphen.
254 Placeholder *expected_;
255
256 /// The forward inference function.
257 Function *infer_;
258
259 /// The result of the forward inference. N x 1 float with a probability.
260 SaveNode *result_;
261
262 /// The corresponding gradient function for training.
263 Function *train_;
264
265 HyphenNetwork(Module &mod, TrainingConfig &conf)
266 : input_(mod.createPlaceholder(ElemKind::FloatTy, {conf.batchSize, 6, 27},
267 "input", false)),
268 expected_(mod.createPlaceholder(ElemKind::Int64ITy, {conf.batchSize, 1},
269 "expected", false)),
270 infer_(mod.createFunction("infer")), result_(nullptr), train_(nullptr) {
271 bindings_.allocate(input_);
272 bindings_.allocate(expected_);
273 Node *n;
274
275 n = infer_->createFullyConnected(bindings_, "hidden_fc", input_, 10);
276 n = infer_->createRELU("hidden", n);
277 n = infer_->createFullyConnected(bindings_, "output_fc", n, 2);
278 n = infer_->createSoftMax("output", n, expected_);
279 result_ = infer_->createSave("result", n);
280 bindings_.allocate(result_->getPlaceholder());
281 train_ = glow::differentiate(infer_, conf);
282 }
283
284 // Run `inputs` through the inference function and check the results against
285 // `hyphens`. Return the number of errors.
286 unsigned inferenceErrors(ExecutionEngine &EE, llvm::StringRef fName,
287 Tensor &inputs, const vector<bool> &hyphens,
288 TrainingConfig &TC) {
289 dim_t batchSize = TC.batchSize;
290 dim_t numSamples = inputs.dims()[0];
291 EXPECT_LE(batchSize, numSamples);
292 auto resultHandle =
293 bindings_.get(bindings_.getPlaceholderByNameSlow("result"))
294 ->getHandle<>();
295 unsigned errors = 0;
296
297 for (dim_t bi = 0; bi < numSamples; bi += batchSize) {
298 // Get a batch-sized slice of inputs and run them through the inference
299 // function. Do a bit of overlapping if the batch size doesn't divide the
300 // number of samples.
301 if (bi + batchSize > numSamples) {
302 bi = numSamples - batchSize;
303 }
304 auto batchInputs = inputs.getUnowned({batchSize, 6, 27}, {bi, 0, 0});
305 updateInputPlaceholders(bindings_, {input_}, {&batchInputs});
306 EE.run(bindings_, fName);
307
308 // Check each output in the batch.
309 for (dim_t i = 0; i != batchSize; i++) {
310 // Note that the two softmax outputs always sum to 1, so we only look at
311 // one.
312 float value = resultHandle.at({i, 1});
313 if ((value > 0.5) != hyphens[bi + i]) {
314 errors++;
315 }
316 }
317 }
318 return errors;
319 }
320};
321} // namespace
322
323TEST(HyphenTest, network) {
324 ExecutionEngine EE("CPU");
325
326 // Convert the training data to word windows and labels.
327 vector<string> words;
328 vector<bool> hyphens;
329 for (auto *hword : TrainingData) {
330 dehyphenate(hword, words, hyphens);
331 }
332
333 // This depends on the training data, of course.
334 const dim_t numSamples = 566;
335 ASSERT_EQ(hyphens.size(), numSamples);
336 ASSERT_EQ(words.size(), numSamples);
337
338 // Randomly shuffle the training data.
339 // This is required for stochastic gradient descent training.
340 auto &PRNG = EE.getModule().getPRNG();
341 for (size_t i = numSamples - 1; i > 0; i--) {
342 size_t j = PRNG.nextRandInt(0, i);
343 std::swap(words[i], words[j]);
344 std::swap(hyphens[i], hyphens[j]);
345 }
346
347 // Convert words and hyphens to a tensor representation.
348 Tensor inputs(ElemKind::FloatTy, {numSamples, 6, 27});
349 Tensor expected(ElemKind::Int64ITy, {numSamples, 1});
350 inputs.zero();
351 auto inputHandle = inputs.getHandle<float>();
352 auto expectedHandle = expected.getHandle<int64_t>();
353 for (dim_t i = 0; i != numSamples; i++) {
354 mapLetterWindow(words[i], i, inputHandle);
355 expectedHandle.at({i, 0}) = hyphens[i];
356 }
357
358 // Now build the network.
359 TrainingConfig TC;
360 TC.learningRate = 0.8;
361 TC.batchSize = 50;
362 HyphenNetwork net(EE.getModule(), TC);
363 auto fName = net.infer_->getName();
364 auto tfName = net.train_->getName();
365
366 // This variable records the number of the next sample to be used for
367 // training.
368 size_t sampleCounter = 0;
369
370 // Train using mini-batch SGD.
371 EE.compile(CompilationMode::Train);
372 runBatch(EE, net.bindings_, 1000, sampleCounter, {net.input_, net.expected_},
373 {&inputs, &expected}, tfName);
374
375 // Now test inference on the trained network.
376 // Note that we have probably overfitted the data, so we expect 100% accuracy.
377 EXPECT_EQ(net.inferenceErrors(EE, fName, inputs, hyphens, TC), 0);
378
379 // See of the interpreter gets the same result.
380
381 ExecutionEngine EE2("CPU");
382 HyphenNetwork netInterpreter(EE2.getModule(), TC);
383 EE2.compile(CompilationMode::Train);
384 // Copy the trained weights from the CPU run.
385 net.bindings_.copyToTarget("bias", netInterpreter.bindings_);
386 net.bindings_.copyToTarget("bias__1", netInterpreter.bindings_);
387 net.bindings_.copyToTarget("weights", netInterpreter.bindings_);
388 net.bindings_.copyToTarget("weights__1", netInterpreter.bindings_);
389
390 EXPECT_EQ(netInterpreter.inferenceErrors(EE2, fName, inputs, hyphens, TC), 0);
391}
392