1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/LLVMIRCodeGen/CommandLine.h" |
18 | #include "glow/LLVMIRCodeGen/LLVMBackend.h" |
19 | #include "glow/LLVMIRCodeGen/LLVMIRGen.h" |
20 | |
21 | #include "glow/IR/Instrs.h" |
22 | #include "glow/Support/Debug.h" |
23 | |
24 | #include "llvm/ADT/Statistic.h" |
25 | #include "llvm/Support/Debug.h" |
26 | #include "llvm/Transforms/Utils/Cloning.h" |
27 | #include "llvm/Transforms/Utils/ValueMapper.h" |
28 | |
29 | #define DEBUG_TYPE "ir-function-specializer" |
30 | |
31 | using namespace glow; |
32 | |
33 | using llvm::cast; |
34 | using llvm::dyn_cast; |
35 | using llvm::isa; |
36 | |
37 | namespace { |
38 | /// Perform function specialization with constant arguments taking into account |
39 | /// only dimensions, but not the buffer addresses. This allows for faster JIT |
40 | /// compilation and the does degrade performance. |
41 | static llvm::cl::opt<bool> |
42 | jitSpecializeDims("jit-specialize" , |
43 | llvm::cl::desc("Create specialized functions for " |
44 | "operations with constant dimensions" ), |
45 | llvm::cl::init(true), llvm::cl::cat(getLLVMBackendCat())); |
46 | |
47 | STATISTIC(NumSpecializations, "Number of created specializations" ); |
48 | STATISTIC(NumSharedSpecializations, "Number of shared specializations" ); |
49 | |
50 | /// Check if the value \p Value is a constant for the purposes of the function |
51 | /// specialization, i.e. it is an LLVM constant or it is a global constant |
52 | /// variable, which is initialized by an LLVM constant. These variables are |
53 | /// produced by IRGen e.g. for arrays of dimensions. |
54 | /// |
55 | /// \returns the constant value if \p Value is a constant or nullptr. |
56 | llvm::Value *getConstantValue(llvm::Value *v) { |
57 | // Check if it is a global variable which are constants and initialized by a |
58 | // const. This pattern is produced by the IRGen for the const arrays |
59 | // containing dimensions. |
60 | if (auto *GV = dyn_cast<llvm::GlobalVariable>(v)) { |
61 | auto *init = GV->getInitializer(); |
62 | if (!GV->isConstant() || !init) |
63 | return nullptr; |
64 | return v; |
65 | } |
66 | if (isa<llvm::Constant>(v)) |
67 | return v; |
68 | // This is an unknown pattern. Be conservative and assume it is not a |
69 | // constant. |
70 | return nullptr; |
71 | } |
72 | |
73 | /// Remember in \p argsToBeSpecialized that the argument \p argIdx needs to be |
74 | /// specialized. |
75 | static void addArgToBeSpecialized(uint64_t &argsToBeSpecialized, |
76 | unsigned argIdx) { |
77 | assert(argIdx < 64 && "argIdx exceeds 64" ); |
78 | argsToBeSpecialized |= (((uint64_t)1) << argIdx); |
79 | } |
80 | |
81 | /// \returns true if the argument \p argIdx needs to be specialized according to |
82 | /// the \p argsToBeSpecialized mask. |
83 | static bool isArgToBeSpecialized(uint64_t argsToBeSpecialized, |
84 | unsigned argIdx) { |
85 | assert(argIdx < 64 && "argIdx exceeds 64" ); |
86 | return argsToBeSpecialized & (((uint64_t)1) << argIdx); |
87 | } |
88 | |
89 | /// Specialize functions for constant arguments. Such specialized functions are |
90 | /// marked as noinline and simply invoke the original function with constant |
91 | /// arguments. This call later gets inlined and optimized. |
92 | class FunctionSpecializer { |
93 | /// Create a unique name for each specialization. |
94 | std::string createUniqueName(llvm::StringRef name) { |
95 | return llvm::Twine(name) |
96 | .concat("_" ) |
97 | .concat(llvm::Twine(uniqueIdx_++)) |
98 | .concat("_specialized" ) |
99 | .str(); |
100 | } |
101 | |
102 | /// \returns True if the argument \p arg needs to be specialized in the |
103 | /// function. |
104 | /// NOTE: Currently, the decision is based on the type of the argument |
105 | /// \p arg and position of the arg \p argIdx. \p callee is not used. In the |
106 | /// future, we may need to improve this logic by taking into account the |
107 | /// semantics of the argument or even the specifics of the function call being |
108 | /// specialized. |
109 | bool shouldSpecializeParameter(llvm::Value *arg, unsigned argIdx, |
110 | llvm::Function *callee) { |
111 | // Don't specialize argument index exceeding 63 because we only have 64 |
112 | // bitmap to index the arguments (check `isArgToBeSpecialized` and |
113 | // `addArgToBeSpecialized`) |
114 | if (argIdx > 63) { |
115 | return false; |
116 | } |
117 | |
118 | // This flag force-specializes all arguments. |
119 | if (jitSpecializeAllArguments_) { |
120 | return true; |
121 | } |
122 | |
123 | // Don't specialize arguments that we were requested to skip. |
124 | if (dontSpecializeArgsSet_.count(arg)) { |
125 | return false; |
126 | } |
127 | |
128 | // We don't specialize arguments which are pointers to floating point and |
129 | // quantized buffers, because this is likely to significantly increase the |
130 | // code size without any big performance benefits. |
131 | if (arg->getType()->isPointerTy()) { |
132 | auto elemTy = cast<llvm::PointerType>(arg->getType())->getElementType(); |
133 | // Bail if it is an FP buffer. |
134 | if (elemTy->isFloatTy()) { |
135 | return false; |
136 | } |
137 | // Bail if it is a quantized buffer. |
138 | if (elemTy->isIntegerTy(8)) { |
139 | return false; |
140 | } |
141 | } |
142 | |
143 | // We specialize all other arguments, which typically represent dimensions |
144 | // of tensors, indices, size of batches, etc. |
145 | return true; |
146 | } |
147 | |
148 | /// Find an existing specialization or create a new one. |
149 | /// \param CI the call that is being specialized. |
150 | /// \param F the function being specialized. |
151 | /// \param ArgsToBeSpecialized the set of arguments that should be |
152 | /// specialized. See SpecializationKey docs for the explanation of how this |
153 | /// information is encoded. |
154 | /// \returns a specialized version of the function for |
155 | /// provided parameters. |
156 | llvm::Function *getOrCreateSpecializedFunction(llvm::CallInst *call, |
157 | llvm::Function *F, |
158 | uint64_t argsToBeSpecialized) { |
159 | // Bail if there is nothing to do |
160 | if (!jitSpecializeAllArguments_ && !jitSpecializeDims) |
161 | return F; |
162 | |
163 | // A key representing the function and arguments to be specialized. |
164 | SpecializationKey key{call, argsToBeSpecialized}; |
165 | // Check if there is any existing specialization for this hash key already. |
166 | auto &specializedF = specializations_[key]; |
167 | if (specializedF) { |
168 | auto specializedFnTy = specializedF->getFunctionType(); |
169 | auto FnTy = F->getFunctionType(); |
170 | (void)specializedFnTy; |
171 | (void)FnTy; |
172 | assert( |
173 | specializedFnTy->getReturnType() == FnTy->getReturnType() && |
174 | "A function and its specialization should have the same return type" ); |
175 | // The specialized function only takes non-specialized parameters from the |
176 | // original function call. Check that the types of these parameters are |
177 | // the same for the original and the specialized function. |
178 | for (size_t argIdx = 0, specializedFnArgIdx = 0, e = F->arg_size(); |
179 | argIdx < e; ++argIdx) { |
180 | // If the parameter is specialized, it is not present in the specialized |
181 | // function. |
182 | if (isArgToBeSpecialized(argsToBeSpecialized, argIdx)) |
183 | continue; |
184 | // The parameter of the original call is not specialized and should be |
185 | // present in the specialized function. |
186 | assert(specializedFnTy->getParamType(specializedFnArgIdx) == |
187 | FnTy->getParamType(argIdx) && |
188 | "A function and its specialization should have the same " |
189 | "parameter type for non-constant arguments" ); |
190 | specializedFnArgIdx++; |
191 | } |
192 | NumSharedSpecializations++; |
193 | return specializedF; |
194 | } |
195 | |
196 | std::string specializedName = createUniqueName(F->getName()); |
197 | |
198 | // We are going to clone the body of the original function and substitute |
199 | // constant values for the (constant) arguments that are going to be |
200 | // specialized. The LLVM's cloning function requires a map for its |
201 | // operation. All arguments mapped by this map are removed from the argument |
202 | // list of the specialized function. |
203 | llvm::ValueToValueMapTy VMap; |
204 | size_t argIdx = 0; |
205 | for (auto &arg : F->args()) { |
206 | // If this argument needs to be specialized, use its constant |
207 | // value from the call instruction. |
208 | if (isArgToBeSpecialized(argsToBeSpecialized, argIdx)) { |
209 | auto *argValue = call->getArgOperand(argIdx); |
210 | // Map the argument to a constant value. |
211 | VMap[&arg] = argValue; |
212 | } |
213 | argIdx++; |
214 | } |
215 | |
216 | // Create a specialized function by cloning the body of the original |
217 | // function and substituting the values of constant arguments. The |
218 | // specialized function should be marked as noinline, to avoid code bloat. |
219 | specializedF = llvm::CloneFunction(F, VMap); |
220 | specializedF->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage); |
221 | assert(specializedF && "Could not create a specialized function" ); |
222 | // Specializations should not be inlined. |
223 | specializedF->addFnAttr(llvm::Attribute::AttrKind::NoInline); |
224 | specializedF->setName(specializedName); |
225 | // No need to explicitly emit a debug info for the specialized function. If |
226 | // the original function had it, the cloner would have automatically copied |
227 | // it into the specialized function. And if the original function did not |
228 | // have any debug info, then its specialization should not have any debug |
229 | // info either. |
230 | DEBUG_GLOW(llvm::dbgs() << "\n\nCreated specialized function " |
231 | << specializedName << "\n" ; |
232 | specializedF->print(llvm::errs(), nullptr)); |
233 | NumSpecializations++; |
234 | return specializedF; |
235 | } |
236 | |
237 | /// \returns true if a function is eligible for specialization. |
238 | bool isEligibleForSpecialization(const llvm::CallInst *call) { |
239 | // For now, specialize all functions invoked from "main". In the future, we |
240 | // may introduce more complex logic for making this decision. It could be |
241 | // based in the number of invocations of a function, number of its |
242 | // arguments, its code size, etc. |
243 | const auto *caller = call->getFunction(); |
244 | const auto *callee = call->getCalledFunction(); |
245 | // Specialized only calls inside main. |
246 | assert(std::find(entryFunctions_.begin(), entryFunctions_.end(), caller) != |
247 | entryFunctions_.end() && |
248 | "Only calls inside the entry function are specialized" ); |
249 | (void)caller; |
250 | // Do not specialize any LLVM internal functions. |
251 | if (callee && callee->getName().startswith("llvm." )) { |
252 | return false; |
253 | } |
254 | // Do not specialize declarations. |
255 | if (callee && callee->isDeclaration()) { |
256 | return false; |
257 | } |
258 | // Do not specialize calls if LLVMIRGen is against it. |
259 | if (!irgen_.isEligibleForSpecialization(call)) { |
260 | return false; |
261 | } |
262 | // Do not specialize noinline functions, because it does not improve |
263 | // anything. |
264 | return callee != nullptr && |
265 | !callee->hasFnAttribute(llvm::Attribute::AttrKind::NoInline); |
266 | } |
267 | |
268 | public: |
269 | FunctionSpecializer(llvm::SmallVectorImpl<llvm::Function *> &entryFunctions, |
270 | llvm::DenseSet<llvm::Value *> &dontSpec, LLVMIRGen &irgen) |
271 | : entryFunctions_(entryFunctions), dontSpecializeArgsSet_(dontSpec), |
272 | irgen_(irgen) {} |
273 | |
274 | /// Specialize a single call. |
275 | /// \returns the specialized Call instruction if it was possible to specialize |
276 | /// the call or nullptr otherwise. |
277 | llvm::CallInst *specializeCall(llvm::CallInst *call) { |
278 | llvm::IRBuilder<> builder(call->getParent()); |
279 | auto *callee = call->getCalledFunction(); |
280 | // Args to be used for calling the specialized function. |
281 | llvm::SmallVector<llvm::Value *, 16> argsForSpecialized; |
282 | // Set of arguments that need to be specialized. See SpecializationKey |
283 | // documentation for more information about the encoding of this set. |
284 | uint64_t argsToBeSpecialized = 0; |
285 | |
286 | // Go over all call arguments. |
287 | // Check that all arguments are constants. |
288 | // Form the set of arguments to be specialized. |
289 | unsigned argIdx = 0; |
290 | for (auto &arg : call->arg_operands()) { |
291 | auto curArgIdx = argIdx++; |
292 | |
293 | if (!shouldSpecializeParameter(arg, curArgIdx, callee)) { |
294 | argsForSpecialized.push_back(arg); |
295 | continue; |
296 | } |
297 | |
298 | addArgToBeSpecialized(argsToBeSpecialized, curArgIdx); |
299 | |
300 | // Bail if the values of arguments are not constants. |
301 | if (!getConstantValue(arg)) { |
302 | DEBUG_GLOW(llvm::dbgs() << "Could not specialize call:\n" ; |
303 | call->print(llvm::dbgs())); |
304 | return nullptr; |
305 | } |
306 | } |
307 | |
308 | auto *specializedF = |
309 | getOrCreateSpecializedFunction(call, callee, argsToBeSpecialized); |
310 | // Generate a call of the specialized function before the current call |
311 | // instruction. |
312 | builder.SetInsertPoint(call); |
313 | return irgen_.createCall(builder, specializedF, argsForSpecialized); |
314 | } |
315 | |
316 | void run() { |
317 | // Bail if there is nothing to be specialized. |
318 | if (!jitSpecializeDims && !jitSpecializeAllArguments_) |
319 | return; |
320 | // Collect calls that were replaced by specialized calls and can be erased. |
321 | // The removal should happen after all specializations are done, because |
322 | // these call instructions are used by the keys in Specializations_ map. |
323 | llvm::DenseMap<llvm::Instruction *, llvm::Instruction *> |
324 | callToSpecializedCall; |
325 | llvm::SmallVector<llvm::CallInst *, 64> calls; |
326 | for (auto *F : entryFunctions_) { |
327 | // Collect all eligable calls in the current function. |
328 | for (auto &BB : *F) { |
329 | for (auto &I : BB) { |
330 | auto *CI = dyn_cast<llvm::CallInst>(&I); |
331 | if (!CI) |
332 | continue; |
333 | if (!isEligibleForSpecialization(CI)) |
334 | continue; |
335 | calls.push_back(CI); |
336 | } |
337 | } |
338 | } |
339 | // Try to specialize all the collected calls. |
340 | for (auto *call : calls) { |
341 | if (auto *specializedCall = specializeCall(call)) { |
342 | callToSpecializedCall.insert(std::make_pair(call, specializedCall)); |
343 | } |
344 | } |
345 | |
346 | // Remove those calls that were successfully replaced by calls of |
347 | // specialized functions. This needs to be done after all specializations, |
348 | // because keys of Specializations_ use these Call instructions for the |
349 | // duration of the whole specialization pass. |
350 | for (auto &kv : callToSpecializedCall) { |
351 | // Check if the original call returns a result and replace all its uses. |
352 | if (!kv.first->getType()->isVoidTy()) { |
353 | kv.first->replaceAllUsesWith(kv.second); |
354 | } |
355 | kv.first->eraseFromParent(); |
356 | } |
357 | DEBUG_GLOW(llvm::dbgs() << "Number of specializations: " |
358 | << NumSpecializations << "\n" ; |
359 | llvm::dbgs() << "Number of shared specializations: " |
360 | << NumSharedSpecializations << "\n" ); |
361 | } |
362 | |
363 | private: |
364 | /// This is a key into the specialization table. It consists of the call |
365 | /// instruction and an integer encoding which arguments of this call should be |
366 | /// used for the hash computation. If the Nth bit is set, then the Nth |
367 | /// argument of the call should participate in the hash computation. |
368 | /// |
369 | /// This encoding heavily relies on the fact that LLVM constants are uniqued |
370 | /// internally and their equality can be checked by means of a simple |
371 | /// pointer comparison. |
372 | struct SpecializationKey { |
373 | SpecializationKey(llvm::CallInst *CI, uint64_t Args) |
374 | : call_(CI), argsToBeSpecialized_(Args) {} |
375 | |
376 | /// The first call instruction that was used to create this specialization. |
377 | llvm::CallInst *call_{nullptr}; |
378 | /// The set of argument numbers that need to be specialized. |
379 | uint64_t argsToBeSpecialized_{0}; |
380 | }; |
381 | |
382 | /// A helper class providing a hash function for FunctionSpecializer. |
383 | struct SpecializationKeyHasher { |
384 | size_t operator()(const SpecializationKey &key) const { |
385 | // Take the name of the callee into account. |
386 | llvm::hash_code hash = |
387 | llvm::hash_value(key.call_->getCalledFunction()->getName()); |
388 | // Hash over all arguments required by the \p ArgsToBeSpecialized_. |
389 | // We can compute the hash this way, because these arguments are LLVM |
390 | // constants which are uniqued. Therefore, the address of a constant is |
391 | // its unique representation. |
392 | for (unsigned idx = 0, e = key.call_->getNumArgOperands(); idx < e; |
393 | ++idx) { |
394 | if (isArgToBeSpecialized(key.argsToBeSpecialized_, idx)) { |
395 | hash = llvm::hash_combine( |
396 | hash, getConstantValue(key.call_->getArgOperand(idx))); |
397 | } |
398 | } |
399 | return hash; |
400 | } |
401 | }; |
402 | |
403 | /// A helper class providing the equality function for FunctionSpecializer. |
404 | struct SpecializationKeyEq { |
405 | bool operator()(const SpecializationKey &lhs, |
406 | const SpecializationKey &rhs) const { |
407 | if (lhs.call_->getCalledFunction() != rhs.call_->getCalledFunction()) |
408 | return false; |
409 | if (lhs.argsToBeSpecialized_ != rhs.argsToBeSpecialized_) |
410 | return false; |
411 | for (unsigned idx = 0, e = lhs.call_->getNumArgOperands(); idx < e; |
412 | ++idx) { |
413 | if (isArgToBeSpecialized(lhs.argsToBeSpecialized_, idx)) { |
414 | if (getConstantValue(lhs.call_->getArgOperand(idx)) != |
415 | getConstantValue(rhs.call_->getArgOperand(idx))) |
416 | return false; |
417 | } |
418 | } |
419 | return true; |
420 | } |
421 | }; |
422 | |
423 | /// The entry functions of the module. |
424 | llvm::SmallVectorImpl<llvm::Function *> &entryFunctions_; |
425 | /// Mapping from specialization keys to the specialized functions. |
426 | std::unordered_map<SpecializationKey, llvm::Function *, |
427 | SpecializationKeyHasher, SpecializationKeyEq> |
428 | specializations_; |
429 | |
430 | /// An index to create unique specialization names. |
431 | unsigned uniqueIdx_{0}; |
432 | |
433 | /// If set, specialize taking into account the whole set of arguments, |
434 | /// including buffer addresses. |
435 | bool jitSpecializeAllArguments_{false}; |
436 | |
437 | /// A reference to a set of values that the specializer was requested not to |
438 | /// specialize. |
439 | llvm::DenseSet<llvm::Value *> &dontSpecializeArgsSet_; |
440 | /// LLVMIRGen to be used. |
441 | LLVMIRGen &irgen_; |
442 | }; |
443 | |
444 | } // namespace |
445 | |
446 | void LLVMIRGen::performSpecialization() { |
447 | FunctionSpecializer FuncSpecializer(emittedLLVMFunctions_, |
448 | dontSpecializeArgsSet_, *this); |
449 | FuncSpecializer.run(); |
450 | } |
451 | |