1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * Compile executable modules.
21 * \file src/target/target.cc
22 */
23#include <dmlc/thread_local.h>
24#include <tvm/runtime/device_api.h>
25#include <tvm/runtime/logging.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/target/tag.h>
28#include <tvm/target/target.h>
29#include <tvm/target/target_kind.h>
30#include <tvm/tir/expr.h>
31
32#include <algorithm>
33#include <cctype>
34#include <ios>
35#include <sstream>
36#include <stack>
37#include <string>
38#include <unordered_map>
39#include <utility>
40#include <vector>
41
42#include "../runtime/object_internal.h"
43
44namespace tvm {
45
46TVM_REGISTER_NODE_TYPE(TargetNode);
47
48class TargetInternal {
49 public:
50 static void EnterScope(Target target) { target.EnterWithScope(); }
51 static void ExitScope(Target target) { target.ExitWithScope(); }
52 static Map<String, ObjectRef> Export(Target target) { return target->Export(); }
53 static const TargetKindNode::ValueTypeInfo& FindTypeInfo(const TargetKind& kind,
54 const std::string& key);
55 static Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs);
56 static ObjectRef ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info);
57 static ObjectRef ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info);
58 static ObjectPtr<Object> FromString(const String& tag_or_config_or_target_str);
59 static ObjectPtr<Object> FromConfigString(const String& config_str);
60 static ObjectPtr<Object> FromRawString(const String& target_str);
61 static ObjectPtr<Object> FromConfig(Map<String, ObjectRef> config);
62 static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv);
63 static Target WithHost(const Target& target, const Target& target_host) {
64 ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get());
65 n->host = target_host;
66 return (Target)n;
67 }
68
69 private:
70 static std::unordered_map<String, ObjectRef> QueryDevice(int device_id, const TargetNode* target);
71 static bool IsQuoted(const std::string& str);
72 static std::string Quote(const std::string& str);
73 static std::string JoinString(const std::vector<std::string>& array, char separator);
74 static std::vector<std::string> SplitString(const std::string& str, char separator);
75 static std::string Interpret(const std::string& str);
76 static std::string Uninterpret(const std::string& str);
77 static std::string StringifyAtomicType(const ObjectRef& obj);
78 static std::string StringifyArray(const ArrayNode& array);
79
80 static constexpr char quote = '\'';
81 static constexpr char escape = '\\';
82};
83
84/********** Helper functions **********/
85Target Target::WithHost(const Target& target, const Target& host) {
86 return TargetInternal::WithHost(target, host);
87}
88
89void CheckAndUpdateHostConsistency(Target* target, Target* host) {
90 *target = Target(*target, *host);
91 *host = (*target)->GetHost().value_or(Target());
92}
93
94void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* host) {
95 Map<Target, IRModule> new_targets;
96 for (auto& it : *targets) {
97 auto target = it.first;
98 CheckAndUpdateHostConsistency(&target, host);
99 new_targets.Set(target, it.second);
100 }
101 *targets = new_targets;
102}
103
104static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) {
105 std::vector<String> new_keys;
106 for (size_t i = 0; i < keys.size(); ++i) {
107 bool found = false;
108 for (size_t j = 0; j < i; ++j) {
109 if (keys[i] == keys[j]) {
110 found = true;
111 break;
112 }
113 }
114 if (!found) {
115 new_keys.push_back(keys[i]);
116 }
117 }
118 return new_keys;
119}
120
121template <class TObj>
122static const TObj* ObjTypeCheck(const ObjectRef& obj, const std::string& expected_type) {
123 const TObj* ptr = obj.as<TObj>();
124 if (ptr == nullptr) {
125 std::ostringstream os;
126 os << ": Expects type \"" << expected_type << "\", but gets \"" << obj->GetTypeKey()
127 << "\" for object: " << obj;
128 throw Error(os.str());
129 }
130 return ptr;
131}
132
133static TargetKind GetTargetKind(const String& name) {
134 Optional<TargetKind> kind = TargetKind::Get(name);
135 if (!kind.defined()) {
136 throw Error(": Target kind \"" + name + "\" is not defined");
137 }
138 return kind.value();
139}
140
141static std::string RemovePrefixDashes(const std::string& s) {
142 int n_dashes = 0;
143 int len = s.length();
144 for (; n_dashes < len && s[n_dashes] == '-'; ++n_dashes) {
145 }
146 if (n_dashes == 0) {
147 throw Error(": Attribute keys should start with '-', not an attribute key: " + s);
148 }
149 if (n_dashes >= len) {
150 throw Error(": Not an attribute key: " + s);
151 }
152 return s.substr(n_dashes);
153}
154
155bool TargetInternal::IsQuoted(const std::string& str) {
156 std::string::size_type start = 0, end = str.size();
157 if (end < 2 || str[start] != quote || str[end - 1] != quote) {
158 return false;
159 }
160 bool escaping = false;
161 for (auto i = start + 1, e = end - 1; i < e; ++i) {
162 if (escaping) {
163 escaping = false;
164 } else if (str[i] == escape) {
165 escaping = true;
166 } else if (str[i] == quote) {
167 return false;
168 }
169 }
170 // If the reduced string ends with \, then the terminating quote is escaped.
171 return !escaping;
172}
173
174std::string TargetInternal::Quote(const std::string& str) {
175 std::string result(1, quote);
176 result.append(str);
177 result.push_back(quote);
178 return result;
179}
180
181std::string TargetInternal::JoinString(const std::vector<std::string>& array, char separator) {
182 std::string result;
183 ICHECK(separator != quote && separator != escape)
184 << "string join separator cannot be " << quote << " or " << escape;
185
186 bool is_first = true;
187 for (const auto& s : array) {
188 if (!is_first) {
189 result.push_back(separator);
190 }
191 result.append(s);
192 is_first = false;
193 }
194
195 return result;
196}
197
198std::vector<std::string> TargetInternal::SplitString(const std::string& str, char separator) {
199 std::vector<std::string> output;
200
201 const char* start = str.data();
202 const char* end = start + str.size();
203 const char* pos = start;
204
205 std::stringstream current_word;
206
207 auto finish_word = [&]() {
208 std::string word = current_word.str();
209 if (word.size()) {
210 output.push_back(word);
211 current_word.str("");
212 }
213 };
214
215 bool pos_quoted = false;
216
217 while (pos < end) {
218 if ((*pos == separator) && !pos_quoted) {
219 finish_word();
220 pos++;
221 } else if (*pos == escape && pos + 1 < end) {
222 current_word << escape;
223 current_word << pos[1];
224 pos += 2;
225 } else if (*pos == quote) {
226 current_word << quote;
227 pos_quoted = !pos_quoted;
228 pos++;
229 } else {
230 current_word << *pos;
231 pos++;
232 }
233 }
234
235 ICHECK(!pos_quoted) << "Mismatched quotes '' in string";
236
237 finish_word();
238
239 return output;
240}
241
242std::string TargetInternal::Interpret(const std::string& str) {
243 // String interpretation deals with quotes (') and escapes(\).
244 // - An escape character must be followed by another character forming an
245 // "escape sequence". (Trailing escape is not allowed.) An escape prevents
246 // interpretation of the character that follows. This happens regardless of
247 // whether the escape sequence appears within quoted substring or not.
248 // - A quote character, when interpreted, marks the beginning or the end of a
249 // quoted substring. (A quoted substring cannot contain unescaped quotes.)
250 // - Any other character, when interpreted, represents itself.
251 //
252 // Interpretation happens in two steps:
253 // 1. If the entire string is quoted, the quotes are removed first, and the
254 // resulting string is treated as unquoted.
255 // 2. Each character or escape sequence is interpreted, and the result is copied
256 // to the result. When not inside a quoted substring, the interpretation of an
257 // escape sequence is the escaped character, otherwise it is the entire escape
258 // sequence.
259 //
260 // Examples:
261 // blah -> blah Nothing happened
262 // 'blah' -> blah Enclosing quotes removed
263 // 'bl'ah -> 'bl'ah Non-enclosing quotes remain
264 // '\'blah\'' -> 'blah' Enclosing quotes removed, escaped quotes
265 // interpreted.
266 // '\'\\\'blah\\\'\'' -> '\'blah\'' Same as above.
267 //
268 // Note that
269 // '\'\\\'blah\\\'\'' -> '\'blah\'' -> 'blah'
270
271 std::string result;
272 if (str.empty()) {
273 return result;
274 }
275
276 // Check if the entire string is enclosed in quotes ''. If so, strip the quotes
277 // and treat the string as unquoted (so that escapes are interpreted). Doing that
278 // will allow '\'foo\'' to become 'foo', instead of \'foo\'.
279 std::string::size_type start = 0, end = str.size();
280 if (IsQuoted(str)) {
281 start++;
282 end--;
283 }
284
285 bool inside_quote = false;
286 bool escaping = false;
287
288 for (auto i = start, e = end; i < e; ++i) {
289 std::string::value_type c = str[i];
290 if (escaping) {
291 escaping = false;
292 } else if (c == escape) {
293 escaping = true;
294 if (!inside_quote) {
295 continue;
296 }
297 } else if (c == quote) {
298 inside_quote = !inside_quote;
299 }
300 result.push_back(c);
301 }
302
303 return result;
304}
305
306std::string TargetInternal::Uninterpret(const std::string& str) {
307 // Do the opposite to `Interpret`, so that Interpret(Uninterpret(str)) == str.
308 std::string result;
309
310 for (std::string::size_type i = 0, e = str.size(); i < e; ++i) {
311 std::string::value_type c = str[i];
312 if (c == escape || c == quote) {
313 result.push_back(escape);
314 }
315 result.push_back(c);
316 }
317
318 return result;
319}
320
321static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key,
322 std::string* value) {
323 std::string::size_type pos;
324 std::string& result_k = *key;
325 std::string& result_v = *value;
326 if ((pos = s.find_first_of('=')) != std::string::npos) {
327 // case 1. --key=value
328 result_k = s.substr(0, pos);
329 result_v = s.substr(pos + 1);
330 if (result_k.empty() || result_v.empty()) {
331 throw Error(": Empty attribute key or value in \"" + s + "\"");
332 }
333 return 1;
334 } else if (!s_next.empty() && s_next[0] != '-') {
335 // case 2. --key value
336 result_k = s;
337 result_v = s_next;
338 return 2;
339 }
340 // case 3. --boolean-key
341 result_k = s;
342 result_v = "1";
343 return 1;
344}
345
346const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKind& kind,
347 const std::string& key) {
348 auto it = kind->key2vtype_.find(key);
349 if (it == kind->key2vtype_.end()) {
350 std::ostringstream os;
351 os << ": Cannot recognize \'" << key << "\'. Candidates are: ";
352 bool is_first = true;
353 for (const auto& kv : kind->key2vtype_) {
354 if (is_first) {
355 is_first = false;
356 } else {
357 os << ", ";
358 }
359 os << kv.first;
360 }
361 throw Error(os.str());
362 }
363 return it->second;
364}
365
366/********** Parsing **********/
367
368ObjectRef TargetInternal::ParseType(const std::string& str,
369 const TargetKindNode::ValueTypeInfo& info) {
370 std::string interp_str = Interpret(str);
371 if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
372 // Parsing integer
373 std::istringstream is(interp_str);
374 int v;
375 if (!(is >> v)) {
376 std::string lower(interp_str.size(), '\x0');
377 std::transform(interp_str.begin(), interp_str.end(), lower.begin(),
378 [](unsigned char c) { return std::tolower(c); });
379 // Bool is a subclass of IntImm, so allow textual boolean values.
380 if (lower == "true") {
381 v = 1;
382 } else if (lower == "false") {
383 v = 0;
384 } else {
385 throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str);
386 }
387 }
388 return Integer(v);
389 } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
390 // Parsing string, strip leading/trailing spaces, and enclosing quotes if any
391 auto start = interp_str.find_first_not_of(' ');
392 auto end = interp_str.find_last_not_of(' ');
393 if (start == std::string::npos || end == std::string::npos) {
394 // The whole string is made of spaces.
395 return String();
396 }
397 return String(interp_str.substr(start, (end - start + 1)));
398
399 } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
400 // Parsing target
401 return Target(TargetInternal::FromString(interp_str));
402 } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
403 // Parsing array
404 std::vector<ObjectRef> result;
405 for (const std::string& substr : SplitString(interp_str, ',')) {
406 try {
407 ObjectRef parsed = TargetInternal::ParseType(substr, *info.key);
408 result.push_back(parsed);
409 } catch (const Error& e) {
410 std::string index = "[" + std::to_string(result.size()) + "]";
411 throw Error(index + e.what());
412 }
413 }
414 return Array<ObjectRef>(result);
415 }
416 throw Error(": Unsupported type \"" + info.type_key +
417 "\" for parsing from string: " + interp_str);
418}
419
420ObjectRef TargetInternal::ParseType(const ObjectRef& obj,
421 const TargetKindNode::ValueTypeInfo& info) {
422 if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
423 // Parsing integer
424 return GetRef<Integer>(ObjTypeCheck<IntImmNode>(obj, "Integer"));
425 } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
426 // Parsing string
427 return GetRef<String>(ObjTypeCheck<StringObj>(obj, "String"));
428 } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
429 // Parsing target
430 if (const auto* ptr = obj.as<TargetNode>()) {
431 return GetRef<Target>(ptr);
432 } else if (const auto* ptr = obj.as<StringObj>()) {
433 return Target(TargetInternal::FromString(GetRef<String>(ptr)));
434 } else if (const auto* ptr = obj.as<MapNode>()) {
435 for (const auto& kv : *ptr) {
436 if (!kv.first->IsInstance<StringObj>()) {
437 throw Error(": Target object requires key of dict to be str, but get: " +
438 kv.first->GetTypeKey());
439 }
440 }
441 Map<String, ObjectRef> config = GetRef<Map<String, ObjectRef>>(ptr);
442 return Target(TargetInternal::FromConfig({config.begin(), config.end()}));
443 }
444 throw Error(": Expect type 'dict' or 'str' to construct Target, but get: " + obj->GetTypeKey());
445 } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
446 // Parsing array
447 const auto* array = ObjTypeCheck<ArrayNode>(obj, "Array");
448 std::vector<ObjectRef> result;
449 for (const ObjectRef& e : *array) {
450 try {
451 result.push_back(TargetInternal::ParseType(e, *info.key));
452 } catch (const Error& e) {
453 std::string index = '[' + std::to_string(result.size()) + ']';
454 throw Error(index + e.what());
455 }
456 }
457 return Array<ObjectRef>(result);
458 } else if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) {
459 // Parsing map
460 const auto* map = ObjTypeCheck<MapNode>(obj, "Map");
461 std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> result;
462 for (const auto& kv : *map) {
463 ObjectRef key, val;
464 try {
465 key = TargetInternal::ParseType(kv.first, *info.key);
466 } catch (const Error& e) {
467 std::ostringstream os;
468 os << "'s key \"" << key << "\"" << e.what();
469 throw Error(os.str());
470 }
471 try {
472 val = TargetInternal::ParseType(kv.second, *info.val);
473 } catch (const Error& e) {
474 std::ostringstream os;
475 os << "[\"" << key << "\"]" << e.what();
476 throw Error(os.str());
477 }
478 result[key] = val;
479 }
480 return Map<ObjectRef, ObjectRef>(result);
481 }
482 if (info.type_index != obj->type_index()) {
483 std::ostringstream os;
484 os << ": Parsing type \"" << info.type_key
485 << "\" is not supported for the given object of type \"" << obj->GetTypeKey()
486 << "\". The object is: " << obj;
487 throw Error(os.str());
488 }
489 return obj;
490}
491
492/********** Stringifying **********/
493
494std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) {
495 if (const auto* p = obj.as<IntImmNode>()) {
496 return std::to_string(p->value);
497 }
498 if (const auto* p = obj.as<StringObj>()) {
499 auto s = static_cast<std::string>(GetRef<String>(p));
500 auto u = Uninterpret(s);
501 if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) {
502 u = Quote(u);
503 }
504 return u;
505 }
506 LOG(FATAL) << "Cannot stringify this object";
507}
508
509std::string TargetInternal::StringifyArray(const ArrayNode& array) {
510 std::vector<std::string> elements;
511
512 for (const ObjectRef& item : array) {
513 std::string s = StringifyAtomicType(item);
514 std::string u = Uninterpret(s);
515 if (u.find_first_of(',') != std::string::npos && !IsQuoted(u)) {
516 u = Quote(u);
517 }
518 elements.push_back(u);
519 }
520
521 return JoinString(elements, ',');
522}
523
524Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) {
525 std::ostringstream os;
526 std::vector<String> keys;
527 for (const auto& kv : attrs) {
528 keys.push_back(kv.first);
529 }
530 std::sort(keys.begin(), keys.end());
531 std::vector<std::string> result;
532
533 for (const auto& key : keys) {
534 const ObjectRef& obj = attrs[key];
535 std::string value;
536 if (const auto* array = obj.as<ArrayNode>()) {
537 value = String(StringifyArray(*array));
538 } else {
539 value = StringifyAtomicType(obj);
540 }
541 if (!value.empty()) {
542 result.push_back("-" + key + "=" + value);
543 }
544 }
545 return String(JoinString(result, ' '));
546}
547
548const std::string& TargetNode::str() const {
549 if (str_repr_.empty()) {
550 std::ostringstream os;
551 os << kind->name;
552 if (!this->keys.empty()) {
553 os << " -keys=";
554 bool is_first = true;
555 for (const String& s : keys) {
556 if (is_first) {
557 is_first = false;
558 } else {
559 os << ',';
560 }
561 os << s;
562 }
563 }
564 if (Optional<String> attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) {
565 os << ' ' << attrs_str.value();
566 }
567
568 str_repr_ = os.str();
569 }
570 return str_repr_;
571}
572
573/********** Small member methods **********/
574
575Target::Target(const String& tag_or_config_or_target_str) {
576 ObjectPtr<Object> target;
577 try {
578 target = TargetInternal::FromString(tag_or_config_or_target_str);
579 } catch (const Error& e) {
580 LOG(FATAL) << "ValueError" << e.what()
581 << ". Target creation from string failed: " << tag_or_config_or_target_str;
582 }
583 data_ = std::move(target);
584}
585
586Target::Target(const Map<String, ObjectRef>& config) {
587 ObjectPtr<Object> target;
588 try {
589 target = TargetInternal::FromConfig({config.begin(), config.end()});
590 } catch (const Error& e) {
591 LOG(FATAL) << "ValueError" << e.what()
592 << ". Target creation from config dict failed: " << config;
593 }
594 data_ = std::move(target);
595}
596
597Target::Target(Target target, Target host) {
598 ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get());
599 n->host = std::move(host);
600 data_ = std::move(n);
601}
602
603Target::Target(TargetKind kind, Optional<ObjectRef> host, String tag, Array<String> keys,
604 Map<String, ObjectRef> attrs) {
605 auto data = runtime::make_object<TargetNode>();
606 data->kind = std::move(kind);
607 data->host = std::move(host);
608 data->tag = std::move(tag);
609 data->keys = std::move(keys);
610 data->attrs = std::move(attrs);
611 data_ = std::move(data);
612}
613
614bool Target::IsExternalCodegen() const {
615 TargetKindAttrMap<Bool> is_external_codegen_map =
616 TargetKind::GetAttrMap<Bool>(tvm::attr::kIsExternalCodegen);
617 TargetKindAttrMap<FTVMRelayToTIR> relay_to_tir_map =
618 TargetKind::GetAttrMap<FTVMRelayToTIR>(tvm::attr::kRelayToTIR);
619 return is_external_codegen_map.get(get()->kind, Bool(false)) ||
620 relay_to_tir_map.count(get()->kind);
621}
622
623bool Target::IsExternalCodegenFor(const Target& that) const {
624 return get()->GetTargetDeviceType() == that->GetTargetDeviceType() && IsExternalCodegen() &&
625 !that.IsExternalCodegen();
626}
627
628std::vector<std::string> TargetNode::GetKeys() const {
629 std::vector<std::string> result;
630 for (auto& expr : keys) {
631 result.push_back(expr);
632 }
633 return result;
634}
635
636std::unordered_set<std::string> TargetNode::GetLibs() const {
637 Optional<Array<String>> libs = this->GetAttr<Array<String>>("libs");
638 if (!libs.defined()) {
639 return {};
640 }
641 std::unordered_set<std::string> result;
642 for (const auto& item : libs.value()) {
643 result.insert(item);
644 }
645 return result;
646}
647
648Map<String, ObjectRef> TargetNode::Export() const {
649 Map<String, ObjectRef> result = {
650 {"kind", this->kind->name},
651 {"tag", this->tag},
652 {"keys", this->keys},
653 };
654 if (this->host.defined()) {
655 result.Set("host", this->GetHost().value_or(Target())->Export());
656 }
657 for (const auto& kv : attrs) {
658 result.Set(kv.first, kv.second);
659 }
660 return result;
661}
662
663Optional<Target> TargetNode::GetHost() const {
664 return GetRef<Optional<Target>>(this->host.as<TargetNode>());
665}
666
667int TargetNode::GetTargetDeviceType() const {
668 if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
669 return Downcast<Integer>(device_type)->value;
670 }
671 return kind->default_device_type;
672}
673
674String TargetNode::ToDebugString() const {
675 std::ostringstream os;
676 os << "Target(";
677 os << "id=" << std::hex << reinterpret_cast<size_t>(this);
678 os << ", kind='" << kind->name << "'";
679 if (!tag.empty()) {
680 os << ", tag='" << tag << "'";
681 }
682 if (!keys.empty()) {
683 os << ", keys={";
684 bool first = true;
685 for (const auto& key : keys) {
686 if (!first) {
687 os << ", ";
688 }
689 os << "'" << key << "'";
690 first = false;
691 }
692 os << "}";
693 }
694 if (!attrs.empty()) {
695 os << ", attrs={";
696 bool first = true;
697 for (const auto& pair : attrs) {
698 if (!first) {
699 os << ", ";
700 }
701 os << "'" << pair.first << "': " << pair.second;
702 first = false;
703 }
704 os << "}";
705 }
706 if (host.defined()) {
707 os << ", host=" << GetHost().value()->ToDebugString();
708 }
709 os << ")";
710 return os.str();
711}
712
713bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const {
714 return equal(kind.get(), other->kind.get()) && equal(host, other->host) &&
715 equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs);
716}
717
718void TargetNode::SHashReduce(SHashReducer hash_reduce) const {
719 hash_reduce(kind.get());
720 hash_reduce(host);
721 hash_reduce(tag);
722 hash_reduce(keys);
723 hash_reduce(attrs);
724}
725
726/*! \brief Entry to hold the Target context stack. */
727struct TVMTargetThreadLocalEntry {
728 /*! \brief The current target context */
729 std::stack<Target> context_stack;
730};
731
732/*! \brief Thread local store to hold the Target context stack. */
733using TVMTargetThreadLocalStore = dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry>;
734
735void Target::EnterWithScope() {
736 TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
737 entry->context_stack.push(*this);
738}
739
740void Target::ExitWithScope() {
741 TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
742 ICHECK(!entry->context_stack.empty());
743 ICHECK(entry->context_stack.top().same_as(*this));
744 entry->context_stack.pop();
745}
746
747Target Target::Current(bool allow_not_defined) {
748 TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
749 if (entry->context_stack.size() > 0) {
750 return entry->context_stack.top();
751 }
752 ICHECK(allow_not_defined)
753 << "Target context required. Please set it by constructing a TargetContext";
754
755 return Target();
756}
757
758/********** Creation **********/
759
760void TargetInternal::ConstructorDispatcher(TVMArgs args, TVMRetValue* rv) {
761 if (args.num_args == 1) {
762 const auto& arg = args[0];
763 if (arg.IsObjectRef<Target>()) {
764 *rv = Target(arg.AsObjectRef<Target>());
765 } else if (String::CanConvertFrom(arg)) {
766 *rv = Target(arg.operator String());
767 } else if (arg.IsObjectRef<Map<String, ObjectRef>>()) {
768 *rv = Target(arg.operator Map<String, ObjectRef>());
769 } else if (arg.type_code() == kTVMObjectHandle) {
770 ObjectRef obj = arg;
771 LOG(FATAL) << "TypeError: Cannot create target with type: " << obj->GetTypeKey();
772 } else {
773 LOG(FATAL) << "TypeError: Cannot create target with type: "
774 << runtime::ArgTypeCode2Str(arg.type_code());
775 }
776 return;
777 } else if (args.num_args == 2) {
778 if (args[0].IsObjectRef<Target>() && args[1].IsObjectRef<Target>()) {
779 Target target = args[0];
780 Target host = args[1];
781 *rv = Target(target, host);
782 } else {
783 LOG(FATAL) << "ValueError: Invalid type of arguments. Expect 2 Target arguments.";
784 }
785 return;
786 }
787 LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: "
788 << args.num_args;
789}
790
791ObjectPtr<Object> TargetInternal::FromString(const String& tag_or_config_or_target_str) {
792 if (Optional<Target> target = TargetTag::Get(tag_or_config_or_target_str)) {
793 Target value = target.value();
794 return runtime::ObjectInternal::MoveObjectPtr(&value);
795 }
796 if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') {
797 return TargetInternal::FromConfigString(tag_or_config_or_target_str);
798 }
799 return TargetInternal::FromRawString(tag_or_config_or_target_str);
800}
801
802ObjectPtr<Object> TargetInternal::FromConfigString(const String& config_str) {
803 const auto* loader = tvm::runtime::Registry::Get("target._load_config_dict");
804 ICHECK(loader) << "AttributeError: \"target._load_config_dict\" is not registered. Please check "
805 "if the python module is properly loaded";
806 Optional<Map<String, ObjectRef>> config = (*loader)(config_str);
807 if (!config.defined()) {
808 throw Error(": Cannot load config dict with python JSON loader");
809 }
810 return TargetInternal::FromConfig({config.value().begin(), config.value().end()});
811}
812
813ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
814 ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string";
815 // Split the string by empty spaces
816 std::vector<std::string> options = SplitString(std::string(target_str), ' ');
817 std::string name = options[0];
818 // Create the target config
819 std::unordered_map<String, ObjectRef> config = {{"kind", String(name)}};
820 TargetKind kind = GetTargetKind(name);
821 for (size_t iter = 1, end = options.size(); iter < end;) {
822 std::string key, value;
823 try {
824 // Parse key-value pair
825 std::string s_next = (iter + 1 < options.size()) ? options[iter + 1] : "";
826 iter += ParseKVPair(RemovePrefixDashes(options[iter]), s_next, &key, &value);
827 } catch (const Error& e) {
828 throw Error(": Error when parsing target" + std::string(e.what()));
829 }
830 try {
831 // check if `key` has been used
832 if (config.count(key)) {
833 throw Error(": The key \"" + key + "\" appears more than once");
834 }
835 config[key] = TargetInternal::ParseType(value, TargetInternal::FindTypeInfo(kind, key));
836 } catch (const Error& e) {
837 throw Error(": Error when parsing target[\"" + key + "\"]" + e.what());
838 }
839 }
840 return TargetInternal::FromConfig(config);
841}
842
843ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) {
844 const String kKind = "kind";
845 const String kTag = "tag";
846 const String kKeys = "keys";
847 const String kDeviceName = "device";
848 const String kHost = "host";
849 const String kFeatures = "features";
850 ObjectPtr<TargetNode> target = make_object<TargetNode>();
851
852 ICHECK(!config.count(kFeatures)) << "Target Features should be generated by Target parser";
853
854 // parse 'kind'
855 if (config.count(kKind)) {
856 if (const auto* kind = config[kKind].as<StringObj>()) {
857 target->kind = GetTargetKind(GetRef<String>(kind));
858 ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr))
859 << "Cannot use both set_attrs_preprocessor and set_target_parser";
860
861 // Run JSON Parser over JSON input
862 if (target->kind->target_parser != nullptr) {
863 VLOG(9) << "TargetInternal::FromConfig - Running target_parser";
864 config = target->kind->target_parser(config);
865 if (config.count(kFeatures)) {
866 target->features = Downcast<Map<String, ObjectRef>>(config[kFeatures]);
867 config.erase(kFeatures);
868 }
869 }
870
871 config.erase(kKind);
872 } else {
873 throw Error(": Expect type of field \"kind\" is String, but get type: " +
874 config[kKind]->GetTypeKey());
875 }
876 } else {
877 throw Error(": Field \"kind\" is not found");
878 }
879 // parse "tag"
880 if (config.count(kTag)) {
881 if (const auto* tag = config[kTag].as<StringObj>()) {
882 target->tag = GetRef<String>(tag);
883 config.erase(kTag);
884 } else {
885 throw Error(": Expect type of field \"tag\" is String, but get type: " +
886 config[kTag]->GetTypeKey());
887 }
888 } else {
889 target->tag = "";
890 }
891 // parse "keys"
892 {
893 std::vector<String> keys;
894 bool has_user_keys = config.count(kKeys);
895 if (has_user_keys) {
896 // user provided keys
897 if (const auto* cfg_keys = config[kKeys].as<ArrayNode>()) {
898 for (const ObjectRef& e : *cfg_keys) {
899 if (const auto* key = e.as<StringObj>()) {
900 keys.push_back(GetRef<String>(key));
901 } else {
902 throw Error(
903 ": Expect 'keys' to be an array of strings, but it "
904 "contains an element of type: " +
905 e->GetTypeKey());
906 }
907 }
908 } else {
909 throw Error(": Expect type of field \"keys\" is Array, but get type: " +
910 config[kKeys]->GetTypeKey());
911 }
912 }
913 // add device name
914 if (config.count(kDeviceName)) {
915 if (const auto* device = config.at(kDeviceName).as<StringObj>()) {
916 keys.push_back(GetRef<String>(device));
917 }
918 }
919 if (!has_user_keys) {
920 // add default keys
921 for (const auto& key : target->kind->default_keys) {
922 keys.push_back(key);
923 }
924 }
925 // de-duplicate keys
926 target->keys = DeduplicateKeys(keys);
927 config.erase(kKeys);
928 }
929 // parse host
930 if (config.count(kHost)) {
931 target->host = PackedFunc(ConstructorDispatcher)(config[kHost]).AsObjectRef<Target>();
932 config.erase(kHost);
933 } else {
934 target->host = NullOpt;
935 }
936 // parse attrs
937 std::unordered_map<String, ObjectRef> attrs;
938 for (const auto& cfg_kv : config) {
939 const String& key = cfg_kv.first;
940 const ObjectRef& value = cfg_kv.second;
941 try {
942 const TargetKindNode::ValueTypeInfo& info = TargetInternal::FindTypeInfo(target->kind, key);
943 attrs[key] = TargetInternal::ParseType(value, info);
944 } catch (const Error& e) {
945 throw Error(": Error when parsing target[\"" + key + "\"]" + e.what());
946 }
947 }
948
949 // If requested, query attributes from the device. User-specified
950 // parameters take precedence over queried parameters.
951 if (attrs.count("from_device")) {
952 int device_id = Downcast<Integer>(attrs.at("from_device")).IntValue();
953 attrs.erase("from_device");
954 auto device_params = QueryDevice(device_id, target.get());
955
956 for (const auto& kv : device_params) {
957 if (attrs.count(kv.first) == 0) {
958 attrs[kv.first] = kv.second;
959 }
960 }
961 }
962
963 // set default attribute values if they do not exist
964 for (const auto& kv : target->kind->key2default_) {
965 if (!attrs.count(kv.first)) {
966 attrs[kv.first] = kv.second;
967 }
968 }
969 // do extra pre-processing
970 if (target->kind->preprocessor != nullptr) {
971 target->attrs = target->kind->preprocessor(Map<String, ObjectRef>(attrs));
972 } else {
973 target->attrs = attrs;
974 }
975
976 return target;
977} // namespace tvm
978
979std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
980 const TargetNode* target) {
981 std::unordered_map<String, ObjectRef> output;
982
983 Device device{static_cast<DLDeviceType>(target->GetTargetDeviceType()), device_id};
984
985 auto api = runtime::DeviceAPI::Get(device, true);
986 if (!api) {
987 LOG(INFO) << "Requested reading the parameters for " << target->kind->name << " from device_id "
988 << device_id << ", but support for this runtime wasn't enabled at compile-time. "
989 << "Using default target parameters.";
990 return output;
991 }
992
993 TVMRetValue ret;
994 api->GetAttr(device, runtime::kExist, &ret);
995 bool device_exists = ret;
996 if (!device_exists) {
997 ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name
998 << " from device_id " << device_id << ", but device_id " << device_id
999 << " doesn't exist. Using default target parameters.";
1000 return output;
1001 }
1002
1003 for (const auto& kv : target->kind->key2vtype_) {
1004 const String& key = kv.first;
1005 const TargetKindNode::ValueTypeInfo& type_info = kv.second;
1006
1007 TVMRetValue ret;
1008 api->GetTargetProperty(device, key, &ret);
1009
1010 switch (ret.type_code()) {
1011 case kTVMNullptr:
1012 // Nothing returned for this parameter, move on to the next one.
1013 continue;
1014
1015 case kTVMArgInt:
1016 if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
1017 output[key] = Integer(static_cast<int64_t>(ret));
1018 } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
1019 output[key] = Bool(static_cast<bool>(ret));
1020 } else {
1021 LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
1022 << "', but received integer from device api";
1023 }
1024 break;
1025
1026 case kTVMStr:
1027 ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
1028 << "Expected " << type_info.type_key << " parameter for attribute '" << key
1029 << "', but received string from device api";
1030 output[key] = String(ret.operator std::string());
1031 break;
1032
1033 default:
1034 LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
1035 << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api";
1036 break;
1037 }
1038 }
1039
1040 return output;
1041}
1042
1043/********** Registry **********/
1044
1045TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher);
1046TVM_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope);
1047TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope);
1048TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current);
1049TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export);
1050TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost);
1051TVM_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const Target& target) {
1052 return target->GetTargetDeviceType();
1053});
1054TVM_REGISTER_GLOBAL("target.TargetGetFeature")
1055 .set_body_typed([](const Target& target, const String& feature_key) {
1056 return target->GetFeature<ObjectRef>(feature_key);
1057 });
1058
1059TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
1060 .set_dispatch<TargetNode>([](const ObjectRef& obj, ReprPrinter* p) {
1061 p->stream << Downcast<Target>(obj)->str();
1062 });
1063
1064} // namespace tvm
1065