1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * Compile executable modules. |
21 | * \file src/target/target.cc |
22 | */ |
23 | #include <dmlc/thread_local.h> |
24 | #include <tvm/runtime/device_api.h> |
25 | #include <tvm/runtime/logging.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/target/tag.h> |
28 | #include <tvm/target/target.h> |
29 | #include <tvm/target/target_kind.h> |
30 | #include <tvm/tir/expr.h> |
31 | |
32 | #include <algorithm> |
33 | #include <cctype> |
34 | #include <ios> |
35 | #include <sstream> |
36 | #include <stack> |
37 | #include <string> |
38 | #include <unordered_map> |
39 | #include <utility> |
40 | #include <vector> |
41 | |
42 | #include "../runtime/object_internal.h" |
43 | |
44 | namespace tvm { |
45 | |
46 | TVM_REGISTER_NODE_TYPE(TargetNode); |
47 | |
48 | class TargetInternal { |
49 | public: |
50 | static void EnterScope(Target target) { target.EnterWithScope(); } |
51 | static void ExitScope(Target target) { target.ExitWithScope(); } |
52 | static Map<String, ObjectRef> Export(Target target) { return target->Export(); } |
53 | static const TargetKindNode::ValueTypeInfo& FindTypeInfo(const TargetKind& kind, |
54 | const std::string& key); |
55 | static Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs); |
56 | static ObjectRef ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info); |
57 | static ObjectRef ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info); |
58 | static ObjectPtr<Object> FromString(const String& tag_or_config_or_target_str); |
59 | static ObjectPtr<Object> FromConfigString(const String& config_str); |
60 | static ObjectPtr<Object> FromRawString(const String& target_str); |
61 | static ObjectPtr<Object> FromConfig(Map<String, ObjectRef> config); |
62 | static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv); |
63 | static Target WithHost(const Target& target, const Target& target_host) { |
64 | ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get()); |
65 | n->host = target_host; |
66 | return (Target)n; |
67 | } |
68 | |
69 | private: |
70 | static std::unordered_map<String, ObjectRef> QueryDevice(int device_id, const TargetNode* target); |
71 | static bool IsQuoted(const std::string& str); |
72 | static std::string Quote(const std::string& str); |
73 | static std::string JoinString(const std::vector<std::string>& array, char separator); |
74 | static std::vector<std::string> SplitString(const std::string& str, char separator); |
75 | static std::string Interpret(const std::string& str); |
76 | static std::string Uninterpret(const std::string& str); |
77 | static std::string StringifyAtomicType(const ObjectRef& obj); |
78 | static std::string StringifyArray(const ArrayNode& array); |
79 | |
80 | static constexpr char quote = '\''; |
81 | static constexpr char escape = '\\'; |
82 | }; |
83 | |
84 | /********** Helper functions **********/ |
85 | Target Target::WithHost(const Target& target, const Target& host) { |
86 | return TargetInternal::WithHost(target, host); |
87 | } |
88 | |
89 | void CheckAndUpdateHostConsistency(Target* target, Target* host) { |
90 | *target = Target(*target, *host); |
91 | *host = (*target)->GetHost().value_or(Target()); |
92 | } |
93 | |
94 | void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* host) { |
95 | Map<Target, IRModule> new_targets; |
96 | for (auto& it : *targets) { |
97 | auto target = it.first; |
98 | CheckAndUpdateHostConsistency(&target, host); |
99 | new_targets.Set(target, it.second); |
100 | } |
101 | *targets = new_targets; |
102 | } |
103 | |
104 | static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) { |
105 | std::vector<String> new_keys; |
106 | for (size_t i = 0; i < keys.size(); ++i) { |
107 | bool found = false; |
108 | for (size_t j = 0; j < i; ++j) { |
109 | if (keys[i] == keys[j]) { |
110 | found = true; |
111 | break; |
112 | } |
113 | } |
114 | if (!found) { |
115 | new_keys.push_back(keys[i]); |
116 | } |
117 | } |
118 | return new_keys; |
119 | } |
120 | |
121 | template <class TObj> |
122 | static const TObj* ObjTypeCheck(const ObjectRef& obj, const std::string& expected_type) { |
123 | const TObj* ptr = obj.as<TObj>(); |
124 | if (ptr == nullptr) { |
125 | std::ostringstream os; |
126 | os << ": Expects type \"" << expected_type << "\", but gets \"" << obj->GetTypeKey() |
127 | << "\" for object: " << obj; |
128 | throw Error(os.str()); |
129 | } |
130 | return ptr; |
131 | } |
132 | |
133 | static TargetKind GetTargetKind(const String& name) { |
134 | Optional<TargetKind> kind = TargetKind::Get(name); |
135 | if (!kind.defined()) { |
136 | throw Error(": Target kind \"" + name + "\" is not defined" ); |
137 | } |
138 | return kind.value(); |
139 | } |
140 | |
141 | static std::string RemovePrefixDashes(const std::string& s) { |
142 | int n_dashes = 0; |
143 | int len = s.length(); |
144 | for (; n_dashes < len && s[n_dashes] == '-'; ++n_dashes) { |
145 | } |
146 | if (n_dashes == 0) { |
147 | throw Error(": Attribute keys should start with '-', not an attribute key: " + s); |
148 | } |
149 | if (n_dashes >= len) { |
150 | throw Error(": Not an attribute key: " + s); |
151 | } |
152 | return s.substr(n_dashes); |
153 | } |
154 | |
155 | bool TargetInternal::IsQuoted(const std::string& str) { |
156 | std::string::size_type start = 0, end = str.size(); |
157 | if (end < 2 || str[start] != quote || str[end - 1] != quote) { |
158 | return false; |
159 | } |
160 | bool escaping = false; |
161 | for (auto i = start + 1, e = end - 1; i < e; ++i) { |
162 | if (escaping) { |
163 | escaping = false; |
164 | } else if (str[i] == escape) { |
165 | escaping = true; |
166 | } else if (str[i] == quote) { |
167 | return false; |
168 | } |
169 | } |
170 | // If the reduced string ends with \, then the terminating quote is escaped. |
171 | return !escaping; |
172 | } |
173 | |
174 | std::string TargetInternal::Quote(const std::string& str) { |
175 | std::string result(1, quote); |
176 | result.append(str); |
177 | result.push_back(quote); |
178 | return result; |
179 | } |
180 | |
181 | std::string TargetInternal::JoinString(const std::vector<std::string>& array, char separator) { |
182 | std::string result; |
183 | ICHECK(separator != quote && separator != escape) |
184 | << "string join separator cannot be " << quote << " or " << escape; |
185 | |
186 | bool is_first = true; |
187 | for (const auto& s : array) { |
188 | if (!is_first) { |
189 | result.push_back(separator); |
190 | } |
191 | result.append(s); |
192 | is_first = false; |
193 | } |
194 | |
195 | return result; |
196 | } |
197 | |
198 | std::vector<std::string> TargetInternal::SplitString(const std::string& str, char separator) { |
199 | std::vector<std::string> output; |
200 | |
201 | const char* start = str.data(); |
202 | const char* end = start + str.size(); |
203 | const char* pos = start; |
204 | |
205 | std::stringstream current_word; |
206 | |
207 | auto finish_word = [&]() { |
208 | std::string word = current_word.str(); |
209 | if (word.size()) { |
210 | output.push_back(word); |
211 | current_word.str("" ); |
212 | } |
213 | }; |
214 | |
215 | bool pos_quoted = false; |
216 | |
217 | while (pos < end) { |
218 | if ((*pos == separator) && !pos_quoted) { |
219 | finish_word(); |
220 | pos++; |
221 | } else if (*pos == escape && pos + 1 < end) { |
222 | current_word << escape; |
223 | current_word << pos[1]; |
224 | pos += 2; |
225 | } else if (*pos == quote) { |
226 | current_word << quote; |
227 | pos_quoted = !pos_quoted; |
228 | pos++; |
229 | } else { |
230 | current_word << *pos; |
231 | pos++; |
232 | } |
233 | } |
234 | |
235 | ICHECK(!pos_quoted) << "Mismatched quotes '' in string" ; |
236 | |
237 | finish_word(); |
238 | |
239 | return output; |
240 | } |
241 | |
242 | std::string TargetInternal::Interpret(const std::string& str) { |
243 | // String interpretation deals with quotes (') and escapes(\). |
244 | // - An escape character must be followed by another character forming an |
245 | // "escape sequence". (Trailing escape is not allowed.) An escape prevents |
246 | // interpretation of the character that follows. This happens regardless of |
247 | // whether the escape sequence appears within quoted substring or not. |
248 | // - A quote character, when interpreted, marks the beginning or the end of a |
249 | // quoted substring. (A quoted substring cannot contain unescaped quotes.) |
250 | // - Any other character, when interpreted, represents itself. |
251 | // |
252 | // Interpretation happens in two steps: |
253 | // 1. If the entire string is quoted, the quotes are removed first, and the |
254 | // resulting string is treated as unquoted. |
255 | // 2. Each character or escape sequence is interpreted, and the result is copied |
256 | // to the result. When not inside a quoted substring, the interpretation of an |
257 | // escape sequence is the escaped character, otherwise it is the entire escape |
258 | // sequence. |
259 | // |
260 | // Examples: |
261 | // blah -> blah Nothing happened |
262 | // 'blah' -> blah Enclosing quotes removed |
263 | // 'bl'ah -> 'bl'ah Non-enclosing quotes remain |
264 | // '\'blah\'' -> 'blah' Enclosing quotes removed, escaped quotes |
265 | // interpreted. |
266 | // '\'\\\'blah\\\'\'' -> '\'blah\'' Same as above. |
267 | // |
268 | // Note that |
269 | // '\'\\\'blah\\\'\'' -> '\'blah\'' -> 'blah' |
270 | |
271 | std::string result; |
272 | if (str.empty()) { |
273 | return result; |
274 | } |
275 | |
276 | // Check if the entire string is enclosed in quotes ''. If so, strip the quotes |
277 | // and treat the string as unquoted (so that escapes are interpreted). Doing that |
278 | // will allow '\'foo\'' to become 'foo', instead of \'foo\'. |
279 | std::string::size_type start = 0, end = str.size(); |
280 | if (IsQuoted(str)) { |
281 | start++; |
282 | end--; |
283 | } |
284 | |
285 | bool inside_quote = false; |
286 | bool escaping = false; |
287 | |
288 | for (auto i = start, e = end; i < e; ++i) { |
289 | std::string::value_type c = str[i]; |
290 | if (escaping) { |
291 | escaping = false; |
292 | } else if (c == escape) { |
293 | escaping = true; |
294 | if (!inside_quote) { |
295 | continue; |
296 | } |
297 | } else if (c == quote) { |
298 | inside_quote = !inside_quote; |
299 | } |
300 | result.push_back(c); |
301 | } |
302 | |
303 | return result; |
304 | } |
305 | |
306 | std::string TargetInternal::Uninterpret(const std::string& str) { |
307 | // Do the opposite to `Interpret`, so that Interpret(Uninterpret(str)) == str. |
308 | std::string result; |
309 | |
310 | for (std::string::size_type i = 0, e = str.size(); i < e; ++i) { |
311 | std::string::value_type c = str[i]; |
312 | if (c == escape || c == quote) { |
313 | result.push_back(escape); |
314 | } |
315 | result.push_back(c); |
316 | } |
317 | |
318 | return result; |
319 | } |
320 | |
321 | static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key, |
322 | std::string* value) { |
323 | std::string::size_type pos; |
324 | std::string& result_k = *key; |
325 | std::string& result_v = *value; |
326 | if ((pos = s.find_first_of('=')) != std::string::npos) { |
327 | // case 1. --key=value |
328 | result_k = s.substr(0, pos); |
329 | result_v = s.substr(pos + 1); |
330 | if (result_k.empty() || result_v.empty()) { |
331 | throw Error(": Empty attribute key or value in \"" + s + "\"" ); |
332 | } |
333 | return 1; |
334 | } else if (!s_next.empty() && s_next[0] != '-') { |
335 | // case 2. --key value |
336 | result_k = s; |
337 | result_v = s_next; |
338 | return 2; |
339 | } |
340 | // case 3. --boolean-key |
341 | result_k = s; |
342 | result_v = "1" ; |
343 | return 1; |
344 | } |
345 | |
346 | const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKind& kind, |
347 | const std::string& key) { |
348 | auto it = kind->key2vtype_.find(key); |
349 | if (it == kind->key2vtype_.end()) { |
350 | std::ostringstream os; |
351 | os << ": Cannot recognize \'" << key << "\'. Candidates are: " ; |
352 | bool is_first = true; |
353 | for (const auto& kv : kind->key2vtype_) { |
354 | if (is_first) { |
355 | is_first = false; |
356 | } else { |
357 | os << ", " ; |
358 | } |
359 | os << kv.first; |
360 | } |
361 | throw Error(os.str()); |
362 | } |
363 | return it->second; |
364 | } |
365 | |
366 | /********** Parsing **********/ |
367 | |
368 | ObjectRef TargetInternal::ParseType(const std::string& str, |
369 | const TargetKindNode::ValueTypeInfo& info) { |
370 | std::string interp_str = Interpret(str); |
371 | if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
372 | // Parsing integer |
373 | std::istringstream is(interp_str); |
374 | int v; |
375 | if (!(is >> v)) { |
376 | std::string lower(interp_str.size(), '\x0'); |
377 | std::transform(interp_str.begin(), interp_str.end(), lower.begin(), |
378 | [](unsigned char c) { return std::tolower(c); }); |
379 | // Bool is a subclass of IntImm, so allow textual boolean values. |
380 | if (lower == "true" ) { |
381 | v = 1; |
382 | } else if (lower == "false" ) { |
383 | v = 0; |
384 | } else { |
385 | throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); |
386 | } |
387 | } |
388 | return Integer(v); |
389 | } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
390 | // Parsing string, strip leading/trailing spaces, and enclosing quotes if any |
391 | auto start = interp_str.find_first_not_of(' '); |
392 | auto end = interp_str.find_last_not_of(' '); |
393 | if (start == std::string::npos || end == std::string::npos) { |
394 | // The whole string is made of spaces. |
395 | return String(); |
396 | } |
397 | return String(interp_str.substr(start, (end - start + 1))); |
398 | |
399 | } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
400 | // Parsing target |
401 | return Target(TargetInternal::FromString(interp_str)); |
402 | } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { |
403 | // Parsing array |
404 | std::vector<ObjectRef> result; |
405 | for (const std::string& substr : SplitString(interp_str, ',')) { |
406 | try { |
407 | ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); |
408 | result.push_back(parsed); |
409 | } catch (const Error& e) { |
410 | std::string index = "[" + std::to_string(result.size()) + "]" ; |
411 | throw Error(index + e.what()); |
412 | } |
413 | } |
414 | return Array<ObjectRef>(result); |
415 | } |
416 | throw Error(": Unsupported type \"" + info.type_key + |
417 | "\" for parsing from string: " + interp_str); |
418 | } |
419 | |
420 | ObjectRef TargetInternal::ParseType(const ObjectRef& obj, |
421 | const TargetKindNode::ValueTypeInfo& info) { |
422 | if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
423 | // Parsing integer |
424 | return GetRef<Integer>(ObjTypeCheck<IntImmNode>(obj, "Integer" )); |
425 | } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
426 | // Parsing string |
427 | return GetRef<String>(ObjTypeCheck<StringObj>(obj, "String" )); |
428 | } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
429 | // Parsing target |
430 | if (const auto* ptr = obj.as<TargetNode>()) { |
431 | return GetRef<Target>(ptr); |
432 | } else if (const auto* ptr = obj.as<StringObj>()) { |
433 | return Target(TargetInternal::FromString(GetRef<String>(ptr))); |
434 | } else if (const auto* ptr = obj.as<MapNode>()) { |
435 | for (const auto& kv : *ptr) { |
436 | if (!kv.first->IsInstance<StringObj>()) { |
437 | throw Error(": Target object requires key of dict to be str, but get: " + |
438 | kv.first->GetTypeKey()); |
439 | } |
440 | } |
441 | Map<String, ObjectRef> config = GetRef<Map<String, ObjectRef>>(ptr); |
442 | return Target(TargetInternal::FromConfig({config.begin(), config.end()})); |
443 | } |
444 | throw Error(": Expect type 'dict' or 'str' to construct Target, but get: " + obj->GetTypeKey()); |
445 | } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { |
446 | // Parsing array |
447 | const auto* array = ObjTypeCheck<ArrayNode>(obj, "Array" ); |
448 | std::vector<ObjectRef> result; |
449 | for (const ObjectRef& e : *array) { |
450 | try { |
451 | result.push_back(TargetInternal::ParseType(e, *info.key)); |
452 | } catch (const Error& e) { |
453 | std::string index = '[' + std::to_string(result.size()) + ']'; |
454 | throw Error(index + e.what()); |
455 | } |
456 | } |
457 | return Array<ObjectRef>(result); |
458 | } else if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) { |
459 | // Parsing map |
460 | const auto* map = ObjTypeCheck<MapNode>(obj, "Map" ); |
461 | std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> result; |
462 | for (const auto& kv : *map) { |
463 | ObjectRef key, val; |
464 | try { |
465 | key = TargetInternal::ParseType(kv.first, *info.key); |
466 | } catch (const Error& e) { |
467 | std::ostringstream os; |
468 | os << "'s key \"" << key << "\"" << e.what(); |
469 | throw Error(os.str()); |
470 | } |
471 | try { |
472 | val = TargetInternal::ParseType(kv.second, *info.val); |
473 | } catch (const Error& e) { |
474 | std::ostringstream os; |
475 | os << "[\"" << key << "\"]" << e.what(); |
476 | throw Error(os.str()); |
477 | } |
478 | result[key] = val; |
479 | } |
480 | return Map<ObjectRef, ObjectRef>(result); |
481 | } |
482 | if (info.type_index != obj->type_index()) { |
483 | std::ostringstream os; |
484 | os << ": Parsing type \"" << info.type_key |
485 | << "\" is not supported for the given object of type \"" << obj->GetTypeKey() |
486 | << "\". The object is: " << obj; |
487 | throw Error(os.str()); |
488 | } |
489 | return obj; |
490 | } |
491 | |
492 | /********** Stringifying **********/ |
493 | |
494 | std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { |
495 | if (const auto* p = obj.as<IntImmNode>()) { |
496 | return std::to_string(p->value); |
497 | } |
498 | if (const auto* p = obj.as<StringObj>()) { |
499 | auto s = static_cast<std::string>(GetRef<String>(p)); |
500 | auto u = Uninterpret(s); |
501 | if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) { |
502 | u = Quote(u); |
503 | } |
504 | return u; |
505 | } |
506 | LOG(FATAL) << "Cannot stringify this object" ; |
507 | } |
508 | |
509 | std::string TargetInternal::StringifyArray(const ArrayNode& array) { |
510 | std::vector<std::string> elements; |
511 | |
512 | for (const ObjectRef& item : array) { |
513 | std::string s = StringifyAtomicType(item); |
514 | std::string u = Uninterpret(s); |
515 | if (u.find_first_of(',') != std::string::npos && !IsQuoted(u)) { |
516 | u = Quote(u); |
517 | } |
518 | elements.push_back(u); |
519 | } |
520 | |
521 | return JoinString(elements, ','); |
522 | } |
523 | |
524 | Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) { |
525 | std::ostringstream os; |
526 | std::vector<String> keys; |
527 | for (const auto& kv : attrs) { |
528 | keys.push_back(kv.first); |
529 | } |
530 | std::sort(keys.begin(), keys.end()); |
531 | std::vector<std::string> result; |
532 | |
533 | for (const auto& key : keys) { |
534 | const ObjectRef& obj = attrs[key]; |
535 | std::string value; |
536 | if (const auto* array = obj.as<ArrayNode>()) { |
537 | value = String(StringifyArray(*array)); |
538 | } else { |
539 | value = StringifyAtomicType(obj); |
540 | } |
541 | if (!value.empty()) { |
542 | result.push_back("-" + key + "=" + value); |
543 | } |
544 | } |
545 | return String(JoinString(result, ' ')); |
546 | } |
547 | |
548 | const std::string& TargetNode::str() const { |
549 | if (str_repr_.empty()) { |
550 | std::ostringstream os; |
551 | os << kind->name; |
552 | if (!this->keys.empty()) { |
553 | os << " -keys=" ; |
554 | bool is_first = true; |
555 | for (const String& s : keys) { |
556 | if (is_first) { |
557 | is_first = false; |
558 | } else { |
559 | os << ','; |
560 | } |
561 | os << s; |
562 | } |
563 | } |
564 | if (Optional<String> attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { |
565 | os << ' ' << attrs_str.value(); |
566 | } |
567 | |
568 | str_repr_ = os.str(); |
569 | } |
570 | return str_repr_; |
571 | } |
572 | |
573 | /********** Small member methods **********/ |
574 | |
575 | Target::Target(const String& tag_or_config_or_target_str) { |
576 | ObjectPtr<Object> target; |
577 | try { |
578 | target = TargetInternal::FromString(tag_or_config_or_target_str); |
579 | } catch (const Error& e) { |
580 | LOG(FATAL) << "ValueError" << e.what() |
581 | << ". Target creation from string failed: " << tag_or_config_or_target_str; |
582 | } |
583 | data_ = std::move(target); |
584 | } |
585 | |
586 | Target::Target(const Map<String, ObjectRef>& config) { |
587 | ObjectPtr<Object> target; |
588 | try { |
589 | target = TargetInternal::FromConfig({config.begin(), config.end()}); |
590 | } catch (const Error& e) { |
591 | LOG(FATAL) << "ValueError" << e.what() |
592 | << ". Target creation from config dict failed: " << config; |
593 | } |
594 | data_ = std::move(target); |
595 | } |
596 | |
597 | Target::Target(Target target, Target host) { |
598 | ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get()); |
599 | n->host = std::move(host); |
600 | data_ = std::move(n); |
601 | } |
602 | |
603 | Target::Target(TargetKind kind, Optional<ObjectRef> host, String tag, Array<String> keys, |
604 | Map<String, ObjectRef> attrs) { |
605 | auto data = runtime::make_object<TargetNode>(); |
606 | data->kind = std::move(kind); |
607 | data->host = std::move(host); |
608 | data->tag = std::move(tag); |
609 | data->keys = std::move(keys); |
610 | data->attrs = std::move(attrs); |
611 | data_ = std::move(data); |
612 | } |
613 | |
614 | bool Target::IsExternalCodegen() const { |
615 | TargetKindAttrMap<Bool> is_external_codegen_map = |
616 | TargetKind::GetAttrMap<Bool>(tvm::attr::kIsExternalCodegen); |
617 | TargetKindAttrMap<FTVMRelayToTIR> relay_to_tir_map = |
618 | TargetKind::GetAttrMap<FTVMRelayToTIR>(tvm::attr::kRelayToTIR); |
619 | return is_external_codegen_map.get(get()->kind, Bool(false)) || |
620 | relay_to_tir_map.count(get()->kind); |
621 | } |
622 | |
623 | bool Target::IsExternalCodegenFor(const Target& that) const { |
624 | return get()->GetTargetDeviceType() == that->GetTargetDeviceType() && IsExternalCodegen() && |
625 | !that.IsExternalCodegen(); |
626 | } |
627 | |
628 | std::vector<std::string> TargetNode::GetKeys() const { |
629 | std::vector<std::string> result; |
630 | for (auto& expr : keys) { |
631 | result.push_back(expr); |
632 | } |
633 | return result; |
634 | } |
635 | |
636 | std::unordered_set<std::string> TargetNode::GetLibs() const { |
637 | Optional<Array<String>> libs = this->GetAttr<Array<String>>("libs" ); |
638 | if (!libs.defined()) { |
639 | return {}; |
640 | } |
641 | std::unordered_set<std::string> result; |
642 | for (const auto& item : libs.value()) { |
643 | result.insert(item); |
644 | } |
645 | return result; |
646 | } |
647 | |
648 | Map<String, ObjectRef> TargetNode::Export() const { |
649 | Map<String, ObjectRef> result = { |
650 | {"kind" , this->kind->name}, |
651 | {"tag" , this->tag}, |
652 | {"keys" , this->keys}, |
653 | }; |
654 | if (this->host.defined()) { |
655 | result.Set("host" , this->GetHost().value_or(Target())->Export()); |
656 | } |
657 | for (const auto& kv : attrs) { |
658 | result.Set(kv.first, kv.second); |
659 | } |
660 | return result; |
661 | } |
662 | |
663 | Optional<Target> TargetNode::GetHost() const { |
664 | return GetRef<Optional<Target>>(this->host.as<TargetNode>()); |
665 | } |
666 | |
667 | int TargetNode::GetTargetDeviceType() const { |
668 | if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type" )) { |
669 | return Downcast<Integer>(device_type)->value; |
670 | } |
671 | return kind->default_device_type; |
672 | } |
673 | |
674 | String TargetNode::ToDebugString() const { |
675 | std::ostringstream os; |
676 | os << "Target(" ; |
677 | os << "id=" << std::hex << reinterpret_cast<size_t>(this); |
678 | os << ", kind='" << kind->name << "'" ; |
679 | if (!tag.empty()) { |
680 | os << ", tag='" << tag << "'" ; |
681 | } |
682 | if (!keys.empty()) { |
683 | os << ", keys={" ; |
684 | bool first = true; |
685 | for (const auto& key : keys) { |
686 | if (!first) { |
687 | os << ", " ; |
688 | } |
689 | os << "'" << key << "'" ; |
690 | first = false; |
691 | } |
692 | os << "}" ; |
693 | } |
694 | if (!attrs.empty()) { |
695 | os << ", attrs={" ; |
696 | bool first = true; |
697 | for (const auto& pair : attrs) { |
698 | if (!first) { |
699 | os << ", " ; |
700 | } |
701 | os << "'" << pair.first << "': " << pair.second; |
702 | first = false; |
703 | } |
704 | os << "}" ; |
705 | } |
706 | if (host.defined()) { |
707 | os << ", host=" << GetHost().value()->ToDebugString(); |
708 | } |
709 | os << ")" ; |
710 | return os.str(); |
711 | } |
712 | |
713 | bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { |
714 | return equal(kind.get(), other->kind.get()) && equal(host, other->host) && |
715 | equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); |
716 | } |
717 | |
718 | void TargetNode::SHashReduce(SHashReducer hash_reduce) const { |
719 | hash_reduce(kind.get()); |
720 | hash_reduce(host); |
721 | hash_reduce(tag); |
722 | hash_reduce(keys); |
723 | hash_reduce(attrs); |
724 | } |
725 | |
726 | /*! \brief Entry to hold the Target context stack. */ |
727 | struct TVMTargetThreadLocalEntry { |
728 | /*! \brief The current target context */ |
729 | std::stack<Target> context_stack; |
730 | }; |
731 | |
732 | /*! \brief Thread local store to hold the Target context stack. */ |
733 | using TVMTargetThreadLocalStore = dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry>; |
734 | |
735 | void Target::EnterWithScope() { |
736 | TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
737 | entry->context_stack.push(*this); |
738 | } |
739 | |
740 | void Target::ExitWithScope() { |
741 | TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
742 | ICHECK(!entry->context_stack.empty()); |
743 | ICHECK(entry->context_stack.top().same_as(*this)); |
744 | entry->context_stack.pop(); |
745 | } |
746 | |
747 | Target Target::Current(bool allow_not_defined) { |
748 | TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); |
749 | if (entry->context_stack.size() > 0) { |
750 | return entry->context_stack.top(); |
751 | } |
752 | ICHECK(allow_not_defined) |
753 | << "Target context required. Please set it by constructing a TargetContext" ; |
754 | |
755 | return Target(); |
756 | } |
757 | |
758 | /********** Creation **********/ |
759 | |
760 | void TargetInternal::ConstructorDispatcher(TVMArgs args, TVMRetValue* rv) { |
761 | if (args.num_args == 1) { |
762 | const auto& arg = args[0]; |
763 | if (arg.IsObjectRef<Target>()) { |
764 | *rv = Target(arg.AsObjectRef<Target>()); |
765 | } else if (String::CanConvertFrom(arg)) { |
766 | *rv = Target(arg.operator String()); |
767 | } else if (arg.IsObjectRef<Map<String, ObjectRef>>()) { |
768 | *rv = Target(arg.operator Map<String, ObjectRef>()); |
769 | } else if (arg.type_code() == kTVMObjectHandle) { |
770 | ObjectRef obj = arg; |
771 | LOG(FATAL) << "TypeError: Cannot create target with type: " << obj->GetTypeKey(); |
772 | } else { |
773 | LOG(FATAL) << "TypeError: Cannot create target with type: " |
774 | << runtime::ArgTypeCode2Str(arg.type_code()); |
775 | } |
776 | return; |
777 | } else if (args.num_args == 2) { |
778 | if (args[0].IsObjectRef<Target>() && args[1].IsObjectRef<Target>()) { |
779 | Target target = args[0]; |
780 | Target host = args[1]; |
781 | *rv = Target(target, host); |
782 | } else { |
783 | LOG(FATAL) << "ValueError: Invalid type of arguments. Expect 2 Target arguments." ; |
784 | } |
785 | return; |
786 | } |
787 | LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: " |
788 | << args.num_args; |
789 | } |
790 | |
791 | ObjectPtr<Object> TargetInternal::FromString(const String& tag_or_config_or_target_str) { |
792 | if (Optional<Target> target = TargetTag::Get(tag_or_config_or_target_str)) { |
793 | Target value = target.value(); |
794 | return runtime::ObjectInternal::MoveObjectPtr(&value); |
795 | } |
796 | if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') { |
797 | return TargetInternal::FromConfigString(tag_or_config_or_target_str); |
798 | } |
799 | return TargetInternal::FromRawString(tag_or_config_or_target_str); |
800 | } |
801 | |
802 | ObjectPtr<Object> TargetInternal::FromConfigString(const String& config_str) { |
803 | const auto* loader = tvm::runtime::Registry::Get("target._load_config_dict" ); |
804 | ICHECK(loader) << "AttributeError: \"target._load_config_dict\" is not registered. Please check " |
805 | "if the python module is properly loaded" ; |
806 | Optional<Map<String, ObjectRef>> config = (*loader)(config_str); |
807 | if (!config.defined()) { |
808 | throw Error(": Cannot load config dict with python JSON loader" ); |
809 | } |
810 | return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); |
811 | } |
812 | |
813 | ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) { |
814 | ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string" ; |
815 | // Split the string by empty spaces |
816 | std::vector<std::string> options = SplitString(std::string(target_str), ' '); |
817 | std::string name = options[0]; |
818 | // Create the target config |
819 | std::unordered_map<String, ObjectRef> config = {{"kind" , String(name)}}; |
820 | TargetKind kind = GetTargetKind(name); |
821 | for (size_t iter = 1, end = options.size(); iter < end;) { |
822 | std::string key, value; |
823 | try { |
824 | // Parse key-value pair |
825 | std::string s_next = (iter + 1 < options.size()) ? options[iter + 1] : "" ; |
826 | iter += ParseKVPair(RemovePrefixDashes(options[iter]), s_next, &key, &value); |
827 | } catch (const Error& e) { |
828 | throw Error(": Error when parsing target" + std::string(e.what())); |
829 | } |
830 | try { |
831 | // check if `key` has been used |
832 | if (config.count(key)) { |
833 | throw Error(": The key \"" + key + "\" appears more than once" ); |
834 | } |
835 | config[key] = TargetInternal::ParseType(value, TargetInternal::FindTypeInfo(kind, key)); |
836 | } catch (const Error& e) { |
837 | throw Error(": Error when parsing target[\"" + key + "\"]" + e.what()); |
838 | } |
839 | } |
840 | return TargetInternal::FromConfig(config); |
841 | } |
842 | |
843 | ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) { |
844 | const String kKind = "kind" ; |
845 | const String kTag = "tag" ; |
846 | const String kKeys = "keys" ; |
847 | const String kDeviceName = "device" ; |
848 | const String kHost = "host" ; |
849 | const String kFeatures = "features" ; |
850 | ObjectPtr<TargetNode> target = make_object<TargetNode>(); |
851 | |
852 | ICHECK(!config.count(kFeatures)) << "Target Features should be generated by Target parser" ; |
853 | |
854 | // parse 'kind' |
855 | if (config.count(kKind)) { |
856 | if (const auto* kind = config[kKind].as<StringObj>()) { |
857 | target->kind = GetTargetKind(GetRef<String>(kind)); |
858 | ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr)) |
859 | << "Cannot use both set_attrs_preprocessor and set_target_parser" ; |
860 | |
861 | // Run JSON Parser over JSON input |
862 | if (target->kind->target_parser != nullptr) { |
863 | VLOG(9) << "TargetInternal::FromConfig - Running target_parser" ; |
864 | config = target->kind->target_parser(config); |
865 | if (config.count(kFeatures)) { |
866 | target->features = Downcast<Map<String, ObjectRef>>(config[kFeatures]); |
867 | config.erase(kFeatures); |
868 | } |
869 | } |
870 | |
871 | config.erase(kKind); |
872 | } else { |
873 | throw Error(": Expect type of field \"kind\" is String, but get type: " + |
874 | config[kKind]->GetTypeKey()); |
875 | } |
876 | } else { |
877 | throw Error(": Field \"kind\" is not found" ); |
878 | } |
879 | // parse "tag" |
880 | if (config.count(kTag)) { |
881 | if (const auto* tag = config[kTag].as<StringObj>()) { |
882 | target->tag = GetRef<String>(tag); |
883 | config.erase(kTag); |
884 | } else { |
885 | throw Error(": Expect type of field \"tag\" is String, but get type: " + |
886 | config[kTag]->GetTypeKey()); |
887 | } |
888 | } else { |
889 | target->tag = "" ; |
890 | } |
891 | // parse "keys" |
892 | { |
893 | std::vector<String> keys; |
894 | bool has_user_keys = config.count(kKeys); |
895 | if (has_user_keys) { |
896 | // user provided keys |
897 | if (const auto* cfg_keys = config[kKeys].as<ArrayNode>()) { |
898 | for (const ObjectRef& e : *cfg_keys) { |
899 | if (const auto* key = e.as<StringObj>()) { |
900 | keys.push_back(GetRef<String>(key)); |
901 | } else { |
902 | throw Error( |
903 | ": Expect 'keys' to be an array of strings, but it " |
904 | "contains an element of type: " + |
905 | e->GetTypeKey()); |
906 | } |
907 | } |
908 | } else { |
909 | throw Error(": Expect type of field \"keys\" is Array, but get type: " + |
910 | config[kKeys]->GetTypeKey()); |
911 | } |
912 | } |
913 | // add device name |
914 | if (config.count(kDeviceName)) { |
915 | if (const auto* device = config.at(kDeviceName).as<StringObj>()) { |
916 | keys.push_back(GetRef<String>(device)); |
917 | } |
918 | } |
919 | if (!has_user_keys) { |
920 | // add default keys |
921 | for (const auto& key : target->kind->default_keys) { |
922 | keys.push_back(key); |
923 | } |
924 | } |
925 | // de-duplicate keys |
926 | target->keys = DeduplicateKeys(keys); |
927 | config.erase(kKeys); |
928 | } |
929 | // parse host |
930 | if (config.count(kHost)) { |
931 | target->host = PackedFunc(ConstructorDispatcher)(config[kHost]).AsObjectRef<Target>(); |
932 | config.erase(kHost); |
933 | } else { |
934 | target->host = NullOpt; |
935 | } |
936 | // parse attrs |
937 | std::unordered_map<String, ObjectRef> attrs; |
938 | for (const auto& cfg_kv : config) { |
939 | const String& key = cfg_kv.first; |
940 | const ObjectRef& value = cfg_kv.second; |
941 | try { |
942 | const TargetKindNode::ValueTypeInfo& info = TargetInternal::FindTypeInfo(target->kind, key); |
943 | attrs[key] = TargetInternal::ParseType(value, info); |
944 | } catch (const Error& e) { |
945 | throw Error(": Error when parsing target[\"" + key + "\"]" + e.what()); |
946 | } |
947 | } |
948 | |
949 | // If requested, query attributes from the device. User-specified |
950 | // parameters take precedence over queried parameters. |
951 | if (attrs.count("from_device" )) { |
952 | int device_id = Downcast<Integer>(attrs.at("from_device" )).IntValue(); |
953 | attrs.erase("from_device" ); |
954 | auto device_params = QueryDevice(device_id, target.get()); |
955 | |
956 | for (const auto& kv : device_params) { |
957 | if (attrs.count(kv.first) == 0) { |
958 | attrs[kv.first] = kv.second; |
959 | } |
960 | } |
961 | } |
962 | |
963 | // set default attribute values if they do not exist |
964 | for (const auto& kv : target->kind->key2default_) { |
965 | if (!attrs.count(kv.first)) { |
966 | attrs[kv.first] = kv.second; |
967 | } |
968 | } |
969 | // do extra pre-processing |
970 | if (target->kind->preprocessor != nullptr) { |
971 | target->attrs = target->kind->preprocessor(Map<String, ObjectRef>(attrs)); |
972 | } else { |
973 | target->attrs = attrs; |
974 | } |
975 | |
976 | return target; |
977 | } // namespace tvm |
978 | |
979 | std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id, |
980 | const TargetNode* target) { |
981 | std::unordered_map<String, ObjectRef> output; |
982 | |
983 | Device device{static_cast<DLDeviceType>(target->GetTargetDeviceType()), device_id}; |
984 | |
985 | auto api = runtime::DeviceAPI::Get(device, true); |
986 | if (!api) { |
987 | LOG(INFO) << "Requested reading the parameters for " << target->kind->name << " from device_id " |
988 | << device_id << ", but support for this runtime wasn't enabled at compile-time. " |
989 | << "Using default target parameters." ; |
990 | return output; |
991 | } |
992 | |
993 | TVMRetValue ret; |
994 | api->GetAttr(device, runtime::kExist, &ret); |
995 | bool device_exists = ret; |
996 | if (!device_exists) { |
997 | ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name |
998 | << " from device_id " << device_id << ", but device_id " << device_id |
999 | << " doesn't exist. Using default target parameters." ; |
1000 | return output; |
1001 | } |
1002 | |
1003 | for (const auto& kv : target->kind->key2vtype_) { |
1004 | const String& key = kv.first; |
1005 | const TargetKindNode::ValueTypeInfo& type_info = kv.second; |
1006 | |
1007 | TVMRetValue ret; |
1008 | api->GetTargetProperty(device, key, &ret); |
1009 | |
1010 | switch (ret.type_code()) { |
1011 | case kTVMNullptr: |
1012 | // Nothing returned for this parameter, move on to the next one. |
1013 | continue; |
1014 | |
1015 | case kTVMArgInt: |
1016 | if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
1017 | output[key] = Integer(static_cast<int64_t>(ret)); |
1018 | } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { |
1019 | output[key] = Bool(static_cast<bool>(ret)); |
1020 | } else { |
1021 | LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key |
1022 | << "', but received integer from device api" ; |
1023 | } |
1024 | break; |
1025 | |
1026 | case kTVMStr: |
1027 | ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) |
1028 | << "Expected " << type_info.type_key << " parameter for attribute '" << key |
1029 | << "', but received string from device api" ; |
1030 | output[key] = String(ret.operator std::string()); |
1031 | break; |
1032 | |
1033 | default: |
1034 | LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key |
1035 | << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api" ; |
1036 | break; |
1037 | } |
1038 | } |
1039 | |
1040 | return output; |
1041 | } |
1042 | |
1043 | /********** Registry **********/ |
1044 | |
1045 | TVM_REGISTER_GLOBAL("target.Target" ).set_body(TargetInternal::ConstructorDispatcher); |
1046 | TVM_REGISTER_GLOBAL("target.TargetEnterScope" ).set_body_typed(TargetInternal::EnterScope); |
1047 | TVM_REGISTER_GLOBAL("target.TargetExitScope" ).set_body_typed(TargetInternal::ExitScope); |
1048 | TVM_REGISTER_GLOBAL("target.TargetCurrent" ).set_body_typed(Target::Current); |
1049 | TVM_REGISTER_GLOBAL("target.TargetExport" ).set_body_typed(TargetInternal::Export); |
1050 | TVM_REGISTER_GLOBAL("target.WithHost" ).set_body_typed(TargetInternal::WithHost); |
1051 | TVM_REGISTER_GLOBAL("target.TargetGetDeviceType" ).set_body_typed([](const Target& target) { |
1052 | return target->GetTargetDeviceType(); |
1053 | }); |
1054 | TVM_REGISTER_GLOBAL("target.TargetGetFeature" ) |
1055 | .set_body_typed([](const Target& target, const String& feature_key) { |
1056 | return target->GetFeature<ObjectRef>(feature_key); |
1057 | }); |
1058 | |
1059 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
1060 | .set_dispatch<TargetNode>([](const ObjectRef& obj, ReprPrinter* p) { |
1061 | p->stream << Downcast<Target>(obj)->str(); |
1062 | }); |
1063 | |
1064 | } // namespace tvm |
1065 | |