1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifdef TVM_LLVM_VERSION
21
22#include "llvm_instance.h"
23
24#include <dmlc/base.h>
25#include <llvm/ADT/ArrayRef.h>
26#include <llvm/ADT/StringRef.h>
27#if TVM_LLVM_VERSION >= 150
28#include <llvm/IR/FMF.h>
29#else
30#include <llvm/IR/Operator.h>
31#endif
32#include <llvm/IR/LLVMContext.h>
33#include <llvm/IR/Metadata.h>
34#include <llvm/IR/Module.h>
35#include <llvm/IRReader/IRReader.h>
36#if TVM_LLVM_VERSION >= 140
37#include <llvm/MC/TargetRegistry.h>
38#else
39#include <llvm/Support/TargetRegistry.h>
40#endif
41#include <llvm/Support/CodeGen.h>
42#include <llvm/Support/CommandLine.h>
43#include <llvm/Support/ErrorOr.h>
44#include <llvm/Support/Host.h>
45#include <llvm/Support/MemoryBuffer.h>
46#include <llvm/Support/SourceMgr.h>
47#include <llvm/Support/TargetSelect.h>
48#include <llvm/Support/raw_ostream.h>
49#include <llvm/Target/TargetMachine.h>
50#include <llvm/Target/TargetOptions.h>
51#include <tvm/runtime/container/array.h>
52#include <tvm/runtime/container/map.h>
53#include <tvm/runtime/container/optional.h>
54#include <tvm/runtime/container/string.h>
55#include <tvm/runtime/logging.h>
56#include <tvm/runtime/object.h>
57#include <tvm/target/target.h>
58
59#include <atomic>
60#include <cctype>
61#include <memory>
62#include <optional>
63#include <ostream>
64#include <sstream>
65#include <string>
66#include <system_error>
67#include <utility>
68
69namespace tvm {
70namespace codegen {
71
72namespace {
73namespace defaults {
74static const char* cpu = "generic";
75static const llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Aggressive;
76} // namespace defaults
77} // namespace
78
79namespace {
80bool InitializeLLVM() {
81 static std::atomic_flag initialized = ATOMIC_FLAG_INIT;
82 if (!initialized.test_and_set()) {
83 llvm::InitializeAllTargetInfos();
84 llvm::InitializeAllTargets();
85 llvm::InitializeAllTargetMCs();
86 llvm::InitializeAllAsmParsers();
87 llvm::InitializeAllAsmPrinters();
88 }
89 return true;
90}
91
92std::string Join(std::string sep, llvm::ArrayRef<std::string> strings) {
93 std::string result;
94 bool is_first = true;
95 for (const std::string& s : strings) {
96 if (!is_first) {
97 result += sep;
98 }
99 result += s;
100 is_first = false;
101 }
102 return result;
103}
104
105} // namespace
106
107// LLVMInstance
108
109LLVMInstance::LLVMInstance() {
110 // Call InitializeLLVM before anything else.
111 static const bool DMLC_ATTRIBUTE_UNUSED init_llvm = InitializeLLVM();
112 ctx_ = std::make_shared<llvm::LLVMContext>();
113}
114
115LLVMInstance::~LLVMInstance() = default;
116
117std::unique_ptr<llvm::Module> LLVMInstance::ParseIR(const std::string& llvm_ir) const {
118 auto buffer = llvm::MemoryBuffer::getMemBuffer(llvm_ir, /*BufferName=*/"",
119 /*RequiresNullTerminator=*/false);
120 return ParseBuffer(*buffer);
121}
122
123std::unique_ptr<llvm::Module> LLVMInstance::LoadIR(const std::string& file_name) const {
124 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> maybe_buffer =
125 llvm::MemoryBuffer::getFileAsStream(file_name);
126 if (std::error_code ec = maybe_buffer.getError()) {
127 LOG(FATAL) << ec.message();
128 }
129 return ParseBuffer(**maybe_buffer);
130}
131
132std::unique_ptr<llvm::Module> LLVMInstance::ParseBuffer(const llvm::MemoryBuffer& buffer) const {
133 llvm::SMDiagnostic error;
134 std::unique_ptr<llvm::Module> module = llvm::parseIR(buffer.getMemBufferRef(), error, *ctx_);
135 if (module == nullptr) {
136 std::string message;
137 llvm::raw_string_ostream ostream(message);
138 error.print(/*ProgName=*/nullptr, ostream, /*ShowColors=*/false, /*ShowKindLabel=*/true);
139 LOG(FATAL) << ostream.str();
140 }
141
142 return module;
143}
144
145// LLVMTargetInfo
146
147std::ostream& operator<<(std::ostream& os, const LLVMTargetInfo::Option& opt) {
148 os << '-' << opt.name;
149 switch (opt.type) {
150 case LLVMTargetInfo::Option::OptType::Bool:
151 return os << ":bool=" << (opt.value.b ? "true" : "false");
152 case LLVMTargetInfo::Option::OptType::Int:
153 return os << ":int=" << opt.value.i;
154 case LLVMTargetInfo::Option::OptType::UInt:
155 return os << ":uint=" << opt.value.u;
156 case LLVMTargetInfo::Option::OptType::String:
157 return os << ":string=" << opt.value.s;
158 default:
159 os << ":?(" << static_cast<int>(opt.type) << ")";
160 break;
161 }
162 return os;
163}
164
165LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
166 triple_ = target->GetAttr<String>("mtriple").value_or("default");
167
168 if (triple_.empty() || triple_ == "default") {
169 triple_ = llvm::sys::getDefaultTargetTriple();
170 }
171 cpu_ = target->GetAttr<String>("mcpu").value_or(defaults::cpu);
172
173 if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("mattr")) {
174 for (const String& s : v.value()) {
175 attrs_.push_back(s);
176 }
177 }
178
179 if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("cl-opt")) {
180 llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions();
181 bool parse_error = false;
182 for (const String& s : v.value()) {
183 Option opt = ParseOptionString(s);
184 if (opt.type == Option::OptType::Invalid) {
185 parse_error = true;
186 continue;
187 }
188 if (options.count(opt.name)) {
189 llvm_options_.push_back(opt);
190 } else {
191 // Flag an error, but don't abort. LLVM flags may change, and this would
192 // give the code a chance to run even if the option no longer applies.
193 LOG(ERROR) << "\"" << opt.name << "\" is not an LLVM option, option ignored";
194 }
195 }
196 ICHECK(!parse_error) << "there were errors parsing command-line options";
197 }
198
199 llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default;
200 if (const Optional<String>& v = target->GetAttr<String>("mfloat-abi")) {
201 String value = v.value();
202 if (value == "hard") {
203 float_abi = llvm::FloatABI::Hard;
204 } else if (value == "soft") {
205 float_abi = llvm::FloatABI::Soft;
206 } else {
207 LOG(FATAL) << "invalid -mfloat-abi option " << value;
208 }
209 }
210
211 // Target options
212
213#if TVM_LLVM_VERSION < 50
214 target_options_.LessPreciseFPMADOption = true;
215#endif
216 // In clang, these are fed from LangOpts which describe language specific features
217 // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags
218 target_options_.AllowFPOpFusion = llvm::FPOpFusion::Fast;
219 target_options_.UnsafeFPMath = false;
220 target_options_.NoInfsFPMath = false;
221 target_options_.NoNaNsFPMath = true;
222 target_options_.FloatABIType = float_abi;
223 if (const Optional<String>& v = target->GetAttr<String>("mabi")) {
224 target_options_.MCOptions.ABIName = v.value();
225 }
226
227 auto maybe_level = target->GetAttr<Integer>("opt-level");
228
229 if (maybe_level.defined()) {
230 int level = maybe_level.value()->value;
231 if (level <= 0) {
232 opt_level_ = llvm::CodeGenOpt::None;
233 } else if (level == 1) {
234 opt_level_ = llvm::CodeGenOpt::Less;
235 } else if (level == 2) {
236 opt_level_ = llvm::CodeGenOpt::Default;
237 } else {
238 // level >= 3
239 opt_level_ = llvm::CodeGenOpt::Aggressive;
240 }
241 } else {
242 opt_level_ = defaults::opt_level;
243 }
244
245 target_options_.UseInitArray = true;
246
247 // Fast math options
248
249 auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool {
250 return target->GetAttr<Bool>(flag.str()).value_or(Bool(false));
251 };
252 if (GetBoolFlag("fast-math")) {
253#if TVM_LLVM_VERSION >= 60
254 fast_math_flags_.setFast();
255#else
256 fast_math_flags_.setUnsafeAlgebra();
257#endif
258 } else {
259#if TVM_LLVM_VERSION >= 50
260 // This option was added in 5.x, and has a boolean argument,
261 // unlike the rest of options at the time.
262 fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract"));
263#endif
264#if TVM_LLVM_VERSION >= 70
265 fast_math_flags_.setNoNaNs(GetBoolFlag("fast-math-nnan"));
266 fast_math_flags_.setNoInfs(GetBoolFlag("fast-math-ninf"));
267 fast_math_flags_.setNoSignedZeros(GetBoolFlag("fast-math-nsz"));
268 fast_math_flags_.setAllowReciprocal(GetBoolFlag("fast-math-arcp"));
269 fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract"));
270 fast_math_flags_.setAllowReassoc(GetBoolFlag("fast-math-reassoc"));
271 fast_math_flags_.setApproxFunc(GetBoolFlag("fast-math-afn"));
272#else
273 // LLVM 4.x, 5.x, and 6.x
274 if (GetBoolFlag("fast-math-nnan")) fast_math_flags_.setNoNaNs();
275 if (GetBoolFlag("fast-math-ninf")) fast_math_flags_.setNoInfs();
276 if (GetBoolFlag("fast-math-nsz")) fast_math_flags_.setNoSignedZeros();
277 if (GetBoolFlag("fast-math-arcp")) fast_math_flags_.setAllowReciprocal();
278#if TVM_LLVM_VERSION >= 60
279 if (GetBoolFlag("fast-math-reassoc")) fast_math_flags_.setAllowReassoc();
280 if (GetBoolFlag("fast-math-afn")) fast_math_flags_.setApproxFunc();
281#endif
282#endif
283 }
284}
285
286LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& scope, const std::string& target_str)
287 : LLVMTargetInfo(scope, Target(target_str)) {}
288
289LLVMTargetInfo::~LLVMTargetInfo() = default;
290
291llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing) {
292 if (target_machine_) return target_machine_.get();
293
294 std::string error;
295 if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) {
296 llvm::TargetMachine* tm =
297 llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_,
298 reloc_model_, code_model_, opt_level_);
299 target_machine_ = std::unique_ptr<llvm::TargetMachine>(tm);
300 }
301 if (!allow_missing) {
302 ICHECK(target_machine_ != nullptr) << error;
303 }
304 return target_machine_.get();
305}
306
307std::string LLVMTargetInfo::GetTargetFeatureString() const { //
308 return Join(",", attrs_);
309}
310
311std::string LLVMTargetInfo::str() const {
312 std::ostringstream os;
313 os << "llvm";
314 if (!triple_.empty()) {
315 os << " -mtriple=" << triple_;
316 }
317 if (!cpu_.empty() && cpu_ != defaults::cpu) {
318 os << " -mcpu=" << cpu_;
319 }
320 if (!attrs_.empty()) {
321 os << " -mattr=" << GetTargetFeatureString();
322 }
323
324 switch (target_options_.FloatABIType) {
325 case llvm::FloatABI::Soft:
326 os << " -mfloat-abi=soft";
327 break;
328 case llvm::FloatABI::Hard:
329 os << " -mfloat-abi=hard";
330 break;
331 case llvm::FloatABI::Default:
332 break;
333 }
334 if (!target_options_.MCOptions.ABIName.empty()) {
335 os << " -mabi=" << target_options_.MCOptions.ABIName;
336 }
337
338 bool do_individual = true;
339#if TVM_LLVM_VERSION >= 60
340 if (fast_math_flags_.isFast()) {
341 os << " -fast-math";
342 do_individual = false;
343 }
344#else
345 if (fast_math_flags_.unsafeAlgebra()) {
346 os << " -fast-math";
347 do_individual = false;
348 }
349#endif
350
351 if (do_individual) {
352 if (fast_math_flags_.noNaNs()) os << " -fast-math-nnan";
353 if (fast_math_flags_.noInfs()) os << " -fast-math-ninf";
354 if (fast_math_flags_.noSignedZeros()) os << " -fast-math-nsz";
355 if (fast_math_flags_.allowReciprocal()) os << " -fast-math-arcp";
356#if TVM_LLVM_VERSION >= 50
357 if (fast_math_flags_.allowContract()) os << " -fast-math-contract";
358#endif
359#if TVM_LLVM_VERSION >= 60
360 if (fast_math_flags_.allowReassoc()) os << " -fast-math-reassoc";
361 if (fast_math_flags_.approxFunc()) os << " -fast-math-afn";
362#endif
363 }
364
365 if (opt_level_ != defaults::opt_level) {
366 os << " -opt-level=";
367 switch (opt_level_) {
368 case llvm::CodeGenOpt::None:
369 os << "0";
370 break;
371 case llvm::CodeGenOpt::Less:
372 os << "1";
373 break;
374 case llvm::CodeGenOpt::Default:
375 os << "2";
376 break;
377 case llvm::CodeGenOpt::Aggressive:
378 os << "3";
379 break;
380 }
381 }
382
383 if (size_t num = llvm_options_.size(); num > 0) {
384 os << " -cl-opt=";
385 std::vector<std::string> opts;
386 for (const Option& opt : llvm_options_) {
387 std::stringstream os;
388 os << opt;
389 opts.emplace_back(os.str());
390 }
391 auto* quote = num > 1 ? "'" : "";
392 os << quote << Join(",", opts) << quote;
393 }
394
395 return os.str();
396}
397
398LLVMTargetInfo::Option LLVMTargetInfo::ParseOptionString(const std::string& str) {
399 Option opt;
400 opt.type = Option::OptType::Invalid;
401
402 // Option string: "-"+ <option_name> ":" <type> "=" <value>
403 //
404 // Note: "-"+ means 1 or more dashes, but only "-" are "--" valid.
405
406 // The first step is to do "lexing" of the option string, i.e. to break
407 // it up into parts (like "tokens") according to the syntax above. These
408 // parts will be non-overlapping substrings of the option string, and
409 // concatenated together, they will be equal to the option string.
410 // The literal elements are parts on their own.
411 //
412 // Note that the option string may be malformed, so any of the literal
413 // elements in the syntax may be missing.
414
415 std::vector<std::string> parts;
416
417 auto find_first_of = [](const std::string& str, const std::string& chars, auto start = 0) {
418 auto pos = str.find_first_of(chars, start);
419 return pos != std::string::npos ? pos : str.size();
420 };
421 auto find_first_not_of = [](const std::string& str, const std::string& chars, auto start = 0) {
422 auto pos = str.find_first_not_of(chars, start);
423 return pos != std::string::npos ? pos : str.size();
424 };
425
426 // "-"+
427 std::string::size_type pos_start = 0, pos_end = str.size();
428 std::string::size_type pos_at = find_first_not_of(str, "-", pos_start);
429 if (pos_at > 0) {
430 parts.push_back(str.substr(pos_start, pos_at));
431 }
432 // <option_name>, always present, may be empty string
433 pos_start = pos_at;
434 pos_at = find_first_of(str, ":=", pos_start);
435 parts.push_back(str.substr(pos_start, pos_at - pos_start));
436
437 // ":" or "=", if any
438 pos_start = pos_at;
439 char c = pos_start < pos_end ? str[pos_start] : 0;
440 if (c != 0) {
441 parts.emplace_back(1, c);
442 pos_start++;
443 }
444 // If the character found in the previous step wasn't '=', look for '='.
445 if (c == ':') {
446 // <type>
447 pos_at = find_first_of(str, "=", pos_start);
448 if (pos_at > pos_start) { // if non-empty
449 parts.push_back(str.substr(pos_start, pos_at - pos_start));
450 }
451
452 // "="
453 if (pos_at < pos_end) {
454 parts.emplace_back(1, str[pos_at]);
455 pos_start = pos_at + 1;
456 }
457 }
458 if (pos_start < pos_end) {
459 // <value>
460 parts.push_back(str.substr(pos_start));
461 }
462
463 // After breaking up the option string, examine and validate the individual
464 // parts.
465
466 int part_this = 0, part_end = parts.size();
467
468 const std::string error_header = "while parsing option \"" + str + "\": ";
469
470 // Check for "-" or "--".
471 if (part_this < part_end) {
472 auto& p = parts[part_this++];
473 if ((p.size() != 1 && p.size() != 2) || p.find_first_not_of('-') != std::string::npos) {
474 LOG(ERROR) << error_header << "option must start with \"-\" or \"--\"";
475 return opt;
476 }
477 }
478
479 // Validate option name.
480 if (part_this < part_end) {
481 auto& p = parts[part_this++];
482 if (p.empty()) {
483 LOG(ERROR) << error_header << "option name must not be empty";
484 return opt;
485 }
486 opt.name = std::move(p);
487 }
488
489 // Check type, if present.
490 Option::OptType type = Option::OptType::Invalid;
491 if (part_this < part_end) {
492 auto& p0 = parts[part_this];
493 if (p0 == ":") {
494 part_this++; // Only advance if we saw ":".
495 if (part_this < part_end) {
496 auto& p1 = parts[part_this];
497 ICHECK(!p1.empty()) << "tokenizing error"; // This shouldn't happen.
498 if (p1 != "=") {
499 part_this++;
500 if (p1 == "bool") {
501 type = Option::OptType::Bool;
502 } else if (p1 == "int") {
503 type = Option::OptType::Int;
504 } else if (p1 == "uint") {
505 type = Option::OptType::UInt;
506 } else if (p1 == "string") {
507 type = Option::OptType::String;
508 }
509 }
510 }
511 // If there was ":", there must be a type.
512 if (type == Option::OptType::Invalid) {
513 LOG(ERROR) << error_header << "invalid type";
514 return opt;
515 }
516 }
517 }
518
519 // Check value, if present.
520 std::optional<std::string> value;
521 if (part_this < part_end) {
522 auto& p0 = parts[part_this];
523 if (p0 == "=") {
524 part_this++;
525 if (part_this < part_end) {
526 value = std::move(parts[part_this]);
527 } else {
528 value = "";
529 }
530 } else {
531 // If there are still any parts left to be processed, there must be "=".
532 LOG(ERROR) << error_header << "expecting \"=\"";
533 return opt;
534 }
535 }
536
537 // NOLINTNEXTLINE(runtime/int)
538 auto to_integer = [](const std::string& s) -> std::optional<long long> {
539 // std::stoll takes "long long"
540 long long number; // NOLINT(runtime/int)
541 size_t pos;
542 try {
543 number = std::stoll(s, &pos);
544 } catch (...) {
545 return std::nullopt;
546 }
547 if (pos == s.size()) {
548 return number;
549 } else {
550 return std::nullopt;
551 }
552 };
553
554 auto to_boolean = [&to_integer](const std::string& s) -> std::optional<bool> {
555 // Return 0 or 1, if string corresponds to a valid boolean value,
556 // otherwise return 2.
557 auto ti = to_integer(s);
558 if (ti.has_value() && (ti.value() == 0 || ti.value() == 1)) {
559 return static_cast<bool>(ti.value());
560 }
561
562 std::string lower;
563 std::transform(s.begin(), s.end(), std::back_inserter(lower),
564 [](unsigned char c) { return std::tolower(c); });
565 if (lower == "true") {
566 return true;
567 } else if (lower == "false") {
568 return false;
569 }
570 return std::nullopt;
571 };
572
573 if (value.has_value()) {
574 if (type == Option::OptType::Int || type == Option::OptType::UInt) {
575 auto v = to_integer(value.value());
576 if (!v.has_value()) {
577 LOG(ERROR) << error_header << "invalid integer value \"" << value.value() << "\"";
578 return opt;
579 }
580 if (type == Option::OptType::Int) {
581 opt.value.i = static_cast<int>(v.value());
582 if (opt.value.i != v.value()) {
583 LOG(WARNING) << error_header << "value exceeds int range, assuming " << opt.value.i;
584 }
585 } else {
586 // NOLINTNEXTLINE(runtime/int)
587 opt.value.u = static_cast<unsigned>(static_cast<unsigned long long>(v.value()));
588 if (opt.value.u != static_cast<unsigned long long>(v.value())) { // NOLINT(runtime/int)
589 LOG(WARNING) << error_header << "value exceeds int range, assuming " << opt.value.u;
590 }
591 }
592 } else if (type == Option::OptType::String) {
593 opt.value.s = std::move(value.value());
594 } else {
595 // "type" is either Bool (given explicitly) or Invalid (type not present in string)
596 auto v = to_boolean(value.value());
597 if (!v.has_value()) {
598 LOG(ERROR) << error_header << "invalid boolean value \"" << value.value() << "\"";
599 return opt;
600 }
601 opt.value.b = v.value();
602 type = Option::OptType::Bool;
603 }
604 } else {
605 // Value was not present in string. Assume "true" if "type" is Bool or Invalid
606 if (type == Option::OptType::Bool || type == Option::OptType::Invalid) {
607 opt.value.b = true;
608 type = Option::OptType::Bool;
609 } else {
610 LOG(ERROR) << error_header << "must have a value";
611 return opt;
612 }
613 }
614
615 ICHECK(type != Option::OptType::Invalid);
616 opt.type = type;
617 return opt;
618}
619
620bool LLVMTargetInfo::MatchesGlobalState() const {
621 for (const Option& opt : GetCommandLineOptions()) {
622 Option current_opt = opt;
623 GetOptionValue(&current_opt);
624 ICHECK(current_opt.type != Option::OptType::Invalid);
625 switch (current_opt.type) {
626 case Option::OptType::Bool:
627 if (current_opt.value.b != opt.value.b) return false;
628 continue;
629 case Option::OptType::Int:
630 if (current_opt.value.i != opt.value.i) return false;
631 continue;
632 case Option::OptType::UInt:
633 if (current_opt.value.u != opt.value.u) return false;
634 continue;
635 case Option::OptType::String:
636 if (current_opt.value.s != opt.value.s) return false;
637 continue;
638 default:; // NOLINT(whitespace/semicolon)
639 }
640 }
641 return true;
642}
643
644void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const {
645 llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions();
646 llvm::cl::Option* base_op = options[opt->name];
647
648 if (opt->type == Option::OptType::Bool) {
649 auto* bool_op = static_cast<llvm::cl::opt<bool>*>(base_op);
650 opt->value.b = bool_op->getValue();
651 } else if (opt->type == Option::OptType::Int) {
652 auto* int_op = static_cast<llvm::cl::opt<int>*>(base_op);
653 opt->value.i = int_op->getValue();
654 } else if (opt->type == Option::OptType::UInt) {
655 auto* uint_op = static_cast<llvm::cl::opt<unsigned>*>(base_op);
656 opt->value.u = uint_op->getValue();
657 } else if (opt->type == Option::OptType::String) {
658 auto* str_op = static_cast<llvm::cl::opt<std::string>*>(base_op);
659 opt->value.s = str_op->getValue();
660 } else {
661 opt->type = Option::OptType::Invalid;
662 }
663}
664
665// LLVMTarget
666
667bool LLVMTarget::modified_llvm_state_ = false;
668
669LLVMTarget::LLVMTarget(LLVMInstance& instance, const LLVMTargetInfo& target_info)
670 : LLVMTargetInfo(target_info), instance_(instance), ctx_(instance.GetContext()) {
671 // Populate the list of saved options with the current values.
672 for (const Option& opt : GetCommandLineOptions()) {
673 GetOptionValue(&saved_llvm_options_.emplace_back(opt));
674 }
675
676 if (modified_llvm_state_) {
677 ICHECK(!ApplyLLVMOptions(true));
678 } else {
679 modified_llvm_state_ = ApplyLLVMOptions(true);
680 }
681}
682
683LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target)
684 : LLVMTarget(instance, LLVMTargetInfo(instance, target)) {}
685
686LLVMTarget::LLVMTarget(LLVMInstance& scope, const std::string& target_str)
687 : LLVMTarget(scope, Target(target_str)) {}
688
689LLVMTarget::~LLVMTarget() {
690 // Revert all applied LLVM options.
691 if (ApplyLLVMOptions(false)) {
692 modified_llvm_state_ = false;
693 }
694}
695
696llvm::LLVMContext* LLVMTarget::GetContext() const {
697 ICHECK(!ctx_.expired()) << "LLVM scope has been deleted";
698 return ctx_.lock().get();
699}
700
701std::string LLVMTarget::GetTargetMetadata(const llvm::Module& module) {
702 if (llvm::Metadata* tvm_target = module.getModuleFlag("tvm_target")) {
703 auto* mdstr = llvm::cast<llvm::MDString>(tvm_target);
704 llvm::StringRef meta = mdstr->getString();
705 if (meta.startswith("llvm")) {
706 return meta.str();
707 }
708 }
709 return "llvm -mtriple " + module.getTargetTriple();
710}
711
712void LLVMTarget::SetTargetMetadata(llvm::Module* module) const {
713 module->addModuleFlag(llvm::Module::Warning, "tvm_target",
714 llvm::MDString::get(*GetContext(), str()));
715}
716
717bool LLVMTarget::ApplyLLVMOptions(bool apply_otherwise_revert, bool dry_run) {
718 llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions();
719 bool changed = false;
720
721#define HANDLE_OPTION_VALUE(option, new_val, saved_val) \
722 do { \
723 auto current = (option)->getValue(); \
724 auto replacement = apply_otherwise_revert ? (new_val) : (saved_val); \
725 if (current != replacement) { \
726 changed = true; \
727 if (!dry_run) { \
728 (option)->setValue(replacement); \
729 } \
730 } \
731 } while (false);
732
733 const auto& new_options = GetCommandLineOptions();
734 for (size_t i = 0, e = saved_llvm_options_.size(); i != e; ++i) {
735 const Option& new_opt = new_options[i];
736 const Option& saved_opt = saved_llvm_options_[i];
737
738 llvm::cl::Option* base_op = options[new_opt.name];
739
740 if (new_opt.type == Option::OptType::Bool) {
741 auto* bool_op = static_cast<llvm::cl::opt<bool>*>(base_op);
742 HANDLE_OPTION_VALUE(bool_op, new_opt.value.b, saved_opt.value.b);
743 } else if (new_opt.type == Option::OptType::Int) {
744 auto* int_op = static_cast<llvm::cl::opt<int>*>(base_op);
745 HANDLE_OPTION_VALUE(int_op, new_opt.value.i, saved_opt.value.i);
746 } else if (new_opt.type == Option::OptType::UInt) {
747 auto* uint_op = static_cast<llvm::cl::opt<unsigned>*>(base_op);
748 HANDLE_OPTION_VALUE(uint_op, new_opt.value.u, saved_opt.value.u);
749 } else if (new_opt.type == Option::OptType::String) {
750 auto* str_op = static_cast<llvm::cl::opt<std::string>*>(base_op);
751 HANDLE_OPTION_VALUE(str_op, new_opt.value.s, saved_opt.value.s);
752 } else {
753 LOG(FATAL) << "unexpected type in option " << new_opt;
754 }
755
756 if (dry_run && changed) {
757 return true;
758 }
759 }
760
761#undef HANDLE_OPTION_VALUE
762
763 return changed;
764}
765
766} // namespace codegen
767} // namespace tvm
768
769#endif // TVM_LLVM_VERSION
770