1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Backend/BackendUtils.h"
17#include "glow/IR/IRUtils.h"
18#include "glow/IR/Instrs.h"
19#include "glow/Support/Debug.h"
20
21#include "llvm/Support/CommandLine.h"
22
23#include "glow/Graph/FXIRWrapper.h"
24#include <glog/logging.h>
25
26#define DEBUG_TYPE "backend-utils"
27
28using namespace glow;
29
30using llvm::cast;
31using llvm::dyn_cast;
32using llvm::isa;
33
34static llvm::cl::OptionCategory BackendUtilsCat("Glow Backend Utils Options");
35
36static llvm::cl::opt<bool> reuseActivationsMemory(
37 "reuse-activation-memory-allocations",
38 llvm::cl::desc("Should activation memory allocations be reused"),
39 llvm::cl::init(true), llvm::cl::cat(BackendUtilsCat));
40
41glow::runtime::RuntimeBundle::RuntimeBundle(
42 glow::runtime::RuntimeBundle &&rhs) {
43 *this = std::move(rhs);
44}
45
46glow::runtime::RuntimeBundle &
47glow::runtime::RuntimeBundle::operator=(glow::runtime::RuntimeBundle &&rhs) {
48 if (this == &rhs) {
49 // Do nothing if rhs is the same object as this.
50 return *this;
51 }
52
53 std::swap(symbolTable_, rhs.symbolTable_);
54 std::swap(constants_, rhs.constants_);
55 std::swap(constantWeightVarsMemSize_, rhs.constantWeightVarsMemSize_);
56 std::swap(mutableWeightVarsMemSize_, rhs.mutableWeightVarsMemSize_);
57 std::swap(activationsMemSize_, rhs.activationsMemSize_);
58 std::swap(isValid_, rhs.isValid_);
59 // rhs is not valid now that all of its contents have been stolen.
60 rhs.isValid_ = false;
61 return *this;
62}
63
64void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) {
65 DCHECK(isValid_);
66 collectConstants(F->getParent());
67}
68
69void glow::runtime::RuntimeBundle::freeConstants() {
70 DCHECK(isValid_);
71
72 if (constants_) {
73 glow::alignedFree(constants_);
74 constants_ = nullptr;
75 }
76}
77void glow::runtime::RuntimeBundle::collectConstants(const Module *M) {
78 DCHECK(isValid_);
79
80 // At compile time condense constants to a single block of memory.
81 // This allows the graph to go away after compile time.
82 // If there are no constants return nullptr.
83 if (constantWeightVarsMemSize_ == 0) {
84 constants_ = nullptr;
85 return;
86 }
87
88 assert(constants_ == nullptr && "constants already allocated");
89 constants_ =
90 (uint8_t *)alignedAlloc(constantWeightVarsMemSize_, TensorAlignment);
91
92 for (const auto &symbol : symbolTable_) {
93 llvm::StringRef name = symbol.first;
94 const RuntimeSymbolInfo &info = symbol.second;
95
96 Constant *c = M->getConstantByName(name);
97 if (!c) {
98 continue;
99 }
100 auto *payload = c->getPayload().getUnsafePtr();
101 assert(info.size == c->getPayload().getSizeInBytes() &&
102 "Mismatched constant size");
103
104 // Copy weight to offset.
105 memcpy(constants_ + info.offset, payload, info.size);
106 }
107}
108
109#if FACEBOOK_INTERNAL
110void glow::runtime::RuntimeBundle::collectConstants(const FXIRWrapper *F) {
111 DCHECK(isValid_);
112
113 // At compile time condense constants to a single block of memory.
114 // This allows the graph to go away after compile time.
115 // If there are no constants return nullptr.
116 if (constantWeightVarsMemSize_ == 0) {
117 constants_ = nullptr;
118 return;
119 }
120
121 assert(constants_ == nullptr && "constants already allocated");
122 constants_ =
123 (uint8_t *)alignedAlloc(constantWeightVarsMemSize_, TensorAlignment);
124
125 for (const auto &symbol : symbolTable_) {
126 llvm::StringRef name = symbol.first;
127 const RuntimeSymbolInfo &info = symbol.second;
128
129 // Only work with constants/weights here.
130 auto category = info.symbolCategory;
131 if (category != glow::runtime::SymbolCategory::Constant) {
132 continue;
133 }
134
135 auto mapToConstants = F->getMapNodeNameToStorage();
136 assert(mapToConstants.find(name.str()) != mapToConstants.end());
137 const auto *wt = mapToConstants[name.str()];
138 const auto *c = llvm::dyn_cast<const Constant>(wt);
139 if (!c) {
140 continue;
141 }
142 auto *payload = c->getPayload().getUnsafePtr();
143 assert(info.size == c->getPayload().getSizeInBytes() &&
144 "Mismatched constant size");
145
146 // Copy weight to offset.
147 memcpy(constants_ + info.offset, payload, info.size);
148 }
149}
150#endif
151
152size_t glow::runtime::RuntimeBundle::getValueOffset(const Named *v) const {
153 DCHECK(isValid_);
154 auto it = symbolTable_.find(std::string(v->getName()));
155 assert(it != symbolTable_.end() && "Symbol not found.");
156 return it->second.offset;
157}
158
159const runtime::RuntimeSymbolInfo &
160runtime::RuntimeBundle::getSymbolInfo(const Named *v) const {
161 DCHECK(isValid_);
162 auto it = symbolTable_.find(std::string(v->getName()));
163 assert(it != symbolTable_.end() && "Symbol not found.");
164 return it->second;
165}
166
167namespace glow {
168
169/// If \p W is an output weight \returns true. This is determined by checking if
170/// the weight has a user which uses it as a write output.
171bool isOutput(const Value *W) {
172 auto *weight = llvm::dyn_cast<WeightVar>(W);
173 DCHECK(weight) << "Expected WeightVar";
174 for (const auto &use : ValueUses(weight)) {
175 Instruction *user = use.get();
176 // Ignore deallocs.
177 if (isa<DeallocActivationInst>(user)) {
178 continue;
179 }
180 OperandKind kind = use.getOperand().second;
181 if (kind == OperandKind::Out || kind == OperandKind::InOut) {
182 return true;
183 }
184 }
185 return false;
186}
187
188/// If \p PH is an output placeholder in the function \p F, \returns true.
189/// This is determined by checking if the PH has a user which uses the PH as an
190/// overwritten input.
191bool isOutput(const Placeholder *PH, const IRFunction &F) {
192 auto *weight = F.getWeightForNode(PH);
193 DCHECK(weight) << "Weight for a node was not found";
194 return isOutput(weight);
195}
196
197/// If \p W is a weight that is first read from \returns true.
198bool isInput(const Value *W) {
199 auto *weight = llvm::dyn_cast<WeightVar>(W);
200 const glow::Instruction *firstUser = nullptr;
201 bool hasReads = false;
202 for (const auto &U : ValueUses(weight)) {
203 const auto *user = U.get();
204 // TensorView instruction doesn't read from a placeholder.
205 if (isa<TensorViewInst>(user)) {
206 continue;
207 }
208 // Remember the earliest use.
209 if (!firstUser || firstUser->getIterator() > user->getIterator()) {
210 firstUser = user;
211 }
212 // Ignore deallocs.
213 if (isa<DeallocActivationInst>(user)) {
214 continue;
215 }
216 OperandKind kind = U.getOperand().second;
217 if (kind == OperandKind::In || kind == OperandKind::InOut) {
218 hasReads = true;
219 }
220 }
221
222 if (!hasReads) {
223 return false;
224 }
225
226 // Check if the first use is a read.
227 if (firstUser) {
228 // If this instruction has reads, then the first use is an @in.
229 auto *weightOrigin = getOrigin(weight);
230 for (int idx = 0, e = firstUser->getNumOperands(); idx < e; ++idx) {
231 const auto op = firstUser->getOperand(idx);
232 auto *opOrigin = getOrigin(op.first);
233 auto opKind = op.second;
234 if (opOrigin == weightOrigin && opKind == OperandKind::In) {
235 return true;
236 }
237 }
238 // No reads were found, thus the first use is a write.
239 return false;
240 }
241 // If there are no users, it is not an input.
242 return false;
243}
244
245/// If \p PH is an input placeholder in the function \p F, \returns true.
246bool isInput(const Placeholder *PH, const IRFunction &F) {
247 // Check that the PH is always used as an @in parameter by the current
248 // function.
249 auto *weight = F.getWeightForNode(PH);
250 DCHECK(weight) << "Weight for a node was not found";
251 return isInput(weight);
252}
253
254bool isOutput(const Placeholder *PH,
255 const std::vector<const Function *> &funcs) {
256 for (const auto &f : funcs) {
257 if (isOutput(PH, *f)) {
258 return true;
259 }
260 }
261
262 return false;
263}
264
265/// \returns true if \p PH is an input Placeholder for any function in \p funcs.
266bool isInput(const Placeholder *PH,
267 const std::vector<const Function *> &funcs) {
268 for (const auto &f : funcs) {
269 if (isInput(PH, *f)) {
270 return true;
271 }
272 }
273
274 return false;
275}
276
277/// If \p N does not have fused activation \returns true.
278bool checkNoFusionForNode(const Node &N) {
279#define DEF_NODE(CLASS, NAME) \
280 case Kinded::Kind::CLASS##Kind: { \
281 const CLASS *CI = llvm::cast<CLASS>(&N); \
282 return checkNoFusion(*CI); \
283 break; \
284 }
285 switch (N.getKind()) {
286#include "glow/AutoGenNodes.def"
287 default:
288 llvm_unreachable("Invalid node.");
289 }
290 return true;
291}
292
293/// If \p I does not have fused activation \returns true.
294bool checkNoFusionForInstr(const Instruction &I) {
295#define DEF_VALUE(CLASS, NAME)
296#define DEF_INSTR(CLASS, NAME) \
297 case Kinded::Kind::CLASS##Kind: { \
298 const CLASS *CI = llvm::cast<CLASS>(&I); \
299 return checkNoFusion(*CI); \
300 break; \
301 }
302#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) \
303 case Kinded::Kind::CLASS##Kind: { \
304 const CLASS *CI = llvm::cast<CLASS>(&I); \
305 return checkNoFusion(*CI); \
306 break; \
307 }
308 switch (I.getKind()) {
309#include "glow/AutoGenInstr.def"
310 default:
311 llvm_unreachable("Invalid instruction.");
312 }
313 return true;
314}
315
316template <typename FUN, typename ARR>
317ContiguousPlaceholders getContiguousPlaceHolder(const ARR &holders,
318 const FUN &F) {
319 // Pure input placeholders.
320 std::vector<const Placeholder *> intputPlaceholders;
321 // Pure output placeholders.
322 std::vector<const Placeholder *> outputPlaceholders;
323 // Input&output placeholders.
324 std::vector<const Placeholder *> inputOutputPlaceholders;
325 // Neither input nor output placeholders.
326 std::vector<const Placeholder *> emptyPlaceholders;
327 // Return value.
328 ContiguousPlaceholders ret;
329
330 for (auto &v : holders) {
331 if (isInput(v, F)) {
332 if (!isOutput(v, F)) {
333 intputPlaceholders.push_back(v);
334 } else {
335 inputOutputPlaceholders.push_back(v);
336 }
337 } else {
338 if (isOutput(v, F)) {
339 outputPlaceholders.push_back(v);
340 } else {
341 emptyPlaceholders.push_back(v);
342 }
343 }
344 }
345
346 for (auto &v : intputPlaceholders) {
347 PlaceholderInputOutputInfo holder;
348 holder.addr = v;
349 holder.isInput = true;
350 holder.isOutput = false;
351 ret.push_back(holder);
352 }
353
354 for (auto &v : inputOutputPlaceholders) {
355 PlaceholderInputOutputInfo holder;
356 holder.addr = v;
357 holder.isInput = true;
358 holder.isOutput = true;
359 ret.push_back(holder);
360 }
361
362 for (auto &v : outputPlaceholders) {
363 PlaceholderInputOutputInfo holder;
364 holder.addr = v;
365 holder.isInput = false;
366 holder.isOutput = true;
367 ret.push_back(holder);
368 }
369
370 for (auto &v : emptyPlaceholders) {
371 PlaceholderInputOutputInfo holder;
372 holder.addr = v;
373 holder.isInput = false;
374 holder.isOutput = false;
375 ret.push_back(holder);
376 }
377
378 return ret;
379}
380
381/// \returns true if \p dst is capable of handling a partial tensor as input
382/// from \p src.
383static bool allowsPartialInput(const Node *src, const Node *dst) {
384 // If N is used as the indices or weights of a sparse lookup, it is safe to
385 // access a partial tensor.
386 if (auto *SLS =
387 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
388 dst)) {
389 return src == SLS->getIndices() || src == SLS->getWeights();
390 } else if (auto *SLS =
391 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsSumNode>(
392 dst)) {
393 return src == SLS->getIndices();
394 } else if (auto *SLS = llvm::dyn_cast<SparseLengthsWeightedSumNode>(dst)) {
395 return src == SLS->getIndices() || src == SLS->getWeights();
396 } else if (auto *SLS = llvm::dyn_cast<SparseLengthsSumNode>(dst)) {
397 return src == SLS->getIndices();
398 } else if (auto *EBB = llvm::dyn_cast<EmbeddingBagNode>(dst)) {
399 return src == EBB->getIndices() || src == EBB->getWeights();
400 } else if (auto *EBB =
401 llvm::dyn_cast<EmbeddingBagByteRowwiseOffsetsNode>(dst)) {
402 return src == EBB->getIndices() || src == EBB->getWeights();
403 }
404 return false;
405}
406
407bool allowsPartialInput(const Placeholder *V, const Function *F) {
408 for (auto const &U : V->getUsers()) {
409 if (U.getUser()->getParent() != F) {
410 continue;
411 }
412 if (!allowsPartialInput(*U.get(), U.getUser())) {
413 return false;
414 }
415 }
416 return true;
417}
418
419/// \returns true if \p dst requires last-element padding for \p src
420/// It is assumed that \p src cannot be partial input
421static bool requiresPadding(const Node *src, const Node *dst) {
422 if (auto *EBB = llvm::dyn_cast<EmbeddingBagNode>(dst)) {
423 return src == EBB->getOffsets();
424 } else if (auto *EBB =
425 llvm::dyn_cast<EmbeddingBagByteRowwiseOffsetsNode>(dst)) {
426 return src == EBB->getOffsets();
427 }
428 return false;
429}
430
431bool requiresPadding(const Placeholder *V, const Function *F) {
432 // TODO: this function is largely duplicated with allowsPartialInput()
433 // we should consider merging the two
434 for (auto const &U : V->getUsers()) {
435 if (U.getUser()->getParent() != F) {
436 continue;
437 }
438 if (!requiresPadding(*U.get(), U.getUser())) {
439 return false;
440 }
441 }
442 return true;
443}
444
445bool usedInFunction(const Placeholder *V, const Function *F) {
446 for (auto const &U : V->getUsers()) {
447 if (U.getUser()->getParent() == F) {
448 return true;
449 }
450 }
451 return false;
452}
453
454/// Allocate space for the Constants in \p constants using \p allocator and
455/// store the resultant symbols in \p symbolTable.
456template <typename ConstantsTy>
457static void allocateConstantsImpl(const ConstantsTy &constants,
458 MemoryAllocator &allocator,
459 glow::runtime::SymbolTableTy &symbolTable) {
460 for (auto const *C : constants) {
461 // Same constant may be used multiple times by different functions. But it
462 // should be assigned an address only once.
463 if (symbolTable.count(std::string(C->getName()))) {
464 continue;
465 }
466 auto size = C->getType()->getSizeInBytes();
467 auto offset = allocator.allocate(size, C);
468 runtime::RuntimeSymbolInfo symbol;
469 symbol.offset = offset;
470 symbol.size = size;
471 symbol.type = *C->getType();
472 symbol.input = false;
473 symbol.output = false;
474 symbol.symbolCategory = glow::runtime::SymbolCategory::Constant;
475 symbolTable.emplace(C->getName(), symbol);
476 DEBUG_GLOW(LOG(INFO) << strFormat(
477 "Assigned address to constant %s: %zx (%zd bytes)\n",
478 C->getName().data(), symbol.offset, symbol.size));
479 }
480}
481
482void allocateConstants(const ConstList &constants, MemoryAllocator &allocator,
483 glow::runtime::SymbolTableTy &symbolTable) {
484 allocateConstantsImpl(constants, allocator, symbolTable);
485}
486
487void allocateConstants(const std::vector<const glow::Constant *> &constants,
488 MemoryAllocator &allocator,
489 glow::runtime::SymbolTableTy &symbolTable) {
490 allocateConstantsImpl(constants, allocator, symbolTable);
491}
492
493/// Allocate space for the Placeholders in \p placeholders using \p allocator
494/// and store the resultant symbols in \p symbolTable.
495void allocatePlaceholders(const ContiguousPlaceholders &placeholders,
496 MemoryAllocator &allocator,
497 glow::runtime::SymbolTableTy &symbolTable) {
498 for (const auto &p : placeholders) {
499 auto &V = p.addr;
500 assert(!symbolTable.count(std::string(V->getName())) &&
501 "Allocation already made!");
502 auto size = V->getType()->getSizeInBytes();
503 auto offset = allocator.allocate(size, V);
504 runtime::RuntimeSymbolInfo symbol;
505 symbol.offset = offset;
506 symbol.size = size;
507 symbol.type = *V->getType();
508 symbol.output = p.isOutput;
509 symbol.input = p.isInput;
510 symbol.symbolCategory = glow::runtime::SymbolCategory::Placeholder;
511 symbolTable.emplace(std::string(V->getName()), symbol);
512 DEBUG_GLOW(LOG(INFO) << strFormat(
513 "Assigned address to mutable weight %s: %zx (%zd bytes)\n",
514 V->getName().data(), symbol.offset, symbol.size));
515 }
516}
517
518/// Allocate space for the activations of \p instrs using \p allocator and store
519/// the resultant symbols in \p symbolTable.
520void allocateActivations(const glow::IRFunction::InstListTy &instrs,
521 MemoryAllocator &allocator,
522 glow::runtime::SymbolTableTy &symbolTable) {
523
524 // Gather allocation/deallocation sequence.
525 std::list<Allocation> allocList;
526 if (reuseActivationsMemory) {
527 // When reusing memory we register allocs/deallocs in their original order.
528 for (const auto &I : instrs) {
529 if (auto *A = dyn_cast<AllocActivationInst>(&I)) {
530 auto numBytes = I.getSizeInBytes();
531 allocList.emplace_back(A, /* alloc */ true, numBytes);
532 continue;
533 }
534 if (auto *D = dyn_cast<DeallocActivationInst>(&I)) {
535 auto *A = D->getAlloc();
536 allocList.emplace_back(A, /* alloc */ false, 0);
537 continue;
538 }
539 }
540 } else {
541 // When not reusing memory we register first the allocs then the deallocs.
542 for (const auto &I : instrs) {
543 if (auto *A = dyn_cast<AllocActivationInst>(&I)) {
544 auto numBytes = I.getSizeInBytes();
545 allocList.emplace_back(A, /* alloc */ true, numBytes);
546 continue;
547 }
548 }
549 for (const auto &I : instrs) {
550 if (auto *D = dyn_cast<DeallocActivationInst>(&I)) {
551 auto *A = D->getAlloc();
552 allocList.emplace_back(A, /* alloc */ false, 0);
553 continue;
554 }
555 }
556 }
557
558 // Allocate all segments at once for better allocation efficiency.
559 // We use a separate allocator object since the function "allocateAll()"
560 // does not work together with the function "allocate()" which could have
561 // been used with the original allocator.
562 MemoryAllocator activationsAllocator("mem", 0, allocator.getAlignment());
563 uint64_t activationsSize = activationsAllocator.allocateAll(allocList);
564
565 // Allocate a contiguous segment for the activations of the current function.
566 // The individual buffers within this segment are placed according to the
567 // logic of allocateAll for better efficiency.
568 uint64_t activationsBaseAddr = 0;
569 if (activationsSize) {
570 MemoryAllocator::Handle activationsHandle = &instrs;
571 activationsBaseAddr =
572 allocator.allocate(activationsSize, activationsHandle);
573 if (reuseActivationsMemory) {
574 allocator.deallocate(activationsHandle);
575 }
576 }
577
578 // Map addresses of allocated segments.
579 for (const auto &I : instrs) {
580 if (auto *A = dyn_cast<AllocActivationInst>(&I)) {
581 auto numBytes = I.getSizeInBytes();
582 size_t addr = activationsBaseAddr + activationsAllocator.getAddress(A);
583 assert(!symbolTable.count(std::string(A->getName())) &&
584 "Allocation already made!");
585 runtime::RuntimeSymbolInfo symbol;
586 symbol.offset = addr;
587 symbol.size = numBytes;
588 symbol.type = *A->getType();
589 symbol.input = false;
590 symbol.output = false;
591 symbol.symbolCategory = glow::runtime::SymbolCategory::Activation;
592 symbolTable.emplace(std::string(A->getName()), symbol);
593 DEBUG_GLOW(LOG(INFO) << strFormat(
594 "Assigned address to activation %s: %zx (%zd bytes)\n",
595 A->getName().data(), symbol.offset, symbol.size));
596 continue;
597 }
598
599 if (auto *TV = dyn_cast<TensorViewInst>(&I)) {
600 // Calculate and store the length of the offset into the base, using the
601 // source of the tensorview.
602 assert(!symbolTable.count(std::string(TV->getName())) &&
603 "Allocation already made!");
604 auto *tvSource = getOrigin(TV);
605 assert(symbolTable.count(std::string(tvSource->getName())) &&
606 "Source allocation not found!");
607 runtime::RuntimeSymbolInfo symbol;
608 size_t originAddr = symbolTable[std::string(tvSource->getName())].offset;
609 size_t offset = calculateTensorViewOffset(TV);
610
611 symbol.offset = originAddr + offset;
612 symbol.size = TV->getSizeInBytes();
613 symbol.type = *TV->getType();
614 symbol.input = false;
615 symbol.output = false;
616 auto parentCategory = symbolTable.find(std::string(tvSource->getName()))
617 ->second.symbolCategory;
618 if (parentCategory == glow::runtime::SymbolCategory::Placeholder) {
619 symbol.symbolCategory =
620 glow::runtime::SymbolCategory::PlaceholderTensorView;
621 } else {
622 symbol.symbolCategory =
623 glow::runtime::SymbolCategory::ConstantTensorView;
624 }
625 symbolTable.emplace(std::string(TV->getName()), symbol);
626 DEBUG_GLOW(LOG(INFO) << strFormat(
627 "Assigned address to activation %s: %zx (%zd bytes)\n",
628 TV->getName().data(), symbol.offset, symbol.size));
629 continue;
630 }
631
632 if (auto *D = dyn_cast<DeallocActivationInst>(&I)) {
633 assert(symbolTable.count(std::string(D->getAlloc()->getName())) &&
634 "Invalid deallocation!");
635 }
636 }
637}
638
639} // namespace glow
640
641runtime::RuntimeBundle
642runtime::RuntimeBundle::create(const Function &F,
643 const std::vector<const IRFunction *> &funcs) {
644 std::map<std::string, runtime::RuntimeSymbolInfo> symbolTable;
645 MemoryAllocator allocator("allocator", 0);
646 uint64_t constantsMaxMem = 0, placeholdersMaxMem = 0, activationsMaxMem = 0;
647
648 // Allocate constants.
649 allocateConstants(F.getParent()->getConstants(), allocator, symbolTable);
650 constantsMaxMem = allocator.getMaxMemoryUsage();
651
652 // Allocate placeholders. Placeholders should be allocated in a order of
653 // Input|InputOutput|Output.
654 std::vector<const Function *> graphs;
655 graphs.reserve(funcs.size());
656 for (const auto &f : funcs) {
657 graphs.emplace_back(f->getGraph());
658 }
659
660 auto contiguousPlaceholders =
661 getContiguousPlaceHolder(F.getParent()->getPlaceholders(), graphs);
662 allocatePlaceholders(contiguousPlaceholders, allocator, symbolTable);
663 placeholdersMaxMem = allocator.getMaxMemoryUsage() - constantsMaxMem;
664
665 // Allocate activations.
666 for (const auto &f : funcs) {
667 allocateActivations(f->getInstrs(), allocator, symbolTable);
668 }
669
670 activationsMaxMem =
671 allocator.getMaxMemoryUsage() - constantsMaxMem - placeholdersMaxMem;
672
673 return runtime::RuntimeBundle(symbolTable, constantsMaxMem,
674 placeholdersMaxMem, activationsMaxMem);
675}
676
677runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) {
678 std::map<std::string, runtime::RuntimeSymbolInfo> symbolTable;
679
680 MemoryAllocator constants("constants", 0);
681 MemoryAllocator placeholders("placeholders", 0);
682
683 // Allocate constants.
684 allocateConstants(F.findConstants(), constants, symbolTable);
685
686 // Allocate placeholders.
687 // Placeholders should be allocated in a order of Input|InputOutput|Output.
688 auto contiguousPlaceholders =
689 getContiguousPlaceHolder(F.findPlaceholders(), F);
690
691 // Compute the offsets for Placeholders.
692 allocatePlaceholders(contiguousPlaceholders, placeholders, symbolTable);
693
694 return runtime::RuntimeBundle(symbolTable, constants.getMaxMemoryUsage(),
695 placeholders.getMaxMemoryUsage(),
696 /*activationsMaxSize*/ 0);
697}
698
699runtime::RuntimeBundle
700runtime::RuntimeBundle::create(const IRFunction &F,
701 MemoryAllocator &constantAllocator,
702 MemoryAllocator &placeholderAllocator,
703 MemoryAllocator &activationsAllocator) {
704
705 // If all allocators refer to the same underlying allocator, Constants,
706 // Placeholders and activations will be allocated contiguously. The maximum
707 // memory usage reported by the allocator for each kind of storage will
708 // include the memory usage of all previously allocated types of storage and
709 // needs to be adjusted accordingly.
710 bool contiguous = (&constantAllocator == &placeholderAllocator &&
711 &constantAllocator == &activationsAllocator);
712 // Handle Constants, Placeholders, and Activations, in that order.
713 // Symbol table mapping symbol name to offset for runtime.
714 std::map<std::string, runtime::RuntimeSymbolInfo> symbolTable;
715
716 allocateConstants(F.findConstants(), constantAllocator, symbolTable);
717 auto constantMaxSize = constantAllocator.getMaxMemoryUsage();
718
719 // Placeholders should be allocated in a order of Input|InputOutput|Output.
720 auto contiguousPlaceholders =
721 getContiguousPlaceHolder(F.findPlaceholders(), F);
722 // Compute the offsets for Placeholders.
723 allocatePlaceholders(contiguousPlaceholders, placeholderAllocator,
724 symbolTable);
725 auto placeholderMaxSize = placeholderAllocator.getMaxMemoryUsage();
726 if (contiguous) {
727 placeholderMaxSize -= constantMaxSize;
728 }
729
730 // Compute the offsets for Activations.
731 allocateActivations(F.getInstrs(), activationsAllocator, symbolTable);
732
733 auto activationsMaxSize = activationsAllocator.getMaxMemoryUsage();
734 if (contiguous) {
735 activationsMaxSize -= constantMaxSize + placeholderMaxSize;
736 DCHECK_EQ(constantAllocator.getMaxMemoryUsage(),
737 constantMaxSize + placeholderMaxSize + activationsMaxSize);
738 }
739
740 return runtime::RuntimeBundle(symbolTable, constantMaxSize,
741 placeholderMaxSize, activationsMaxSize);
742}
743
744runtime::RuntimeBundle
745runtime::RuntimeBundle::create(const IRFunction &F,
746 MemoryAllocator &allocator) {
747 return create(F, allocator, allocator, allocator);
748}
749