1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/ExecutionEngine/ExecutionEngine.h"
17#include "glow/Graph/Graph.h"
18#include "glow/IR/IR.h"
19#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
20#include "glow/Support/Support.h"
21
22#include "llvm/Support/CommandLine.h"
23#include "llvm/Support/Format.h"
24#include "llvm/Support/MemoryBuffer.h"
25#include "llvm/Support/Timer.h"
26
27#include <glog/logging.h>
28
29#include <string>
30
31//----------------------------------------------------------------------------//
32// This is a small program that's based on Andrej's char-rnn generator. This is
33// a small RNN-based neural network that's used to generate random text after
34// analyzing some other text. The network is described here:
35// http://karpathy.github.io/2015/05/21/rnn-effectiveness/
36//----------------------------------------------------------------------------//
37
38using namespace glow;
39using llvm::format;
40
41namespace {
42llvm::cl::OptionCategory category("char-rnn Options");
43static llvm::cl::opt<std::string> inputFilename(llvm::cl::desc("input file"),
44 llvm::cl::init("-"),
45 llvm::cl::Positional,
46 llvm::cl::cat(category));
47
48llvm::cl::opt<std::string> executionBackend(
49 "backend",
50 llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:"),
51 llvm::cl::Optional, llvm::cl::init("Interpreter"), llvm::cl::cat(category));
52
53llvm::cl::opt<unsigned> numEpochs("epochs",
54 llvm::cl::desc("Process the input N times."),
55 llvm::cl::init(4), llvm::cl::value_desc("N"),
56 llvm::cl::cat(category));
57llvm::cl::opt<unsigned>
58 generateChars("chars", llvm::cl::desc("Generate this number of chars."),
59 llvm::cl::init(10), llvm::cl::value_desc("N"),
60 llvm::cl::cat(category));
61
62} // namespace
63
64/// Clip the value \p c to the range 0..127, which is standard ascii.
65static size_t clipASCII(char c) {
66 size_t c1 = c;
67 if (c1 > 127) {
68 c1 = 127;
69 }
70 return c1;
71}
72
73/// Load text into \p inputText that has the format [B, S, 128], where B is
74/// the batch size, S is the length of the sentence, and 128 is the one-hot
75/// representation of the text (https://en.wikipedia.org/wiki/One-hot).
76/// Load the expected index into \p nextChar that has the format [B, S], where
77/// each element is the softmax index of the next char. If \p train is false
78/// then only load the first slice of inputText.
79static void loadText(Tensor &inputText, Tensor &nextChar, llvm::StringRef text,
80 bool train) {
81 DCHECK_GT(text.size(), 2) << "The buffer must contain at least two chars";
82 inputText.zero();
83 nextChar.zero();
84
85 auto idim = inputText.dims();
86 DCHECK_EQ(idim.size(), 3) << "invalid input tensor";
87 auto B = idim[0];
88 auto S = idim[1];
89
90 auto IH = inputText.getHandle();
91 auto NH = nextChar.getHandle<int64_t>();
92
93 // Fill the tensor with slices from the sentence with an offset of 1.
94 // Example:
95 // |Hell|o| World
96 // |ello| |World
97 // |llo |W|orld
98 // |lo W|o|rld
99 for (dim_t i = 0; i < B; i++) {
100 for (dim_t j = 0; j < S; j++) {
101 dim_t c = clipASCII(text[i + j]);
102
103 IH.at({i, j, c}) = 1.0;
104 if (train) {
105 size_t c1 = clipASCII(text[i + j + 1]);
106 NH.at({i, j}) = c1;
107 }
108 }
109
110 // Only load the first slice in the batch when in inference mode.
111 if (!train) {
112 return;
113 }
114 }
115}
116
117PseudoRNG &getRNG() {
118 static PseudoRNG RNG;
119
120 return RNG;
121}
122
123/// This method selects a random number based on a softmax distribution. One
124/// property of this distribution is that the sum of all probabilities is equal
125/// to one. The algorithm that we use here picks a random number between zero
126/// and one. Then, we scan the tensor and accumulate the probabilities. We stop
127/// and pick the index when sum is greater than the selected random number.
128static char getPredictedChar(Tensor &inputText, dim_t slice, dim_t word) {
129 auto IH = inputText.getHandle();
130
131 // Pick a random number between zero and one.
132 double x = std::abs(getRNG().nextRand());
133 double sum = 0;
134 // Accumulate the probabilities into 'sum'.
135 for (dim_t i = 0; i < 128; i++) {
136 sum += IH.at({slice, word, i});
137 // As soon as we cross the threshold return the index.
138 if (sum > x) {
139 return i;
140 }
141 }
142 return 127;
143}
144
145/// Loads the content of a file or stdin to a memory buffer.
146/// The default filename of "-" reads from stdin.
147static std::unique_ptr<llvm::MemoryBuffer> loadFile(llvm::StringRef filename) {
148 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileBufOrErr =
149 llvm::MemoryBuffer::getFileOrSTDIN(filename);
150 if (!fileBufOrErr) {
151 LOG(ERROR) << "Error! Failed to open file: " << filename.str() << "\n";
152 LOG(ERROR) << fileBufOrErr.getError().message() << "\n";
153 exit(-1);
154 }
155
156 return std::move(fileBufOrErr.get());
157}
158
159/// Creates a new RNN network. The network answers the question, given N chars
160/// of input, what is the character following each one of these chars.
161static Function *createNetwork(Module &mod, PlaceholderBindings &bindings,
162 dim_t minibatchSize, dim_t numSteps,
163 dim_t hiddenSize) {
164 Function *F = mod.createFunction("main");
165
166 auto *X = mod.createPlaceholder(
167 ElemKind::FloatTy, {minibatchSize, numSteps, 128}, "input", false);
168 bindings.allocate(X);
169
170 auto *Y = mod.createPlaceholder(ElemKind::Int64ITy, {minibatchSize, numSteps},
171 "expected", false);
172 bindings.allocate(Y);
173
174 std::vector<NodeValue> slicesX;
175 std::vector<Node *> expectedX;
176
177 for (unsigned t = 0; t < numSteps; t++) {
178 auto XtName = "X." + std::to_string(t);
179 auto *Xt =
180 F->createSlice(XtName, X, {0, t, 0}, {minibatchSize, t + 1, 128});
181 slicesX.push_back(Xt);
182
183 auto YtName = "Y." + std::to_string(t);
184 auto *Yt = F->createSlice(YtName, Y, {0, t}, {minibatchSize, t + 1});
185 expectedX.push_back(Yt);
186 }
187
188 std::vector<NodeValue> outputNodes;
189 F->createLSTM(bindings, "rnn", slicesX, minibatchSize, hiddenSize, 128,
190 outputNodes);
191
192 std::vector<NodeValue> resX;
193 for (unsigned i = 0; i < numSteps; i++) {
194 auto *R =
195 F->createReshape("reshapeSelector", expectedX[i], {minibatchSize, 1});
196 auto *SM = F->createSoftMax("softmax", outputNodes[i], R);
197 auto *K = F->createReshape("reshapeSM", SM, {minibatchSize, 1, 128});
198 resX.push_back(K);
199 }
200
201 Node *O = F->createConcat("output", resX, 1);
202 auto *S = F->createSave("result", O);
203 bindings.allocate(S->getPlaceholder());
204
205 return F;
206}
207
208int main(int argc, char **argv) {
209 llvm::cl::ParseCommandLineOptions(argc, argv, " The char-rnn test\n\n");
210 auto mb = loadFile(inputFilename);
211 auto text = mb.get()->getBuffer();
212 LOG(INFO) << "Loaded " << text.size() << " chars.\n";
213 PlaceholderBindings inferBindings, trainingBindings;
214
215 const dim_t numSteps = 50;
216 const dim_t minibatchSize = 32;
217 const dim_t batchSize = text.size() - numSteps;
218 const dim_t hiddenSize = 256;
219
220 CHECK_GT(text.size(), numSteps) << "Text is too short";
221 TrainingConfig TC;
222
223 ExecutionEngine EET(executionBackend);
224 TC.learningRate = 0.001;
225 TC.momentum = 0.9;
226 TC.batchSize = minibatchSize;
227
228 auto &modT = EET.getModule();
229
230 //// Train the network ////
231 Function *F2 = createNetwork(modT, trainingBindings, minibatchSize, numSteps,
232 hiddenSize);
233 differentiate(F2, TC);
234 EET.compile(CompilationMode::Train);
235 trainingBindings.allocate(modT.getPlaceholders());
236
237 auto *XT = modT.getPlaceholderByNameSlow("input");
238 auto *YT = modT.getPlaceholderByNameSlow("expected");
239
240 Tensor thisCharTrain(ElemKind::FloatTy, {batchSize, numSteps, 128});
241 Tensor nextCharTrain(ElemKind::Int64ITy, {batchSize, numSteps});
242 loadText(thisCharTrain, nextCharTrain, text, true);
243
244 // This variable records the number of the next sample to be used for
245 // training.
246 size_t sampleCounter = 0;
247
248 // Run this number of iterations over the input. On each iteration: train the
249 // network on the whole input and then generate some sample text.
250 for (unsigned i = 0; i < numEpochs; i++) {
251
252 // Train the network on the whole input.
253 LOG(INFO) << "Iteration " << i + 1 << "/" << numEpochs;
254 runBatch(EET, trainingBindings, batchSize / minibatchSize, sampleCounter,
255 {XT, YT}, {&thisCharTrain, &nextCharTrain});
256
257 ExecutionEngine EEO(executionBackend);
258 inferBindings.clear();
259 auto &mod = EEO.getModule();
260 auto OF =
261 createNetwork(mod, inferBindings, minibatchSize, numSteps, hiddenSize);
262 auto *X = mod.getPlaceholderByNameSlow("input");
263 inferBindings.allocate(mod.getPlaceholders());
264 trainingBindings.copyTrainableWeightsTo(inferBindings);
265
266 //// Use the trained network to generate some text ////
267 auto *res =
268 llvm::cast<SaveNode>(OF->getNodeByName("result"))->getPlaceholder();
269 // Promote placeholders to constants.
270 ::glow::convertPlaceholdersToConstants(OF, inferBindings, {X, res});
271 EEO.compile(CompilationMode::Infer);
272
273 // Load a few characters to start the text that we generate.
274 Tensor currCharInfer(ElemKind::FloatTy, {minibatchSize, numSteps, 128});
275 Tensor nextCharInfer(ElemKind::Int64ITy, {minibatchSize, numSteps});
276 loadText(currCharInfer, nextCharInfer, text.slice(0, 128), false);
277
278 auto *T = inferBindings.get(res);
279 std::string result;
280 std::string input;
281 input.insert(input.begin(), text.begin(), text.begin() + numSteps);
282 result = input;
283
284 // Generate a sentence by running inference over and over again.
285 for (unsigned i = 0; i < generateChars; i++) {
286 // Generate a char:
287 updateInputPlaceholders(inferBindings, {X}, {&currCharInfer});
288 EEO.run(inferBindings);
289
290 // Pick a char at random from the softmax distribution.
291 char c = getPredictedChar(*T, 0, numSteps - 1);
292
293 // Update the inputs for the next iteration:
294 result.push_back(c);
295 input.push_back(c);
296 input.erase(input.begin());
297 loadText(currCharInfer, nextCharInfer, input, false);
298 }
299
300 llvm::outs() << "Generated output:\n" << result << "\n";
301 }
302
303 return 0;
304}
305