1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_META_SCHEDULE_UTILS_H_ |
20 | #define TVM_META_SCHEDULE_UTILS_H_ |
21 | |
22 | #include <dmlc/memory_io.h> |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/meta_schedule/arg_info.h> |
25 | #include <tvm/meta_schedule/builder.h> |
26 | #include <tvm/meta_schedule/cost_model.h> |
27 | #include <tvm/meta_schedule/database.h> |
28 | #include <tvm/meta_schedule/extracted_task.h> |
29 | #include <tvm/meta_schedule/feature_extractor.h> |
30 | #include <tvm/meta_schedule/measure_callback.h> |
31 | #include <tvm/meta_schedule/profiler.h> |
32 | #include <tvm/meta_schedule/runner.h> |
33 | #include <tvm/meta_schedule/schedule_rule.h> |
34 | #include <tvm/meta_schedule/search_strategy.h> |
35 | #include <tvm/meta_schedule/space_generator.h> |
36 | #include <tvm/meta_schedule/task_scheduler.h> |
37 | #include <tvm/meta_schedule/tune_context.h> |
38 | #include <tvm/node/node.h> |
39 | #include <tvm/node/serialization.h> |
40 | #include <tvm/runtime/container/optional.h> |
41 | #include <tvm/support/parallel_for.h> |
42 | #include <tvm/tir/schedule/schedule.h> |
43 | #include <tvm/tir/transform.h> |
44 | |
45 | #include <algorithm> |
46 | #include <string> |
47 | #include <unordered_set> |
48 | #include <utility> |
49 | #include <vector> |
50 | |
51 | #include "../support/array.h" |
52 | #include "../support/base64.h" |
53 | #include "../support/nd_int_set.h" |
54 | #include "../support/table_printer.h" |
55 | #include "../support/utils.h" |
56 | #include "../tir/schedule/primitive.h" |
57 | #include "../tir/schedule/utils.h" |
58 | |
59 | #define TVM_PY_LOG(logging_level, logger) \ |
60 | ::tvm::meta_schedule::PyLogMessage(__FILE__, __LINE__, logger, \ |
61 | PyLogMessage::Level::logging_level) \ |
62 | .stream() |
63 | #define TVM_PY_LOG_CLEAR_SCREEN(logging_func) clear_logging(__FILE__, __LINE__, logging_func) |
64 | |
65 | namespace tvm { |
66 | namespace meta_schedule { |
67 | |
68 | /*! |
69 | * \brief Class to accumulate an log message on the python side. Do not use directly, instead use |
70 | * TVM_PY_LOG(DEBUG), TVM_PY_LOG(INFO), TVM_PY_LOG(WARNING), TVM_PY_ERROR(ERROR). |
71 | * \sa TVM_PY_LOG |
72 | * \sa TVM_PY_LOG_CLEAR_SCREEN |
73 | */ |
74 | class PyLogMessage { |
75 | public: |
76 | enum class Level : int32_t { |
77 | CLEAR = -10, |
78 | DEBUG = 10, |
79 | INFO = 20, |
80 | WARNING = 30, |
81 | ERROR = 40, |
82 | // FATAL not included |
83 | }; |
84 | |
85 | explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger, Level logging_level) |
86 | : filename_(filename), lineno_(lineno), logger_(logger), logging_level_(logging_level) {} |
87 | |
88 | TVM_NO_INLINE ~PyLogMessage() { |
89 | ICHECK(logging_level_ != Level::CLEAR) |
90 | << "Cannot use CLEAR as logging level in TVM_PY_LOG, please use TVM_PY_LOG_CLEAR_SCREEN." ; |
91 | if (this->logger_ != nullptr) { |
92 | logger_(static_cast<int>(logging_level_), std::string(filename_), lineno_, stream_.str()); |
93 | } else { |
94 | if (logging_level_ == Level::INFO) { |
95 | runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_INFO).stream() |
96 | << stream_.str(); |
97 | } else if (logging_level_ == Level::WARNING) { |
98 | runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_WARNING).stream() |
99 | << stream_.str(); |
100 | } else if (logging_level_ == Level::ERROR) { |
101 | runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_ERROR).stream() |
102 | << stream_.str(); |
103 | } else if (logging_level_ == Level::DEBUG) { |
104 | runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_DEBUG).stream() |
105 | << stream_.str(); |
106 | } else { |
107 | runtime::detail::LogFatal(filename_, lineno_).stream() << stream_.str(); |
108 | } |
109 | } |
110 | } |
111 | std::ostringstream& stream() { return stream_; } |
112 | |
113 | private: |
114 | const char* filename_; |
115 | int lineno_; |
116 | std::ostringstream stream_; |
117 | PackedFunc logger_; |
118 | Level logging_level_; |
119 | }; |
120 | |
121 | /*! |
122 | * \brief Whether the tuning is running on ipython kernel. |
123 | * \return A boolean indicating whether ipython kernel is used. |
124 | */ |
125 | inline bool using_ipython() { |
126 | bool flag = false; |
127 | const auto* f_using_ipython = runtime::Registry::Get("meta_schedule.using_ipython" ); |
128 | if (f_using_ipython) { |
129 | flag = (*f_using_ipython)(); |
130 | } |
131 | return flag; |
132 | } |
133 | |
134 | /*! |
135 | * \brief Print out the performance table interactively in jupyter notebook. |
136 | * \param str The serialized performance table. |
137 | */ |
138 | inline void print_interactive_table(const String& data) { |
139 | const auto* f_print_interactive_table = |
140 | runtime::Registry::Get("meta_schedule.print_interactive_table" ); |
141 | ICHECK(f_print_interactive_table->defined()) |
142 | << "Cannot find print_interactive_table function in registry." ; |
143 | (*f_print_interactive_table)(data); |
144 | } |
145 | |
146 | /*! |
147 | * \brief A helper function to clear logging output for ipython kernel and console. |
148 | * \param file The file name. |
149 | * \param lineno The line number. |
150 | * \param logging_func The logging function. |
151 | */ |
152 | inline void clear_logging(const char* file, int lineno, PackedFunc logging_func) { |
153 | if (logging_func.defined() && using_ipython()) { |
154 | logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), file, lineno, "" ); |
155 | } else { |
156 | // this would clear all logging output in the console |
157 | runtime::detail::LogMessage(file, lineno, TVM_LOG_LEVEL_INFO).stream() |
158 | << "\033c\033[3J\033[2J\033[0m\033[H" ; |
159 | } |
160 | } |
161 | |
162 | /*! \brief The type of the random state */ |
163 | using TRandState = support::LinearCongruentialEngine::TRandState; |
164 | |
165 | /*! |
166 | * \brief Get the base64 encoded result of a string. |
167 | * \param str The string to encode. |
168 | * \return The base64 encoded string. |
169 | */ |
170 | inline std::string Base64Encode(std::string str) { |
171 | std::string result; |
172 | dmlc::MemoryStringStream m_stream(&result); |
173 | support::Base64OutStream b64stream(&m_stream); |
174 | static_cast<dmlc::Stream*>(&b64stream)->Write(str); |
175 | b64stream.Finish(); |
176 | return result; |
177 | } |
178 | |
179 | /*! |
180 | * \brief Get the base64 decoded result of a string. |
181 | * \param str The string to decode. |
182 | * \return The base64 decoded string. |
183 | */ |
184 | inline std::string Base64Decode(std::string str) { |
185 | std::string result; |
186 | dmlc::MemoryStringStream m_stream(&str); |
187 | support::Base64InStream b64stream(&m_stream); |
188 | b64stream.InitPosition(); |
189 | static_cast<dmlc::Stream*>(&b64stream)->Read(&result); |
190 | return result; |
191 | } |
192 | |
193 | /*! |
194 | * \brief Parses a json string into a json object. |
195 | * \param json_str The json string. |
196 | * \return The json object |
197 | */ |
198 | ObjectRef JSONLoads(std::string json_str); |
199 | |
200 | /*! |
201 | * \brief Dumps a json object into a json string. |
202 | * \param json_obj The json object. |
203 | * \return The json string |
204 | */ |
205 | std::string JSONDumps(ObjectRef json_obj); |
206 | |
207 | /*! |
208 | * \brief Converts a structural hash code to string |
209 | * \param hash_code The hash code |
210 | * \return The string representation of the hash code |
211 | */ |
212 | inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } |
213 | |
214 | /*! |
215 | * \brief Converts an TVM object to the hex string representation of its structural hash. |
216 | * \param obj The TVM object. |
217 | * \return The hex string representation of the hash code. |
218 | */ |
219 | inline String SHash2Hex(const ObjectRef& obj) { |
220 | std::ostringstream os; |
221 | size_t hash_code = 0; |
222 | if (obj.defined()) { |
223 | hash_code = StructuralHash()(obj); |
224 | } |
225 | os << "0x" << std::setw(16) << std::setfill('0') << std::hex << hash_code; |
226 | return os.str(); |
227 | } |
228 | |
229 | /*! |
230 | * \brief Fork a random state into another, i.e. PRNG splitting. |
231 | * The given random state is also mutated. |
232 | * \param rand_state The random state to be forked |
233 | * \return The forked random state |
234 | */ |
235 | inline support::LinearCongruentialEngine::TRandState ForkSeed( |
236 | support::LinearCongruentialEngine::TRandState* rand_state) { |
237 | return support::LinearCongruentialEngine(rand_state).ForkSeed(); |
238 | } |
239 | |
240 | /*! |
241 | * \brief Fork a random state into another ones, i.e. PRNG splitting. |
242 | * The given random state is also mutated. |
243 | * \param rand_state The random state to be forked |
244 | * \param n The number of forks |
245 | * \return The forked random states |
246 | */ |
247 | inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed( |
248 | support::LinearCongruentialEngine::TRandState* rand_state, int n) { |
249 | std::vector<support::LinearCongruentialEngine::TRandState> results; |
250 | results.reserve(n); |
251 | for (int i = 0; i < n; ++i) { |
252 | results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed()); |
253 | } |
254 | return results; |
255 | } |
256 | |
257 | /*! |
258 | * \brief Get deep copy of an IRModule. |
259 | * \param mod The IRModule to make a deep copy. |
260 | * \return The deep copy of the IRModule. |
261 | */ |
262 | inline IRModule DeepCopyIRModule(IRModule mod) { |
263 | return Downcast<IRModule>(LoadJSON(SaveJSON(mod))); |
264 | } |
265 | |
266 | /*! |
267 | * \brief Concatenate strings |
268 | * \param strs The strings to concatenate |
269 | * \param delim The delimiter |
270 | * \return The concatenated string |
271 | */ |
272 | inline std::string Concat(const Array<String>& strs, const std::string& delim) { |
273 | if (strs.empty()) { |
274 | return "" ; |
275 | } |
276 | std::ostringstream os; |
277 | os << strs[0]; |
278 | for (int i = 1, n = strs.size(); i < n; ++i) { |
279 | os << delim << strs[i]; |
280 | } |
281 | return os.str(); |
282 | } |
283 | |
284 | /*! |
285 | * \brief Get the BlockRV from a block StmtSRef |
286 | * \param sch The schedule |
287 | * \param block_sref The block StmtSRef |
288 | * \param global_var_name The global variable name |
289 | * \return The BlockRV |
290 | */ |
291 | inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, |
292 | const String& global_var_name) { |
293 | const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
294 | return sch->GetBlock(block->name_hint, global_var_name); |
295 | } |
296 | |
297 | /*! |
298 | * \brief A helper data structure that replays a trace and collects failure counts |
299 | * for each postprocessor |
300 | */ |
301 | struct ThreadedTraceApply { |
302 | /*! \brief Constructor */ |
303 | explicit ThreadedTraceApply(const Array<Postproc>& postprocs) |
304 | : n_(postprocs.size()), items_(new Item[n_]) { |
305 | for (int i = 0; i < n_; ++i) { |
306 | items_[i].postproc = postprocs[i]; |
307 | items_[i].fail_counter = 0; |
308 | } |
309 | } |
310 | |
311 | /*! \brief Destructor */ |
312 | ~ThreadedTraceApply() { delete[] items_; } |
313 | |
314 | /*! |
315 | * \brief Apply the trace and postprocessors to an IRModule |
316 | * \param mod The IRModule to be applied |
317 | * \param trace The trace to apply to the IRModule |
318 | * \param rand_state The random seed |
319 | * \return The schedule created, or NullOpt if any postprocessor fails |
320 | */ |
321 | Optional<tir::Schedule> Apply(const IRModule& mod, const tir::Trace& trace, |
322 | TRandState* rand_state) { |
323 | tir::Schedule sch = |
324 | tir::Schedule::Traced(mod, |
325 | /*rand_state=*/ForkSeed(rand_state), |
326 | /*debug_mode=*/0, |
327 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); |
328 | |
329 | trace->ApplyToSchedule(sch, /*remove_postproc=*/true); |
330 | sch->EnterPostproc(); |
331 | |
332 | for (int i = 0; i < n_; ++i) { |
333 | Item& item = items_[i]; |
334 | if (!item.postproc->Apply(sch)) { |
335 | item.fail_counter++; |
336 | return NullOpt; |
337 | } |
338 | } |
339 | return sch; |
340 | } |
341 | |
342 | /*! \brief Returns a string summarizing the failures on each postprocessor */ |
343 | std::string SummarizeFailures() const { |
344 | std::ostringstream os; |
345 | for (int i = 0; i < n_; ++i) { |
346 | const Item& item = items_[i]; |
347 | os << "Postproc #" << i << " [" << item.postproc // |
348 | << "]: " << item.fail_counter.load() << " failure(s)" ; |
349 | if (i != n_ - 1) { |
350 | os << "\n" ; |
351 | } |
352 | } |
353 | return os.str(); |
354 | } |
355 | |
356 | private: |
357 | /*! \brief A helper data structure that stores the fail count for each postprocessor. */ |
358 | struct Item { |
359 | /*! \brief The postprocessor. */ |
360 | Postproc postproc{nullptr}; |
361 | /*! \brief The thread-safe postprocessor failure counter. */ |
362 | std::atomic<int> fail_counter{0}; |
363 | }; |
364 | |
365 | /*! \brief The number of total postprocessors. */ |
366 | int n_; |
367 | /*! \brief The pointer to the list of postprocessor items. */ |
368 | Item* items_; |
369 | }; |
370 | |
371 | /*! |
372 | * \brief Get the number of cores in CPU |
373 | * \param target The target |
374 | * \return The number of cores. |
375 | */ |
376 | inline int GetTargetNumCores(const Target& target) { |
377 | int num_cores = target->GetAttr<Integer>("num-cores" ).value_or(-1).IntValue(); |
378 | if (num_cores == -1) { |
379 | static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count" ); |
380 | ICHECK(f_cpu_count) |
381 | << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\"" ; |
382 | num_cores = (*f_cpu_count)(false); |
383 | LOG(FATAL) |
384 | << "Target does not have attribute \"num-cores\", physical core number must be " |
385 | "defined! For example, on the local machine, the target must be \"llvm -num-cores " |
386 | << num_cores << "\"" ; |
387 | } |
388 | return num_cores; |
389 | } |
390 | |
391 | /*! |
392 | * \brief Get the median of the running time from RunnerResult in millisecond |
393 | * \param results The results from RunnerResult |
394 | * \return The median of the running time in millisecond |
395 | */ |
396 | inline double GetRunMsMedian(const RunnerResult& runner_result) { |
397 | Array<FloatImm> run_secs = runner_result->run_secs.value(); |
398 | ICHECK(!run_secs.empty()); |
399 | std::vector<double> v; |
400 | v.reserve(run_secs.size()); |
401 | std::transform(run_secs.begin(), run_secs.end(), std::back_inserter(v), |
402 | [](const FloatImm& f) -> double { return f->value; }); |
403 | std::sort(v.begin(), v.end()); |
404 | int n = v.size(); |
405 | if (n % 2 == 0) { |
406 | return (v[n / 2 - 1] + v[n / 2]) * 0.5 * 1000.0; |
407 | } else { |
408 | return v[n / 2] * 1000.0; |
409 | } |
410 | } |
411 | |
412 | /*! |
413 | * \brief Convert the given object to an array of floating point numbers |
414 | * \param obj The object to be converted |
415 | * \return The array of floating point numbers |
416 | */ |
417 | inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) { |
418 | const ArrayNode* arr = obj.as<ArrayNode>(); |
419 | ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); |
420 | Array<FloatImm> results; |
421 | results.reserve(arr->size()); |
422 | for (const ObjectRef& elem : *arr) { |
423 | if (const auto* int_imm = elem.as<IntImmNode>()) { |
424 | results.push_back(FloatImm(DataType::Float(32), int_imm->value)); |
425 | } else if (const auto* float_imm = elem.as<FloatImmNode>()) { |
426 | results.push_back(FloatImm(DataType::Float(32), float_imm->value)); |
427 | } else { |
428 | LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey(); |
429 | } |
430 | } |
431 | return results; |
432 | } |
433 | |
434 | /*! |
435 | * \brief Convert the given object to an array of integers |
436 | * \param obj The object to be converted |
437 | * \return The array of integers |
438 | */ |
439 | inline Array<Integer> AsIntArray(const ObjectRef& obj) { |
440 | const ArrayNode* arr = obj.as<ArrayNode>(); |
441 | ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); |
442 | Array<Integer> results; |
443 | results.reserve(arr->size()); |
444 | for (const ObjectRef& elem : *arr) { |
445 | if (const auto* int_imm = elem.as<IntImmNode>()) { |
446 | results.push_back(Integer(int_imm->value)); |
447 | } else { |
448 | LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); |
449 | } |
450 | } |
451 | return results; |
452 | } |
453 | |
454 | /*! \brief The struct defining comparison function of sorting by mean run seconds. */ |
455 | struct SortTuningRecordByMeanRunSecs { |
456 | static const constexpr double kMaxMeanTime = 1e10; |
457 | |
458 | static double Mean(const Array<FloatImm>& a) { |
459 | if (a.empty()) { |
460 | return kMaxMeanTime; |
461 | } |
462 | double sum = 0.0; |
463 | for (const FloatImm& i : a) { |
464 | sum += i->value; |
465 | } |
466 | return sum / a.size(); |
467 | } |
468 | |
469 | bool operator()(const TuningRecord& a, const TuningRecord& b) const { |
470 | double a_time = Mean(a->run_secs.value_or({})); |
471 | double b_time = Mean(b->run_secs.value_or({})); |
472 | return a_time < b_time; |
473 | } |
474 | }; |
475 | |
476 | /*! |
477 | * \brief The helper function to clone schedule rules, postprocessors, and mutators. |
478 | * \param src The source space generator. |
479 | * \param dst The destination space generator. |
480 | */ |
481 | inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { |
482 | if (src->sch_rules.defined()) { |
483 | Array<ScheduleRule> original = src->sch_rules.value(); |
484 | Array<ScheduleRule> sch_rules; |
485 | sch_rules.reserve(original.size()); |
486 | for (const ScheduleRule& sch_rule : original) { |
487 | sch_rules.push_back(sch_rule->Clone()); |
488 | } |
489 | dst->sch_rules = std::move(sch_rules); |
490 | } |
491 | if (src->postprocs.defined()) { |
492 | Array<Postproc> original = src->postprocs.value(); |
493 | Array<Postproc> postprocs; |
494 | postprocs.reserve(original.size()); |
495 | for (const Postproc& postproc : original) { |
496 | postprocs.push_back(postproc->Clone()); |
497 | } |
498 | dst->postprocs = std::move(postprocs); |
499 | } |
500 | if (src->mutator_probs.defined()) { |
501 | Map<Mutator, FloatImm> original = src->mutator_probs.value(); |
502 | Map<Mutator, FloatImm> mutator_probs; |
503 | for (const auto& kv : original) { |
504 | mutator_probs.Set(kv.first->Clone(), kv.second); |
505 | } |
506 | dst->mutator_probs = std::move(mutator_probs); |
507 | } |
508 | } |
509 | |
510 | /*! \brief Returns true if the given target is one of the supported gpu targets. */ |
511 | inline bool IsGPUTarget(const std::string& target_name) { |
512 | static const std::unordered_set<std::string> gpu_targets{"cuda" , "rocm" , "vulkan" , "metal" }; |
513 | return gpu_targets.count(target_name); |
514 | } |
515 | |
516 | /*! |
517 | * \brief Create an AutoInline schedule rule for the given target. |
518 | * \param target_name The name of the target ("llvm", "cuda", etc.) |
519 | * \return The AutoInline schedule rule for the given target. |
520 | */ |
521 | inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { |
522 | Array<ScheduleRule> rules{nullptr}; |
523 | if (target_name == "llvm" ) { |
524 | rules = ScheduleRule::DefaultLLVM(); |
525 | } else if (target_name == "hexagon" ) { |
526 | rules = ScheduleRule::DefaultHexagon(); |
527 | } else if (target_name == "c" ) { |
528 | rules = ScheduleRule::DefaultMicro(); |
529 | } else if (IsGPUTarget(target_name)) { |
530 | rules = ScheduleRule::DefaultCUDA(); |
531 | } else { |
532 | LOG(FATAL) << "ValueError: Unsupported target: " << target_name; |
533 | } |
534 | for (const ScheduleRule& rule : rules) { |
535 | if (rule->GetTypeKey() == "meta_schedule.AutoInline" ) { |
536 | return rule; |
537 | } |
538 | } |
539 | LOG(FATAL) << "ValueError: AutoInline rule is not found in the default rules for target: " |
540 | << target_name; |
541 | throw; |
542 | } |
543 | |
544 | /*! |
545 | * \brief Summarize the run time of the given FloatImm array. |
546 | * \param arr The array of FloatImm. |
547 | * \return The summary of the values in the given array. |
548 | */ |
549 | inline double Sum(const Array<FloatImm>& arr) { |
550 | double sum = 0; |
551 | for (const FloatImm& f : arr) { |
552 | sum += f->value; |
553 | } |
554 | return sum; |
555 | } |
556 | |
557 | } // namespace meta_schedule |
558 | } // namespace tvm |
559 | |
560 | #endif // TVM_META_SCHEDULE_UTILS_H_ |
561 | |