1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/ExecutionEngine/ExecutionEngine.h"
17#include "glow/Graph/Graph.h"
18#include "glow/IR/IR.h"
19#include "glow/Support/Support.h"
20
21#include "llvm/Support/CommandLine.h"
22#include "llvm/Support/Format.h"
23#include "llvm/Support/Timer.h"
24
25#include <glog/logging.h>
26
27#include <algorithm>
28#include <fstream>
29#include <iomanip>
30#include <iostream>
31#include <map>
32#include <string>
33
34using namespace glow;
35using llvm::format;
36
37namespace {
38llvm::cl::OptionCategory ptbCat("PTB Options");
39llvm::cl::opt<std::string> executionBackend(
40 "backend",
41 llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:"),
42 llvm::cl::Optional, llvm::cl::init("Interpreter"), llvm::cl::cat(ptbCat));
43
44llvm::cl::opt<std::string> dumpInitialGraphDAGFileOpt(
45 "dumpInitialGraphDAG",
46 llvm::cl::desc(
47 "Specify the file to export the initial Graph in DOT format"),
48 llvm::cl::value_desc("file.dot"), llvm::cl::cat(ptbCat));
49
50llvm::cl::opt<std::string> dumpTrainingGraphDAGFileOpt(
51 "dumpTrainingGraphDAG",
52 llvm::cl::desc(
53 "Specify the file to export the training Graph in DOT format"),
54 llvm::cl::value_desc("file.dot"), llvm::cl::cat(ptbCat));
55
56} // namespace
57
58unsigned loadPTB(Tensor &inputWords, Tensor &targetWords, dim_t numSteps,
59 dim_t vocabSize, dim_t minibatchSize, dim_t maxNumWords) {
60
61 std::ifstream ptbInput("ptb/simple-examples/data/ptb.train.txt");
62 CHECK(ptbInput.is_open()) << "Error loading ptb.train.txt";
63
64 std::vector<std::string> words;
65 std::string line;
66
67 while (getline(ptbInput, line)) {
68 std::istringstream ss(line);
69 std::string token;
70 while (getline(ss, token, ' ')) {
71 if (!token.empty()) {
72 words.push_back(token);
73 }
74 }
75 words.push_back("<eos>");
76 }
77 ptbInput.close();
78
79 // We limit the number of words to 50,000 otherwise things will be slower.
80 words = std::vector<std::string>(words.begin(), words.begin() + maxNumWords);
81 size_t numWords = words.size();
82
83 CHECK_GT(numWords, 0) << "No words were found.";
84
85 std::map<std::string, int> counter;
86 // Counter of words occurences in the input text
87 for (auto word : words) {
88 if (counter.find(word) == counter.end()) {
89 counter[word] = 0;
90 }
91 counter[word] += 1;
92 }
93
94 // Sort the counter
95 std::vector<std::pair<std::string, int>> counters(counter.begin(),
96 counter.end());
97
98 sort(counters.begin(), counters.end(),
99 [](const std::pair<std::string, int> &lhs,
100 const std::pair<std::string, int> &rhs) {
101 if (lhs.second == rhs.second) {
102 return rhs.first > lhs.first;
103 }
104 return lhs.second > rhs.second;
105 });
106
107 // Build the word to id map
108 std::map<std::string, int> wordToId;
109 for (unsigned i = 0; i < counters.size(); i++) {
110 auto const &word = counters[i].first;
111 wordToId[word] = std::min<size_t>(i, vocabSize - 1);
112 }
113
114 // Load the PTB database into two 3d tensors for word inputs and targets.
115 dim_t batchLength = numWords / minibatchSize;
116 dim_t numBatches = (batchLength - 1) / numSteps;
117 dim_t numSequences = minibatchSize * numBatches;
118
119 // While we dont have embedding, we are using one-hot encoding to represent
120 // input words. To limit the size of the data we use an upper bound on the
121 // vocabulary size.
122 inputWords.reset(ElemKind::FloatTy, {numSequences, vocabSize * numSteps});
123 targetWords.reset(ElemKind::Int64ITy, {numSequences, numSteps});
124 auto IIH = inputWords.getHandle<>();
125 auto TIH = targetWords.getHandle<int64_t>();
126 for (unsigned batch = 0; batch < minibatchSize; batch++) {
127 for (unsigned iter = 0; iter < numBatches; iter++) {
128 dim_t sequence = batch + iter * minibatchSize;
129 for (unsigned step = 0; step < numSteps; step++) {
130 int wordCounterId = step + iter * numSteps + batch * batchLength;
131 const std::string word1 = words[wordCounterId];
132 const std::string word2 = words[wordCounterId + 1];
133 IIH.at({sequence, step * vocabSize + wordToId[word1]}) = 1;
134 TIH.at({sequence, step}) = wordToId[word2];
135 }
136 }
137 }
138 return numWords;
139}
140
141/// This test builds a RNN language model on the Penn TreeBank dataset.
142/// Results for RNN word-level perplexity are reported in
143/// https://arxiv.org/pdf/1409.2329.pdf Here we simplify the problem to be able
144/// to run it on a single CPU.
145/// The results were cross-checked with an equivalent tensorflow implementation
146/// as well as a Vanilla implementation inspired from Karpathy's Char-RNN code.
147/// Tensorflow https://gist.github.com/mcaounfb/7ba05b0a62383c36e24a33defa3f11aa
148/// Vanilla https://gist.github.com/mcaounfb/c4ee98bbddaa6f8505f283ac018f8c34
149///
150/// The results for the perplexity are expected to look as:
151///
152/// Iteration 1: 105.4579
153/// Iteration 2: 82.3274
154/// Iteration 4: 70.8094
155/// Iteration 6: 63.8546
156/// Iteration 8: 58.4330
157/// Iteration 10: 53.7943
158/// Iteration 12: 49.7214
159/// Iteration 14: 46.1715
160/// Iteration 16: 43.1474
161/// Iteration 18: 40.5605
162/// Iteration 20: 38.2837
163///
164/// For reference, we expect the usage of an LSTM instead of the current
165/// simple RNN block will improve the perplexity to ~20.
166void testPTB() {
167 LOG(INFO) << "Loading the ptb database.";
168
169 Tensor inputWords;
170 Tensor targetWords;
171
172 const dim_t minibatchSize = 10;
173 const dim_t numSteps = 10;
174 const dim_t numEpochs = 20;
175
176 const dim_t hiddenSize = 20;
177 const dim_t vocabSize = 500;
178 const dim_t maxNumWords = 10000;
179
180 float learningRate = .1;
181
182 unsigned numWords = loadPTB(inputWords, targetWords, numSteps, vocabSize,
183 minibatchSize, maxNumWords);
184 LOG(INFO) << "Loaded " << numWords << " words.";
185 ExecutionEngine EE(executionBackend);
186 PlaceholderBindings bindings;
187
188 // Construct the network:
189 TrainingConfig TC;
190 TC.learningRate = learningRate;
191 TC.momentum = 0;
192 TC.batchSize = minibatchSize;
193
194 auto &mod = EE.getModule();
195 Function *F = mod.createFunction("main");
196 LOG(INFO) << "Building";
197
198 auto *X = mod.createPlaceholder(
199 ElemKind::FloatTy, {minibatchSize, vocabSize * numSteps}, "input", false);
200 bindings.allocate(X);
201 auto *Y = mod.createPlaceholder(ElemKind::Int64ITy, {minibatchSize, numSteps},
202 "selected", false);
203 bindings.allocate(Y);
204
205 std::vector<NodeValue> slicesX;
206
207 for (unsigned t = 0; t < numSteps; t++) {
208 auto XtName = "X." + std::to_string(t);
209 auto *Xt = F->createSlice(XtName, X, {0, t * vocabSize},
210 {minibatchSize, (t + 1) * vocabSize});
211 slicesX.push_back(Xt);
212 }
213
214 std::vector<NodeValue> outputNodes;
215 F->createSimpleRNN(bindings, "rnn", slicesX, minibatchSize, hiddenSize,
216 vocabSize, outputNodes);
217
218 // O has a shape of {numSteps * minibatchSize, vocabSize}
219 Node *O = F->createConcat("output", outputNodes, 0);
220 // T has shape of {numSteps * minibatchSize, 1}
221 Node *TN = F->createTranspose("Y.transpose", Y, {1, 0});
222 Node *T = F->createReshape("Y.reshape", TN, {numSteps * minibatchSize, 1});
223
224 auto *SM = F->createSoftMax("softmax", O, T);
225 auto *save = F->createSave("result", SM);
226 auto *result = bindings.allocate(save->getPlaceholder());
227
228 if (!dumpInitialGraphDAGFileOpt.empty()) {
229 LOG(INFO) << "Dumping initial graph";
230 F->dumpDAG(dumpInitialGraphDAGFileOpt.c_str());
231 }
232
233 Function *TF = glow::differentiate(F, TC);
234 auto tfName = TF->getName();
235
236 EE.compile(CompilationMode::Train);
237 bindings.allocate(mod.getPlaceholders());
238
239 if (!dumpTrainingGraphDAGFileOpt.empty()) {
240 LOG(INFO) << "Dumping training graph";
241 TF->dumpDAG(dumpTrainingGraphDAGFileOpt.c_str());
242 }
243
244 size_t numBatches = (numWords / minibatchSize - 1) / numSteps;
245
246 LOG(INFO) << "Training for " << numBatches << " rounds";
247
248 float metricValues[numEpochs];
249
250 for (size_t iter = 0; iter < numEpochs; iter++) {
251 llvm::outs() << "Training - iteration #" << (iter + 1) << "\n";
252
253 llvm::Timer timer("Training", "Training");
254 timer.startTimer();
255
256 // Compute the perplexity over a few minibatches
257 float perplexity = 0;
258 size_t perplexityWordsCount = 0;
259
260 // This variable records the number of the next sample to be used for
261 // training.
262 size_t sampleCounter = 0;
263
264 for (unsigned batch = 0; batch < numBatches; batch++) {
265 Tensor inputWordsBatch(ElemKind::FloatTy,
266 {minibatchSize, vocabSize * numSteps});
267 inputWordsBatch.copyConsecutiveSlices(&inputWords, minibatchSize * batch);
268
269 Tensor targetWordsBatch(ElemKind::Int64ITy, {minibatchSize, numSteps});
270 targetWordsBatch.copyConsecutiveSlices(&targetWords,
271 minibatchSize * batch);
272
273 runBatch(EE, bindings, 1, sampleCounter, {X, Y},
274 {&inputWordsBatch, &targetWordsBatch}, tfName);
275 for (dim_t step = 0; step < numSteps; step++) {
276 for (unsigned int i = 0; i < minibatchSize; i++) {
277 auto T =
278 result->getHandle<float>().extractSlice(step * minibatchSize + i);
279 dim_t correct = targetWords.getHandle<int64_t>().at(
280 {minibatchSize * batch + i, step});
281 float soft_guess = -std::log(T.getHandle<float>().at({correct}));
282 perplexity += soft_guess;
283 perplexityWordsCount += 1;
284 }
285 }
286 if (batch % 10 == 1) {
287 llvm::outs() << "perplexity: "
288 << format("%0.4f",
289 std::exp(perplexity / perplexityWordsCount))
290 << "\n";
291 }
292 }
293 metricValues[iter] = std::exp(perplexity / perplexityWordsCount);
294 llvm::outs() << "perplexity: " << format("%0.4f", metricValues[iter])
295 << "\n\n";
296
297 timer.stopTimer();
298 }
299
300 llvm::outs() << "Perplexity scores in copy-pastable format:\n";
301 for (size_t iter = 0; iter < numEpochs; iter++) {
302 if (iter != 0 && iter % 2 == 0)
303 continue;
304 llvm::outs() << "/// Iteration " << iter + 1 << ": "
305 << format("%0.4f", metricValues[iter]) << "\n";
306 }
307 llvm::outs()
308 << "Note, that small 1E-4 error is considered acceptable and may "
309 << "be coming from fast math optimizations.\n";
310}
311
312int main(int argc, char **argv) {
313 llvm::cl::ParseCommandLineOptions(argc, argv, " The PTB test\n\n");
314 testPTB();
315
316 return 0;
317}
318