1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/ExecutionEngine/ExecutionEngine.h"
17#include "glow/Graph/Graph.h"
18#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
19#include "glow/Quantization/Quantization.h"
20#include "glow/Quantization/Serialization.h"
21
22#include "llvm/Support/CommandLine.h"
23#include "llvm/Support/Timer.h"
24
25#include <glog/logging.h>
26
27#include <algorithm>
28#include <fstream>
29#include <iostream>
30#include <sstream>
31#include <string>
32#include <unordered_map>
33#include <vector>
34
35using namespace glow;
36
37namespace {
38/// Debugging options.
39llvm::cl::OptionCategory debugCat("Glow Debugging Options");
40
41llvm::cl::opt<std::string> dumpGraphDAGFileOpt(
42 "dump-graph-DAG",
43 llvm::cl::desc("Dump the graph to the given file in DOT format."),
44 llvm::cl::value_desc("file.dot"), llvm::cl::cat(debugCat));
45
46/// Translator options.
47llvm::cl::OptionCategory fr2enCat("French-to-English Translator Options");
48
49llvm::cl::opt<unsigned> batchSizeOpt(
50 "batchsize", llvm::cl::desc("Process batches of N sentences at a time."),
51 llvm::cl::init(1), llvm::cl::value_desc("N"), llvm::cl::cat(fr2enCat));
52llvm::cl::alias batchSizeA("b", llvm::cl::desc("Alias for -batchsize"),
53 llvm::cl::aliasopt(batchSizeOpt),
54 llvm::cl::cat(fr2enCat));
55
56llvm::cl::opt<bool>
57 timeOpt("time",
58 llvm::cl::desc("Print timer data detailing how long it "
59 "takes for the program to execute translate phase. "
60 "This option will be useful if input is read from "
61 "the file directly."),
62 llvm::cl::Optional, llvm::cl::cat(fr2enCat));
63
64llvm::cl::opt<std::string> ExecutionBackend(
65 "backend",
66 llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, OpenCL:"),
67 llvm::cl::Optional, llvm::cl::init("Interpreter"), llvm::cl::cat(fr2enCat));
68
69/// Quantization options.
70llvm::cl::OptionCategory quantizationCat("Quantization Options");
71
72llvm::cl::opt<std::string> dumpProfileFileOpt(
73 "dump-profile",
74 llvm::cl::desc("Perform quantization profiling for a given graph "
75 "and dump result to the file."),
76 llvm::cl::value_desc("profile.yaml"), llvm::cl::Optional,
77 llvm::cl::cat(quantizationCat));
78
79llvm::cl::opt<std::string> loadProfileFileOpt(
80 "load-profile",
81 llvm::cl::desc("Load quantization profile file and quantize the graph"),
82 llvm::cl::value_desc("profile.yaml"), llvm::cl::Optional,
83 llvm::cl::cat(quantizationCat));
84
85} // namespace
86
87const unsigned MAX_LENGTH = 10;
88const unsigned EMBEDDING_SIZE = 256;
89const unsigned HIDDEN_SIZE = EMBEDDING_SIZE * 3;
90
91/// Stores vocabulary of a language. Contains mapping from word to index and
92/// vice versa.
93struct Vocabulary {
94 std::vector<std::string> index2word_;
95 std::unordered_map<std::string, int64_t> word2index_;
96
97 void addWord(llvm::StringRef word) {
98 word2index_[word.str()] = index2word_.size();
99 index2word_.push_back(word.str());
100 }
101
102 Vocabulary() = default;
103
104 void loadVocabularyFromFile(llvm::StringRef filename) {
105 std::ifstream file(filename.str());
106 std::string word;
107 while (getline(file, word))
108 addWord(word);
109 }
110};
111
112/// Loads tensor of floats from binary file.
113void loadMatrixFromFile(llvm::StringRef filename, Tensor &result) {
114 std::ifstream file(filename.str(), std::ios::binary);
115 if (!file.read(result.getUnsafePtr(), result.size() * sizeof(float))) {
116 std::cout
117 << "Error reading file: " << filename.str() << '\n'
118 << "Need to be downloaded by calling:\n"
119 << "python ../glow/utils/download_datasets_and_models.py -d fr2en\n";
120 exit(1);
121 }
122}
123
124/// Represents a single RNN model: encoder combined with decoder.
125/// Stores vocabulary, compiled Graph (ready to be executed), and
126/// few references to input/output Variables.
127struct Model {
128 unsigned batchSize_;
129 ExecutionEngine EE_{ExecutionBackend};
130 Function *F_;
131 Vocabulary en_, fr_;
132 Placeholder *input_;
133 Placeholder *seqLength_;
134 Placeholder *output_;
135 PlaceholderBindings bindings;
136 LoweredInfoMap loweredMap_;
137
138 void loadLanguages();
139 void loadEncoder();
140 void loadDecoder();
141 void translate(const std::vector<std::string> &batch);
142
143 Model(unsigned batchSize) : batchSize_(batchSize) {
144 F_ = EE_.getModule().createFunction("main");
145 }
146
147 void dumpGraphDAG(const char *filename) { F_->dumpDAG(filename); }
148
149 void compile() {
150 CompilationContext cctx{&bindings, &loweredMap_};
151 PrecisionConfiguration &precConfig = cctx.precisionConfig;
152
153 ::glow::convertPlaceholdersToConstants(F_, bindings,
154 {input_, seqLength_, output_});
155
156 if (!dumpProfileFileOpt.empty()) {
157 precConfig.quantMode = QuantizationMode::Profile;
158 }
159
160 // Load the quantization profile and transform the graph.
161 if (!loadProfileFileOpt.empty()) {
162 precConfig.quantMode = QuantizationMode::Quantize;
163 deserializeProfilingInfosFromYaml(
164 loadProfileFileOpt, precConfig.quantConfig.graphPreLowerHash,
165 precConfig.quantConfig.infos);
166 precConfig.quantConfig.assertAllNodesQuantized = true;
167 }
168
169 EE_.compile(cctx);
170
171 // After compilation, the original function may be removed/replaced. Need to
172 // update F_.
173 F_ = EE_.getModule().getFunctions().front();
174 }
175
176private:
177 Placeholder *embedding_fr_, *embedding_en_;
178 Node *encoderHiddenOutput_;
179
180 Placeholder *loadEmbedding(llvm::StringRef langPrefix, dim_t langSize) {
181 auto &mod = EE_.getModule();
182 auto *result =
183 mod.createPlaceholder(ElemKind::FloatTy, {langSize, EMBEDDING_SIZE},
184 "embedding." + langPrefix.str(), false);
185 loadMatrixFromFile("fr2en/" + langPrefix.str() + "_embedding.bin",
186 *bindings.allocate(result));
187
188 return result;
189 }
190
191 Node *createPyTorchGRUCell(Function *G, Node *input, Node *hidden,
192 Placeholder *wIh, Placeholder *bIh,
193 Placeholder *wHh, Placeholder *bHh) {
194 // reference implementation:
195 // https://github.com/pytorch/pytorch/blob/dd5c195646b941d3e20a72847ac48c41e272b8b2/torch/nn/_functions/rnn.py#L46
196 Node *gi = G->createFullyConnected("pytorch.GRU.gi", input, wIh, bIh);
197 Node *gh = G->createFullyConnected("pytorch.GRU.gh", hidden, wHh, bHh);
198
199 Node *i_r = G->createSlice("pytorch.GRU.i_r", gi, {0, 0},
200 {batchSize_, EMBEDDING_SIZE});
201 Node *i_i = G->createSlice("pytorch.GRU.i_i", gi, {0, EMBEDDING_SIZE},
202 {batchSize_, 2 * EMBEDDING_SIZE});
203 Node *i_n = G->createSlice("pytorch.GRU.i_n", gi, {0, 2 * EMBEDDING_SIZE},
204 {batchSize_, 3 * EMBEDDING_SIZE});
205
206 Node *h_r = G->createSlice("pytorch.GRU.h_r", gh, {0, 0},
207 {batchSize_, EMBEDDING_SIZE});
208 Node *h_i = G->createSlice("pytorch.GRU.h_i", gh, {0, EMBEDDING_SIZE},
209 {batchSize_, 2 * EMBEDDING_SIZE});
210 Node *h_n = G->createSlice("pytorch.GRU.h_n", gh, {0, 2 * EMBEDDING_SIZE},
211 {batchSize_, 3 * EMBEDDING_SIZE});
212
213 Node *resetgate = G->createSigmoid("pytorch.GRU.resetgate",
214 G->createAdd("i_r_plus_h_r", i_r, h_r));
215 Node *inputgate = G->createSigmoid("pytorch.GRU.inputgate",
216 G->createAdd("i_i_plus_h_i", i_i, h_i));
217 Node *newgate = G->createTanh(
218 "pytorch.GRU.newgate",
219 G->createAdd("i_n_plus_rg_mult_h_n", i_n,
220 G->createMul("rg_mult_h_n", resetgate, h_n)));
221 return G->createAdd(
222 "pytorch.GRU.hy", newgate,
223 G->createMul("ig_mult_hmng", inputgate,
224 G->createSub("hidden_minus_newgate", hidden, newgate)));
225 }
226};
227
228void Model::loadLanguages() {
229 fr_.loadVocabularyFromFile("fr2en/fr_vocabulary.txt");
230 en_.loadVocabularyFromFile("fr2en/en_vocabulary.txt");
231 embedding_fr_ = loadEmbedding("fr", fr_.index2word_.size());
232 embedding_en_ = loadEmbedding("en", en_.index2word_.size());
233}
234
235/// Model part representing Encoder. Remembers input sentence into hidden layer.
236/// \p input is Variable representing the sentence.
237/// \p seqLength is Variable representing the length of sentence.
238/// \p encoderHiddenOutput saves resulting hidden layer.
239void Model::loadEncoder() {
240 auto &mod = EE_.getModule();
241 input_ = mod.createPlaceholder(ElemKind::Int64ITy, {batchSize_, MAX_LENGTH},
242 "encoder.inputsentence", false);
243 bindings.allocate(input_);
244 seqLength_ = mod.createPlaceholder(ElemKind::Int64ITy, {batchSize_},
245 "encoder.seqLength", false);
246 bindings.allocate(seqLength_);
247
248 auto *hiddenInit =
249 mod.createPlaceholder(ElemKind::FloatTy, {batchSize_, EMBEDDING_SIZE},
250 "encoder.hiddenInit", false);
251 auto *hiddenInitTensor = bindings.allocate(hiddenInit);
252 hiddenInitTensor->zero();
253
254 Node *hidden = hiddenInit;
255
256 auto *wIh = mod.createPlaceholder(
257 ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "encoder.w_ih", false);
258 auto *bIh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE},
259 "encoder.b_ih", false);
260 auto *wHh = mod.createPlaceholder(
261 ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "encoder.w_hh", false);
262 auto *bHh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE},
263 "encoder.b_hh", false);
264
265 loadMatrixFromFile("fr2en/encoder_w_ih.bin", *bindings.allocate(wIh));
266 loadMatrixFromFile("fr2en/encoder_b_ih.bin", *bindings.allocate(bIh));
267 loadMatrixFromFile("fr2en/encoder_w_hh.bin", *bindings.allocate(wHh));
268 loadMatrixFromFile("fr2en/encoder_b_hh.bin", *bindings.allocate(bHh));
269
270 Node *inputEmbedded =
271 F_->createGather("encoder.embedding", embedding_fr_, input_);
272
273 // TODO: encoder does exactly MAX_LENGTH steps, while input size is smaller.
274 // We could use control flow here.
275 std::vector<NodeValue> outputs;
276 for (unsigned step = 0; step < MAX_LENGTH; step++) {
277 Node *inputSlice = F_->createSlice(
278 "encoder." + std::to_string(step) + ".inputSlice", inputEmbedded,
279 {0, step, 0}, {batchSize_, step + 1, EMBEDDING_SIZE});
280 Node *reshape =
281 F_->createReshape("encoder." + std::to_string(step) + ".reshape",
282 inputSlice, {batchSize_, EMBEDDING_SIZE}, ANY_LAYOUT);
283 hidden = createPyTorchGRUCell(F_, reshape, hidden, wIh, bIh, wHh, bHh);
284 outputs.push_back(hidden);
285 }
286
287 Node *output = F_->createConcat("encoder.output", outputs, 1);
288 Node *r2 =
289 F_->createReshape("encoder.output.r2", output,
290 {MAX_LENGTH * batchSize_, EMBEDDING_SIZE}, ANY_LAYOUT);
291
292 encoderHiddenOutput_ = F_->createGather("encoder.outputNth", r2, seqLength_);
293}
294
295/// Model part representing Decoder.
296/// Uses \p encoderHiddenOutput as final state from Encoder.
297/// Resulting translation is put into \p output Variable.
298void Model::loadDecoder() {
299 auto &mod = EE_.getModule();
300 auto *input = mod.createPlaceholder(ElemKind::Int64ITy, {batchSize_},
301 "decoder.input", false);
302 auto *inputTensor = bindings.allocate(input);
303 for (dim_t i = 0; i < batchSize_; i++) {
304 inputTensor->getHandle<int64_t>().at({i}) = en_.word2index_["SOS"];
305 }
306
307 auto *wIh = mod.createPlaceholder(
308 ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "decoder.w_ih", false);
309 auto *bIh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE},
310 "decoder.b_ih", false);
311 auto *wHh = mod.createPlaceholder(
312 ElemKind::FloatTy, {EMBEDDING_SIZE, HIDDEN_SIZE}, "decoder.w_hh", false);
313 auto *bHh = mod.createPlaceholder(ElemKind::FloatTy, {HIDDEN_SIZE},
314 "decoder.b_hh", false);
315 auto *outW = mod.createPlaceholder(
316 ElemKind::FloatTy, {EMBEDDING_SIZE, (dim_t)en_.index2word_.size()},
317 "decoder.out_w", false);
318 auto *outB =
319 mod.createPlaceholder(ElemKind::FloatTy, {(dim_t)en_.index2word_.size()},
320 "decoder.out_b", false);
321 loadMatrixFromFile("fr2en/decoder_w_ih.bin", *bindings.allocate(wIh));
322 loadMatrixFromFile("fr2en/decoder_b_ih.bin", *bindings.allocate(bIh));
323 loadMatrixFromFile("fr2en/decoder_w_hh.bin", *bindings.allocate(wHh));
324 loadMatrixFromFile("fr2en/decoder_b_hh.bin", *bindings.allocate(bHh));
325 loadMatrixFromFile("fr2en/decoder_out_w.bin", *bindings.allocate(outW));
326 loadMatrixFromFile("fr2en/decoder_out_b.bin", *bindings.allocate(outB));
327
328 Node *hidden = encoderHiddenOutput_;
329 Node *lastWordIdx = input;
330
331 std::vector<NodeValue> outputs;
332 // TODO: decoder does exactly MAX_LENGTH steps, while translation could be
333 // smaller. We could use control flow here.
334 for (unsigned step = 0; step < MAX_LENGTH; step++) {
335 // Use last translated word as an input at the current step.
336 Node *embedded =
337 F_->createGather("decoder.embedding." + std::to_string(step),
338 embedding_en_, lastWordIdx);
339
340 Node *relu = F_->createRELU("decoder.relu", embedded);
341 hidden = createPyTorchGRUCell(F_, relu, hidden, wIh, bIh, wHh, bHh);
342
343 Node *FC = F_->createFullyConnected("decoder.outFC", hidden, outW, outB);
344 auto *topK = F_->createTopK("decoder.topK", FC, 1);
345
346 lastWordIdx = F_->createReshape("decoder.reshape", topK->getIndices(),
347 {batchSize_}, "N");
348 outputs.push_back(lastWordIdx);
349 }
350
351 Node *concat = F_->createConcat("decoder.output.concat", outputs, 0);
352 Node *reshape = F_->createReshape("decoder.output.reshape", concat,
353 {MAX_LENGTH, batchSize_}, ANY_LAYOUT);
354 auto *save = F_->createSave("decoder.output", reshape);
355 output_ = save->getPlaceholder();
356 bindings.allocate(output_);
357}
358
359/// Translation has 2 stages:
360/// 1) Input sentence is fed into Encoder word by word.
361/// 2) "Memory" of Encoder is written into memory of Decoder.
362/// Now Decoder streams resulting translation word by word.
363void Model::translate(const std::vector<std::string> &batch) {
364 Tensor input(ElemKind::Int64ITy, {batchSize_, MAX_LENGTH});
365 Tensor seqLength(ElemKind::Int64ITy, {batchSize_});
366 input.zero();
367
368 for (dim_t j = 0; j < batch.size(); j++) {
369 std::istringstream iss(batch[j]);
370 std::vector<std::string> words;
371 std::string word;
372 while (iss >> word)
373 words.push_back(word);
374 words.push_back("EOS");
375
376 CHECK_LE(words.size(), MAX_LENGTH) << "sentence is too long.";
377
378 for (dim_t i = 0; i < words.size(); i++) {
379 auto iter = fr_.word2index_.find(words[i]);
380 CHECK(iter != fr_.word2index_.end()) << "Unknown word: " << words[i];
381 input.getHandle<int64_t>().at({j, i}) = iter->second;
382 }
383 seqLength.getHandle<int64_t>().at({j}) =
384 (words.size() - 1) + j * MAX_LENGTH;
385 }
386
387 updateInputPlaceholders(bindings, {input_, seqLength_}, {&input, &seqLength});
388 EE_.run(bindings);
389
390 auto OH = bindings.get(output_)->getHandle<int64_t>();
391 for (unsigned j = 0; j < batch.size(); j++) {
392 for (unsigned i = 0; i < MAX_LENGTH; i++) {
393 dim_t wordIdx = OH.at({i, j});
394 if (wordIdx == en_.word2index_["EOS"])
395 break;
396
397 if (i)
398 std::cout << ' ';
399 if (en_.index2word_.size() > (wordIdx))
400 std::cout << en_.index2word_[wordIdx];
401 else
402 std::cout << "[" << wordIdx << "]";
403 }
404 std::cout << "\n\n";
405 }
406
407 if (!dumpProfileFileOpt.empty()) {
408 std::vector<NodeProfilingInfo> PI =
409 quantization::generateNodeProfilingInfos(bindings, F_, loweredMap_);
410 serializeProfilingInfosToYaml(dumpProfileFileOpt,
411 /* graphPreLowerHash */ 0, PI);
412 }
413}
414
415int main(int argc, char **argv) {
416 std::array<const llvm::cl::OptionCategory *, 3> showCategories = {
417 {&debugCat, &quantizationCat, &fr2enCat}};
418 llvm::cl::HideUnrelatedOptions(showCategories);
419 llvm::cl::ParseCommandLineOptions(
420 argc, argv, "Translate sentences from French to English");
421
422 Model seq2seq(batchSizeOpt);
423 seq2seq.loadLanguages();
424 seq2seq.loadEncoder();
425 seq2seq.loadDecoder();
426 seq2seq.compile();
427
428 if (!dumpGraphDAGFileOpt.empty()) {
429 seq2seq.dumpGraphDAG(dumpGraphDAGFileOpt.c_str());
430 }
431
432 std::cout << "Please enter a sentence in French, such that its English "
433 << "translation starts with one of the following:\n"
434 << "\ti am\n"
435 << "\the is\n"
436 << "\tshe is\n"
437 << "\tyou are\n"
438 << "\twe are\n"
439 << "\tthey are\n"
440 << "\n"
441 << "Here are some examples:\n"
442 << "\tnous sommes desormais en securite .\n"
443 << "\tvous etes puissantes .\n"
444 << "\til etudie l histoire a l universite .\n"
445 << "\tje ne suis pas timide .\n"
446 << "\tj y songe encore .\n"
447 << "\tje suis maintenant a l aeroport .\n\n";
448
449 llvm::Timer timer("Translate", "Translate");
450 if (timeOpt) {
451 timer.startTimer();
452 }
453
454 std::vector<std::string> batch;
455 do {
456 batch.clear();
457 for (size_t i = 0; i < batchSizeOpt; i++) {
458 std::string sentence;
459 if (!getline(std::cin, sentence)) {
460 break;
461 }
462 batch.push_back(sentence);
463 }
464 if (!batch.empty()) {
465 seq2seq.translate(batch);
466 }
467 } while (batch.size() == batchSizeOpt);
468
469 if (timeOpt) {
470 timer.stopTimer();
471 }
472
473 return 0;
474}
475