1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
17 | #include "glow/Graph/Graph.h" |
18 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
19 | #include "glow/Quantization/Quantization.h" |
20 | #include "glow/Quantization/Serialization.h" |
21 | |
22 | #include "llvm/Support/CommandLine.h" |
23 | #include "llvm/Support/Timer.h" |
24 | |
25 | #include <glog/logging.h> |
26 | |
27 | #include <algorithm> |
28 | #include <fstream> |
29 | #include <iostream> |
30 | #include <sstream> |
31 | #include <string> |
32 | #include <unordered_map> |
33 | #include <vector> |
34 | |
35 | using namespace glow; |
36 | |
37 | namespace { |
38 | /// Debugging options. |
39 | llvm::cl::OptionCategory debugCat("Glow Debugging Options" ); |
40 | |
41 | llvm::cl::opt<std::string> dumpGraphDAGFileOpt( |
42 | "dump-graph-DAG" , |
43 | llvm::cl::desc("Dump the graph to the given file in DOT format." ), |
44 | llvm::cl::value_desc("file.dot" ), llvm::cl::cat(debugCat)); |
45 | |
46 | /// Translator options. |
47 | llvm::cl::OptionCategory fr2enCat("French-to-English Translator Options" ); |
48 | |
49 | llvm::cl::opt<unsigned> batchSizeOpt( |
50 | "batchsize" , llvm::cl::desc("Process batches of N sentences at a time." ), |
51 | llvm::cl::init(1), llvm::cl::value_desc("N" ), llvm::cl::cat(fr2enCat)); |
52 | llvm::cl::alias batchSizeA("b" , llvm::cl::desc("Alias for -batchsize" ), |
53 | llvm::cl::aliasopt(batchSizeOpt), |
54 | llvm::cl::cat(fr2enCat)); |
55 | |
56 | llvm::cl::opt<bool> |
57 | timeOpt("time" , |
58 | llvm::cl::desc("Print timer data detailing how long it " |
59 | "takes for the program to execute translate phase. " |
60 | "This option will be useful if input is read from " |
61 | "the file directly." ), |
62 | llvm::cl::Optional, llvm::cl::cat(fr2enCat)); |
63 | |
64 | llvm::cl::opt<std::string> ExecutionBackend( |
65 | "backend" , |
66 | llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:" ), |
67 | llvm::cl::Optional, llvm::cl::init("Interpreter" ), llvm::cl::cat(fr2enCat)); |
68 | |
69 | /// Quantization options. |
70 | llvm::cl::OptionCategory quantizationCat("Quantization Options" ); |
71 | |
72 | llvm::cl::opt<std::string> dumpProfileFileOpt( |
73 | "dump-profile" , |
74 | llvm::cl::desc("Perform quantization profiling for a given graph " |
75 | "and dump result to the file." ), |
76 | llvm::cl::value_desc("profile.yaml" ), llvm::cl::Optional, |
77 | llvm::cl::cat(quantizationCat)); |
78 | |
79 | llvm::cl::opt<std::string> loadProfileFileOpt( |
80 | "load-profile" , |
81 | llvm::cl::desc("Load quantization profile file and quantize the graph" ), |
82 | llvm::cl::value_desc("profile.yaml" ), llvm::cl::Optional, |
83 | llvm::cl::cat(quantizationCat)); |
84 | |
85 | } // namespace |
86 | |
87 | const unsigned MAX_LENGTH = 10; |
88 | const unsigned EMBEDDING_SIZE = 256; |
89 | const unsigned HIDDEN_SIZE = EMBEDDING_SIZE * 3; |
90 | |
91 | /// Stores vocabulary of a language. Contains mapping from word to index and |
92 | /// vice versa. |
93 | struct Vocabulary { |
94 | std::vector<std::string> index2word_; |
95 | std::unordered_map<std::string, int64_t> word2index_; |
96 | |
97 | void addWord(llvm::StringRef word) { |
98 | word2index_[word.str()] = index2word_.size(); |
99 | index2word_.push_back(word.str()); |
100 | } |
101 | |
102 | Vocabulary() = default; |
103 | |
104 | void loadVocabularyFromFile(llvm::StringRef filename) { |
105 | std::ifstream file(filename.str()); |
106 | std::string word; |
107 | while (getline(file, word)) |
108 | addWord(word); |
109 | } |
110 | }; |
111 | |
112 | /// Loads tensor of floats from binary file. |
113 | void loadMatrixFromFile(llvm::StringRef filename, Tensor &result) { |
114 | std::ifstream file(filename.str(), std::ios::binary); |
115 | if (!file.read(result.getUnsafePtr(), result.size() * sizeof(float))) { |
116 | std::cout |
117 | << "Error reading file: " << filename.str() << '\n' |
118 | << "Need to be downloaded by calling:\n" |
119 | << "python ../glow/utils/download_datasets_and_models.py -d fr2en\n" ; |
120 | exit(1); |
121 | } |
122 | } |
123 | |
124 | /// Represents a single RNN model: encoder combined with decoder. |
125 | /// Stores vocabulary, compiled Graph (ready to be executed), and |
126 | /// few references to input/output Variables. |
127 | struct Model { |
128 | unsigned batchSize_; |
129 | ExecutionEngine EE_{ExecutionBackend}; |
130 | Function *F_; |
131 | Vocabulary en_, fr_; |
132 | Placeholder *input_; |
133 | Placeholder *seqLength_; |
134 | Placeholder *output_; |
135 | PlaceholderBindings bindings; |
136 | LoweredInfoMap loweredMap_; |
137 | |
138 | void loadLanguages(); |
139 | void loadEncoder(); |
140 | void loadDecoder(); |
141 | void translate(const std::vector<std::string> &batch); |
142 | |
143 | Model(unsigned batchSize) : batchSize_(batchSize) { |
144 | F_ = EE_.getModule().createFunction("main" ); |
145 | } |
146 | |
147 | void dumpGraphDAG(const char *filename) { F_->dumpDAG(filename); } |
148 | |
149 | void compile() { |
150 | CompilationContext cctx{&bindings, &loweredMap_}; |
151 | PrecisionConfiguration &precConfig = cctx.precisionConfig; |
152 | |
153 | ::glow::convertPlaceholdersToConstants(F_, bindings, |
154 | {input_, seqLength_, output_}); |
155 | |
156 | if (!dumpProfileFileOpt.empty()) { |
157 | precConfig.quantMode = QuantizationMode::Profile; |
158 | } |
159 | |
160 | // Load the quantization profile and transform the graph. |
161 | if (!loadProfileFileOpt.empty()) { |
162 | precConfig.quantMode = QuantizationMode::Quantize; |
163 | deserializeProfilingInfosFromYaml( |
164 | loadProfileFileOpt, precConfig.quantConfig.graphPreLowerHash, |
165 | precConfig.quantConfig.infos); |
166 | precConfig.quantConfig.assertAllNodesQuantized = true; |
167 | } |
168 | |
169 | EE_.compile(cctx); |
170 | |
171 | // After compilation, the original function may be removed/replaced. Need to |
172 | // update F_. |
173 | F_ = EE_.getModule().getFunctions().front(); |
174 | } |
175 | |
176 | private: |
177 | Placeholder *embedding_fr_, *embedding_en_; |
178 | Node *encoderHiddenOutput_; |
179 | |
180 | Placeholder *loadEmbedding(llvm::StringRef langPrefix, dim_t langSize) { |
181 | auto &mod = EE_.getModule(); |
182 | auto *result = |
183 | mod.createPlaceholder(ElemKind::FloatTy, {langSize, EMBEDDING_SIZE}, |
184 | "embedding." + langPrefix.str(), false); |
185 | loadMatrixFromFile("fr2en/" + langPrefix.str() + "_embedding.bin" , |
186 | *bindings.allocate(result)); |
187 | |
188 | return result; |
189 | } |
190 | |
191 | Node *createPyTorchGRUCell(Function *G, Node *input, Node *hidden, |
192 | Placeholder *wIh, Placeholder *bIh, |
193 | Placeholder *wHh, Placeholder *bHh) { |
194 | // reference implementation: |
195 | // https://github.com/pytorch/pytorch/blob/dd5c195646b941d3e20a72847ac48c41e272b8b2/torch/nn/_functions/rnn.py#L46 |
196 | Node *gi = G->createFullyConnected("pytorch.GRU.gi" , input, wIh, bIh); |
197 | Node *gh = G->createFullyConnected("pytorch.GRU.gh" , hidden, wHh, bHh); |
198 | |
199 | Node *i_r = G->createSlice("pytorch.GRU.i_r" , gi, {0, 0}, |
200 | {batchSize_, EMBEDDING_SIZE}); |
201 | Node *i_i = G->createSlice("pytorch.GRU.i_i" , gi, {0, EMBEDDING_SIZE}, |
202 | {batchSize_, 2 * EMBEDDING_SIZE}); |
203 | Node *i_n = G->createSlice("pytorch.GRU.i_n" , gi, {0, 2 * EMBEDDING_SIZE}, |
204 | {batchSize_, 3 * EMBEDDING_SIZE}); |
205 | |
206 | Node *h_r = G->createSlice("pytorch.GRU.h_r" , gh, {0, 0}, |
207 | {batchSize_, EMBEDDING_SIZE}); |
208 | Node *h_i = G->createSlice("pytorch.GRU.h_i" , gh, {0, EMBEDDING_SIZE}, |
209 | {batchSize_, 2 * EMBEDDING_SIZE}); |
210 | Node *h_n = G->createSlice("pytorch.GRU.h_n" , gh, {0, 2 * EMBEDDING_SIZE}, |
211 | {batchSize_, 3 * EMBEDDING_SIZE}); |
212 | |
213 | Node *resetgate = G->createSigmoid("pytorch.GRU.resetgate" , |
214 | G->createAdd("i_r_plus_h_r" , i_r, h_r)); |
215 | Node *inputgate = G->createSigmoid("pytorch.GRU.inputgate" , |
216 | G->createAdd("i_i_plus_h_i" , i_i, h_i)); |
217 | Node *newgate = G->createTanh( |
218 | "pytorch.GRU.newgate" , |
219 | G->createAdd("i_n_plus_rg_mult_h_n" , i_n, |
220 | G->createMul("rg_mult_h_n" , resetgate, h_n))); |
221 | return G->createAdd( |
222 | "pytorch.GRU.hy" , newgate, |
223 | G->createMul("ig_mult_hmng" , inputgate, |
224 | G->createSub("hidden_minus_newgate" , hidden, newgate))); |
225 | } |
226 | }; |
227 | |
228 | void Model::loadLanguages() { |
229 | fr_.loadVocabularyFromFile("fr2en/fr_vocabulary.txt" ); |
230 | en_.loadVocabularyFromFile("fr2en/en_vocabulary.txt" ); |
231 | embedding_fr_ = loadEmbedding("fr" , fr_.index2word_.size()); |
232 | embedding_en_ = loadEmbedding("en" , en_.index2word_.size()); |
233 | } |
234 | |
235 | /// Model part representing Encoder. Remembers input sentence into hidden layer. |
236 | /// \p input is Variable representing the sentence. |
237 | /// \p seqLength is Variable representing the length of sentence. |
238 | /// \p encoderHiddenOutput saves resulting hidden layer. |
239 | void Model::loadEncoder() { |
240 | auto &mod = EE_.getModule(); |
241 | input_ = mod.createPlaceholder(ElemKind::Int64ITy, {batchSize_, MAX_LENGTH}, |
242 | "encoder.inputsentence" , false); |
243 | bindings.allocate(input_); |
244 | seqLength_ = mod.createPlaceholder(ElemKind::Int64ITy, {batchSize_}, |
245 | "encoder.seqLength" , false); |
246 | bindings.allocate(seqLength_); |
247 | |
248 | auto *hiddenInit = |
249 | mod.createPlaceholder(ElemKind::FloatTy, {batchSize_, EMBEDDING_SIZE}, |
250 | "encoder.hiddenInit" , false); |
251 | auto *hiddenInitTensor = bindings.allocate(hiddenInit); |
252 | hiddenInitTensor->zero(); |
253 | |
254 | Node *hidden = hiddenInit; |
255 | |
256 | auto *wIh = mod.createPlaceholder( |
257 | ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "encoder.w_ih" , false); |
258 | auto *bIh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE}, |
259 | "encoder.b_ih" , false); |
260 | auto *wHh = mod.createPlaceholder( |
261 | ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "encoder.w_hh" , false); |
262 | auto *bHh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE}, |
263 | "encoder.b_hh" , false); |
264 | |
265 | loadMatrixFromFile("fr2en/encoder_w_ih.bin" , *bindings.allocate(wIh)); |
266 | loadMatrixFromFile("fr2en/encoder_b_ih.bin" , *bindings.allocate(bIh)); |
267 | loadMatrixFromFile("fr2en/encoder_w_hh.bin" , *bindings.allocate(wHh)); |
268 | loadMatrixFromFile("fr2en/encoder_b_hh.bin" , *bindings.allocate(bHh)); |
269 | |
270 | Node *inputEmbedded = |
271 | F_->createGather("encoder.embedding" , embedding_fr_, input_); |
272 | |
273 | // TODO: encoder does exactly MAX_LENGTH steps, while input size is smaller. |
274 | // We could use control flow here. |
275 | std::vector<NodeValue> outputs; |
276 | for (unsigned step = 0; step < MAX_LENGTH; step++) { |
277 | Node *inputSlice = F_->createSlice( |
278 | "encoder." + std::to_string(step) + ".inputSlice" , inputEmbedded, |
279 | {0, step, 0}, {batchSize_, step + 1, EMBEDDING_SIZE}); |
280 | Node *reshape = |
281 | F_->createReshape("encoder." + std::to_string(step) + ".reshape" , |
282 | inputSlice, {batchSize_, EMBEDDING_SIZE}, ANY_LAYOUT); |
283 | hidden = createPyTorchGRUCell(F_, reshape, hidden, wIh, bIh, wHh, bHh); |
284 | outputs.push_back(hidden); |
285 | } |
286 | |
287 | Node *output = F_->createConcat("encoder.output" , outputs, 1); |
288 | Node *r2 = |
289 | F_->createReshape("encoder.output.r2" , output, |
290 | {MAX_LENGTH * batchSize_, EMBEDDING_SIZE}, ANY_LAYOUT); |
291 | |
292 | encoderHiddenOutput_ = F_->createGather("encoder.outputNth" , r2, seqLength_); |
293 | } |
294 | |
295 | /// Model part representing Decoder. |
296 | /// Uses \p encoderHiddenOutput as final state from Encoder. |
297 | /// Resulting translation is put into \p output Variable. |
298 | void Model::loadDecoder() { |
299 | auto &mod = EE_.getModule(); |
300 | auto *input = mod.createPlaceholder(ElemKind::Int64ITy, {batchSize_}, |
301 | "decoder.input" , false); |
302 | auto *inputTensor = bindings.allocate(input); |
303 | for (dim_t i = 0; i < batchSize_; i++) { |
304 | inputTensor->getHandle<int64_t>().at({i}) = en_.word2index_["SOS" ]; |
305 | } |
306 | |
307 | auto *wIh = mod.createPlaceholder( |
308 | ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "decoder.w_ih" , false); |
309 | auto *bIh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE}, |
310 | "decoder.b_ih" , false); |
311 | auto *wHh = mod.createPlaceholder( |
312 | ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "decoder.w_hh" , false); |
313 | auto *bHh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE}, |
314 | "decoder.b_hh" , false); |
315 | auto *outW = mod.createPlaceholder( |
316 | ElemKind::FloatTy, {EMBEDDING_SIZE, (dim_t)en_.index2word_.size()}, |
317 | "decoder.out_w" , false); |
318 | auto *outB = |
319 | mod.createPlaceholder(ElemKind::FloatTy, {(dim_t)en_.index2word_.size()}, |
320 | "decoder.out_b" , false); |
321 | loadMatrixFromFile("fr2en/decoder_w_ih.bin" , *bindings.allocate(wIh)); |
322 | loadMatrixFromFile("fr2en/decoder_b_ih.bin" , *bindings.allocate(bIh)); |
323 | loadMatrixFromFile("fr2en/decoder_w_hh.bin" , *bindings.allocate(wHh)); |
324 | loadMatrixFromFile("fr2en/decoder_b_hh.bin" , *bindings.allocate(bHh)); |
325 | loadMatrixFromFile("fr2en/decoder_out_w.bin" , *bindings.allocate(outW)); |
326 | loadMatrixFromFile("fr2en/decoder_out_b.bin" , *bindings.allocate(outB)); |
327 | |
328 | Node *hidden = encoderHiddenOutput_; |
329 | Node *lastWordIdx = input; |
330 | |
331 | std::vector<NodeValue> outputs; |
332 | // TODO: decoder does exactly MAX_LENGTH steps, while translation could be |
333 | // smaller. We could use control flow here. |
334 | for (unsigned step = 0; step < MAX_LENGTH; step++) { |
335 | // Use last translated word as an input at the current step. |
336 | Node *embedded = |
337 | F_->createGather("decoder.embedding." + std::to_string(step), |
338 | embedding_en_, lastWordIdx); |
339 | |
340 | Node *relu = F_->createRELU("decoder.relu" , embedded); |
341 | hidden = createPyTorchGRUCell(F_, relu, hidden, wIh, bIh, wHh, bHh); |
342 | |
343 | Node *FC = F_->createFullyConnected("decoder.outFC" , hidden, outW, outB); |
344 | auto *topK = F_->createTopK("decoder.topK" , FC, 1); |
345 | |
346 | lastWordIdx = F_->createReshape("decoder.reshape" , topK->getIndices(), |
347 | {batchSize_}, "N" ); |
348 | outputs.push_back(lastWordIdx); |
349 | } |
350 | |
351 | Node *concat = F_->createConcat("decoder.output.concat" , outputs, 0); |
352 | Node *reshape = F_->createReshape("decoder.output.reshape" , concat, |
353 | {MAX_LENGTH, batchSize_}, ANY_LAYOUT); |
354 | auto *save = F_->createSave("decoder.output" , reshape); |
355 | output_ = save->getPlaceholder(); |
356 | bindings.allocate(output_); |
357 | } |
358 | |
359 | /// Translation has 2 stages: |
360 | /// 1) Input sentence is fed into Encoder word by word. |
361 | /// 2) "Memory" of Encoder is written into memory of Decoder. |
362 | /// Now Decoder streams resulting translation word by word. |
363 | void Model::translate(const std::vector<std::string> &batch) { |
364 | Tensor input(ElemKind::Int64ITy, {batchSize_, MAX_LENGTH}); |
365 | Tensor seqLength(ElemKind::Int64ITy, {batchSize_}); |
366 | input.zero(); |
367 | |
368 | for (dim_t j = 0; j < batch.size(); j++) { |
369 | std::istringstream iss(batch[j]); |
370 | std::vector<std::string> words; |
371 | std::string word; |
372 | while (iss >> word) |
373 | words.push_back(word); |
374 | words.push_back("EOS" ); |
375 | |
376 | CHECK_LE(words.size(), MAX_LENGTH) << "sentence is too long." ; |
377 | |
378 | for (dim_t i = 0; i < words.size(); i++) { |
379 | auto iter = fr_.word2index_.find(words[i]); |
380 | CHECK(iter != fr_.word2index_.end()) << "Unknown word: " << words[i]; |
381 | input.getHandle<int64_t>().at({j, i}) = iter->second; |
382 | } |
383 | seqLength.getHandle<int64_t>().at({j}) = |
384 | (words.size() - 1) + j * MAX_LENGTH; |
385 | } |
386 | |
387 | updateInputPlaceholders(bindings, {input_, seqLength_}, {&input, &seqLength}); |
388 | EE_.run(bindings); |
389 | |
390 | auto OH = bindings.get(output_)->getHandle<int64_t>(); |
391 | for (unsigned j = 0; j < batch.size(); j++) { |
392 | for (unsigned i = 0; i < MAX_LENGTH; i++) { |
393 | dim_t wordIdx = OH.at({i, j}); |
394 | if (wordIdx == en_.word2index_["EOS" ]) |
395 | break; |
396 | |
397 | if (i) |
398 | std::cout << ' '; |
399 | if (en_.index2word_.size() > (wordIdx)) |
400 | std::cout << en_.index2word_[wordIdx]; |
401 | else |
402 | std::cout << "[" << wordIdx << "]" ; |
403 | } |
404 | std::cout << "\n\n" ; |
405 | } |
406 | |
407 | if (!dumpProfileFileOpt.empty()) { |
408 | std::vector<NodeProfilingInfo> PI = |
409 | quantization::generateNodeProfilingInfos(bindings, F_, loweredMap_); |
410 | serializeProfilingInfosToYaml(dumpProfileFileOpt, |
411 | /* graphPreLowerHash */ 0, PI); |
412 | } |
413 | } |
414 | |
415 | int main(int argc, char **argv) { |
416 | std::array<const llvm::cl::OptionCategory *, 3> showCategories = { |
417 | {&debugCat, &quantizationCat, &fr2enCat}}; |
418 | llvm::cl::HideUnrelatedOptions(showCategories); |
419 | llvm::cl::ParseCommandLineOptions( |
420 | argc, argv, "Translate sentences from French to English" ); |
421 | |
422 | Model seq2seq(batchSizeOpt); |
423 | seq2seq.loadLanguages(); |
424 | seq2seq.loadEncoder(); |
425 | seq2seq.loadDecoder(); |
426 | seq2seq.compile(); |
427 | |
428 | if (!dumpGraphDAGFileOpt.empty()) { |
429 | seq2seq.dumpGraphDAG(dumpGraphDAGFileOpt.c_str()); |
430 | } |
431 | |
432 | std::cout << "Please enter a sentence in French, such that its English " |
433 | << "translation starts with one of the following:\n" |
434 | << "\ti am\n" |
435 | << "\the is\n" |
436 | << "\tshe is\n" |
437 | << "\tyou are\n" |
438 | << "\twe are\n" |
439 | << "\tthey are\n" |
440 | << "\n" |
441 | << "Here are some examples:\n" |
442 | << "\tnous sommes desormais en securite .\n" |
443 | << "\tvous etes puissantes .\n" |
444 | << "\til etudie l histoire a l universite .\n" |
445 | << "\tje ne suis pas timide .\n" |
446 | << "\tj y songe encore .\n" |
447 | << "\tje suis maintenant a l aeroport .\n\n" ; |
448 | |
449 | llvm::Timer timer("Translate" , "Translate" ); |
450 | if (timeOpt) { |
451 | timer.startTimer(); |
452 | } |
453 | |
454 | std::vector<std::string> batch; |
455 | do { |
456 | batch.clear(); |
457 | for (size_t i = 0; i < batchSizeOpt; i++) { |
458 | std::string sentence; |
459 | if (!getline(std::cin, sentence)) { |
460 | break; |
461 | } |
462 | batch.push_back(sentence); |
463 | } |
464 | if (!batch.empty()) { |
465 | seq2seq.translate(batch); |
466 | } |
467 | } while (batch.size() == batchSizeOpt); |
468 | |
469 | if (timeOpt) { |
470 | timer.stopTimer(); |
471 | } |
472 | |
473 | return 0; |
474 | } |
475 | |