1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/LLVMIRCodeGen/BundleSaver.h" |
18 | #include "glow/LLVMIRCodeGen/CommandLine.h" |
19 | #include "glow/LLVMIRCodeGen/LLVMBackend.h" |
20 | |
21 | #include "glow/Graph/Graph.h" |
22 | #include "glow/Graph/PlaceholderBindings.h" |
23 | #include "glow/IR/Instrs.h" |
24 | #include "glow/IR/LLVMAPIMacros.h" |
25 | #include "glow/Support/Debug.h" |
26 | |
27 | #include "llvm/ADT/SmallVector.h" |
28 | #include "llvm/Bitcode/BitcodeWriter.h" |
29 | #include "llvm/IR/LLVMContext.h" |
30 | #include "llvm/IR/LegacyPassManager.h" |
31 | #include "llvm/IR/Module.h" |
32 | #include "llvm/Object/Archive.h" |
33 | #include "llvm/Object/ArchiveWriter.h" |
34 | #include "llvm/Support/Casting.h" |
35 | #include "llvm/Support/Debug.h" |
36 | #include "llvm/Support/FileSystem.h" |
37 | #include "llvm/Support/raw_ostream.h" |
38 | |
39 | #include <glog/logging.h> |
40 | |
41 | #define DEBUG_TYPE "jit" |
42 | |
43 | using namespace glow; |
44 | using llvm::cast; |
45 | using llvm::dyn_cast; |
46 | using llvm::isa; |
47 | |
48 | /// Header file string template. |
49 | static const char * = |
50 | R"RAW(%s |
51 | #ifndef _GLOW_BUNDLE_%s_H |
52 | #define _GLOW_BUNDLE_%s_H |
53 | |
54 | #include <stdint.h> |
55 | |
56 | // --------------------------------------------------------------- |
57 | // Common definitions |
58 | // --------------------------------------------------------------- |
59 | #ifndef _GLOW_BUNDLE_COMMON_DEFS |
60 | #define _GLOW_BUNDLE_COMMON_DEFS |
61 | |
62 | // Glow bundle error code for correct execution. |
63 | #define GLOW_SUCCESS 0 |
64 | %s |
65 | #endif |
66 | |
67 | // --------------------------------------------------------------- |
68 | // Bundle API |
69 | // --------------------------------------------------------------- |
70 | %s |
71 | // NOTE: Placeholders are allocated within the "mutableWeight" |
72 | // buffer and are identified using an offset relative to base. |
73 | // --------------------------------------------------------------- |
74 | #ifdef __cplusplus |
75 | extern "C" { |
76 | #endif |
77 | %s%s |
78 | #ifdef __cplusplus |
79 | } |
80 | #endif |
81 | #endif |
82 | )RAW" ; |
83 | |
84 | /// Function to print the header file using the template. |
85 | static void (llvm::StringRef , |
86 | llvm::StringRef bundleName, |
87 | llvm::StringRef commonDefines, |
88 | llvm::StringRef modelInfo, llvm::StringRef modelApi, |
89 | llvm::StringRef ) { |
90 | std::error_code EC; |
91 | llvm::raw_fd_ostream (headerFileName, EC, |
92 | llvm::sys::fs::OpenFlags::F_Text); |
93 | CHECK(!EC) << "Could not open header file!" ; |
94 | std::string ; |
95 | header += "// Bundle API auto-generated header file. Do not edit!\n" ; |
96 | #ifdef GLOW_VERSION |
97 | header += "// Glow Tools version: " + std::string(GLOW_VERSION) + "\n" ; |
98 | #endif |
99 | headerFile << strFormat(headerFileTemplate, header.c_str(), |
100 | bundleName.upper().data(), bundleName.upper().data(), |
101 | commonDefines.data(), modelInfo.data(), |
102 | modelApi.data(), headerExtra.data()); |
103 | headerFile.close(); |
104 | } |
105 | |
106 | /// Header file common definitions for dynamic API. |
107 | static const char *dynamicApiCommonDefines = R"RAW( |
108 | // Type describing a symbol table entry of a generated bundle. |
109 | struct SymbolTableEntry { |
110 | // Name of a variable. |
111 | const char *name; |
112 | // Offset of the variable inside the memory area. |
113 | uint64_t offset; |
114 | // The number of elements inside this variable. |
115 | uint64_t size; |
116 | // Variable kind: 1 if it is a mutable variable, 0 otherwise. |
117 | char kind; |
118 | }; |
119 | |
120 | // Type describing the config of a generated bundle. |
121 | struct BundleConfig { |
122 | // Size of the constant weight variables memory area. |
123 | uint64_t constantWeightVarsMemSize; |
124 | // Size of the mutable weight variables memory area. |
125 | uint64_t mutableWeightVarsMemSize; |
126 | // Size of the activations memory area. |
127 | uint64_t activationsMemSize; |
128 | // Alignment to be used for weights and activations. |
129 | uint64_t alignment; |
130 | // Number of symbols in the symbol table. |
131 | uint64_t numSymbols; |
132 | // Symbol table. |
133 | const SymbolTableEntry *symbolTable; |
134 | }; |
135 | )RAW" ; |
136 | |
137 | /// Header file common definitions for static API. |
138 | static const char *staticApiCommonDefines = R"RAW( |
139 | // Memory alignment definition with given alignment size |
140 | // for static allocation of memory. |
141 | #define GLOW_MEM_ALIGN(size) __attribute__((aligned(size))) |
142 | |
143 | // Macro function to get the absolute address of a |
144 | // placeholder using the base address of the mutable |
145 | // weight buffer and placeholder offset definition. |
146 | #define GLOW_GET_ADDR(mutableBaseAddr, placeholderOff) (((uint8_t*)(mutableBaseAddr)) + placeholderOff) |
147 | )RAW" ; |
148 | |
149 | /// Utility function to serialize a binary file to text file as a C array. |
150 | static void serializeBinaryToText(llvm::StringRef binFileName, |
151 | llvm::StringRef txtFileName) { |
152 | FILE *inpFile = fopen(binFileName.str().c_str(), "rb" ); |
153 | CHECK(inpFile) << "Could not open binary input file: " << binFileName.str(); |
154 | FILE *outFile = fopen(txtFileName.str().c_str(), "w" ); |
155 | CHECK(outFile) << "Could not open text output file: " << txtFileName.str(); |
156 | const size_t numBytesPerLine = 20; |
157 | for (size_t i = 0;; i++) { |
158 | int ch = fgetc(inpFile); |
159 | if (ch == EOF) { |
160 | break; |
161 | } |
162 | fprintf(outFile, " 0X%02X," , ch); |
163 | if ((i % numBytesPerLine) == (numBytesPerLine - 1)) { |
164 | fprintf(outFile, "\n" ); |
165 | } |
166 | } |
167 | fprintf(outFile, "\n" ); |
168 | fclose(inpFile); |
169 | fclose(outFile); |
170 | } |
171 | |
172 | BundleSaver::BundleSaver(const LLVMBackend &llvmBackend, |
173 | llvm::StringRef outputDir, llvm::StringRef bundleName) |
174 | : irgen_(llvmBackend.createIRGen(nullptr, allocationsInfo_)), |
175 | bundleAPI_(llvmBackend.getOptions().getBundleAPI()) { |
176 | llvm::SmallVector<std::string, 8> targetFeatures(llvmTargetFeatures.begin(), |
177 | llvmTargetFeatures.end()); |
178 | irgen_->setBundleName(bundleName.str()); |
179 | irgen_->setOutputDir(outputDir); |
180 | irgen_->setObjectRegistry(llvmBackend.getObjectRegistry()); |
181 | // Use the bundle code model as a code model for the TargetMachine. |
182 | auto opts = llvmBackend.getOptions(); |
183 | opts.setCodeModel(opts.getBundleCodeModel()); |
184 | irgen_->initTargetMachine(opts); |
185 | irgen_->initCodeGen(); |
186 | } |
187 | |
188 | void BundleSaver::setIRFunction(llvm::StringRef mainEntryName, |
189 | const IRFunction *F) { |
190 | irgen_->setIRFunction(F); |
191 | if (F) { |
192 | savedIRFunctions_.push_back(SavedIRFunction{mainEntryName.str(), F}); |
193 | } |
194 | } |
195 | |
196 | bool BundleSaver::WeightAddrComparator::operator()( |
197 | const WeightInfo &LHS, const WeightInfo &RHS) const { |
198 | auto lhsAddr = |
199 | bundleSaver_->allocationsInfo_.allocatedAddress_.lookup(LHS.first); |
200 | auto rhsAddr = |
201 | bundleSaver_->allocationsInfo_.allocatedAddress_.lookup(RHS.first); |
202 | return lhsAddr < rhsAddr; |
203 | } |
204 | |
205 | std::set<BundleSaver::WeightInfo, BundleSaver::WeightAddrComparator> |
206 | BundleSaver::findConstantWeights() const { |
207 | std::set<BundleSaver::WeightInfo, BundleSaver::WeightAddrComparator> |
208 | constants(WeightAddrComparator(*const_cast<BundleSaver *>(this))); |
209 | for (auto &savedIRFunction : savedIRFunctions_) { |
210 | for (auto *c : savedIRFunction.savedF->findConstants()) { |
211 | auto *w = cast<WeightVar>(savedIRFunction.savedF->getWeightForNode(c)); |
212 | constants.insert({w, c}); |
213 | } |
214 | } |
215 | return constants; |
216 | } |
217 | |
218 | std::set<const Placeholder *> BundleSaver::findPlaceholders() const { |
219 | std::set<const Placeholder *> placeholders; |
220 | for (auto &savedIRFunction : savedIRFunctions_) { |
221 | for (auto *ph : savedIRFunction.savedF->findPlaceholders()) { |
222 | placeholders.insert(ph); |
223 | } |
224 | } |
225 | return placeholders; |
226 | } |
227 | |
228 | Value *BundleSaver::getWeightForNode(const Storage *V) const { |
229 | for (auto &savedIRFunction : savedIRFunctions_) { |
230 | if (auto *W = savedIRFunction.savedF->getWeightForNode(V)) { |
231 | return W; |
232 | } |
233 | } |
234 | return nullptr; |
235 | } |
236 | |
237 | void BundleSaver::saveWeights(llvm::StringRef weightsFileName) { |
238 | std::error_code EC; |
239 | llvm::raw_fd_ostream weightsFile(weightsFileName, EC, llvm::sys::fs::F_None); |
240 | CHECK(!EC) << "Could not open the output file for saving the bundle weights " |
241 | "with file name: " |
242 | << weightsFileName.str(); |
243 | // Serialize only constant weights. |
244 | // Do not serialize mutable weights representing inputs and outputs, because |
245 | // it should be configurable and set by the client. |
246 | size_t pos = 0; |
247 | size_t maxPos = 0; |
248 | for (auto &weightInfo : findConstantWeights()) { |
249 | auto *w = weightInfo.first; |
250 | auto *c = weightInfo.second; |
251 | auto numBytes = w->getSizeInBytes(); |
252 | auto payload = c->getPayload().getUnsafePtr(); |
253 | auto addr = allocationsInfo_.allocatedAddress_[weightInfo.first]; |
254 | if (addr < pos) { |
255 | // The payload was written already. It aliases something we have seen |
256 | // already. |
257 | continue; |
258 | } |
259 | weightsFile.seek(addr); |
260 | CHECK(!weightsFile.has_error()) << "Could not set file write position" ; |
261 | weightsFile.write(payload, numBytes); |
262 | CHECK(!weightsFile.has_error()) << "Could not write bytes" ; |
263 | pos = addr + numBytes; |
264 | maxPos = std::max(pos, maxPos); |
265 | } |
266 | // Make sure that the file is as long as the constantWeightVarsMemSize_. |
267 | // This is needed to properly handle alignments. |
268 | weightsFile.seek(maxPos); |
269 | for (size_t endPos = irgen_->getAllocationsInfo().constantWeightVarsMemSize_; |
270 | maxPos < endPos; maxPos++) { |
271 | weightsFile.write(0); |
272 | } |
273 | weightsFile.close(); |
274 | } |
275 | |
276 | void BundleSaver::(llvm::StringRef ) { |
277 | auto bundleName = irgen_->getBundleName(); |
278 | auto bundleNameUpper = llvm::StringRef(bundleName).upper(); |
279 | auto constMemSize = irgen_->getAllocationsInfo().constantWeightVarsMemSize_; |
280 | auto mutableMemSize = irgen_->getAllocationsInfo().mutableWeightVarsMemSize_; |
281 | auto activationsMemSize = irgen_->getAllocationsInfo().activationsMemSize_; |
282 | auto activationsMemAllocEff = irgen_->getAllocationsInfo() |
283 | .getActivationsAllocator() |
284 | .getAllocationEfficiency(); |
285 | auto memAlignSize = TensorAlignment; |
286 | auto totMemSize = constMemSize + mutableMemSize + activationsMemSize; |
287 | |
288 | // Format common bundle definitions. |
289 | auto commonDefines = (bundleAPI_ == BundleApiType::Dynamic) |
290 | ? dynamicApiCommonDefines |
291 | : staticApiCommonDefines; |
292 | |
293 | // Format model description. |
294 | std::string modelInfo = |
295 | strFormat("// Model name: \"%s\"\n" |
296 | "// Total data size: %lu (bytes)\n" |
297 | "// Activations allocation efficiency: %.4f\n" , |
298 | bundleName.data(), totMemSize, activationsMemAllocEff); |
299 | // Print placeholders (mandatory). |
300 | modelInfo += "// Placeholders:\n" ; |
301 | auto placeholders = findPlaceholders(); |
302 | for (auto &v : placeholders) { |
303 | auto *w = cast<WeightVar>(getWeightForNode(v)); |
304 | // Get placeholder properties. |
305 | auto name = w->getName(); |
306 | auto type = w->getType(); |
307 | auto typeName = type->toString(); |
308 | auto sizeElem = type->size(); |
309 | auto sizeByte = type->getSizeInBytes(); |
310 | auto offset = allocationsInfo_.allocatedAddress_[w]; |
311 | modelInfo += strFormat("//\n" |
312 | "// Name: \"%s\"\n" |
313 | "// Type: %s\n" |
314 | "// Size: %" PRIuDIM " (elements)\n" |
315 | "// Size: %zu (bytes)\n" |
316 | "// Offset: %lu (bytes)\n" , |
317 | name.data(), typeName.c_str(), sizeElem, sizeByte, |
318 | (unsigned long)offset); |
319 | } |
320 | // Print constants (optional). |
321 | if (bundleAPIVerbose) { |
322 | modelInfo += "//\n" |
323 | "// Constants:\n" ; |
324 | auto constantWeights = findConstantWeights(); |
325 | for (auto &weightInfo : constantWeights) { |
326 | auto *w = weightInfo.first; |
327 | // Get constant properties. |
328 | auto name = w->getName(); |
329 | auto type = w->getType(); |
330 | auto typeName = type->toString(); |
331 | auto sizeElem = type->size(); |
332 | auto sizeByte = type->getSizeInBytes(); |
333 | auto offset = allocationsInfo_.allocatedAddress_[w]; |
334 | modelInfo += strFormat("//\n" |
335 | "// Name: \"%s\"\n" |
336 | "// Type: %s\n" |
337 | "// Size: %" PRIuDIM " (elements)\n" |
338 | "// Size: %zu (bytes)\n" |
339 | "// Offset: %lu (bytes)\n" , |
340 | name.data(), typeName.c_str(), sizeElem, sizeByte, |
341 | (unsigned long)offset); |
342 | } |
343 | } |
344 | modelInfo += "//" ; |
345 | |
346 | std::string modelApi = "\n" ; |
347 | if (bundleAPI_ == BundleApiType::Dynamic) { |
348 | // Print bundle memory configuration. |
349 | modelApi += strFormat("// Bundle memory configuration (memory layout).\n" |
350 | "extern BundleConfig %s_config;\n" |
351 | "\n" , |
352 | bundleName.data()); |
353 | |
354 | } else { |
355 | // Get placeholder names and offsets. Compute also the maximum placeholder |
356 | // name length for print purposes. |
357 | unsigned nameMaxLen = 0; |
358 | std::vector<std::pair<llvm::StringRef, unsigned>> nameAddrPairs; |
359 | for (auto &v : placeholders) { |
360 | auto *w = cast<WeightVar>(getWeightForNode(v)); |
361 | auto name = w->getName(); |
362 | auto addr = allocationsInfo_.allocatedAddress_[w]; |
363 | nameMaxLen = name.size() > nameMaxLen ? name.size() : nameMaxLen; |
364 | nameAddrPairs.push_back(std::pair<llvm::StringRef, unsigned>(name, addr)); |
365 | } |
366 | |
367 | // Print placeholder address offsets. |
368 | modelApi += |
369 | "// Placeholder address offsets within mutable buffer (bytes).\n" ; |
370 | for (auto &pair : nameAddrPairs) { |
371 | modelApi += strFormat( |
372 | "#define %s_%s%s %u\n" , bundleNameUpper.data(), pair.first.data(), |
373 | std::string(nameMaxLen - pair.first.size(), ' ').c_str(), |
374 | pair.second); |
375 | } |
376 | modelApi += "\n" ; |
377 | |
378 | // Print memory sizes and memory alignment. |
379 | modelApi += |
380 | strFormat("// Memory sizes (bytes).\n" |
381 | "#define %s_CONSTANT_MEM_SIZE %lu\n" |
382 | "#define %s_MUTABLE_MEM_SIZE %lu\n" |
383 | "#define %s_ACTIVATIONS_MEM_SIZE %lu\n" |
384 | "\n" |
385 | "// Memory alignment (bytes).\n" |
386 | "#define %s_MEM_ALIGN %d\n" |
387 | "\n" , |
388 | bundleNameUpper.data(), constMemSize, bundleNameUpper.data(), |
389 | mutableMemSize, bundleNameUpper.data(), activationsMemSize, |
390 | bundleNameUpper.data(), memAlignSize); |
391 | } |
392 | |
393 | // Print bundle entry functions. |
394 | for (auto &savedIRFunction : savedIRFunctions_) { |
395 | modelApi += |
396 | strFormat("// Bundle entry point (inference function). Returns 0\n" |
397 | "// for correct execution or some error code otherwise.\n" |
398 | "int %s(" |
399 | "uint8_t *constantWeight, " |
400 | "uint8_t *mutableWeight, " |
401 | "uint8_t *activations" |
402 | ");\n" , |
403 | savedIRFunction.entryName.c_str()); |
404 | } |
405 | |
406 | // Get bundle header extra content. |
407 | std::string = irgen_->getBundleHeaderExtra(); |
408 | |
409 | // Print header file. |
410 | printHeader(headerFileName, bundleName, commonDefines, modelInfo, modelApi, |
411 | headerExtra); |
412 | } |
413 | |
414 | void BundleSaver::emitSymbolTable() { |
415 | // Define a struct for symbol table entries: |
416 | // struct SymbolTableEntry { |
417 | // const char *name; |
418 | // uint64_t offset; |
419 | // uint64_t size; |
420 | // char kind; |
421 | // }; |
422 | auto *charTy = llvm::Type::getInt8Ty(irgen_->getLLVMContext()); |
423 | auto *uint64TTy = |
424 | llvm::Type::getIntNTy(irgen_->getLLVMContext(), sizeof(uint64_t) * 8); |
425 | auto symbolTableEntryTy = |
426 | GET_TYPE_BY_NAME(irgen_->getModule(), "struct.SymbolTableEntry" ); |
427 | if (!symbolTableEntryTy) { |
428 | symbolTableEntryTy = llvm::StructType::get( |
429 | irgen_->getLLVMContext(), |
430 | {charTy->getPointerTo(), uint64TTy, uint64TTy, charTy}); |
431 | } |
432 | // Set of entries in the symbol table. |
433 | llvm::SmallVector<llvm::Constant *, 128> entries; |
434 | // Iterate over all Placeholders and record information about their names, |
435 | // offset, size and kind. |
436 | for (auto &v : findPlaceholders()) { |
437 | auto *w = cast<WeightVar>(getWeightForNode(v)); |
438 | auto size = w->getType()->size(); |
439 | auto addr = allocationsInfo_.allocatedAddress_[w]; |
440 | // Create an SymbolTableEntry. |
441 | auto *entry = llvm::ConstantStruct::get( |
442 | symbolTableEntryTy, |
443 | {// name. |
444 | dyn_cast<llvm::Constant>(irgen_->getBuilder().CreateBitCast( |
445 | irgen_->emitStringConst(irgen_->getBuilder(), w->getName()), |
446 | charTy->getPointerTo())), |
447 | // offset. |
448 | llvm::ConstantInt::get(uint64TTy, addr), |
449 | // size. |
450 | llvm::ConstantInt::get(uint64TTy, size), |
451 | // 1 for Mutable Kind |
452 | llvm::ConstantInt::get(charTy, 1)}); |
453 | entries.push_back(entry); |
454 | } |
455 | |
456 | // Create a constant array with these entries. |
457 | auto *arr = llvm::ConstantArray::get( |
458 | llvm::ArrayType::get(symbolTableEntryTy, entries.size()), entries); |
459 | new llvm::GlobalVariable(irgen_->getModule(), arr->getType(), true, |
460 | llvm::GlobalValue::InternalLinkage, arr, |
461 | irgen_->getBundleName() + "SymbolTable" ); |
462 | } |
463 | |
464 | void BundleSaver::createBundleArchive( |
465 | llvm::StringRef bundlePath, |
466 | llvm::ArrayRef<llvm::MemoryBufferRef> bundleObjectRegistry, |
467 | const std::vector<std::string> &bundleObjects) { |
468 | |
469 | // If we do not have extra object files then return early. |
470 | if (bundleObjects.empty()) { |
471 | return; |
472 | } |
473 | |
474 | // Read original bundle object file as archive member. |
475 | std::vector<llvm::NewArchiveMember> newMembers; |
476 | llvm::Expected<llvm::NewArchiveMember> newMember = |
477 | llvm::NewArchiveMember::getFile(bundlePath.str(), |
478 | /* Deterministic */ true); |
479 | newMembers.push_back(std::move(*newMember)); |
480 | |
481 | // Add other object files as archive members. |
482 | for (const auto &objectName : bundleObjects) { |
483 | // If this object was already added then we skip it. |
484 | bool objectAdded = false; |
485 | for (const auto &member : newMembers) { |
486 | if (member.MemberName.str() == objectName) { |
487 | objectAdded = true; |
488 | break; |
489 | } |
490 | } |
491 | if (objectAdded) { |
492 | continue; |
493 | } |
494 | // Find current object and add it as archive member. |
495 | bool objectFound = false; |
496 | for (const auto &memBuffRef : bundleObjectRegistry) { |
497 | if (memBuffRef.getBufferIdentifier().str() == objectName) { |
498 | llvm::NewArchiveMember newMember(memBuffRef); |
499 | newMembers.push_back(std::move(newMember)); |
500 | objectFound = true; |
501 | break; |
502 | } |
503 | } |
504 | // If object is not found (not registered) then throw error. |
505 | if (!objectFound) { |
506 | std::string errMsg; |
507 | errMsg += "Object '" + objectName + "' is not registered in Glow and " ; |
508 | errMsg += "cannot be archived into the bundle. The following objects " ; |
509 | errMsg += "are available for archiving:\n" ; |
510 | for (const auto &memBuffRef : bundleObjectRegistry) { |
511 | errMsg += " - " + memBuffRef.getBufferIdentifier().str() + "\n" ; |
512 | } |
513 | CHECK(false) << errMsg; |
514 | } |
515 | } |
516 | |
517 | // Write the new bundle as archive. |
518 | llvm::Error err = |
519 | llvm::writeArchive(bundlePath.str(), newMembers, /* WriteSymtab */ true, |
520 | llvm::object::Archive::K_GNU, |
521 | /* Deterministic */ true, /* Thin */ false, |
522 | /* OldArchiveBuf */ std::move(nullptr)); |
523 | CHECK(!err) << "Could not add extra objects to bundle " << bundlePath.str(); |
524 | } |
525 | |
526 | void BundleSaver::produceBundle() { |
527 | DCHECK(!isSaved_) << "produceBundle can be invoked only once" ; |
528 | isSaved_ = true; |
529 | // Emit entry functions. |
530 | for (auto &savedFunction : savedIRFunctions_) { |
531 | emitBundleEntryFunction(savedFunction); |
532 | } |
533 | // Finish code generation. |
534 | irgen_->finishCodeGen(); |
535 | setIRFunction("<noname>" , nullptr); |
536 | // Emit symbol table and bundle config only for dynamic API |
537 | if (bundleAPI_ == BundleApiType::Dynamic) { |
538 | // Emit the symbol table for weight variables. |
539 | emitSymbolTable(); |
540 | // Emit the config for the bundle. |
541 | emitBundleConfig(); |
542 | } |
543 | |
544 | auto &M = irgen_->getModule(); |
545 | auto outputDir = irgen_->getOutputDir(); |
546 | auto bundleName = irgen_->getBundleName(); |
547 | auto savedBundleName = irgen_->getSavedBundleName().empty() |
548 | ? bundleName |
549 | : irgen_->getSavedBundleName(); |
550 | std::string extension = (llvmCompiler.empty()) ? ".o" : ".bc" ; |
551 | std::string bundleCodeOutput; |
552 | bundleCodeOutput = (outputDir + "/" + savedBundleName + extension).str(); |
553 | auto bundleWeightsBinOut = |
554 | (outputDir + "/" + savedBundleName + ".weights.bin" ).str(); |
555 | auto = (outputDir + "/" + savedBundleName + ".h" ).str(); |
556 | DEBUG_GLOW(llvm::dbgs() << "Producing a bundle:\n" |
557 | << "saved bundle name: " << savedBundleName << "\n" |
558 | << "bundle name: " << bundleName << "\n" |
559 | << "bundle code: " << bundleCodeOutput << "\n" |
560 | << "bundle weights:" << bundleWeightsBinOut << "\n" |
561 | << "header file: " << bundleHeaderOutput << "\n" ); |
562 | llvm::StringRef fileName = bundleCodeOutput; |
563 | std::error_code EC; |
564 | llvm::raw_fd_ostream outputFile(fileName, EC, llvm::sys::fs::OF_None); |
565 | CHECK(!EC) << "Could not open the output file for saving the bundle " |
566 | "code with file name: " |
567 | << fileName.str(); |
568 | if (fileName.endswith(".bc" )) { |
569 | // Emit the bitcode file. |
570 | llvm::WriteBitcodeToFile(M, outputFile); |
571 | outputFile.flush(); |
572 | if (!llvmCompiler.empty()) { |
573 | // Compile bitcode using an external LLVM compiler. |
574 | // The code is optimized twice with the external opt tool. |
575 | std::string cmd = llvmCompiler; |
576 | for (auto option : llvmCompilerOptions) { |
577 | cmd += " " + option + " " ; |
578 | } |
579 | cmd += " " + bundleCodeOutput; |
580 | std::string bundleObjectCodeOutputOpt; |
581 | if (!llvmOpt.empty()) { |
582 | bundleObjectCodeOutputOpt = |
583 | " -emit-llvm -o " + |
584 | (outputDir + "/" + savedBundleName + ".beforeopt.bc" ).str(); |
585 | } else { |
586 | bundleObjectCodeOutputOpt = |
587 | " -o " + (outputDir + "/" + savedBundleName + ".o" ).str(); |
588 | } |
589 | |
590 | cmd += bundleObjectCodeOutputOpt; |
591 | CHECK(!system(cmd.c_str())) |
592 | << "Error running external LLVM compiler: " << cmd; |
593 | |
594 | // Running opt tool to optimize a second time. |
595 | // TODO: Only run the appropriate passes as needed. |
596 | if (!llvmOpt.empty()) { |
597 | cmd.clear(); |
598 | cmd = llvmOpt; |
599 | cmd += |
600 | " " + (outputDir + "/" + savedBundleName + ".beforeopt.bc" ).str(); |
601 | cmd += |
602 | " -O3 -o " + (outputDir + "/" + savedBundleName + ".opt.bc" ).str(); |
603 | CHECK(!system(cmd.c_str())) |
604 | << "Error running external opt compiler: " << cmd; |
605 | |
606 | if (llvmSaveAsm) { |
607 | cmd.clear(); |
608 | cmd = llvmCompiler; |
609 | for (auto option : llvmCompilerOptions) { |
610 | cmd += " " + option + " " ; |
611 | } |
612 | cmd += " " + (outputDir + "/" + savedBundleName + ".opt.bc" ).str(); |
613 | cmd += " -S -o " + (outputDir + "/" + savedBundleName + ".s" ).str(); |
614 | CHECK(!system(cmd.c_str())) |
615 | << "Error running external LLVM compiler: " << cmd; |
616 | } |
617 | |
618 | cmd.clear(); |
619 | cmd = llvmCompiler; |
620 | for (auto option : llvmCompilerOptions) { |
621 | cmd += " " + option + " " ; |
622 | } |
623 | cmd += " " + (outputDir + "/" + savedBundleName + ".opt.bc" ).str(); |
624 | cmd += " -o " + (outputDir + "/" + savedBundleName + ".o" ).str(); |
625 | CHECK(!system(cmd.c_str())) |
626 | << "Error running external LLVM compiler: " << cmd; |
627 | } |
628 | } |
629 | } else if (fileName.endswith(".o" )) { |
630 | // Emit the object file. |
631 | llvm::legacy::PassManager PM; |
632 | auto &TM = irgen_->getTargetMachine(); |
633 | |
634 | // Create asm output file. |
635 | if (llvmSaveAsm) { |
636 | auto asm_FileName = (outputDir + "/" + savedBundleName + ".s" ).str(); |
637 | llvm::StringRef asmFileName = asm_FileName; |
638 | std::error_code EC2; |
639 | llvm::raw_fd_ostream outputFileAsm(asmFileName, EC2, |
640 | llvm::sys::fs::OF_None); |
641 | CHECK(!EC2) << "Could not open the output file for saving the asm " |
642 | "code with file name: " |
643 | << asmFileName.str(); |
644 | llvm::legacy::PassManager PM2; |
645 | #if FACEBOOK_INTERNAL && LLVM_VERSION_MAJOR < 8 |
646 | TM.addPassesToEmitFile( |
647 | PM2, outputFileAsm, |
648 | llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile); |
649 | #elif LLVM_VERSION_MAJOR < 10 |
650 | TM.addPassesToEmitFile( |
651 | PM2, outputFileAsm, nullptr, |
652 | llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile); |
653 | #else |
654 | TM.addPassesToEmitFile(PM2, outputFileAsm, nullptr, |
655 | llvm::CGFT_AssemblyFile); |
656 | #endif |
657 | PM2.run(M); |
658 | } |
659 | |
660 | #if FACEBOOK_INTERNAL && LLVM_VERSION_MAJOR < 8 |
661 | TM.addPassesToEmitFile( |
662 | PM, outputFile, llvm::TargetMachine::CodeGenFileType::CGFT_ObjectFile); |
663 | #elif LLVM_VERSION_MAJOR < 10 |
664 | TM.addPassesToEmitFile( |
665 | PM, outputFile, nullptr, |
666 | llvm::TargetMachine::CodeGenFileType::CGFT_ObjectFile); |
667 | #else |
668 | TM.addPassesToEmitFile(PM, outputFile, nullptr, llvm::CGFT_ObjectFile); |
669 | #endif |
670 | PM.run(M); |
671 | } |
672 | outputFile.close(); |
673 | // Create bundle archive with additional object files. |
674 | createBundleArchive(fileName, irgen_->getObjectRegistry(), |
675 | irgen_->getBundleObjects()); |
676 | // Output weights. |
677 | if (saveWeights_) { |
678 | saveWeights(bundleWeightsBinOut); |
679 | } |
680 | // Header file. |
681 | if (saveHeader_) { |
682 | saveHeader(bundleHeaderOutput); |
683 | } |
684 | // Save weights also in text format for Static API. |
685 | if (saveWeightsAsText_) { |
686 | if (bundleAPI_ == BundleApiType::Static) { |
687 | auto bundleWeightsTxtOut = |
688 | (outputDir + "/" + savedBundleName + ".weights.txt" ).str(); |
689 | serializeBinaryToText(bundleWeightsBinOut, bundleWeightsTxtOut); |
690 | } |
691 | } |
692 | } |
693 | |
694 | /// Emit the entry function for the bundle. It simply calls the main entry of |
695 | /// the module and forwards its arguments to it. As the last argument it |
696 | /// provides the constant array of offsets. Since these offsets are constants, |
697 | /// the LLVM optimizer will constant propagate them into relative addressing |
698 | /// computations and the like and produce a very efficient code that uses |
699 | /// absolute addressing whenever possible. |
700 | void BundleSaver::emitBundleEntryFunction( |
701 | BundleSaver::SavedIRFunction &savedF) { |
702 | auto *func = irgen_->getModule().getFunction(savedF.entryName); |
703 | if (!func) { |
704 | // The bundle entry point has the following API: |
705 | // int entry(uint8_t *constantWeight, |
706 | // uint8_t *mutableWeight, |
707 | // uint8_t *activations); |
708 | auto int8PtrTy = llvm::Type::getInt8PtrTy(irgen_->getLLVMContext()); |
709 | llvm::Type *retTy = llvm::Type::getIntNTy(irgen_->getLLVMContext(), |
710 | irgen_->getLibjitIntWidth()); |
711 | llvm::FunctionType *bundleFuncTy = llvm::FunctionType::get( |
712 | retTy, {int8PtrTy, int8PtrTy, int8PtrTy}, false); |
713 | func = llvm::Function::Create(bundleFuncTy, llvm::Function::ExternalLinkage, |
714 | savedF.entryName, &irgen_->getModule()); |
715 | } |
716 | CHECK(func->isDeclaration()) << "Function definition of " << savedF.entryName |
717 | << " already exists in the LLVM module" ; |
718 | |
719 | llvm::BasicBlock *entry_bb = |
720 | llvm::BasicBlock::Create(irgen_->getLLVMContext(), "entry" , func); |
721 | llvm::IRBuilder<> builder(entry_bb); |
722 | // Add a provisional terminator to make the function well-formed. |
723 | auto *zero = builder.getIntN(irgen_->getLibjitIntWidth(), 0); |
724 | auto *ret = builder.CreateRet(zero); |
725 | builder.SetInsertPoint(ret); |
726 | |
727 | // Prepare arguments for the "main" function. |
728 | llvm::SmallVector<llvm::Value *, 4> initFunctionCallArgs; |
729 | initFunctionCallArgs.push_back(func->args().begin()); |
730 | initFunctionCallArgs.push_back(func->args().begin() + 1); |
731 | initFunctionCallArgs.push_back(func->args().begin() + 2); |
732 | // Now form the offsets array and pass it as the last argument. |
733 | auto offsetsArray = irgen_->emitConstOffsetsArray(builder, allocationsInfo_); |
734 | initFunctionCallArgs.push_back(offsetsArray); |
735 | // Invoke the main entry with constant arguments and let LLVM optimizer make |
736 | // use of it. |
737 | auto *entryF = savedF.llvmF; |
738 | entryF->setLinkage(llvm::Function::InternalLinkage); |
739 | auto *result = irgen_->createCall(builder, entryF, initFunctionCallArgs); |
740 | // Terminate the function. |
741 | builder.CreateRet(result); |
742 | // Remove the provisional terminator. |
743 | ret->eraseFromParent(); |
744 | // Create the debug info for the bundle entry point function. |
745 | irgen_->generateFunctionDebugInfo(func); |
746 | } |
747 | |
748 | // Create a config for this network. It will be exposed to the clients, |
749 | // so that they know how much memory they need to allocate, etc. |
750 | // Config consists of the following fields: |
751 | // struct BundleConfig { |
752 | // uint64_t constantWeightVarsMemSize; |
753 | // uint64_t mutableWeightVarsMemSize; |
754 | // uint64_t activationsMemSize; |
755 | // uint64_t alignment; |
756 | // uint64_t numSymbols; |
757 | // SymbolTableEntry *symbolTable; |
758 | // }; |
759 | void BundleSaver::emitBundleConfig() { |
760 | auto symbolTableName = irgen_->getBundleName().str() + "SymbolTable" ; |
761 | auto symbolTable = |
762 | irgen_->getModule().getGlobalVariable(symbolTableName, true); |
763 | CHECK(symbolTable) |
764 | << "Expected to find a symbol table for the AOT bundle with name: " |
765 | << symbolTableName; |
766 | // Get the integer type having the same size in bits as uint64_t. |
767 | auto *uint64TType = irgen_->getBuilder().getIntNTy(sizeof(uint64_t) * 8); |
768 | auto symbolTableEntryTy = symbolTable->getType()->getPointerElementType(); |
769 | auto *bundleConfigTy = |
770 | llvm::StructType::get(irgen_->getLLVMContext(), |
771 | {uint64TType, uint64TType, uint64TType, uint64TType, |
772 | uint64TType, symbolTableEntryTy->getPointerTo()}); |
773 | // Checking if LLVM module already has <bundle>_config otherwise creating new. |
774 | auto config = irgen_->getModule().getGlobalVariable( |
775 | irgen_->getBundleName().str() + "_config" ); |
776 | if (!config) { |
777 | config = new llvm::GlobalVariable( |
778 | irgen_->getModule(), bundleConfigTy, /* isConst */ true, |
779 | llvm::GlobalValue::LinkageTypes::ExternalLinkage, nullptr, |
780 | irgen_->getBundleName().str() + "_config" ); |
781 | } else { |
782 | bundleConfigTy = llvm::dyn_cast<llvm::StructType>( |
783 | config->getType()->getPointerElementType()); |
784 | } |
785 | |
786 | // If symbolTable is not the same type as bundleConfig struct's symbolTable |
787 | // member, bitcast the pointer to the appropriate type. |
788 | llvm::Constant *symbolTableTyped = symbolTable; |
789 | llvm::Type *configSymbolTableType = |
790 | config->getValueType()->getStructElementType(5); |
791 | if (symbolTableEntryTy->getPointerTo() != configSymbolTableType) { |
792 | symbolTableTyped = llvm::ConstantExpr::getPointerCast( |
793 | symbolTable, config->getValueType()->getStructElementType(5)); |
794 | } |
795 | |
796 | CHECK(!config->hasInitializer()) |
797 | << "Bundle config has already been initialized" ; |
798 | |
799 | config->setInitializer(llvm::ConstantStruct::get( |
800 | bundleConfigTy, |
801 | llvm::ConstantInt::get( |
802 | uint64TType, irgen_->getAllocationsInfo().constantWeightVarsMemSize_), |
803 | llvm::ConstantInt::get( |
804 | uint64TType, irgen_->getAllocationsInfo().mutableWeightVarsMemSize_), |
805 | llvm::ConstantInt::get(uint64TType, |
806 | irgen_->getAllocationsInfo().activationsMemSize_), |
807 | llvm::ConstantInt::get(uint64TType, TensorAlignment), |
808 | llvm::ConstantInt::get(uint64TType, findPlaceholders().size()), |
809 | symbolTableTyped)); |
810 | } |
811 | |
812 | void BundleSaver::performBundleMemoryAllocation() { |
813 | // Perform memory allocation for the current function. |
814 | auto *F = savedIRFunctions_.back().savedF; |
815 | allocationsInfo_.numberValues(F); |
816 | // Tell the allocateWeightVars to not reuse any existing addresses for |
817 | // weights and to assign new ones. |
818 | allocationsInfo_.allocateWeightVars(F); |
819 | allocationsInfo_.allocateActivations(F); |
820 | allocationsInfo_.allocateTensorViews(F); |
821 | } |
822 | |
823 | void BundleSaver::save(llvm::StringRef mainEntryName, const IRFunction *F) { |
824 | // Object files generation works properly only in small mode. |
825 | irgen_->setMainEntryName(mainEntryName.str()); |
826 | // Set current IRFunction using the legalized name. |
827 | setIRFunction(irgen_->getMainEntryName(), F); |
828 | // irgen_->initCodeGen(); |
829 | // Perform the address assignment for activations and WeightVars. |
830 | performBundleMemoryAllocation(); |
831 | // Emit the code for the body of the entry function. |
832 | irgen_->performCodeGen(); |
833 | savedIRFunctions_.back().llvmF = irgen_->getLLVMFunction(); |
834 | } |
835 | |
836 | LLVMIRGen *BundleSaver::getLLVMIRGen() { return irgen_.get(); } |
837 | |