1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Backend/Backend.h"
18#include "glow/Backends/DummyDeviceManager.h"
19
20#include "glow/Flags/Flags.h"
21#include "glow/Graph/Graph.h"
22#include "glow/Graph/PlaceholderBindings.h"
23#include "glow/Graph/TensorLayout.h"
24#include "glow/IR/Instrs.h"
25#include "glow/Optimizer/GraphOptimizer/CompilationContext.h"
26#include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h"
27
28using namespace glow;
29
30namespace {
31/// Structure for tracking individual compilation thread's state
32struct PerCompilationThreadState {
33 // Functions to compile
34 std::vector<Function *> functions;
35 // Results of compilation
36 llvm::StringMap<std::unique_ptr<CompiledFunction>> compiledFunctions;
37 // Any error that occurred
38 Error err = Error::empty();
39};
40} // namespace
41
42runtime::DeviceManager *
43Backend::createDeviceManager(const runtime::DeviceConfig &deviceConfig) {
44 LOG(ERROR) << "Warning: Creating a DummyDeviceManager.\n";
45 return new runtime::DummyDeviceManager(deviceConfig);
46}
47
48TraceInfo Backend::buildManualTraceInfo(Function *F) const {
49 TraceInfo info(false, getTraceEventDataSize());
50 const auto &nodes = F->getNodes();
51 for (const auto &node : nodes) {
52 if (const TraceEventNode *TEN = llvm::dyn_cast<TraceEventNode>(&node)) {
53 Placeholder *backing =
54 llvm::dyn_cast<Placeholder>(TEN->getData().getNode());
55 assert(backing);
56 char type = TraceEvent::InstantType;
57 if (!TEN->getEventType().empty()) {
58 type = TEN->getEventType()[0];
59 }
60 info.add(backing, TEN->getIndex(), TEN->getEventName(), type);
61 info.enabled = true;
62 }
63 }
64
65 return info;
66}
67
68void Backend::autoInstrument(TraceInfo &traceInfo, IRFunction *IR) const {
69 if (getTraceEventDataSize() == 0) {
70 LOG(ERROR) << "Auto instrumentation not supported on this backend";
71 return;
72 }
73
74 Function *F = IR->getGraph();
75 // Get all instructions in the IRFunction.
76 IRFunction::InstListTy &instructions = IR->getInstrs();
77
78 // First pass, find out how many TraceEvents we should add. Existing
79 // TraceEvents have their own backing Tensors, so don't count them.
80 dim_t numEvents = 1; // Starts at 1 since there is always a start event.
81 for (auto it = instructions.begin(); it != instructions.end(); it++) {
82 auto &I = *it;
83 bool isInstrumentation = llvm::isa<TraceEventInst>(&I);
84 if (!isInstrumentation) {
85 numEvents++;
86 }
87 }
88
89 // Default name for the instrumentation placeholder, will be made unique by
90 // createPlaceholder.
91 std::string name = F->getName().str() + "_instrumentation";
92 Placeholder *backingPH = F->getParent()->getPlaceholderByNameSlow(name);
93
94 auto &varmap = IR->getVariableMap();
95 auto type = F->getParent()->uniqueType(
96 ElemKind::Int64ITy,
97 {numEvents, (dim_t)getTraceEventDataSize() /
98 Type::getElementSize(ElemKind::Int64ITy)});
99
100 WeightVar *backingWeight = nullptr;
101
102 if (backingPH) {
103 // If the standard instrumentation placeholder exists, we might be able to
104 // reuse it.
105 auto it = varmap.find(backingPH);
106 if (it != varmap.end() && backingPH->getType()->isEqual(type)) {
107 // We have a weight for it already and the types match, can reuse it.
108 backingWeight = llvm::cast<WeightVar>(it->second);
109 } else {
110 // This isn't ideal, the placeholder exists but we have no weight.
111 // Probably indicates a bug in the graph, best we can do is create a new
112 // placeholder and weight for the instrumentation.
113 // assert(!"could not find weight for existing instrumentation
114 // placeholder");
115 backingPH = nullptr;
116 }
117 }
118
119 // If we don't have a Placeholder, then we need to create one.
120 if (!backingPH) {
121 // Build a Placeholder to a backing tensor with space to fit all
122 // timestamps.
123 backingPH = F->getParent()->createPlaceholder(type, name,
124 /* isTrainable */ false);
125 assert(backingPH);
126 }
127
128 // Add Placeholder to the graph so we can add it to the runtimeBundle later.
129 F->addMetadataPlaceholder(backingPH);
130
131 // If we don't have a weight we need to create one too, whether or not we
132 // just created a Placeholder.
133 if (!backingWeight) {
134 // Create an associated weight and add it to the IR.
135 backingWeight =
136 new WeightVar(IR->uniqueName(backingPH->getName()),
137 backingPH->getType(), WeightVar::MutabilityKind::Mutable);
138 IR->getWeights().push_back(backingWeight);
139 IR->getVariableMap()[backingPH] = backingWeight;
140 }
141
142 traceInfo.enabled = true;
143 traceInfo.autoInstrumented = true;
144 size_t index = 0;
145
146 // For each instruction, insert a TraceEventInst to record the timestamp,
147 // and two TraceInfo Events for the end of the previous Instruction and the
148 // start of the next.
149 auto it = instructions.begin();
150 while (it != instructions.end()) {
151 auto &I = *it;
152 if (llvm::isa<TraceEventInst>(&I)) {
153 // Don't instrument instrumentation.
154 it++;
155 continue;
156 }
157
158 auto instName = I.getName().str();
159
160 // Start a new event
161 traceInfo.add(backingPH, index, index + 1, instName, std::string(),
162 Kinded::getKindName(I.getKind()));
163
164 it = instructions.insert(
165 it, new TraceEventInst(instName + "_trace", backingWeight, index++));
166
167 // Skip over both I and the new TraceEvent.
168 it++;
169 it++;
170 }
171
172 IR->pushInstr(new TraceEventInst("end_trace", backingWeight, index));
173}
174
175bool Backend::checkAllNodesSupported(const Function &F, bool verbose) const {
176 bool allSupported = true;
177 for (const Node &N : F.getNodes()) {
178 if (!isOpSupported(N)) {
179 allSupported = false;
180 if (verbose) {
181 report("Unsupported node found while compiling Function " +
182 F.getName().str() + " for backend " + getBackendName() + ": " +
183 N.getDebugDesc());
184 }
185 }
186 }
187 return allSupported;
188}
189
190Expected<llvm::StringMap<std::unique_ptr<CompiledFunction>>>
191Backend::compileFunctions(std::vector<Function *> &functions,
192 llvm::StringMap<BackendOptions> &optsMap) const {
193 const size_t numThreads = runtime::flags::NumCompilationThreads;
194
195 // Split functions up into threads
196 size_t functionsPerThread = (functions.size() + numThreads - 1) / numThreads;
197 std::vector<PerCompilationThreadState> threadStates;
198 for (size_t i = 0; i < numThreads; ++i) {
199 PerCompilationThreadState state;
200 // Mark the state's Error as checked to begin with in case we discard this
201 // state
202 EXIT_ON_ERR(std::move(state.err));
203 for (size_t j = 0; j < functionsPerThread; ++j) {
204 size_t idx = i * functionsPerThread + j;
205 if (idx >= functions.size()) {
206 break;
207 }
208 state.functions.push_back(functions[idx]);
209 }
210 if (!state.functions.empty()) {
211 threadStates.emplace_back(std::move(state));
212 }
213 }
214
215 auto compileFn = [&](PerCompilationThreadState *threadState) {
216 auto wrapped = [&]() -> Error {
217 for (auto *function : threadState->functions) {
218 auto functionName = function->getName();
219 RETURN_ERR_IF_NOT(optsMap.count(functionName),
220 strFormat("Can't find corresponding option for "
221 "compiling function with name %s",
222 functionName.str().c_str()));
223 auto backendOpts = optsMap.find(functionName)->second;
224 auto resOrErr = compile(function, backendOpts);
225 if (resOrErr) {
226 threadState->compiledFunctions.insert(
227 {functionName, std::move(*resOrErr)});
228 } else {
229 RETURN_ERR(resOrErr.takeError());
230 }
231 }
232 return Error::success();
233 };
234 // Should be no errors here but it enforces we check anyways
235 EXIT_ON_ERR(std::move(threadState->err));
236 threadState->err = wrapped();
237 };
238
239 // Launch threads
240 std::vector<std::thread> threads;
241 for (auto &threadState : threadStates) {
242 threads.emplace_back(compileFn, &threadState);
243 }
244
245 // Join threads and aggregate compiledFunctions
246 llvm::StringMap<std::unique_ptr<CompiledFunction>> compiledFunctions;
247 for (size_t i = 0; i < threads.size(); ++i) {
248 threads[i].join();
249 auto &threadState = threadStates[i];
250 if (threadState.err) {
251 RETURN_ERR(std::move(threadState.err));
252 }
253 for (auto &kv : threadState.compiledFunctions) {
254 compiledFunctions.insert({kv.first(), std::move(kv.second)});
255 }
256 }
257
258 return Expected<llvm::StringMap<std::unique_ptr<CompiledFunction>>>(
259 std::move(compiledFunctions));
260}
261
262bool Backend::verify(const Function &F, bool verbose) const {
263 return F.verify(this) && checkAllNodesSupported(F, verbose);
264}
265
266bool Backend::verify(const IRFunction &IR) const {
267 (void)IR;
268 return true;
269}
270
271TensorLayoutCommon &Backend::getTensorLayoutRequirements() const {
272 return CanonicalTensorLayout::getInstance();
273}
274
275std::unique_ptr<FunctionPassPipeline> Backend::getOptimizationPipeline() const {
276 auto pipeline = createDefaultGraphOptimizationPassPipeline();
277 // Fold Tile followed by Add into BatchedAdd. Currently this is not part of
278 // the default pipeline to avoid issues with some backends. If backends do
279 // not want this opt then they should override getOptimizationPipeline().
280 pipeline->pushFront({FunctionPassID::FoldTileAddIntoBatchedAdd});
281 return pipeline;
282}
283
284std::unique_ptr<IRFunctionPassPipeline>
285Backend::getIROptimizationPipeline() const {
286 auto pipeline = createDefaultIRFunctionOptimizationPipeline();
287 return pipeline;
288}
289