1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/LLVMIRCodeGen/CommandLine.h"
18#include "glow/LLVMIRCodeGen/LLVMBackend.h"
19#include "glow/LLVMIRCodeGen/LLVMIRGen.h"
20
21#include "glow/IR/Instrs.h"
22#include "glow/Support/Debug.h"
23
24#include "llvm/ADT/Statistic.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Transforms/Utils/Cloning.h"
27#include "llvm/Transforms/Utils/ValueMapper.h"
28
29#define DEBUG_TYPE "ir-function-specializer"
30
31using namespace glow;
32
33using llvm::cast;
34using llvm::dyn_cast;
35using llvm::isa;
36
37namespace {
38/// Perform function specialization with constant arguments taking into account
39/// only dimensions, but not the buffer addresses. This allows for faster JIT
40/// compilation and the does degrade performance.
41static llvm::cl::opt<bool>
42 jitSpecializeDims("jit-specialize",
43 llvm::cl::desc("Create specialized functions for "
44 "operations with constant dimensions"),
45 llvm::cl::init(true), llvm::cl::cat(getLLVMBackendCat()));
46
47STATISTIC(NumSpecializations, "Number of created specializations");
48STATISTIC(NumSharedSpecializations, "Number of shared specializations");
49
50/// Check if the value \p Value is a constant for the purposes of the function
51/// specialization, i.e. it is an LLVM constant or it is a global constant
52/// variable, which is initialized by an LLVM constant. These variables are
53/// produced by IRGen e.g. for arrays of dimensions.
54///
55/// \returns the constant value if \p Value is a constant or nullptr.
56llvm::Value *getConstantValue(llvm::Value *v) {
57 // Check if it is a global variable which are constants and initialized by a
58 // const. This pattern is produced by the IRGen for the const arrays
59 // containing dimensions.
60 if (auto *GV = dyn_cast<llvm::GlobalVariable>(v)) {
61 auto *init = GV->getInitializer();
62 if (!GV->isConstant() || !init)
63 return nullptr;
64 return v;
65 }
66 if (isa<llvm::Constant>(v))
67 return v;
68 // This is an unknown pattern. Be conservative and assume it is not a
69 // constant.
70 return nullptr;
71}
72
73/// Remember in \p argsToBeSpecialized that the argument \p argIdx needs to be
74/// specialized.
75static void addArgToBeSpecialized(uint64_t &argsToBeSpecialized,
76 unsigned argIdx) {
77 assert(argIdx < 64 && "argIdx exceeds 64");
78 argsToBeSpecialized |= (((uint64_t)1) << argIdx);
79}
80
81/// \returns true if the argument \p argIdx needs to be specialized according to
82/// the \p argsToBeSpecialized mask.
83static bool isArgToBeSpecialized(uint64_t argsToBeSpecialized,
84 unsigned argIdx) {
85 assert(argIdx < 64 && "argIdx exceeds 64");
86 return argsToBeSpecialized & (((uint64_t)1) << argIdx);
87}
88
89/// Specialize functions for constant arguments. Such specialized functions are
90/// marked as noinline and simply invoke the original function with constant
91/// arguments. This call later gets inlined and optimized.
92class FunctionSpecializer {
93 /// Create a unique name for each specialization.
94 std::string createUniqueName(llvm::StringRef name) {
95 return llvm::Twine(name)
96 .concat("_")
97 .concat(llvm::Twine(uniqueIdx_++))
98 .concat("_specialized")
99 .str();
100 }
101
102 /// \returns True if the argument \p arg needs to be specialized in the
103 /// function.
104 /// NOTE: Currently, the decision is based on the type of the argument
105 /// \p arg and position of the arg \p argIdx. \p callee is not used. In the
106 /// future, we may need to improve this logic by taking into account the
107 /// semantics of the argument or even the specifics of the function call being
108 /// specialized.
109 bool shouldSpecializeParameter(llvm::Value *arg, unsigned argIdx,
110 llvm::Function *callee) {
111 // Don't specialize argument index exceeding 63 because we only have 64
112 // bitmap to index the arguments (check `isArgToBeSpecialized` and
113 // `addArgToBeSpecialized`)
114 if (argIdx > 63) {
115 return false;
116 }
117
118 // This flag force-specializes all arguments.
119 if (jitSpecializeAllArguments_) {
120 return true;
121 }
122
123 // Don't specialize arguments that we were requested to skip.
124 if (dontSpecializeArgsSet_.count(arg)) {
125 return false;
126 }
127
128 // We don't specialize arguments which are pointers to floating point and
129 // quantized buffers, because this is likely to significantly increase the
130 // code size without any big performance benefits.
131 if (arg->getType()->isPointerTy()) {
132 auto elemTy = cast<llvm::PointerType>(arg->getType())->getElementType();
133 // Bail if it is an FP buffer.
134 if (elemTy->isFloatTy()) {
135 return false;
136 }
137 // Bail if it is a quantized buffer.
138 if (elemTy->isIntegerTy(8)) {
139 return false;
140 }
141 }
142
143 // We specialize all other arguments, which typically represent dimensions
144 // of tensors, indices, size of batches, etc.
145 return true;
146 }
147
148 /// Find an existing specialization or create a new one.
149 /// \param CI the call that is being specialized.
150 /// \param F the function being specialized.
151 /// \param ArgsToBeSpecialized the set of arguments that should be
152 /// specialized. See SpecializationKey docs for the explanation of how this
153 /// information is encoded.
154 /// \returns a specialized version of the function for
155 /// provided parameters.
156 llvm::Function *getOrCreateSpecializedFunction(llvm::CallInst *call,
157 llvm::Function *F,
158 uint64_t argsToBeSpecialized) {
159 // Bail if there is nothing to do
160 if (!jitSpecializeAllArguments_ && !jitSpecializeDims)
161 return F;
162
163 // A key representing the function and arguments to be specialized.
164 SpecializationKey key{call, argsToBeSpecialized};
165 // Check if there is any existing specialization for this hash key already.
166 auto &specializedF = specializations_[key];
167 if (specializedF) {
168 auto specializedFnTy = specializedF->getFunctionType();
169 auto FnTy = F->getFunctionType();
170 (void)specializedFnTy;
171 (void)FnTy;
172 assert(
173 specializedFnTy->getReturnType() == FnTy->getReturnType() &&
174 "A function and its specialization should have the same return type");
175 // The specialized function only takes non-specialized parameters from the
176 // original function call. Check that the types of these parameters are
177 // the same for the original and the specialized function.
178 for (size_t argIdx = 0, specializedFnArgIdx = 0, e = F->arg_size();
179 argIdx < e; ++argIdx) {
180 // If the parameter is specialized, it is not present in the specialized
181 // function.
182 if (isArgToBeSpecialized(argsToBeSpecialized, argIdx))
183 continue;
184 // The parameter of the original call is not specialized and should be
185 // present in the specialized function.
186 assert(specializedFnTy->getParamType(specializedFnArgIdx) ==
187 FnTy->getParamType(argIdx) &&
188 "A function and its specialization should have the same "
189 "parameter type for non-constant arguments");
190 specializedFnArgIdx++;
191 }
192 NumSharedSpecializations++;
193 return specializedF;
194 }
195
196 std::string specializedName = createUniqueName(F->getName());
197
198 // We are going to clone the body of the original function and substitute
199 // constant values for the (constant) arguments that are going to be
200 // specialized. The LLVM's cloning function requires a map for its
201 // operation. All arguments mapped by this map are removed from the argument
202 // list of the specialized function.
203 llvm::ValueToValueMapTy VMap;
204 size_t argIdx = 0;
205 for (auto &arg : F->args()) {
206 // If this argument needs to be specialized, use its constant
207 // value from the call instruction.
208 if (isArgToBeSpecialized(argsToBeSpecialized, argIdx)) {
209 auto *argValue = call->getArgOperand(argIdx);
210 // Map the argument to a constant value.
211 VMap[&arg] = argValue;
212 }
213 argIdx++;
214 }
215
216 // Create a specialized function by cloning the body of the original
217 // function and substituting the values of constant arguments. The
218 // specialized function should be marked as noinline, to avoid code bloat.
219 specializedF = llvm::CloneFunction(F, VMap);
220 specializedF->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
221 assert(specializedF && "Could not create a specialized function");
222 // Specializations should not be inlined.
223 specializedF->addFnAttr(llvm::Attribute::AttrKind::NoInline);
224 specializedF->setName(specializedName);
225 // No need to explicitly emit a debug info for the specialized function. If
226 // the original function had it, the cloner would have automatically copied
227 // it into the specialized function. And if the original function did not
228 // have any debug info, then its specialization should not have any debug
229 // info either.
230 DEBUG_GLOW(llvm::dbgs() << "\n\nCreated specialized function "
231 << specializedName << "\n";
232 specializedF->print(llvm::errs(), nullptr));
233 NumSpecializations++;
234 return specializedF;
235 }
236
237 /// \returns true if a function is eligible for specialization.
238 bool isEligibleForSpecialization(const llvm::CallInst *call) {
239 // For now, specialize all functions invoked from "main". In the future, we
240 // may introduce more complex logic for making this decision. It could be
241 // based in the number of invocations of a function, number of its
242 // arguments, its code size, etc.
243 const auto *caller = call->getFunction();
244 const auto *callee = call->getCalledFunction();
245 // Specialized only calls inside main.
246 assert(std::find(entryFunctions_.begin(), entryFunctions_.end(), caller) !=
247 entryFunctions_.end() &&
248 "Only calls inside the entry function are specialized");
249 (void)caller;
250 // Do not specialize any LLVM internal functions.
251 if (callee && callee->getName().startswith("llvm.")) {
252 return false;
253 }
254 // Do not specialize declarations.
255 if (callee && callee->isDeclaration()) {
256 return false;
257 }
258 // Do not specialize calls if LLVMIRGen is against it.
259 if (!irgen_.isEligibleForSpecialization(call)) {
260 return false;
261 }
262 // Do not specialize noinline functions, because it does not improve
263 // anything.
264 return callee != nullptr &&
265 !callee->hasFnAttribute(llvm::Attribute::AttrKind::NoInline);
266 }
267
268public:
269 FunctionSpecializer(llvm::SmallVectorImpl<llvm::Function *> &entryFunctions,
270 llvm::DenseSet<llvm::Value *> &dontSpec, LLVMIRGen &irgen)
271 : entryFunctions_(entryFunctions), dontSpecializeArgsSet_(dontSpec),
272 irgen_(irgen) {}
273
274 /// Specialize a single call.
275 /// \returns the specialized Call instruction if it was possible to specialize
276 /// the call or nullptr otherwise.
277 llvm::CallInst *specializeCall(llvm::CallInst *call) {
278 llvm::IRBuilder<> builder(call->getParent());
279 auto *callee = call->getCalledFunction();
280 // Args to be used for calling the specialized function.
281 llvm::SmallVector<llvm::Value *, 16> argsForSpecialized;
282 // Set of arguments that need to be specialized. See SpecializationKey
283 // documentation for more information about the encoding of this set.
284 uint64_t argsToBeSpecialized = 0;
285
286 // Go over all call arguments.
287 // Check that all arguments are constants.
288 // Form the set of arguments to be specialized.
289 unsigned argIdx = 0;
290 for (auto &arg : call->arg_operands()) {
291 auto curArgIdx = argIdx++;
292
293 if (!shouldSpecializeParameter(arg, curArgIdx, callee)) {
294 argsForSpecialized.push_back(arg);
295 continue;
296 }
297
298 addArgToBeSpecialized(argsToBeSpecialized, curArgIdx);
299
300 // Bail if the values of arguments are not constants.
301 if (!getConstantValue(arg)) {
302 DEBUG_GLOW(llvm::dbgs() << "Could not specialize call:\n";
303 call->print(llvm::dbgs()));
304 return nullptr;
305 }
306 }
307
308 auto *specializedF =
309 getOrCreateSpecializedFunction(call, callee, argsToBeSpecialized);
310 // Generate a call of the specialized function before the current call
311 // instruction.
312 builder.SetInsertPoint(call);
313 return irgen_.createCall(builder, specializedF, argsForSpecialized);
314 }
315
316 void run() {
317 // Bail if there is nothing to be specialized.
318 if (!jitSpecializeDims && !jitSpecializeAllArguments_)
319 return;
320 // Collect calls that were replaced by specialized calls and can be erased.
321 // The removal should happen after all specializations are done, because
322 // these call instructions are used by the keys in Specializations_ map.
323 llvm::DenseMap<llvm::Instruction *, llvm::Instruction *>
324 callToSpecializedCall;
325 llvm::SmallVector<llvm::CallInst *, 64> calls;
326 for (auto *F : entryFunctions_) {
327 // Collect all eligable calls in the current function.
328 for (auto &BB : *F) {
329 for (auto &I : BB) {
330 auto *CI = dyn_cast<llvm::CallInst>(&I);
331 if (!CI)
332 continue;
333 if (!isEligibleForSpecialization(CI))
334 continue;
335 calls.push_back(CI);
336 }
337 }
338 }
339 // Try to specialize all the collected calls.
340 for (auto *call : calls) {
341 if (auto *specializedCall = specializeCall(call)) {
342 callToSpecializedCall.insert(std::make_pair(call, specializedCall));
343 }
344 }
345
346 // Remove those calls that were successfully replaced by calls of
347 // specialized functions. This needs to be done after all specializations,
348 // because keys of Specializations_ use these Call instructions for the
349 // duration of the whole specialization pass.
350 for (auto &kv : callToSpecializedCall) {
351 // Check if the original call returns a result and replace all its uses.
352 if (!kv.first->getType()->isVoidTy()) {
353 kv.first->replaceAllUsesWith(kv.second);
354 }
355 kv.first->eraseFromParent();
356 }
357 DEBUG_GLOW(llvm::dbgs() << "Number of specializations: "
358 << NumSpecializations << "\n";
359 llvm::dbgs() << "Number of shared specializations: "
360 << NumSharedSpecializations << "\n");
361 }
362
363private:
364 /// This is a key into the specialization table. It consists of the call
365 /// instruction and an integer encoding which arguments of this call should be
366 /// used for the hash computation. If the Nth bit is set, then the Nth
367 /// argument of the call should participate in the hash computation.
368 ///
369 /// This encoding heavily relies on the fact that LLVM constants are uniqued
370 /// internally and their equality can be checked by means of a simple
371 /// pointer comparison.
372 struct SpecializationKey {
373 SpecializationKey(llvm::CallInst *CI, uint64_t Args)
374 : call_(CI), argsToBeSpecialized_(Args) {}
375
376 /// The first call instruction that was used to create this specialization.
377 llvm::CallInst *call_{nullptr};
378 /// The set of argument numbers that need to be specialized.
379 uint64_t argsToBeSpecialized_{0};
380 };
381
382 /// A helper class providing a hash function for FunctionSpecializer.
383 struct SpecializationKeyHasher {
384 size_t operator()(const SpecializationKey &key) const {
385 // Take the name of the callee into account.
386 llvm::hash_code hash =
387 llvm::hash_value(key.call_->getCalledFunction()->getName());
388 // Hash over all arguments required by the \p ArgsToBeSpecialized_.
389 // We can compute the hash this way, because these arguments are LLVM
390 // constants which are uniqued. Therefore, the address of a constant is
391 // its unique representation.
392 for (unsigned idx = 0, e = key.call_->getNumArgOperands(); idx < e;
393 ++idx) {
394 if (isArgToBeSpecialized(key.argsToBeSpecialized_, idx)) {
395 hash = llvm::hash_combine(
396 hash, getConstantValue(key.call_->getArgOperand(idx)));
397 }
398 }
399 return hash;
400 }
401 };
402
403 /// A helper class providing the equality function for FunctionSpecializer.
404 struct SpecializationKeyEq {
405 bool operator()(const SpecializationKey &lhs,
406 const SpecializationKey &rhs) const {
407 if (lhs.call_->getCalledFunction() != rhs.call_->getCalledFunction())
408 return false;
409 if (lhs.argsToBeSpecialized_ != rhs.argsToBeSpecialized_)
410 return false;
411 for (unsigned idx = 0, e = lhs.call_->getNumArgOperands(); idx < e;
412 ++idx) {
413 if (isArgToBeSpecialized(lhs.argsToBeSpecialized_, idx)) {
414 if (getConstantValue(lhs.call_->getArgOperand(idx)) !=
415 getConstantValue(rhs.call_->getArgOperand(idx)))
416 return false;
417 }
418 }
419 return true;
420 }
421 };
422
423 /// The entry functions of the module.
424 llvm::SmallVectorImpl<llvm::Function *> &entryFunctions_;
425 /// Mapping from specialization keys to the specialized functions.
426 std::unordered_map<SpecializationKey, llvm::Function *,
427 SpecializationKeyHasher, SpecializationKeyEq>
428 specializations_;
429
430 /// An index to create unique specialization names.
431 unsigned uniqueIdx_{0};
432
433 /// If set, specialize taking into account the whole set of arguments,
434 /// including buffer addresses.
435 bool jitSpecializeAllArguments_{false};
436
437 /// A reference to a set of values that the specializer was requested not to
438 /// specialize.
439 llvm::DenseSet<llvm::Value *> &dontSpecializeArgsSet_;
440 /// LLVMIRGen to be used.
441 LLVMIRGen &irgen_;
442};
443
444} // namespace
445
446void LLVMIRGen::performSpecialization() {
447 FunctionSpecializer FuncSpecializer(emittedLLVMFunctions_,
448 dontSpecializeArgsSet_, *this);
449 FuncSpecializer.run();
450}
451