1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Backend/BackendUtils.h" |
17 | #include "glow/IR/IRUtils.h" |
18 | #include "glow/IR/Instrs.h" |
19 | #include "glow/Support/Debug.h" |
20 | |
21 | #include "llvm/Support/CommandLine.h" |
22 | |
23 | #include "glow/Graph/FXIRWrapper.h" |
24 | #include <glog/logging.h> |
25 | |
26 | #define DEBUG_TYPE "backend-utils" |
27 | |
28 | using namespace glow; |
29 | |
30 | using llvm::cast; |
31 | using llvm::dyn_cast; |
32 | using llvm::isa; |
33 | |
34 | static llvm::cl::OptionCategory BackendUtilsCat("Glow Backend Utils Options" ); |
35 | |
36 | static llvm::cl::opt<bool> reuseActivationsMemory( |
37 | "reuse-activation-memory-allocations" , |
38 | llvm::cl::desc("Should activation memory allocations be reused" ), |
39 | llvm::cl::init(true), llvm::cl::cat(BackendUtilsCat)); |
40 | |
41 | glow::runtime::RuntimeBundle::RuntimeBundle( |
42 | glow::runtime::RuntimeBundle &&rhs) { |
43 | *this = std::move(rhs); |
44 | } |
45 | |
46 | glow::runtime::RuntimeBundle & |
47 | glow::runtime::RuntimeBundle::operator=(glow::runtime::RuntimeBundle &&rhs) { |
48 | if (this == &rhs) { |
49 | // Do nothing if rhs is the same object as this. |
50 | return *this; |
51 | } |
52 | |
53 | std::swap(symbolTable_, rhs.symbolTable_); |
54 | std::swap(constants_, rhs.constants_); |
55 | std::swap(constantWeightVarsMemSize_, rhs.constantWeightVarsMemSize_); |
56 | std::swap(mutableWeightVarsMemSize_, rhs.mutableWeightVarsMemSize_); |
57 | std::swap(activationsMemSize_, rhs.activationsMemSize_); |
58 | std::swap(isValid_, rhs.isValid_); |
59 | // rhs is not valid now that all of its contents have been stolen. |
60 | rhs.isValid_ = false; |
61 | return *this; |
62 | } |
63 | |
64 | void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) { |
65 | DCHECK(isValid_); |
66 | collectConstants(F->getParent()); |
67 | } |
68 | |
69 | void glow::runtime::RuntimeBundle::freeConstants() { |
70 | DCHECK(isValid_); |
71 | |
72 | if (constants_) { |
73 | glow::alignedFree(constants_); |
74 | constants_ = nullptr; |
75 | } |
76 | } |
77 | void glow::runtime::RuntimeBundle::collectConstants(const Module *M) { |
78 | DCHECK(isValid_); |
79 | |
80 | // At compile time condense constants to a single block of memory. |
81 | // This allows the graph to go away after compile time. |
82 | // If there are no constants return nullptr. |
83 | if (constantWeightVarsMemSize_ == 0) { |
84 | constants_ = nullptr; |
85 | return; |
86 | } |
87 | |
88 | assert(constants_ == nullptr && "constants already allocated" ); |
89 | constants_ = |
90 | (uint8_t *)alignedAlloc(constantWeightVarsMemSize_, TensorAlignment); |
91 | |
92 | for (const auto &symbol : symbolTable_) { |
93 | llvm::StringRef name = symbol.first; |
94 | const RuntimeSymbolInfo &info = symbol.second; |
95 | |
96 | Constant *c = M->getConstantByName(name); |
97 | if (!c) { |
98 | continue; |
99 | } |
100 | auto *payload = c->getPayload().getUnsafePtr(); |
101 | assert(info.size == c->getPayload().getSizeInBytes() && |
102 | "Mismatched constant size" ); |
103 | |
104 | // Copy weight to offset. |
105 | memcpy(constants_ + info.offset, payload, info.size); |
106 | } |
107 | } |
108 | |
109 | #if FACEBOOK_INTERNAL |
110 | void glow::runtime::RuntimeBundle::collectConstants(const FXIRWrapper *F) { |
111 | DCHECK(isValid_); |
112 | |
113 | // At compile time condense constants to a single block of memory. |
114 | // This allows the graph to go away after compile time. |
115 | // If there are no constants return nullptr. |
116 | if (constantWeightVarsMemSize_ == 0) { |
117 | constants_ = nullptr; |
118 | return; |
119 | } |
120 | |
121 | assert(constants_ == nullptr && "constants already allocated" ); |
122 | constants_ = |
123 | (uint8_t *)alignedAlloc(constantWeightVarsMemSize_, TensorAlignment); |
124 | |
125 | for (const auto &symbol : symbolTable_) { |
126 | llvm::StringRef name = symbol.first; |
127 | const RuntimeSymbolInfo &info = symbol.second; |
128 | |
129 | // Only work with constants/weights here. |
130 | auto category = info.symbolCategory; |
131 | if (category != glow::runtime::SymbolCategory::Constant) { |
132 | continue; |
133 | } |
134 | |
135 | auto mapToConstants = F->getMapNodeNameToStorage(); |
136 | assert(mapToConstants.find(name.str()) != mapToConstants.end()); |
137 | const auto *wt = mapToConstants[name.str()]; |
138 | const auto *c = llvm::dyn_cast<const Constant>(wt); |
139 | if (!c) { |
140 | continue; |
141 | } |
142 | auto *payload = c->getPayload().getUnsafePtr(); |
143 | assert(info.size == c->getPayload().getSizeInBytes() && |
144 | "Mismatched constant size" ); |
145 | |
146 | // Copy weight to offset. |
147 | memcpy(constants_ + info.offset, payload, info.size); |
148 | } |
149 | } |
150 | #endif |
151 | |
152 | size_t glow::runtime::RuntimeBundle::getValueOffset(const Named *v) const { |
153 | DCHECK(isValid_); |
154 | auto it = symbolTable_.find(std::string(v->getName())); |
155 | assert(it != symbolTable_.end() && "Symbol not found." ); |
156 | return it->second.offset; |
157 | } |
158 | |
159 | const runtime::RuntimeSymbolInfo & |
160 | runtime::RuntimeBundle::getSymbolInfo(const Named *v) const { |
161 | DCHECK(isValid_); |
162 | auto it = symbolTable_.find(std::string(v->getName())); |
163 | assert(it != symbolTable_.end() && "Symbol not found." ); |
164 | return it->second; |
165 | } |
166 | |
167 | namespace glow { |
168 | |
169 | /// If \p W is an output weight \returns true. This is determined by checking if |
170 | /// the weight has a user which uses it as a write output. |
171 | bool isOutput(const Value *W) { |
172 | auto *weight = llvm::dyn_cast<WeightVar>(W); |
173 | DCHECK(weight) << "Expected WeightVar" ; |
174 | for (const auto &use : ValueUses(weight)) { |
175 | Instruction *user = use.get(); |
176 | // Ignore deallocs. |
177 | if (isa<DeallocActivationInst>(user)) { |
178 | continue; |
179 | } |
180 | OperandKind kind = use.getOperand().second; |
181 | if (kind == OperandKind::Out || kind == OperandKind::InOut) { |
182 | return true; |
183 | } |
184 | } |
185 | return false; |
186 | } |
187 | |
188 | /// If \p PH is an output placeholder in the function \p F, \returns true. |
189 | /// This is determined by checking if the PH has a user which uses the PH as an |
190 | /// overwritten input. |
191 | bool isOutput(const Placeholder *PH, const IRFunction &F) { |
192 | auto *weight = F.getWeightForNode(PH); |
193 | DCHECK(weight) << "Weight for a node was not found" ; |
194 | return isOutput(weight); |
195 | } |
196 | |
197 | /// If \p W is a weight that is first read from \returns true. |
198 | bool isInput(const Value *W) { |
199 | auto *weight = llvm::dyn_cast<WeightVar>(W); |
200 | const glow::Instruction *firstUser = nullptr; |
201 | bool hasReads = false; |
202 | for (const auto &U : ValueUses(weight)) { |
203 | const auto *user = U.get(); |
204 | // TensorView instruction doesn't read from a placeholder. |
205 | if (isa<TensorViewInst>(user)) { |
206 | continue; |
207 | } |
208 | // Remember the earliest use. |
209 | if (!firstUser || firstUser->getIterator() > user->getIterator()) { |
210 | firstUser = user; |
211 | } |
212 | // Ignore deallocs. |
213 | if (isa<DeallocActivationInst>(user)) { |
214 | continue; |
215 | } |
216 | OperandKind kind = U.getOperand().second; |
217 | if (kind == OperandKind::In || kind == OperandKind::InOut) { |
218 | hasReads = true; |
219 | } |
220 | } |
221 | |
222 | if (!hasReads) { |
223 | return false; |
224 | } |
225 | |
226 | // Check if the first use is a read. |
227 | if (firstUser) { |
228 | // If this instruction has reads, then the first use is an @in. |
229 | auto *weightOrigin = getOrigin(weight); |
230 | for (int idx = 0, e = firstUser->getNumOperands(); idx < e; ++idx) { |
231 | const auto op = firstUser->getOperand(idx); |
232 | auto *opOrigin = getOrigin(op.first); |
233 | auto opKind = op.second; |
234 | if (opOrigin == weightOrigin && opKind == OperandKind::In) { |
235 | return true; |
236 | } |
237 | } |
238 | // No reads were found, thus the first use is a write. |
239 | return false; |
240 | } |
241 | // If there are no users, it is not an input. |
242 | return false; |
243 | } |
244 | |
245 | /// If \p PH is an input placeholder in the function \p F, \returns true. |
246 | bool isInput(const Placeholder *PH, const IRFunction &F) { |
247 | // Check that the PH is always used as an @in parameter by the current |
248 | // function. |
249 | auto *weight = F.getWeightForNode(PH); |
250 | DCHECK(weight) << "Weight for a node was not found" ; |
251 | return isInput(weight); |
252 | } |
253 | |
254 | bool isOutput(const Placeholder *PH, |
255 | const std::vector<const Function *> &funcs) { |
256 | for (const auto &f : funcs) { |
257 | if (isOutput(PH, *f)) { |
258 | return true; |
259 | } |
260 | } |
261 | |
262 | return false; |
263 | } |
264 | |
265 | /// \returns true if \p PH is an input Placeholder for any function in \p funcs. |
266 | bool isInput(const Placeholder *PH, |
267 | const std::vector<const Function *> &funcs) { |
268 | for (const auto &f : funcs) { |
269 | if (isInput(PH, *f)) { |
270 | return true; |
271 | } |
272 | } |
273 | |
274 | return false; |
275 | } |
276 | |
277 | /// If \p N does not have fused activation \returns true. |
278 | bool checkNoFusionForNode(const Node &N) { |
279 | #define DEF_NODE(CLASS, NAME) \ |
280 | case Kinded::Kind::CLASS##Kind: { \ |
281 | const CLASS *CI = llvm::cast<CLASS>(&N); \ |
282 | return checkNoFusion(*CI); \ |
283 | break; \ |
284 | } |
285 | switch (N.getKind()) { |
286 | #include "glow/AutoGenNodes.def" |
287 | default: |
288 | llvm_unreachable("Invalid node." ); |
289 | } |
290 | return true; |
291 | } |
292 | |
293 | /// If \p I does not have fused activation \returns true. |
294 | bool checkNoFusionForInstr(const Instruction &I) { |
295 | #define DEF_VALUE(CLASS, NAME) |
296 | #define DEF_INSTR(CLASS, NAME) \ |
297 | case Kinded::Kind::CLASS##Kind: { \ |
298 | const CLASS *CI = llvm::cast<CLASS>(&I); \ |
299 | return checkNoFusion(*CI); \ |
300 | break; \ |
301 | } |
302 | #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) \ |
303 | case Kinded::Kind::CLASS##Kind: { \ |
304 | const CLASS *CI = llvm::cast<CLASS>(&I); \ |
305 | return checkNoFusion(*CI); \ |
306 | break; \ |
307 | } |
308 | switch (I.getKind()) { |
309 | #include "glow/AutoGenInstr.def" |
310 | default: |
311 | llvm_unreachable("Invalid instruction." ); |
312 | } |
313 | return true; |
314 | } |
315 | |
316 | template <typename FUN, typename ARR> |
317 | ContiguousPlaceholders getContiguousPlaceHolder(const ARR &holders, |
318 | const FUN &F) { |
319 | // Pure input placeholders. |
320 | std::vector<const Placeholder *> intputPlaceholders; |
321 | // Pure output placeholders. |
322 | std::vector<const Placeholder *> outputPlaceholders; |
323 | // Input&output placeholders. |
324 | std::vector<const Placeholder *> inputOutputPlaceholders; |
325 | // Neither input nor output placeholders. |
326 | std::vector<const Placeholder *> emptyPlaceholders; |
327 | // Return value. |
328 | ContiguousPlaceholders ret; |
329 | |
330 | for (auto &v : holders) { |
331 | if (isInput(v, F)) { |
332 | if (!isOutput(v, F)) { |
333 | intputPlaceholders.push_back(v); |
334 | } else { |
335 | inputOutputPlaceholders.push_back(v); |
336 | } |
337 | } else { |
338 | if (isOutput(v, F)) { |
339 | outputPlaceholders.push_back(v); |
340 | } else { |
341 | emptyPlaceholders.push_back(v); |
342 | } |
343 | } |
344 | } |
345 | |
346 | for (auto &v : intputPlaceholders) { |
347 | PlaceholderInputOutputInfo holder; |
348 | holder.addr = v; |
349 | holder.isInput = true; |
350 | holder.isOutput = false; |
351 | ret.push_back(holder); |
352 | } |
353 | |
354 | for (auto &v : inputOutputPlaceholders) { |
355 | PlaceholderInputOutputInfo holder; |
356 | holder.addr = v; |
357 | holder.isInput = true; |
358 | holder.isOutput = true; |
359 | ret.push_back(holder); |
360 | } |
361 | |
362 | for (auto &v : outputPlaceholders) { |
363 | PlaceholderInputOutputInfo holder; |
364 | holder.addr = v; |
365 | holder.isInput = false; |
366 | holder.isOutput = true; |
367 | ret.push_back(holder); |
368 | } |
369 | |
370 | for (auto &v : emptyPlaceholders) { |
371 | PlaceholderInputOutputInfo holder; |
372 | holder.addr = v; |
373 | holder.isInput = false; |
374 | holder.isOutput = false; |
375 | ret.push_back(holder); |
376 | } |
377 | |
378 | return ret; |
379 | } |
380 | |
381 | /// \returns true if \p dst is capable of handling a partial tensor as input |
382 | /// from \p src. |
383 | static bool allowsPartialInput(const Node *src, const Node *dst) { |
384 | // If N is used as the indices or weights of a sparse lookup, it is safe to |
385 | // access a partial tensor. |
386 | if (auto *SLS = |
387 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>( |
388 | dst)) { |
389 | return src == SLS->getIndices() || src == SLS->getWeights(); |
390 | } else if (auto *SLS = |
391 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsSumNode>( |
392 | dst)) { |
393 | return src == SLS->getIndices(); |
394 | } else if (auto *SLS = llvm::dyn_cast<SparseLengthsWeightedSumNode>(dst)) { |
395 | return src == SLS->getIndices() || src == SLS->getWeights(); |
396 | } else if (auto *SLS = llvm::dyn_cast<SparseLengthsSumNode>(dst)) { |
397 | return src == SLS->getIndices(); |
398 | } else if (auto *EBB = llvm::dyn_cast<EmbeddingBagNode>(dst)) { |
399 | return src == EBB->getIndices() || src == EBB->getWeights(); |
400 | } else if (auto *EBB = |
401 | llvm::dyn_cast<EmbeddingBagByteRowwiseOffsetsNode>(dst)) { |
402 | return src == EBB->getIndices() || src == EBB->getWeights(); |
403 | } |
404 | return false; |
405 | } |
406 | |
407 | bool allowsPartialInput(const Placeholder *V, const Function *F) { |
408 | for (auto const &U : V->getUsers()) { |
409 | if (U.getUser()->getParent() != F) { |
410 | continue; |
411 | } |
412 | if (!allowsPartialInput(*U.get(), U.getUser())) { |
413 | return false; |
414 | } |
415 | } |
416 | return true; |
417 | } |
418 | |
419 | /// \returns true if \p dst requires last-element padding for \p src |
420 | /// It is assumed that \p src cannot be partial input |
421 | static bool requiresPadding(const Node *src, const Node *dst) { |
422 | if (auto *EBB = llvm::dyn_cast<EmbeddingBagNode>(dst)) { |
423 | return src == EBB->getOffsets(); |
424 | } else if (auto *EBB = |
425 | llvm::dyn_cast<EmbeddingBagByteRowwiseOffsetsNode>(dst)) { |
426 | return src == EBB->getOffsets(); |
427 | } |
428 | return false; |
429 | } |
430 | |
431 | bool requiresPadding(const Placeholder *V, const Function *F) { |
432 | // TODO: this function is largely duplicated with allowsPartialInput() |
433 | // we should consider merging the two |
434 | for (auto const &U : V->getUsers()) { |
435 | if (U.getUser()->getParent() != F) { |
436 | continue; |
437 | } |
438 | if (!requiresPadding(*U.get(), U.getUser())) { |
439 | return false; |
440 | } |
441 | } |
442 | return true; |
443 | } |
444 | |
445 | bool usedInFunction(const Placeholder *V, const Function *F) { |
446 | for (auto const &U : V->getUsers()) { |
447 | if (U.getUser()->getParent() == F) { |
448 | return true; |
449 | } |
450 | } |
451 | return false; |
452 | } |
453 | |
454 | /// Allocate space for the Constants in \p constants using \p allocator and |
455 | /// store the resultant symbols in \p symbolTable. |
456 | template <typename ConstantsTy> |
457 | static void allocateConstantsImpl(const ConstantsTy &constants, |
458 | MemoryAllocator &allocator, |
459 | glow::runtime::SymbolTableTy &symbolTable) { |
460 | for (auto const *C : constants) { |
461 | // Same constant may be used multiple times by different functions. But it |
462 | // should be assigned an address only once. |
463 | if (symbolTable.count(std::string(C->getName()))) { |
464 | continue; |
465 | } |
466 | auto size = C->getType()->getSizeInBytes(); |
467 | auto offset = allocator.allocate(size, C); |
468 | runtime::RuntimeSymbolInfo symbol; |
469 | symbol.offset = offset; |
470 | symbol.size = size; |
471 | symbol.type = *C->getType(); |
472 | symbol.input = false; |
473 | symbol.output = false; |
474 | symbol.symbolCategory = glow::runtime::SymbolCategory::Constant; |
475 | symbolTable.emplace(C->getName(), symbol); |
476 | DEBUG_GLOW(LOG(INFO) << strFormat( |
477 | "Assigned address to constant %s: %zx (%zd bytes)\n" , |
478 | C->getName().data(), symbol.offset, symbol.size)); |
479 | } |
480 | } |
481 | |
482 | void allocateConstants(const ConstList &constants, MemoryAllocator &allocator, |
483 | glow::runtime::SymbolTableTy &symbolTable) { |
484 | allocateConstantsImpl(constants, allocator, symbolTable); |
485 | } |
486 | |
487 | void allocateConstants(const std::vector<const glow::Constant *> &constants, |
488 | MemoryAllocator &allocator, |
489 | glow::runtime::SymbolTableTy &symbolTable) { |
490 | allocateConstantsImpl(constants, allocator, symbolTable); |
491 | } |
492 | |
493 | /// Allocate space for the Placeholders in \p placeholders using \p allocator |
494 | /// and store the resultant symbols in \p symbolTable. |
495 | void allocatePlaceholders(const ContiguousPlaceholders &placeholders, |
496 | MemoryAllocator &allocator, |
497 | glow::runtime::SymbolTableTy &symbolTable) { |
498 | for (const auto &p : placeholders) { |
499 | auto &V = p.addr; |
500 | assert(!symbolTable.count(std::string(V->getName())) && |
501 | "Allocation already made!" ); |
502 | auto size = V->getType()->getSizeInBytes(); |
503 | auto offset = allocator.allocate(size, V); |
504 | runtime::RuntimeSymbolInfo symbol; |
505 | symbol.offset = offset; |
506 | symbol.size = size; |
507 | symbol.type = *V->getType(); |
508 | symbol.output = p.isOutput; |
509 | symbol.input = p.isInput; |
510 | symbol.symbolCategory = glow::runtime::SymbolCategory::Placeholder; |
511 | symbolTable.emplace(std::string(V->getName()), symbol); |
512 | DEBUG_GLOW(LOG(INFO) << strFormat( |
513 | "Assigned address to mutable weight %s: %zx (%zd bytes)\n" , |
514 | V->getName().data(), symbol.offset, symbol.size)); |
515 | } |
516 | } |
517 | |
518 | /// Allocate space for the activations of \p instrs using \p allocator and store |
519 | /// the resultant symbols in \p symbolTable. |
520 | void allocateActivations(const glow::IRFunction::InstListTy &instrs, |
521 | MemoryAllocator &allocator, |
522 | glow::runtime::SymbolTableTy &symbolTable) { |
523 | |
524 | // Gather allocation/deallocation sequence. |
525 | std::list<Allocation> allocList; |
526 | if (reuseActivationsMemory) { |
527 | // When reusing memory we register allocs/deallocs in their original order. |
528 | for (const auto &I : instrs) { |
529 | if (auto *A = dyn_cast<AllocActivationInst>(&I)) { |
530 | auto numBytes = I.getSizeInBytes(); |
531 | allocList.emplace_back(A, /* alloc */ true, numBytes); |
532 | continue; |
533 | } |
534 | if (auto *D = dyn_cast<DeallocActivationInst>(&I)) { |
535 | auto *A = D->getAlloc(); |
536 | allocList.emplace_back(A, /* alloc */ false, 0); |
537 | continue; |
538 | } |
539 | } |
540 | } else { |
541 | // When not reusing memory we register first the allocs then the deallocs. |
542 | for (const auto &I : instrs) { |
543 | if (auto *A = dyn_cast<AllocActivationInst>(&I)) { |
544 | auto numBytes = I.getSizeInBytes(); |
545 | allocList.emplace_back(A, /* alloc */ true, numBytes); |
546 | continue; |
547 | } |
548 | } |
549 | for (const auto &I : instrs) { |
550 | if (auto *D = dyn_cast<DeallocActivationInst>(&I)) { |
551 | auto *A = D->getAlloc(); |
552 | allocList.emplace_back(A, /* alloc */ false, 0); |
553 | continue; |
554 | } |
555 | } |
556 | } |
557 | |
558 | // Allocate all segments at once for better allocation efficiency. |
559 | // We use a separate allocator object since the function "allocateAll()" |
560 | // does not work together with the function "allocate()" which could have |
561 | // been used with the original allocator. |
562 | MemoryAllocator activationsAllocator("mem" , 0, allocator.getAlignment()); |
563 | uint64_t activationsSize = activationsAllocator.allocateAll(allocList); |
564 | |
565 | // Allocate a contiguous segment for the activations of the current function. |
566 | // The individual buffers within this segment are placed according to the |
567 | // logic of allocateAll for better efficiency. |
568 | uint64_t activationsBaseAddr = 0; |
569 | if (activationsSize) { |
570 | MemoryAllocator::Handle activationsHandle = &instrs; |
571 | activationsBaseAddr = |
572 | allocator.allocate(activationsSize, activationsHandle); |
573 | if (reuseActivationsMemory) { |
574 | allocator.deallocate(activationsHandle); |
575 | } |
576 | } |
577 | |
578 | // Map addresses of allocated segments. |
579 | for (const auto &I : instrs) { |
580 | if (auto *A = dyn_cast<AllocActivationInst>(&I)) { |
581 | auto numBytes = I.getSizeInBytes(); |
582 | size_t addr = activationsBaseAddr + activationsAllocator.getAddress(A); |
583 | assert(!symbolTable.count(std::string(A->getName())) && |
584 | "Allocation already made!" ); |
585 | runtime::RuntimeSymbolInfo symbol; |
586 | symbol.offset = addr; |
587 | symbol.size = numBytes; |
588 | symbol.type = *A->getType(); |
589 | symbol.input = false; |
590 | symbol.output = false; |
591 | symbol.symbolCategory = glow::runtime::SymbolCategory::Activation; |
592 | symbolTable.emplace(std::string(A->getName()), symbol); |
593 | DEBUG_GLOW(LOG(INFO) << strFormat( |
594 | "Assigned address to activation %s: %zx (%zd bytes)\n" , |
595 | A->getName().data(), symbol.offset, symbol.size)); |
596 | continue; |
597 | } |
598 | |
599 | if (auto *TV = dyn_cast<TensorViewInst>(&I)) { |
600 | // Calculate and store the length of the offset into the base, using the |
601 | // source of the tensorview. |
602 | assert(!symbolTable.count(std::string(TV->getName())) && |
603 | "Allocation already made!" ); |
604 | auto *tvSource = getOrigin(TV); |
605 | assert(symbolTable.count(std::string(tvSource->getName())) && |
606 | "Source allocation not found!" ); |
607 | runtime::RuntimeSymbolInfo symbol; |
608 | size_t originAddr = symbolTable[std::string(tvSource->getName())].offset; |
609 | size_t offset = calculateTensorViewOffset(TV); |
610 | |
611 | symbol.offset = originAddr + offset; |
612 | symbol.size = TV->getSizeInBytes(); |
613 | symbol.type = *TV->getType(); |
614 | symbol.input = false; |
615 | symbol.output = false; |
616 | auto parentCategory = symbolTable.find(std::string(tvSource->getName())) |
617 | ->second.symbolCategory; |
618 | if (parentCategory == glow::runtime::SymbolCategory::Placeholder) { |
619 | symbol.symbolCategory = |
620 | glow::runtime::SymbolCategory::PlaceholderTensorView; |
621 | } else { |
622 | symbol.symbolCategory = |
623 | glow::runtime::SymbolCategory::ConstantTensorView; |
624 | } |
625 | symbolTable.emplace(std::string(TV->getName()), symbol); |
626 | DEBUG_GLOW(LOG(INFO) << strFormat( |
627 | "Assigned address to activation %s: %zx (%zd bytes)\n" , |
628 | TV->getName().data(), symbol.offset, symbol.size)); |
629 | continue; |
630 | } |
631 | |
632 | if (auto *D = dyn_cast<DeallocActivationInst>(&I)) { |
633 | assert(symbolTable.count(std::string(D->getAlloc()->getName())) && |
634 | "Invalid deallocation!" ); |
635 | } |
636 | } |
637 | } |
638 | |
639 | } // namespace glow |
640 | |
641 | runtime::RuntimeBundle |
642 | runtime::RuntimeBundle::create(const Function &F, |
643 | const std::vector<const IRFunction *> &funcs) { |
644 | std::map<std::string, runtime::RuntimeSymbolInfo> symbolTable; |
645 | MemoryAllocator allocator("allocator" , 0); |
646 | uint64_t constantsMaxMem = 0, placeholdersMaxMem = 0, activationsMaxMem = 0; |
647 | |
648 | // Allocate constants. |
649 | allocateConstants(F.getParent()->getConstants(), allocator, symbolTable); |
650 | constantsMaxMem = allocator.getMaxMemoryUsage(); |
651 | |
652 | // Allocate placeholders. Placeholders should be allocated in a order of |
653 | // Input|InputOutput|Output. |
654 | std::vector<const Function *> graphs; |
655 | graphs.reserve(funcs.size()); |
656 | for (const auto &f : funcs) { |
657 | graphs.emplace_back(f->getGraph()); |
658 | } |
659 | |
660 | auto contiguousPlaceholders = |
661 | getContiguousPlaceHolder(F.getParent()->getPlaceholders(), graphs); |
662 | allocatePlaceholders(contiguousPlaceholders, allocator, symbolTable); |
663 | placeholdersMaxMem = allocator.getMaxMemoryUsage() - constantsMaxMem; |
664 | |
665 | // Allocate activations. |
666 | for (const auto &f : funcs) { |
667 | allocateActivations(f->getInstrs(), allocator, symbolTable); |
668 | } |
669 | |
670 | activationsMaxMem = |
671 | allocator.getMaxMemoryUsage() - constantsMaxMem - placeholdersMaxMem; |
672 | |
673 | return runtime::RuntimeBundle(symbolTable, constantsMaxMem, |
674 | placeholdersMaxMem, activationsMaxMem); |
675 | } |
676 | |
677 | runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) { |
678 | std::map<std::string, runtime::RuntimeSymbolInfo> symbolTable; |
679 | |
680 | MemoryAllocator constants("constants" , 0); |
681 | MemoryAllocator placeholders("placeholders" , 0); |
682 | |
683 | // Allocate constants. |
684 | allocateConstants(F.findConstants(), constants, symbolTable); |
685 | |
686 | // Allocate placeholders. |
687 | // Placeholders should be allocated in a order of Input|InputOutput|Output. |
688 | auto contiguousPlaceholders = |
689 | getContiguousPlaceHolder(F.findPlaceholders(), F); |
690 | |
691 | // Compute the offsets for Placeholders. |
692 | allocatePlaceholders(contiguousPlaceholders, placeholders, symbolTable); |
693 | |
694 | return runtime::RuntimeBundle(symbolTable, constants.getMaxMemoryUsage(), |
695 | placeholders.getMaxMemoryUsage(), |
696 | /*activationsMaxSize*/ 0); |
697 | } |
698 | |
699 | runtime::RuntimeBundle |
700 | runtime::RuntimeBundle::create(const IRFunction &F, |
701 | MemoryAllocator &constantAllocator, |
702 | MemoryAllocator &placeholderAllocator, |
703 | MemoryAllocator &activationsAllocator) { |
704 | |
705 | // If all allocators refer to the same underlying allocator, Constants, |
706 | // Placeholders and activations will be allocated contiguously. The maximum |
707 | // memory usage reported by the allocator for each kind of storage will |
708 | // include the memory usage of all previously allocated types of storage and |
709 | // needs to be adjusted accordingly. |
710 | bool contiguous = (&constantAllocator == &placeholderAllocator && |
711 | &constantAllocator == &activationsAllocator); |
712 | // Handle Constants, Placeholders, and Activations, in that order. |
713 | // Symbol table mapping symbol name to offset for runtime. |
714 | std::map<std::string, runtime::RuntimeSymbolInfo> symbolTable; |
715 | |
716 | allocateConstants(F.findConstants(), constantAllocator, symbolTable); |
717 | auto constantMaxSize = constantAllocator.getMaxMemoryUsage(); |
718 | |
719 | // Placeholders should be allocated in a order of Input|InputOutput|Output. |
720 | auto contiguousPlaceholders = |
721 | getContiguousPlaceHolder(F.findPlaceholders(), F); |
722 | // Compute the offsets for Placeholders. |
723 | allocatePlaceholders(contiguousPlaceholders, placeholderAllocator, |
724 | symbolTable); |
725 | auto placeholderMaxSize = placeholderAllocator.getMaxMemoryUsage(); |
726 | if (contiguous) { |
727 | placeholderMaxSize -= constantMaxSize; |
728 | } |
729 | |
730 | // Compute the offsets for Activations. |
731 | allocateActivations(F.getInstrs(), activationsAllocator, symbolTable); |
732 | |
733 | auto activationsMaxSize = activationsAllocator.getMaxMemoryUsage(); |
734 | if (contiguous) { |
735 | activationsMaxSize -= constantMaxSize + placeholderMaxSize; |
736 | DCHECK_EQ(constantAllocator.getMaxMemoryUsage(), |
737 | constantMaxSize + placeholderMaxSize + activationsMaxSize); |
738 | } |
739 | |
740 | return runtime::RuntimeBundle(symbolTable, constantMaxSize, |
741 | placeholderMaxSize, activationsMaxSize); |
742 | } |
743 | |
744 | runtime::RuntimeBundle |
745 | runtime::RuntimeBundle::create(const IRFunction &F, |
746 | MemoryAllocator &allocator) { |
747 | return create(F, allocator, allocator, allocator); |
748 | } |
749 | |