1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifdef TVM_LLVM_VERSION |
21 | |
22 | #include "llvm_instance.h" |
23 | |
24 | #include <dmlc/base.h> |
25 | #include <llvm/ADT/ArrayRef.h> |
26 | #include <llvm/ADT/StringRef.h> |
27 | #if TVM_LLVM_VERSION >= 150 |
28 | #include <llvm/IR/FMF.h> |
29 | #else |
30 | #include <llvm/IR/Operator.h> |
31 | #endif |
32 | #include <llvm/IR/LLVMContext.h> |
33 | #include <llvm/IR/Metadata.h> |
34 | #include <llvm/IR/Module.h> |
35 | #include <llvm/IRReader/IRReader.h> |
36 | #if TVM_LLVM_VERSION >= 140 |
37 | #include <llvm/MC/TargetRegistry.h> |
38 | #else |
39 | #include <llvm/Support/TargetRegistry.h> |
40 | #endif |
41 | #include <llvm/Support/CodeGen.h> |
42 | #include <llvm/Support/CommandLine.h> |
43 | #include <llvm/Support/ErrorOr.h> |
44 | #include <llvm/Support/Host.h> |
45 | #include <llvm/Support/MemoryBuffer.h> |
46 | #include <llvm/Support/SourceMgr.h> |
47 | #include <llvm/Support/TargetSelect.h> |
48 | #include <llvm/Support/raw_ostream.h> |
49 | #include <llvm/Target/TargetMachine.h> |
50 | #include <llvm/Target/TargetOptions.h> |
51 | #include <tvm/runtime/container/array.h> |
52 | #include <tvm/runtime/container/map.h> |
53 | #include <tvm/runtime/container/optional.h> |
54 | #include <tvm/runtime/container/string.h> |
55 | #include <tvm/runtime/logging.h> |
56 | #include <tvm/runtime/object.h> |
57 | #include <tvm/target/target.h> |
58 | |
59 | #include <atomic> |
60 | #include <cctype> |
61 | #include <memory> |
62 | #include <optional> |
63 | #include <ostream> |
64 | #include <sstream> |
65 | #include <string> |
66 | #include <system_error> |
67 | #include <utility> |
68 | |
69 | namespace tvm { |
70 | namespace codegen { |
71 | |
72 | namespace { |
73 | namespace defaults { |
74 | static const char* cpu = "generic" ; |
75 | static const llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Aggressive; |
76 | } // namespace defaults |
77 | } // namespace |
78 | |
79 | namespace { |
80 | bool InitializeLLVM() { |
81 | static std::atomic_flag initialized = ATOMIC_FLAG_INIT; |
82 | if (!initialized.test_and_set()) { |
83 | llvm::InitializeAllTargetInfos(); |
84 | llvm::InitializeAllTargets(); |
85 | llvm::InitializeAllTargetMCs(); |
86 | llvm::InitializeAllAsmParsers(); |
87 | llvm::InitializeAllAsmPrinters(); |
88 | } |
89 | return true; |
90 | } |
91 | |
92 | std::string Join(std::string sep, llvm::ArrayRef<std::string> strings) { |
93 | std::string result; |
94 | bool is_first = true; |
95 | for (const std::string& s : strings) { |
96 | if (!is_first) { |
97 | result += sep; |
98 | } |
99 | result += s; |
100 | is_first = false; |
101 | } |
102 | return result; |
103 | } |
104 | |
105 | } // namespace |
106 | |
107 | // LLVMInstance |
108 | |
109 | LLVMInstance::LLVMInstance() { |
110 | // Call InitializeLLVM before anything else. |
111 | static const bool DMLC_ATTRIBUTE_UNUSED init_llvm = InitializeLLVM(); |
112 | ctx_ = std::make_shared<llvm::LLVMContext>(); |
113 | } |
114 | |
115 | LLVMInstance::~LLVMInstance() = default; |
116 | |
117 | std::unique_ptr<llvm::Module> LLVMInstance::ParseIR(const std::string& llvm_ir) const { |
118 | auto buffer = llvm::MemoryBuffer::getMemBuffer(llvm_ir, /*BufferName=*/"" , |
119 | /*RequiresNullTerminator=*/false); |
120 | return ParseBuffer(*buffer); |
121 | } |
122 | |
123 | std::unique_ptr<llvm::Module> LLVMInstance::LoadIR(const std::string& file_name) const { |
124 | llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> maybe_buffer = |
125 | llvm::MemoryBuffer::getFileAsStream(file_name); |
126 | if (std::error_code ec = maybe_buffer.getError()) { |
127 | LOG(FATAL) << ec.message(); |
128 | } |
129 | return ParseBuffer(**maybe_buffer); |
130 | } |
131 | |
132 | std::unique_ptr<llvm::Module> LLVMInstance::ParseBuffer(const llvm::MemoryBuffer& buffer) const { |
133 | llvm::SMDiagnostic error; |
134 | std::unique_ptr<llvm::Module> module = llvm::parseIR(buffer.getMemBufferRef(), error, *ctx_); |
135 | if (module == nullptr) { |
136 | std::string message; |
137 | llvm::raw_string_ostream ostream(message); |
138 | error.print(/*ProgName=*/nullptr, ostream, /*ShowColors=*/false, /*ShowKindLabel=*/true); |
139 | LOG(FATAL) << ostream.str(); |
140 | } |
141 | |
142 | return module; |
143 | } |
144 | |
145 | // LLVMTargetInfo |
146 | |
147 | std::ostream& operator<<(std::ostream& os, const LLVMTargetInfo::Option& opt) { |
148 | os << '-' << opt.name; |
149 | switch (opt.type) { |
150 | case LLVMTargetInfo::Option::OptType::Bool: |
151 | return os << ":bool=" << (opt.value.b ? "true" : "false" ); |
152 | case LLVMTargetInfo::Option::OptType::Int: |
153 | return os << ":int=" << opt.value.i; |
154 | case LLVMTargetInfo::Option::OptType::UInt: |
155 | return os << ":uint=" << opt.value.u; |
156 | case LLVMTargetInfo::Option::OptType::String: |
157 | return os << ":string=" << opt.value.s; |
158 | default: |
159 | os << ":?(" << static_cast<int>(opt.type) << ")" ; |
160 | break; |
161 | } |
162 | return os; |
163 | } |
164 | |
165 | LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) { |
166 | triple_ = target->GetAttr<String>("mtriple" ).value_or("default" ); |
167 | |
168 | if (triple_.empty() || triple_ == "default" ) { |
169 | triple_ = llvm::sys::getDefaultTargetTriple(); |
170 | } |
171 | cpu_ = target->GetAttr<String>("mcpu" ).value_or(defaults::cpu); |
172 | |
173 | if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("mattr" )) { |
174 | for (const String& s : v.value()) { |
175 | attrs_.push_back(s); |
176 | } |
177 | } |
178 | |
179 | if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("cl-opt" )) { |
180 | llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions(); |
181 | bool parse_error = false; |
182 | for (const String& s : v.value()) { |
183 | Option opt = ParseOptionString(s); |
184 | if (opt.type == Option::OptType::Invalid) { |
185 | parse_error = true; |
186 | continue; |
187 | } |
188 | if (options.count(opt.name)) { |
189 | llvm_options_.push_back(opt); |
190 | } else { |
191 | // Flag an error, but don't abort. LLVM flags may change, and this would |
192 | // give the code a chance to run even if the option no longer applies. |
193 | LOG(ERROR) << "\"" << opt.name << "\" is not an LLVM option, option ignored" ; |
194 | } |
195 | } |
196 | ICHECK(!parse_error) << "there were errors parsing command-line options" ; |
197 | } |
198 | |
199 | llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; |
200 | if (const Optional<String>& v = target->GetAttr<String>("mfloat-abi" )) { |
201 | String value = v.value(); |
202 | if (value == "hard" ) { |
203 | float_abi = llvm::FloatABI::Hard; |
204 | } else if (value == "soft" ) { |
205 | float_abi = llvm::FloatABI::Soft; |
206 | } else { |
207 | LOG(FATAL) << "invalid -mfloat-abi option " << value; |
208 | } |
209 | } |
210 | |
211 | // Target options |
212 | |
213 | #if TVM_LLVM_VERSION < 50 |
214 | target_options_.LessPreciseFPMADOption = true; |
215 | #endif |
216 | // In clang, these are fed from LangOpts which describe language specific features |
217 | // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags |
218 | target_options_.AllowFPOpFusion = llvm::FPOpFusion::Fast; |
219 | target_options_.UnsafeFPMath = false; |
220 | target_options_.NoInfsFPMath = false; |
221 | target_options_.NoNaNsFPMath = true; |
222 | target_options_.FloatABIType = float_abi; |
223 | if (const Optional<String>& v = target->GetAttr<String>("mabi" )) { |
224 | target_options_.MCOptions.ABIName = v.value(); |
225 | } |
226 | |
227 | auto maybe_level = target->GetAttr<Integer>("opt-level" ); |
228 | |
229 | if (maybe_level.defined()) { |
230 | int level = maybe_level.value()->value; |
231 | if (level <= 0) { |
232 | opt_level_ = llvm::CodeGenOpt::None; |
233 | } else if (level == 1) { |
234 | opt_level_ = llvm::CodeGenOpt::Less; |
235 | } else if (level == 2) { |
236 | opt_level_ = llvm::CodeGenOpt::Default; |
237 | } else { |
238 | // level >= 3 |
239 | opt_level_ = llvm::CodeGenOpt::Aggressive; |
240 | } |
241 | } else { |
242 | opt_level_ = defaults::opt_level; |
243 | } |
244 | |
245 | target_options_.UseInitArray = true; |
246 | |
247 | // Fast math options |
248 | |
249 | auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { |
250 | return target->GetAttr<Bool>(flag.str()).value_or(Bool(false)); |
251 | }; |
252 | if (GetBoolFlag("fast-math" )) { |
253 | #if TVM_LLVM_VERSION >= 60 |
254 | fast_math_flags_.setFast(); |
255 | #else |
256 | fast_math_flags_.setUnsafeAlgebra(); |
257 | #endif |
258 | } else { |
259 | #if TVM_LLVM_VERSION >= 50 |
260 | // This option was added in 5.x, and has a boolean argument, |
261 | // unlike the rest of options at the time. |
262 | fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract" )); |
263 | #endif |
264 | #if TVM_LLVM_VERSION >= 70 |
265 | fast_math_flags_.setNoNaNs(GetBoolFlag("fast-math-nnan" )); |
266 | fast_math_flags_.setNoInfs(GetBoolFlag("fast-math-ninf" )); |
267 | fast_math_flags_.setNoSignedZeros(GetBoolFlag("fast-math-nsz" )); |
268 | fast_math_flags_.setAllowReciprocal(GetBoolFlag("fast-math-arcp" )); |
269 | fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract" )); |
270 | fast_math_flags_.setAllowReassoc(GetBoolFlag("fast-math-reassoc" )); |
271 | fast_math_flags_.setApproxFunc(GetBoolFlag("fast-math-afn" )); |
272 | #else |
273 | // LLVM 4.x, 5.x, and 6.x |
274 | if (GetBoolFlag("fast-math-nnan" )) fast_math_flags_.setNoNaNs(); |
275 | if (GetBoolFlag("fast-math-ninf" )) fast_math_flags_.setNoInfs(); |
276 | if (GetBoolFlag("fast-math-nsz" )) fast_math_flags_.setNoSignedZeros(); |
277 | if (GetBoolFlag("fast-math-arcp" )) fast_math_flags_.setAllowReciprocal(); |
278 | #if TVM_LLVM_VERSION >= 60 |
279 | if (GetBoolFlag("fast-math-reassoc" )) fast_math_flags_.setAllowReassoc(); |
280 | if (GetBoolFlag("fast-math-afn" )) fast_math_flags_.setApproxFunc(); |
281 | #endif |
282 | #endif |
283 | } |
284 | } |
285 | |
286 | LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& scope, const std::string& target_str) |
287 | : LLVMTargetInfo(scope, Target(target_str)) {} |
288 | |
289 | LLVMTargetInfo::~LLVMTargetInfo() = default; |
290 | |
291 | llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing) { |
292 | if (target_machine_) return target_machine_.get(); |
293 | |
294 | std::string error; |
295 | if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) { |
296 | llvm::TargetMachine* tm = |
297 | llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, |
298 | reloc_model_, code_model_, opt_level_); |
299 | target_machine_ = std::unique_ptr<llvm::TargetMachine>(tm); |
300 | } |
301 | if (!allow_missing) { |
302 | ICHECK(target_machine_ != nullptr) << error; |
303 | } |
304 | return target_machine_.get(); |
305 | } |
306 | |
307 | std::string LLVMTargetInfo::GetTargetFeatureString() const { // |
308 | return Join("," , attrs_); |
309 | } |
310 | |
311 | std::string LLVMTargetInfo::str() const { |
312 | std::ostringstream os; |
313 | os << "llvm" ; |
314 | if (!triple_.empty()) { |
315 | os << " -mtriple=" << triple_; |
316 | } |
317 | if (!cpu_.empty() && cpu_ != defaults::cpu) { |
318 | os << " -mcpu=" << cpu_; |
319 | } |
320 | if (!attrs_.empty()) { |
321 | os << " -mattr=" << GetTargetFeatureString(); |
322 | } |
323 | |
324 | switch (target_options_.FloatABIType) { |
325 | case llvm::FloatABI::Soft: |
326 | os << " -mfloat-abi=soft" ; |
327 | break; |
328 | case llvm::FloatABI::Hard: |
329 | os << " -mfloat-abi=hard" ; |
330 | break; |
331 | case llvm::FloatABI::Default: |
332 | break; |
333 | } |
334 | if (!target_options_.MCOptions.ABIName.empty()) { |
335 | os << " -mabi=" << target_options_.MCOptions.ABIName; |
336 | } |
337 | |
338 | bool do_individual = true; |
339 | #if TVM_LLVM_VERSION >= 60 |
340 | if (fast_math_flags_.isFast()) { |
341 | os << " -fast-math" ; |
342 | do_individual = false; |
343 | } |
344 | #else |
345 | if (fast_math_flags_.unsafeAlgebra()) { |
346 | os << " -fast-math" ; |
347 | do_individual = false; |
348 | } |
349 | #endif |
350 | |
351 | if (do_individual) { |
352 | if (fast_math_flags_.noNaNs()) os << " -fast-math-nnan" ; |
353 | if (fast_math_flags_.noInfs()) os << " -fast-math-ninf" ; |
354 | if (fast_math_flags_.noSignedZeros()) os << " -fast-math-nsz" ; |
355 | if (fast_math_flags_.allowReciprocal()) os << " -fast-math-arcp" ; |
356 | #if TVM_LLVM_VERSION >= 50 |
357 | if (fast_math_flags_.allowContract()) os << " -fast-math-contract" ; |
358 | #endif |
359 | #if TVM_LLVM_VERSION >= 60 |
360 | if (fast_math_flags_.allowReassoc()) os << " -fast-math-reassoc" ; |
361 | if (fast_math_flags_.approxFunc()) os << " -fast-math-afn" ; |
362 | #endif |
363 | } |
364 | |
365 | if (opt_level_ != defaults::opt_level) { |
366 | os << " -opt-level=" ; |
367 | switch (opt_level_) { |
368 | case llvm::CodeGenOpt::None: |
369 | os << "0" ; |
370 | break; |
371 | case llvm::CodeGenOpt::Less: |
372 | os << "1" ; |
373 | break; |
374 | case llvm::CodeGenOpt::Default: |
375 | os << "2" ; |
376 | break; |
377 | case llvm::CodeGenOpt::Aggressive: |
378 | os << "3" ; |
379 | break; |
380 | } |
381 | } |
382 | |
383 | if (size_t num = llvm_options_.size(); num > 0) { |
384 | os << " -cl-opt=" ; |
385 | std::vector<std::string> opts; |
386 | for (const Option& opt : llvm_options_) { |
387 | std::stringstream os; |
388 | os << opt; |
389 | opts.emplace_back(os.str()); |
390 | } |
391 | auto* quote = num > 1 ? "'" : "" ; |
392 | os << quote << Join("," , opts) << quote; |
393 | } |
394 | |
395 | return os.str(); |
396 | } |
397 | |
398 | LLVMTargetInfo::Option LLVMTargetInfo::ParseOptionString(const std::string& str) { |
399 | Option opt; |
400 | opt.type = Option::OptType::Invalid; |
401 | |
402 | // Option string: "-"+ <option_name> ":" <type> "=" <value> |
403 | // |
404 | // Note: "-"+ means 1 or more dashes, but only "-" are "--" valid. |
405 | |
406 | // The first step is to do "lexing" of the option string, i.e. to break |
407 | // it up into parts (like "tokens") according to the syntax above. These |
408 | // parts will be non-overlapping substrings of the option string, and |
409 | // concatenated together, they will be equal to the option string. |
410 | // The literal elements are parts on their own. |
411 | // |
412 | // Note that the option string may be malformed, so any of the literal |
413 | // elements in the syntax may be missing. |
414 | |
415 | std::vector<std::string> parts; |
416 | |
417 | auto find_first_of = [](const std::string& str, const std::string& chars, auto start = 0) { |
418 | auto pos = str.find_first_of(chars, start); |
419 | return pos != std::string::npos ? pos : str.size(); |
420 | }; |
421 | auto find_first_not_of = [](const std::string& str, const std::string& chars, auto start = 0) { |
422 | auto pos = str.find_first_not_of(chars, start); |
423 | return pos != std::string::npos ? pos : str.size(); |
424 | }; |
425 | |
426 | // "-"+ |
427 | std::string::size_type pos_start = 0, pos_end = str.size(); |
428 | std::string::size_type pos_at = find_first_not_of(str, "-" , pos_start); |
429 | if (pos_at > 0) { |
430 | parts.push_back(str.substr(pos_start, pos_at)); |
431 | } |
432 | // <option_name>, always present, may be empty string |
433 | pos_start = pos_at; |
434 | pos_at = find_first_of(str, ":=" , pos_start); |
435 | parts.push_back(str.substr(pos_start, pos_at - pos_start)); |
436 | |
437 | // ":" or "=", if any |
438 | pos_start = pos_at; |
439 | char c = pos_start < pos_end ? str[pos_start] : 0; |
440 | if (c != 0) { |
441 | parts.emplace_back(1, c); |
442 | pos_start++; |
443 | } |
444 | // If the character found in the previous step wasn't '=', look for '='. |
445 | if (c == ':') { |
446 | // <type> |
447 | pos_at = find_first_of(str, "=" , pos_start); |
448 | if (pos_at > pos_start) { // if non-empty |
449 | parts.push_back(str.substr(pos_start, pos_at - pos_start)); |
450 | } |
451 | |
452 | // "=" |
453 | if (pos_at < pos_end) { |
454 | parts.emplace_back(1, str[pos_at]); |
455 | pos_start = pos_at + 1; |
456 | } |
457 | } |
458 | if (pos_start < pos_end) { |
459 | // <value> |
460 | parts.push_back(str.substr(pos_start)); |
461 | } |
462 | |
463 | // After breaking up the option string, examine and validate the individual |
464 | // parts. |
465 | |
466 | int part_this = 0, part_end = parts.size(); |
467 | |
468 | const std::string = "while parsing option \"" + str + "\": " ; |
469 | |
470 | // Check for "-" or "--". |
471 | if (part_this < part_end) { |
472 | auto& p = parts[part_this++]; |
473 | if ((p.size() != 1 && p.size() != 2) || p.find_first_not_of('-') != std::string::npos) { |
474 | LOG(ERROR) << error_header << "option must start with \"-\" or \"--\"" ; |
475 | return opt; |
476 | } |
477 | } |
478 | |
479 | // Validate option name. |
480 | if (part_this < part_end) { |
481 | auto& p = parts[part_this++]; |
482 | if (p.empty()) { |
483 | LOG(ERROR) << error_header << "option name must not be empty" ; |
484 | return opt; |
485 | } |
486 | opt.name = std::move(p); |
487 | } |
488 | |
489 | // Check type, if present. |
490 | Option::OptType type = Option::OptType::Invalid; |
491 | if (part_this < part_end) { |
492 | auto& p0 = parts[part_this]; |
493 | if (p0 == ":" ) { |
494 | part_this++; // Only advance if we saw ":". |
495 | if (part_this < part_end) { |
496 | auto& p1 = parts[part_this]; |
497 | ICHECK(!p1.empty()) << "tokenizing error" ; // This shouldn't happen. |
498 | if (p1 != "=" ) { |
499 | part_this++; |
500 | if (p1 == "bool" ) { |
501 | type = Option::OptType::Bool; |
502 | } else if (p1 == "int" ) { |
503 | type = Option::OptType::Int; |
504 | } else if (p1 == "uint" ) { |
505 | type = Option::OptType::UInt; |
506 | } else if (p1 == "string" ) { |
507 | type = Option::OptType::String; |
508 | } |
509 | } |
510 | } |
511 | // If there was ":", there must be a type. |
512 | if (type == Option::OptType::Invalid) { |
513 | LOG(ERROR) << error_header << "invalid type" ; |
514 | return opt; |
515 | } |
516 | } |
517 | } |
518 | |
519 | // Check value, if present. |
520 | std::optional<std::string> value; |
521 | if (part_this < part_end) { |
522 | auto& p0 = parts[part_this]; |
523 | if (p0 == "=" ) { |
524 | part_this++; |
525 | if (part_this < part_end) { |
526 | value = std::move(parts[part_this]); |
527 | } else { |
528 | value = "" ; |
529 | } |
530 | } else { |
531 | // If there are still any parts left to be processed, there must be "=". |
532 | LOG(ERROR) << error_header << "expecting \"=\"" ; |
533 | return opt; |
534 | } |
535 | } |
536 | |
537 | // NOLINTNEXTLINE(runtime/int) |
538 | auto to_integer = [](const std::string& s) -> std::optional<long long> { |
539 | // std::stoll takes "long long" |
540 | long long number; // NOLINT(runtime/int) |
541 | size_t pos; |
542 | try { |
543 | number = std::stoll(s, &pos); |
544 | } catch (...) { |
545 | return std::nullopt; |
546 | } |
547 | if (pos == s.size()) { |
548 | return number; |
549 | } else { |
550 | return std::nullopt; |
551 | } |
552 | }; |
553 | |
554 | auto to_boolean = [&to_integer](const std::string& s) -> std::optional<bool> { |
555 | // Return 0 or 1, if string corresponds to a valid boolean value, |
556 | // otherwise return 2. |
557 | auto ti = to_integer(s); |
558 | if (ti.has_value() && (ti.value() == 0 || ti.value() == 1)) { |
559 | return static_cast<bool>(ti.value()); |
560 | } |
561 | |
562 | std::string lower; |
563 | std::transform(s.begin(), s.end(), std::back_inserter(lower), |
564 | [](unsigned char c) { return std::tolower(c); }); |
565 | if (lower == "true" ) { |
566 | return true; |
567 | } else if (lower == "false" ) { |
568 | return false; |
569 | } |
570 | return std::nullopt; |
571 | }; |
572 | |
573 | if (value.has_value()) { |
574 | if (type == Option::OptType::Int || type == Option::OptType::UInt) { |
575 | auto v = to_integer(value.value()); |
576 | if (!v.has_value()) { |
577 | LOG(ERROR) << error_header << "invalid integer value \"" << value.value() << "\"" ; |
578 | return opt; |
579 | } |
580 | if (type == Option::OptType::Int) { |
581 | opt.value.i = static_cast<int>(v.value()); |
582 | if (opt.value.i != v.value()) { |
583 | LOG(WARNING) << error_header << "value exceeds int range, assuming " << opt.value.i; |
584 | } |
585 | } else { |
586 | // NOLINTNEXTLINE(runtime/int) |
587 | opt.value.u = static_cast<unsigned>(static_cast<unsigned long long>(v.value())); |
588 | if (opt.value.u != static_cast<unsigned long long>(v.value())) { // NOLINT(runtime/int) |
589 | LOG(WARNING) << error_header << "value exceeds int range, assuming " << opt.value.u; |
590 | } |
591 | } |
592 | } else if (type == Option::OptType::String) { |
593 | opt.value.s = std::move(value.value()); |
594 | } else { |
595 | // "type" is either Bool (given explicitly) or Invalid (type not present in string) |
596 | auto v = to_boolean(value.value()); |
597 | if (!v.has_value()) { |
598 | LOG(ERROR) << error_header << "invalid boolean value \"" << value.value() << "\"" ; |
599 | return opt; |
600 | } |
601 | opt.value.b = v.value(); |
602 | type = Option::OptType::Bool; |
603 | } |
604 | } else { |
605 | // Value was not present in string. Assume "true" if "type" is Bool or Invalid |
606 | if (type == Option::OptType::Bool || type == Option::OptType::Invalid) { |
607 | opt.value.b = true; |
608 | type = Option::OptType::Bool; |
609 | } else { |
610 | LOG(ERROR) << error_header << "must have a value" ; |
611 | return opt; |
612 | } |
613 | } |
614 | |
615 | ICHECK(type != Option::OptType::Invalid); |
616 | opt.type = type; |
617 | return opt; |
618 | } |
619 | |
620 | bool LLVMTargetInfo::MatchesGlobalState() const { |
621 | for (const Option& opt : GetCommandLineOptions()) { |
622 | Option current_opt = opt; |
623 | GetOptionValue(¤t_opt); |
624 | ICHECK(current_opt.type != Option::OptType::Invalid); |
625 | switch (current_opt.type) { |
626 | case Option::OptType::Bool: |
627 | if (current_opt.value.b != opt.value.b) return false; |
628 | continue; |
629 | case Option::OptType::Int: |
630 | if (current_opt.value.i != opt.value.i) return false; |
631 | continue; |
632 | case Option::OptType::UInt: |
633 | if (current_opt.value.u != opt.value.u) return false; |
634 | continue; |
635 | case Option::OptType::String: |
636 | if (current_opt.value.s != opt.value.s) return false; |
637 | continue; |
638 | default:; // NOLINT(whitespace/semicolon) |
639 | } |
640 | } |
641 | return true; |
642 | } |
643 | |
644 | void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const { |
645 | llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions(); |
646 | llvm::cl::Option* base_op = options[opt->name]; |
647 | |
648 | if (opt->type == Option::OptType::Bool) { |
649 | auto* bool_op = static_cast<llvm::cl::opt<bool>*>(base_op); |
650 | opt->value.b = bool_op->getValue(); |
651 | } else if (opt->type == Option::OptType::Int) { |
652 | auto* int_op = static_cast<llvm::cl::opt<int>*>(base_op); |
653 | opt->value.i = int_op->getValue(); |
654 | } else if (opt->type == Option::OptType::UInt) { |
655 | auto* uint_op = static_cast<llvm::cl::opt<unsigned>*>(base_op); |
656 | opt->value.u = uint_op->getValue(); |
657 | } else if (opt->type == Option::OptType::String) { |
658 | auto* str_op = static_cast<llvm::cl::opt<std::string>*>(base_op); |
659 | opt->value.s = str_op->getValue(); |
660 | } else { |
661 | opt->type = Option::OptType::Invalid; |
662 | } |
663 | } |
664 | |
665 | // LLVMTarget |
666 | |
667 | bool LLVMTarget::modified_llvm_state_ = false; |
668 | |
669 | LLVMTarget::LLVMTarget(LLVMInstance& instance, const LLVMTargetInfo& target_info) |
670 | : LLVMTargetInfo(target_info), instance_(instance), ctx_(instance.GetContext()) { |
671 | // Populate the list of saved options with the current values. |
672 | for (const Option& opt : GetCommandLineOptions()) { |
673 | GetOptionValue(&saved_llvm_options_.emplace_back(opt)); |
674 | } |
675 | |
676 | if (modified_llvm_state_) { |
677 | ICHECK(!ApplyLLVMOptions(true)); |
678 | } else { |
679 | modified_llvm_state_ = ApplyLLVMOptions(true); |
680 | } |
681 | } |
682 | |
683 | LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) |
684 | : LLVMTarget(instance, LLVMTargetInfo(instance, target)) {} |
685 | |
686 | LLVMTarget::LLVMTarget(LLVMInstance& scope, const std::string& target_str) |
687 | : LLVMTarget(scope, Target(target_str)) {} |
688 | |
689 | LLVMTarget::~LLVMTarget() { |
690 | // Revert all applied LLVM options. |
691 | if (ApplyLLVMOptions(false)) { |
692 | modified_llvm_state_ = false; |
693 | } |
694 | } |
695 | |
696 | llvm::LLVMContext* LLVMTarget::GetContext() const { |
697 | ICHECK(!ctx_.expired()) << "LLVM scope has been deleted" ; |
698 | return ctx_.lock().get(); |
699 | } |
700 | |
701 | std::string LLVMTarget::GetTargetMetadata(const llvm::Module& module) { |
702 | if (llvm::Metadata* tvm_target = module.getModuleFlag("tvm_target" )) { |
703 | auto* mdstr = llvm::cast<llvm::MDString>(tvm_target); |
704 | llvm::StringRef meta = mdstr->getString(); |
705 | if (meta.startswith("llvm" )) { |
706 | return meta.str(); |
707 | } |
708 | } |
709 | return "llvm -mtriple " + module.getTargetTriple(); |
710 | } |
711 | |
712 | void LLVMTarget::SetTargetMetadata(llvm::Module* module) const { |
713 | module->addModuleFlag(llvm::Module::Warning, "tvm_target" , |
714 | llvm::MDString::get(*GetContext(), str())); |
715 | } |
716 | |
717 | bool LLVMTarget::ApplyLLVMOptions(bool apply_otherwise_revert, bool dry_run) { |
718 | llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions(); |
719 | bool changed = false; |
720 | |
721 | #define HANDLE_OPTION_VALUE(option, new_val, saved_val) \ |
722 | do { \ |
723 | auto current = (option)->getValue(); \ |
724 | auto replacement = apply_otherwise_revert ? (new_val) : (saved_val); \ |
725 | if (current != replacement) { \ |
726 | changed = true; \ |
727 | if (!dry_run) { \ |
728 | (option)->setValue(replacement); \ |
729 | } \ |
730 | } \ |
731 | } while (false); |
732 | |
733 | const auto& new_options = GetCommandLineOptions(); |
734 | for (size_t i = 0, e = saved_llvm_options_.size(); i != e; ++i) { |
735 | const Option& new_opt = new_options[i]; |
736 | const Option& saved_opt = saved_llvm_options_[i]; |
737 | |
738 | llvm::cl::Option* base_op = options[new_opt.name]; |
739 | |
740 | if (new_opt.type == Option::OptType::Bool) { |
741 | auto* bool_op = static_cast<llvm::cl::opt<bool>*>(base_op); |
742 | HANDLE_OPTION_VALUE(bool_op, new_opt.value.b, saved_opt.value.b); |
743 | } else if (new_opt.type == Option::OptType::Int) { |
744 | auto* int_op = static_cast<llvm::cl::opt<int>*>(base_op); |
745 | HANDLE_OPTION_VALUE(int_op, new_opt.value.i, saved_opt.value.i); |
746 | } else if (new_opt.type == Option::OptType::UInt) { |
747 | auto* uint_op = static_cast<llvm::cl::opt<unsigned>*>(base_op); |
748 | HANDLE_OPTION_VALUE(uint_op, new_opt.value.u, saved_opt.value.u); |
749 | } else if (new_opt.type == Option::OptType::String) { |
750 | auto* str_op = static_cast<llvm::cl::opt<std::string>*>(base_op); |
751 | HANDLE_OPTION_VALUE(str_op, new_opt.value.s, saved_opt.value.s); |
752 | } else { |
753 | LOG(FATAL) << "unexpected type in option " << new_opt; |
754 | } |
755 | |
756 | if (dry_run && changed) { |
757 | return true; |
758 | } |
759 | } |
760 | |
761 | #undef HANDLE_OPTION_VALUE |
762 | |
763 | return changed; |
764 | } |
765 | |
766 | } // namespace codegen |
767 | } // namespace tvm |
768 | |
769 | #endif // TVM_LLVM_VERSION |
770 | |