1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/LLVMIRCodeGen/LLVMCompiledFunction.h" |
17 | |
18 | #include "glow/Graph/PlaceholderBindings.h" |
19 | #include "glow/Support/Compiler.h" |
20 | #include "glow/Support/Memory.h" |
21 | #include "glow/Support/ThreadPool.h" |
22 | |
23 | using namespace glow; |
24 | |
25 | LLVMCompiledFunction::LLVMCompiledFunction( |
26 | std::unique_ptr<GlowJIT> JIT, runtime::RuntimeBundle &&runtimeBundle) |
27 | : CompiledFunction(std::move(runtimeBundle)), JIT_(std::move(JIT)) {} |
28 | |
29 | void LLVMCompiledFunction::collectConstants(const Module *module) { |
30 | runtimeBundle_.collectConstants(module); |
31 | } |
32 | |
33 | void LLVMCompiledFunction::loadPlaceholders( |
34 | PlaceholderBindings *bindings, uint8_t *baseMutableWeightVarsAddress) { |
35 | // Make sure our inputs are on the host. |
36 | bindings->ensureOnHost(); |
37 | |
38 | // Copy Placeholders into allocated memory. |
39 | auto &symbolTable = runtimeBundle_.getSymbolTable(); |
40 | for (auto &PH : bindings->pairs()) { |
41 | auto it = symbolTable.find(PH.first->getName().str()); |
42 | if (it == symbolTable.end()) { |
43 | continue; |
44 | } |
45 | assert(!PH.second.isDeviceResident()); |
46 | auto symbolInfo = it->second; |
47 | auto payload = PH.second.getUnsafePtr(); |
48 | auto addr = symbolInfo.offset; |
49 | auto numBytes = PH.second.getUnpaddedSizeInBytes(); |
50 | // copy PH to allocated memory. |
51 | memcpy(baseMutableWeightVarsAddress + addr, payload, numBytes); |
52 | } |
53 | } |
54 | |
55 | void LLVMCompiledFunction::updatePlaceholders( |
56 | PlaceholderBindings *bindings, uint8_t *baseMutableWeightVarsAddress) { |
57 | // Copy placeholders from device back into bindings. |
58 | auto &symbolTable = runtimeBundle_.getSymbolTable(); |
59 | for (auto &PH : bindings->pairs()) { |
60 | auto it = symbolTable.find(PH.first->getName().str()); |
61 | if (it == symbolTable.end()) { |
62 | continue; |
63 | } |
64 | auto symbolInfo = it->second; |
65 | auto payload = baseMutableWeightVarsAddress + symbolInfo.offset; |
66 | auto numBytes = PH.second.getUnpaddedSizeInBytes(); |
67 | auto addr = PH.second.getUnsafePtr(); |
68 | // copy PH from allocated memory. |
69 | memcpy(addr, payload, numBytes); |
70 | } |
71 | } |
72 | |
73 | Error LLVMCompiledFunction::execute(ExecutionContext *context) { |
74 | uint8_t *baseActivationsAddress{nullptr}; |
75 | |
76 | /// Base address for Mutable weights memory block, Inputs and Outputs. |
77 | uint8_t *baseMutableWeightVarsAddress{nullptr}; |
78 | |
79 | { |
80 | TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "allocBuffers" ); |
81 | if (runtimeBundle_.getActivationsSize() != 0) { |
82 | baseActivationsAddress = (uint8_t *)alignedAlloc( |
83 | runtimeBundle_.getActivationsSize(), TensorAlignment); |
84 | } |
85 | |
86 | if (runtimeBundle_.getMutableWeightSize() != 0) { |
87 | baseMutableWeightVarsAddress = (uint8_t *)alignedAlloc( |
88 | runtimeBundle_.getMutableWeightSize(), TensorAlignment); |
89 | } |
90 | } |
91 | |
92 | { |
93 | TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "loadPlaceholders" ); |
94 | loadPlaceholders(context->getPlaceholderBindings(), |
95 | baseMutableWeightVarsAddress); |
96 | } |
97 | |
98 | auto *traceContext = context->getTraceContext(); |
99 | TRACE_EVENT_SCOPE_NAMED(traceContext, TraceLevel::RUNTIME, |
100 | "findJitmainSymbol" , fjEvent); |
101 | Expected<llvm::JITTargetAddress> address = NULL; |
102 | { |
103 | std::lock_guard<std::mutex> lock(JITLock_); |
104 | auto sym = JIT_->findSymbol("jitmain" ); |
105 | |
106 | DCHECK(sym) << "Unable to JIT the code!" ; |
107 | // We know address is success since we just made it. Mark it as checked. |
108 | if (address) { |
109 | auto addrOrLLVMError = sym.getAddress(); |
110 | if (addrOrLLVMError) { |
111 | address = addrOrLLVMError.get(); |
112 | } else { |
113 | address = MAKE_ERR( |
114 | strFormat("Failed to get address: %s" , |
115 | llvm::toString(addrOrLLVMError.takeError()).data())); |
116 | } |
117 | } |
118 | } |
119 | using JitFuncType = |
120 | void (*)(uint8_t * constantWeightVars, uint8_t * mutableWeightVars, |
121 | uint8_t * activations); |
122 | if (address) { |
123 | JitFuncType funcPtr = reinterpret_cast<JitFuncType>(address.get()); |
124 | TRACE_EVENT_SCOPE_END_NAMED(fjEvent); |
125 | TRACE_EVENT_SCOPE(traceContext, TraceLevel::RUNTIME, "execute" ); |
126 | funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress, |
127 | baseActivationsAddress); |
128 | } else { |
129 | return MAKE_ERR("Error getting address" ); |
130 | } |
131 | |
132 | { |
133 | TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "updatePlaceholders" ); |
134 | updatePlaceholders(context->getPlaceholderBindings(), |
135 | baseMutableWeightVarsAddress); |
136 | } |
137 | |
138 | { |
139 | TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "freeBuffers" ); |
140 | alignedFree(baseMutableWeightVarsAddress); |
141 | alignedFree(baseActivationsAddress); |
142 | } |
143 | |
144 | { |
145 | TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "processInstrumentation" ); |
146 | translateTraceEvents(context); |
147 | } |
148 | |
149 | return Error::success(); |
150 | } |
151 | |
152 | void LLVMCompiledFunction::translateTraceEvents( |
153 | ExecutionContext *context) const { |
154 | auto &traceInfo = getTraceInfo(); |
155 | if (!traceInfo.enabled) { |
156 | return; |
157 | } |
158 | |
159 | TraceContext *traceContext = context->getTraceContext(); |
160 | |
161 | if (!traceContext->shouldLog(TraceLevel::OPERATOR)) { |
162 | return; |
163 | } |
164 | |
165 | PlaceholderBindings *bindings = context->getPlaceholderBindings(); |
166 | |
167 | int tid = threads::getThreadId(); |
168 | for (auto &backing : traceInfo.events) { |
169 | Tensor *backingTensor = bindings->get(backing.first); |
170 | DCHECK(backingTensor) << "Could not get backing tensor for Placeholder: " |
171 | << backing.first->getName().str(); |
172 | |
173 | auto &traceEvents = traceContext->getTraceEvents(); |
174 | for (const TraceInfo::Event &event : backing.second) { |
175 | // If it's a complete event grab both timestamps. |
176 | if (event.type == TraceEvent::CompleteType) { |
177 | uint64_t start{0}, end{0}; |
178 | memcpy(&start, |
179 | backingTensor->getUnsafePtr() + |
180 | (event.startIndex * traceInfo.dataSize), |
181 | traceInfo.dataSize); |
182 | memcpy(&end, |
183 | backingTensor->getUnsafePtr() + |
184 | (event.endIndex * traceInfo.dataSize), |
185 | traceInfo.dataSize); |
186 | traceEvents.push_back({event.name, |
187 | TraceLevel::OPERATOR, |
188 | start, |
189 | end - start, |
190 | tid, |
191 | {{"kind" , event.kind}}}); |
192 | } else { |
193 | uint64_t ts{0}; |
194 | memcpy(&ts, |
195 | backingTensor->getUnsafePtr() + |
196 | (event.startIndex * traceInfo.dataSize), |
197 | traceInfo.dataSize); |
198 | traceEvents.push_back({event.name, |
199 | TraceLevel::OPERATOR, |
200 | ts, |
201 | event.type, |
202 | tid, |
203 | {{"kind" , event.kind}}}); |
204 | } |
205 | } |
206 | } |
207 | } |
208 | |