1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/LLVMIRCodeGen/LLVMCompiledFunction.h"
17
18#include "glow/Graph/PlaceholderBindings.h"
19#include "glow/Support/Compiler.h"
20#include "glow/Support/Memory.h"
21#include "glow/Support/ThreadPool.h"
22
23using namespace glow;
24
25LLVMCompiledFunction::LLVMCompiledFunction(
26 std::unique_ptr<GlowJIT> JIT, runtime::RuntimeBundle &&runtimeBundle)
27 : CompiledFunction(std::move(runtimeBundle)), JIT_(std::move(JIT)) {}
28
29void LLVMCompiledFunction::collectConstants(const Module *module) {
30 runtimeBundle_.collectConstants(module);
31}
32
33void LLVMCompiledFunction::loadPlaceholders(
34 PlaceholderBindings *bindings, uint8_t *baseMutableWeightVarsAddress) {
35 // Make sure our inputs are on the host.
36 bindings->ensureOnHost();
37
38 // Copy Placeholders into allocated memory.
39 auto &symbolTable = runtimeBundle_.getSymbolTable();
40 for (auto &PH : bindings->pairs()) {
41 auto it = symbolTable.find(PH.first->getName().str());
42 if (it == symbolTable.end()) {
43 continue;
44 }
45 assert(!PH.second.isDeviceResident());
46 auto symbolInfo = it->second;
47 auto payload = PH.second.getUnsafePtr();
48 auto addr = symbolInfo.offset;
49 auto numBytes = PH.second.getUnpaddedSizeInBytes();
50 // copy PH to allocated memory.
51 memcpy(baseMutableWeightVarsAddress + addr, payload, numBytes);
52 }
53}
54
55void LLVMCompiledFunction::updatePlaceholders(
56 PlaceholderBindings *bindings, uint8_t *baseMutableWeightVarsAddress) {
57 // Copy placeholders from device back into bindings.
58 auto &symbolTable = runtimeBundle_.getSymbolTable();
59 for (auto &PH : bindings->pairs()) {
60 auto it = symbolTable.find(PH.first->getName().str());
61 if (it == symbolTable.end()) {
62 continue;
63 }
64 auto symbolInfo = it->second;
65 auto payload = baseMutableWeightVarsAddress + symbolInfo.offset;
66 auto numBytes = PH.second.getUnpaddedSizeInBytes();
67 auto addr = PH.second.getUnsafePtr();
68 // copy PH from allocated memory.
69 memcpy(addr, payload, numBytes);
70 }
71}
72
73Error LLVMCompiledFunction::execute(ExecutionContext *context) {
74 uint8_t *baseActivationsAddress{nullptr};
75
76 /// Base address for Mutable weights memory block, Inputs and Outputs.
77 uint8_t *baseMutableWeightVarsAddress{nullptr};
78
79 {
80 TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "allocBuffers");
81 if (runtimeBundle_.getActivationsSize() != 0) {
82 baseActivationsAddress = (uint8_t *)alignedAlloc(
83 runtimeBundle_.getActivationsSize(), TensorAlignment);
84 }
85
86 if (runtimeBundle_.getMutableWeightSize() != 0) {
87 baseMutableWeightVarsAddress = (uint8_t *)alignedAlloc(
88 runtimeBundle_.getMutableWeightSize(), TensorAlignment);
89 }
90 }
91
92 {
93 TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "loadPlaceholders");
94 loadPlaceholders(context->getPlaceholderBindings(),
95 baseMutableWeightVarsAddress);
96 }
97
98 auto *traceContext = context->getTraceContext();
99 TRACE_EVENT_SCOPE_NAMED(traceContext, TraceLevel::RUNTIME,
100 "findJitmainSymbol", fjEvent);
101 Expected<llvm::JITTargetAddress> address = NULL;
102 {
103 std::lock_guard<std::mutex> lock(JITLock_);
104 auto sym = JIT_->findSymbol("jitmain");
105
106 DCHECK(sym) << "Unable to JIT the code!";
107 // We know address is success since we just made it. Mark it as checked.
108 if (address) {
109 auto addrOrLLVMError = sym.getAddress();
110 if (addrOrLLVMError) {
111 address = addrOrLLVMError.get();
112 } else {
113 address = MAKE_ERR(
114 strFormat("Failed to get address: %s",
115 llvm::toString(addrOrLLVMError.takeError()).data()));
116 }
117 }
118 }
119 using JitFuncType =
120 void (*)(uint8_t * constantWeightVars, uint8_t * mutableWeightVars,
121 uint8_t * activations);
122 if (address) {
123 JitFuncType funcPtr = reinterpret_cast<JitFuncType>(address.get());
124 TRACE_EVENT_SCOPE_END_NAMED(fjEvent);
125 TRACE_EVENT_SCOPE(traceContext, TraceLevel::RUNTIME, "execute");
126 funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress,
127 baseActivationsAddress);
128 } else {
129 return MAKE_ERR("Error getting address");
130 }
131
132 {
133 TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "updatePlaceholders");
134 updatePlaceholders(context->getPlaceholderBindings(),
135 baseMutableWeightVarsAddress);
136 }
137
138 {
139 TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "freeBuffers");
140 alignedFree(baseMutableWeightVarsAddress);
141 alignedFree(baseActivationsAddress);
142 }
143
144 {
145 TRACE_EVENT_SCOPE(context, TraceLevel::RUNTIME, "processInstrumentation");
146 translateTraceEvents(context);
147 }
148
149 return Error::success();
150}
151
152void LLVMCompiledFunction::translateTraceEvents(
153 ExecutionContext *context) const {
154 auto &traceInfo = getTraceInfo();
155 if (!traceInfo.enabled) {
156 return;
157 }
158
159 TraceContext *traceContext = context->getTraceContext();
160
161 if (!traceContext->shouldLog(TraceLevel::OPERATOR)) {
162 return;
163 }
164
165 PlaceholderBindings *bindings = context->getPlaceholderBindings();
166
167 int tid = threads::getThreadId();
168 for (auto &backing : traceInfo.events) {
169 Tensor *backingTensor = bindings->get(backing.first);
170 DCHECK(backingTensor) << "Could not get backing tensor for Placeholder: "
171 << backing.first->getName().str();
172
173 auto &traceEvents = traceContext->getTraceEvents();
174 for (const TraceInfo::Event &event : backing.second) {
175 // If it's a complete event grab both timestamps.
176 if (event.type == TraceEvent::CompleteType) {
177 uint64_t start{0}, end{0};
178 memcpy(&start,
179 backingTensor->getUnsafePtr() +
180 (event.startIndex * traceInfo.dataSize),
181 traceInfo.dataSize);
182 memcpy(&end,
183 backingTensor->getUnsafePtr() +
184 (event.endIndex * traceInfo.dataSize),
185 traceInfo.dataSize);
186 traceEvents.push_back({event.name,
187 TraceLevel::OPERATOR,
188 start,
189 end - start,
190 tid,
191 {{"kind", event.kind}}});
192 } else {
193 uint64_t ts{0};
194 memcpy(&ts,
195 backingTensor->getUnsafePtr() +
196 (event.startIndex * traceInfo.dataSize),
197 traceInfo.dataSize);
198 traceEvents.push_back({event.name,
199 TraceLevel::OPERATOR,
200 ts,
201 event.type,
202 tid,
203 {{"kind", event.kind}}});
204 }
205 }
206 }
207}
208