1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
17 | #include "glow/Graph/Graph.h" |
18 | #include "glow/IR/IR.h" |
19 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
20 | #include "glow/Support/Support.h" |
21 | |
22 | #include "llvm/Support/CommandLine.h" |
23 | #include "llvm/Support/Format.h" |
24 | #include "llvm/Support/MemoryBuffer.h" |
25 | #include "llvm/Support/Timer.h" |
26 | |
27 | #include <glog/logging.h> |
28 | |
29 | #include <string> |
30 | |
31 | //----------------------------------------------------------------------------// |
32 | // This is a small program that's based on Andrej's char-rnn generator. This is |
33 | // a small RNN-based neural network that's used to generate random text after |
34 | // analyzing some other text. The network is described here: |
35 | // http://karpathy.github.io/2015/05/21/rnn-effectiveness/ |
36 | //----------------------------------------------------------------------------// |
37 | |
38 | using namespace glow; |
39 | using llvm::format; |
40 | |
41 | namespace { |
42 | llvm::cl::OptionCategory category("char-rnn Options" ); |
43 | static llvm::cl::opt<std::string> inputFilename(llvm::cl::desc("input file" ), |
44 | llvm::cl::init("-" ), |
45 | llvm::cl::Positional, |
46 | llvm::cl::cat(category)); |
47 | |
48 | llvm::cl::opt<std::string> executionBackend( |
49 | "backend" , |
50 | llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:" ), |
51 | llvm::cl::Optional, llvm::cl::init("Interpreter" ), llvm::cl::cat(category)); |
52 | |
53 | llvm::cl::opt<unsigned> numEpochs("epochs" , |
54 | llvm::cl::desc("Process the input N times." ), |
55 | llvm::cl::init(4), llvm::cl::value_desc("N" ), |
56 | llvm::cl::cat(category)); |
57 | llvm::cl::opt<unsigned> |
58 | generateChars("chars" , llvm::cl::desc("Generate this number of chars." ), |
59 | llvm::cl::init(10), llvm::cl::value_desc("N" ), |
60 | llvm::cl::cat(category)); |
61 | |
62 | } // namespace |
63 | |
64 | /// Clip the value \p c to the range 0..127, which is standard ascii. |
65 | static size_t clipASCII(char c) { |
66 | size_t c1 = c; |
67 | if (c1 > 127) { |
68 | c1 = 127; |
69 | } |
70 | return c1; |
71 | } |
72 | |
73 | /// Load text into \p inputText that has the format [B, S, 128], where B is |
74 | /// the batch size, S is the length of the sentence, and 128 is the one-hot |
75 | /// representation of the text (https://en.wikipedia.org/wiki/One-hot). |
76 | /// Load the expected index into \p nextChar that has the format [B, S], where |
77 | /// each element is the softmax index of the next char. If \p train is false |
78 | /// then only load the first slice of inputText. |
79 | static void loadText(Tensor &inputText, Tensor &nextChar, llvm::StringRef text, |
80 | bool train) { |
81 | DCHECK_GT(text.size(), 2) << "The buffer must contain at least two chars" ; |
82 | inputText.zero(); |
83 | nextChar.zero(); |
84 | |
85 | auto idim = inputText.dims(); |
86 | DCHECK_EQ(idim.size(), 3) << "invalid input tensor" ; |
87 | auto B = idim[0]; |
88 | auto S = idim[1]; |
89 | |
90 | auto IH = inputText.getHandle(); |
91 | auto NH = nextChar.getHandle<int64_t>(); |
92 | |
93 | // Fill the tensor with slices from the sentence with an offset of 1. |
94 | // Example: |
95 | // |Hell|o| World |
96 | // |ello| |World |
97 | // |llo |W|orld |
98 | // |lo W|o|rld |
99 | for (dim_t i = 0; i < B; i++) { |
100 | for (dim_t j = 0; j < S; j++) { |
101 | dim_t c = clipASCII(text[i + j]); |
102 | |
103 | IH.at({i, j, c}) = 1.0; |
104 | if (train) { |
105 | size_t c1 = clipASCII(text[i + j + 1]); |
106 | NH.at({i, j}) = c1; |
107 | } |
108 | } |
109 | |
110 | // Only load the first slice in the batch when in inference mode. |
111 | if (!train) { |
112 | return; |
113 | } |
114 | } |
115 | } |
116 | |
117 | PseudoRNG &getRNG() { |
118 | static PseudoRNG RNG; |
119 | |
120 | return RNG; |
121 | } |
122 | |
123 | /// This method selects a random number based on a softmax distribution. One |
124 | /// property of this distribution is that the sum of all probabilities is equal |
125 | /// to one. The algorithm that we use here picks a random number between zero |
126 | /// and one. Then, we scan the tensor and accumulate the probabilities. We stop |
127 | /// and pick the index when sum is greater than the selected random number. |
128 | static char getPredictedChar(Tensor &inputText, dim_t slice, dim_t word) { |
129 | auto IH = inputText.getHandle(); |
130 | |
131 | // Pick a random number between zero and one. |
132 | double x = std::abs(getRNG().nextRand()); |
133 | double sum = 0; |
134 | // Accumulate the probabilities into 'sum'. |
135 | for (dim_t i = 0; i < 128; i++) { |
136 | sum += IH.at({slice, word, i}); |
137 | // As soon as we cross the threshold return the index. |
138 | if (sum > x) { |
139 | return i; |
140 | } |
141 | } |
142 | return 127; |
143 | } |
144 | |
145 | /// Loads the content of a file or stdin to a memory buffer. |
146 | /// The default filename of "-" reads from stdin. |
147 | static std::unique_ptr<llvm::MemoryBuffer> loadFile(llvm::StringRef filename) { |
148 | llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileBufOrErr = |
149 | llvm::MemoryBuffer::getFileOrSTDIN(filename); |
150 | if (!fileBufOrErr) { |
151 | LOG(ERROR) << "Error! Failed to open file: " << filename.str() << "\n" ; |
152 | LOG(ERROR) << fileBufOrErr.getError().message() << "\n" ; |
153 | exit(-1); |
154 | } |
155 | |
156 | return std::move(fileBufOrErr.get()); |
157 | } |
158 | |
159 | /// Creates a new RNN network. The network answers the question, given N chars |
160 | /// of input, what is the character following each one of these chars. |
161 | static Function *createNetwork(Module &mod, PlaceholderBindings &bindings, |
162 | dim_t minibatchSize, dim_t numSteps, |
163 | dim_t hiddenSize) { |
164 | Function *F = mod.createFunction("main" ); |
165 | |
166 | auto *X = mod.createPlaceholder( |
167 | ElemKind::FloatTy, {minibatchSize, numSteps, 128}, "input" , false); |
168 | bindings.allocate(X); |
169 | |
170 | auto *Y = mod.createPlaceholder(ElemKind::Int64ITy, {minibatchSize, numSteps}, |
171 | "expected" , false); |
172 | bindings.allocate(Y); |
173 | |
174 | std::vector<NodeValue> slicesX; |
175 | std::vector<Node *> expectedX; |
176 | |
177 | for (unsigned t = 0; t < numSteps; t++) { |
178 | auto XtName = "X." + std::to_string(t); |
179 | auto *Xt = |
180 | F->createSlice(XtName, X, {0, t, 0}, {minibatchSize, t + 1, 128}); |
181 | slicesX.push_back(Xt); |
182 | |
183 | auto YtName = "Y." + std::to_string(t); |
184 | auto *Yt = F->createSlice(YtName, Y, {0, t}, {minibatchSize, t + 1}); |
185 | expectedX.push_back(Yt); |
186 | } |
187 | |
188 | std::vector<NodeValue> outputNodes; |
189 | F->createLSTM(bindings, "rnn" , slicesX, minibatchSize, hiddenSize, 128, |
190 | outputNodes); |
191 | |
192 | std::vector<NodeValue> resX; |
193 | for (unsigned i = 0; i < numSteps; i++) { |
194 | auto *R = |
195 | F->createReshape("reshapeSelector" , expectedX[i], {minibatchSize, 1}); |
196 | auto *SM = F->createSoftMax("softmax" , outputNodes[i], R); |
197 | auto *K = F->createReshape("reshapeSM" , SM, {minibatchSize, 1, 128}); |
198 | resX.push_back(K); |
199 | } |
200 | |
201 | Node *O = F->createConcat("output" , resX, 1); |
202 | auto *S = F->createSave("result" , O); |
203 | bindings.allocate(S->getPlaceholder()); |
204 | |
205 | return F; |
206 | } |
207 | |
208 | int main(int argc, char **argv) { |
209 | llvm::cl::ParseCommandLineOptions(argc, argv, " The char-rnn test\n\n" ); |
210 | auto mb = loadFile(inputFilename); |
211 | auto text = mb.get()->getBuffer(); |
212 | LOG(INFO) << "Loaded " << text.size() << " chars.\n" ; |
213 | PlaceholderBindings inferBindings, trainingBindings; |
214 | |
215 | const dim_t numSteps = 50; |
216 | const dim_t minibatchSize = 32; |
217 | const dim_t batchSize = text.size() - numSteps; |
218 | const dim_t hiddenSize = 256; |
219 | |
220 | CHECK_GT(text.size(), numSteps) << "Text is too short" ; |
221 | TrainingConfig TC; |
222 | |
223 | ExecutionEngine EET(executionBackend); |
224 | TC.learningRate = 0.001; |
225 | TC.momentum = 0.9; |
226 | TC.batchSize = minibatchSize; |
227 | |
228 | auto &modT = EET.getModule(); |
229 | |
230 | //// Train the network //// |
231 | Function *F2 = createNetwork(modT, trainingBindings, minibatchSize, numSteps, |
232 | hiddenSize); |
233 | differentiate(F2, TC); |
234 | EET.compile(CompilationMode::Train); |
235 | trainingBindings.allocate(modT.getPlaceholders()); |
236 | |
237 | auto *XT = modT.getPlaceholderByNameSlow("input" ); |
238 | auto *YT = modT.getPlaceholderByNameSlow("expected" ); |
239 | |
240 | Tensor thisCharTrain(ElemKind::FloatTy, {batchSize, numSteps, 128}); |
241 | Tensor nextCharTrain(ElemKind::Int64ITy, {batchSize, numSteps}); |
242 | loadText(thisCharTrain, nextCharTrain, text, true); |
243 | |
244 | // This variable records the number of the next sample to be used for |
245 | // training. |
246 | size_t sampleCounter = 0; |
247 | |
248 | // Run this number of iterations over the input. On each iteration: train the |
249 | // network on the whole input and then generate some sample text. |
250 | for (unsigned i = 0; i < numEpochs; i++) { |
251 | |
252 | // Train the network on the whole input. |
253 | LOG(INFO) << "Iteration " << i + 1 << "/" << numEpochs; |
254 | runBatch(EET, trainingBindings, batchSize / minibatchSize, sampleCounter, |
255 | {XT, YT}, {&thisCharTrain, &nextCharTrain}); |
256 | |
257 | ExecutionEngine EEO(executionBackend); |
258 | inferBindings.clear(); |
259 | auto &mod = EEO.getModule(); |
260 | auto OF = |
261 | createNetwork(mod, inferBindings, minibatchSize, numSteps, hiddenSize); |
262 | auto *X = mod.getPlaceholderByNameSlow("input" ); |
263 | inferBindings.allocate(mod.getPlaceholders()); |
264 | trainingBindings.copyTrainableWeightsTo(inferBindings); |
265 | |
266 | //// Use the trained network to generate some text //// |
267 | auto *res = |
268 | llvm::cast<SaveNode>(OF->getNodeByName("result" ))->getPlaceholder(); |
269 | // Promote placeholders to constants. |
270 | ::glow::convertPlaceholdersToConstants(OF, inferBindings, {X, res}); |
271 | EEO.compile(CompilationMode::Infer); |
272 | |
273 | // Load a few characters to start the text that we generate. |
274 | Tensor currCharInfer(ElemKind::FloatTy, {minibatchSize, numSteps, 128}); |
275 | Tensor nextCharInfer(ElemKind::Int64ITy, {minibatchSize, numSteps}); |
276 | loadText(currCharInfer, nextCharInfer, text.slice(0, 128), false); |
277 | |
278 | auto *T = inferBindings.get(res); |
279 | std::string result; |
280 | std::string input; |
281 | input.insert(input.begin(), text.begin(), text.begin() + numSteps); |
282 | result = input; |
283 | |
284 | // Generate a sentence by running inference over and over again. |
285 | for (unsigned i = 0; i < generateChars; i++) { |
286 | // Generate a char: |
287 | updateInputPlaceholders(inferBindings, {X}, {&currCharInfer}); |
288 | EEO.run(inferBindings); |
289 | |
290 | // Pick a char at random from the softmax distribution. |
291 | char c = getPredictedChar(*T, 0, numSteps - 1); |
292 | |
293 | // Update the inputs for the next iteration: |
294 | result.push_back(c); |
295 | input.push_back(c); |
296 | input.erase(input.begin()); |
297 | loadText(currCharInfer, nextCharInfer, input, false); |
298 | } |
299 | |
300 | llvm::outs() << "Generated output:\n" << result << "\n" ; |
301 | } |
302 | |
303 | return 0; |
304 | } |
305 | |