1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_UTILS_H_
20#define TVM_META_SCHEDULE_UTILS_H_
21
22#include <dmlc/memory_io.h>
23#include <tvm/arith/analyzer.h>
24#include <tvm/meta_schedule/arg_info.h>
25#include <tvm/meta_schedule/builder.h>
26#include <tvm/meta_schedule/cost_model.h>
27#include <tvm/meta_schedule/database.h>
28#include <tvm/meta_schedule/extracted_task.h>
29#include <tvm/meta_schedule/feature_extractor.h>
30#include <tvm/meta_schedule/measure_callback.h>
31#include <tvm/meta_schedule/profiler.h>
32#include <tvm/meta_schedule/runner.h>
33#include <tvm/meta_schedule/schedule_rule.h>
34#include <tvm/meta_schedule/search_strategy.h>
35#include <tvm/meta_schedule/space_generator.h>
36#include <tvm/meta_schedule/task_scheduler.h>
37#include <tvm/meta_schedule/tune_context.h>
38#include <tvm/node/node.h>
39#include <tvm/node/serialization.h>
40#include <tvm/runtime/container/optional.h>
41#include <tvm/support/parallel_for.h>
42#include <tvm/tir/schedule/schedule.h>
43#include <tvm/tir/transform.h>
44
45#include <algorithm>
46#include <string>
47#include <unordered_set>
48#include <utility>
49#include <vector>
50
51#include "../support/array.h"
52#include "../support/base64.h"
53#include "../support/nd_int_set.h"
54#include "../support/table_printer.h"
55#include "../support/utils.h"
56#include "../tir/schedule/primitive.h"
57#include "../tir/schedule/utils.h"
58
59#define TVM_PY_LOG(logging_level, logger) \
60 ::tvm::meta_schedule::PyLogMessage(__FILE__, __LINE__, logger, \
61 PyLogMessage::Level::logging_level) \
62 .stream()
63#define TVM_PY_LOG_CLEAR_SCREEN(logging_func) clear_logging(__FILE__, __LINE__, logging_func)
64
65namespace tvm {
66namespace meta_schedule {
67
68/*!
69 * \brief Class to accumulate an log message on the python side. Do not use directly, instead use
70 * TVM_PY_LOG(DEBUG), TVM_PY_LOG(INFO), TVM_PY_LOG(WARNING), TVM_PY_ERROR(ERROR).
71 * \sa TVM_PY_LOG
72 * \sa TVM_PY_LOG_CLEAR_SCREEN
73 */
74class PyLogMessage {
75 public:
76 enum class Level : int32_t {
77 CLEAR = -10,
78 DEBUG = 10,
79 INFO = 20,
80 WARNING = 30,
81 ERROR = 40,
82 // FATAL not included
83 };
84
85 explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger, Level logging_level)
86 : filename_(filename), lineno_(lineno), logger_(logger), logging_level_(logging_level) {}
87
88 TVM_NO_INLINE ~PyLogMessage() {
89 ICHECK(logging_level_ != Level::CLEAR)
90 << "Cannot use CLEAR as logging level in TVM_PY_LOG, please use TVM_PY_LOG_CLEAR_SCREEN.";
91 if (this->logger_ != nullptr) {
92 logger_(static_cast<int>(logging_level_), std::string(filename_), lineno_, stream_.str());
93 } else {
94 if (logging_level_ == Level::INFO) {
95 runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_INFO).stream()
96 << stream_.str();
97 } else if (logging_level_ == Level::WARNING) {
98 runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_WARNING).stream()
99 << stream_.str();
100 } else if (logging_level_ == Level::ERROR) {
101 runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_ERROR).stream()
102 << stream_.str();
103 } else if (logging_level_ == Level::DEBUG) {
104 runtime::detail::LogMessage(filename_, lineno_, TVM_LOG_LEVEL_DEBUG).stream()
105 << stream_.str();
106 } else {
107 runtime::detail::LogFatal(filename_, lineno_).stream() << stream_.str();
108 }
109 }
110 }
111 std::ostringstream& stream() { return stream_; }
112
113 private:
114 const char* filename_;
115 int lineno_;
116 std::ostringstream stream_;
117 PackedFunc logger_;
118 Level logging_level_;
119};
120
121/*!
122 * \brief Whether the tuning is running on ipython kernel.
123 * \return A boolean indicating whether ipython kernel is used.
124 */
125inline bool using_ipython() {
126 bool flag = false;
127 const auto* f_using_ipython = runtime::Registry::Get("meta_schedule.using_ipython");
128 if (f_using_ipython) {
129 flag = (*f_using_ipython)();
130 }
131 return flag;
132}
133
134/*!
135 * \brief Print out the performance table interactively in jupyter notebook.
136 * \param str The serialized performance table.
137 */
138inline void print_interactive_table(const String& data) {
139 const auto* f_print_interactive_table =
140 runtime::Registry::Get("meta_schedule.print_interactive_table");
141 ICHECK(f_print_interactive_table->defined())
142 << "Cannot find print_interactive_table function in registry.";
143 (*f_print_interactive_table)(data);
144}
145
146/*!
147 * \brief A helper function to clear logging output for ipython kernel and console.
148 * \param file The file name.
149 * \param lineno The line number.
150 * \param logging_func The logging function.
151 */
152inline void clear_logging(const char* file, int lineno, PackedFunc logging_func) {
153 if (logging_func.defined() && using_ipython()) {
154 logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), file, lineno, "");
155 } else {
156 // this would clear all logging output in the console
157 runtime::detail::LogMessage(file, lineno, TVM_LOG_LEVEL_INFO).stream()
158 << "\033c\033[3J\033[2J\033[0m\033[H";
159 }
160}
161
162/*! \brief The type of the random state */
163using TRandState = support::LinearCongruentialEngine::TRandState;
164
165/*!
166 * \brief Get the base64 encoded result of a string.
167 * \param str The string to encode.
168 * \return The base64 encoded string.
169 */
170inline std::string Base64Encode(std::string str) {
171 std::string result;
172 dmlc::MemoryStringStream m_stream(&result);
173 support::Base64OutStream b64stream(&m_stream);
174 static_cast<dmlc::Stream*>(&b64stream)->Write(str);
175 b64stream.Finish();
176 return result;
177}
178
179/*!
180 * \brief Get the base64 decoded result of a string.
181 * \param str The string to decode.
182 * \return The base64 decoded string.
183 */
184inline std::string Base64Decode(std::string str) {
185 std::string result;
186 dmlc::MemoryStringStream m_stream(&str);
187 support::Base64InStream b64stream(&m_stream);
188 b64stream.InitPosition();
189 static_cast<dmlc::Stream*>(&b64stream)->Read(&result);
190 return result;
191}
192
193/*!
194 * \brief Parses a json string into a json object.
195 * \param json_str The json string.
196 * \return The json object
197 */
198ObjectRef JSONLoads(std::string json_str);
199
200/*!
201 * \brief Dumps a json object into a json string.
202 * \param json_obj The json object.
203 * \return The json string
204 */
205std::string JSONDumps(ObjectRef json_obj);
206
207/*!
208 * \brief Converts a structural hash code to string
209 * \param hash_code The hash code
210 * \return The string representation of the hash code
211 */
212inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); }
213
214/*!
215 * \brief Converts an TVM object to the hex string representation of its structural hash.
216 * \param obj The TVM object.
217 * \return The hex string representation of the hash code.
218 */
219inline String SHash2Hex(const ObjectRef& obj) {
220 std::ostringstream os;
221 size_t hash_code = 0;
222 if (obj.defined()) {
223 hash_code = StructuralHash()(obj);
224 }
225 os << "0x" << std::setw(16) << std::setfill('0') << std::hex << hash_code;
226 return os.str();
227}
228
229/*!
230 * \brief Fork a random state into another, i.e. PRNG splitting.
231 * The given random state is also mutated.
232 * \param rand_state The random state to be forked
233 * \return The forked random state
234 */
235inline support::LinearCongruentialEngine::TRandState ForkSeed(
236 support::LinearCongruentialEngine::TRandState* rand_state) {
237 return support::LinearCongruentialEngine(rand_state).ForkSeed();
238}
239
240/*!
241 * \brief Fork a random state into another ones, i.e. PRNG splitting.
242 * The given random state is also mutated.
243 * \param rand_state The random state to be forked
244 * \param n The number of forks
245 * \return The forked random states
246 */
247inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed(
248 support::LinearCongruentialEngine::TRandState* rand_state, int n) {
249 std::vector<support::LinearCongruentialEngine::TRandState> results;
250 results.reserve(n);
251 for (int i = 0; i < n; ++i) {
252 results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed());
253 }
254 return results;
255}
256
257/*!
258 * \brief Get deep copy of an IRModule.
259 * \param mod The IRModule to make a deep copy.
260 * \return The deep copy of the IRModule.
261 */
262inline IRModule DeepCopyIRModule(IRModule mod) {
263 return Downcast<IRModule>(LoadJSON(SaveJSON(mod)));
264}
265
266/*!
267 * \brief Concatenate strings
268 * \param strs The strings to concatenate
269 * \param delim The delimiter
270 * \return The concatenated string
271 */
272inline std::string Concat(const Array<String>& strs, const std::string& delim) {
273 if (strs.empty()) {
274 return "";
275 }
276 std::ostringstream os;
277 os << strs[0];
278 for (int i = 1, n = strs.size(); i < n; ++i) {
279 os << delim << strs[i];
280 }
281 return os.str();
282}
283
284/*!
285 * \brief Get the BlockRV from a block StmtSRef
286 * \param sch The schedule
287 * \param block_sref The block StmtSRef
288 * \param global_var_name The global variable name
289 * \return The BlockRV
290 */
291inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref,
292 const String& global_var_name) {
293 const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
294 return sch->GetBlock(block->name_hint, global_var_name);
295}
296
297/*!
298 * \brief A helper data structure that replays a trace and collects failure counts
299 * for each postprocessor
300 */
301struct ThreadedTraceApply {
302 /*! \brief Constructor */
303 explicit ThreadedTraceApply(const Array<Postproc>& postprocs)
304 : n_(postprocs.size()), items_(new Item[n_]) {
305 for (int i = 0; i < n_; ++i) {
306 items_[i].postproc = postprocs[i];
307 items_[i].fail_counter = 0;
308 }
309 }
310
311 /*! \brief Destructor */
312 ~ThreadedTraceApply() { delete[] items_; }
313
314 /*!
315 * \brief Apply the trace and postprocessors to an IRModule
316 * \param mod The IRModule to be applied
317 * \param trace The trace to apply to the IRModule
318 * \param rand_state The random seed
319 * \return The schedule created, or NullOpt if any postprocessor fails
320 */
321 Optional<tir::Schedule> Apply(const IRModule& mod, const tir::Trace& trace,
322 TRandState* rand_state) {
323 tir::Schedule sch =
324 tir::Schedule::Traced(mod,
325 /*rand_state=*/ForkSeed(rand_state),
326 /*debug_mode=*/0,
327 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
328
329 trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
330 sch->EnterPostproc();
331
332 for (int i = 0; i < n_; ++i) {
333 Item& item = items_[i];
334 if (!item.postproc->Apply(sch)) {
335 item.fail_counter++;
336 return NullOpt;
337 }
338 }
339 return sch;
340 }
341
342 /*! \brief Returns a string summarizing the failures on each postprocessor */
343 std::string SummarizeFailures() const {
344 std::ostringstream os;
345 for (int i = 0; i < n_; ++i) {
346 const Item& item = items_[i];
347 os << "Postproc #" << i << " [" << item.postproc //
348 << "]: " << item.fail_counter.load() << " failure(s)";
349 if (i != n_ - 1) {
350 os << "\n";
351 }
352 }
353 return os.str();
354 }
355
356 private:
357 /*! \brief A helper data structure that stores the fail count for each postprocessor. */
358 struct Item {
359 /*! \brief The postprocessor. */
360 Postproc postproc{nullptr};
361 /*! \brief The thread-safe postprocessor failure counter. */
362 std::atomic<int> fail_counter{0};
363 };
364
365 /*! \brief The number of total postprocessors. */
366 int n_;
367 /*! \brief The pointer to the list of postprocessor items. */
368 Item* items_;
369};
370
371/*!
372 * \brief Get the number of cores in CPU
373 * \param target The target
374 * \return The number of cores.
375 */
376inline int GetTargetNumCores(const Target& target) {
377 int num_cores = target->GetAttr<Integer>("num-cores").value_or(-1).IntValue();
378 if (num_cores == -1) {
379 static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count");
380 ICHECK(f_cpu_count)
381 << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\"";
382 num_cores = (*f_cpu_count)(false);
383 LOG(FATAL)
384 << "Target does not have attribute \"num-cores\", physical core number must be "
385 "defined! For example, on the local machine, the target must be \"llvm -num-cores "
386 << num_cores << "\"";
387 }
388 return num_cores;
389}
390
391/*!
392 * \brief Get the median of the running time from RunnerResult in millisecond
393 * \param results The results from RunnerResult
394 * \return The median of the running time in millisecond
395 */
396inline double GetRunMsMedian(const RunnerResult& runner_result) {
397 Array<FloatImm> run_secs = runner_result->run_secs.value();
398 ICHECK(!run_secs.empty());
399 std::vector<double> v;
400 v.reserve(run_secs.size());
401 std::transform(run_secs.begin(), run_secs.end(), std::back_inserter(v),
402 [](const FloatImm& f) -> double { return f->value; });
403 std::sort(v.begin(), v.end());
404 int n = v.size();
405 if (n % 2 == 0) {
406 return (v[n / 2 - 1] + v[n / 2]) * 0.5 * 1000.0;
407 } else {
408 return v[n / 2] * 1000.0;
409 }
410}
411
412/*!
413 * \brief Convert the given object to an array of floating point numbers
414 * \param obj The object to be converted
415 * \return The array of floating point numbers
416 */
417inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) {
418 const ArrayNode* arr = obj.as<ArrayNode>();
419 ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
420 Array<FloatImm> results;
421 results.reserve(arr->size());
422 for (const ObjectRef& elem : *arr) {
423 if (const auto* int_imm = elem.as<IntImmNode>()) {
424 results.push_back(FloatImm(DataType::Float(32), int_imm->value));
425 } else if (const auto* float_imm = elem.as<FloatImmNode>()) {
426 results.push_back(FloatImm(DataType::Float(32), float_imm->value));
427 } else {
428 LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey();
429 }
430 }
431 return results;
432}
433
434/*!
435 * \brief Convert the given object to an array of integers
436 * \param obj The object to be converted
437 * \return The array of integers
438 */
439inline Array<Integer> AsIntArray(const ObjectRef& obj) {
440 const ArrayNode* arr = obj.as<ArrayNode>();
441 ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
442 Array<Integer> results;
443 results.reserve(arr->size());
444 for (const ObjectRef& elem : *arr) {
445 if (const auto* int_imm = elem.as<IntImmNode>()) {
446 results.push_back(Integer(int_imm->value));
447 } else {
448 LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey();
449 }
450 }
451 return results;
452}
453
454/*! \brief The struct defining comparison function of sorting by mean run seconds. */
455struct SortTuningRecordByMeanRunSecs {
456 static const constexpr double kMaxMeanTime = 1e10;
457
458 static double Mean(const Array<FloatImm>& a) {
459 if (a.empty()) {
460 return kMaxMeanTime;
461 }
462 double sum = 0.0;
463 for (const FloatImm& i : a) {
464 sum += i->value;
465 }
466 return sum / a.size();
467 }
468
469 bool operator()(const TuningRecord& a, const TuningRecord& b) const {
470 double a_time = Mean(a->run_secs.value_or({}));
471 double b_time = Mean(b->run_secs.value_or({}));
472 return a_time < b_time;
473 }
474};
475
476/*!
477 * \brief The helper function to clone schedule rules, postprocessors, and mutators.
478 * \param src The source space generator.
479 * \param dst The destination space generator.
480 */
481inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) {
482 if (src->sch_rules.defined()) {
483 Array<ScheduleRule> original = src->sch_rules.value();
484 Array<ScheduleRule> sch_rules;
485 sch_rules.reserve(original.size());
486 for (const ScheduleRule& sch_rule : original) {
487 sch_rules.push_back(sch_rule->Clone());
488 }
489 dst->sch_rules = std::move(sch_rules);
490 }
491 if (src->postprocs.defined()) {
492 Array<Postproc> original = src->postprocs.value();
493 Array<Postproc> postprocs;
494 postprocs.reserve(original.size());
495 for (const Postproc& postproc : original) {
496 postprocs.push_back(postproc->Clone());
497 }
498 dst->postprocs = std::move(postprocs);
499 }
500 if (src->mutator_probs.defined()) {
501 Map<Mutator, FloatImm> original = src->mutator_probs.value();
502 Map<Mutator, FloatImm> mutator_probs;
503 for (const auto& kv : original) {
504 mutator_probs.Set(kv.first->Clone(), kv.second);
505 }
506 dst->mutator_probs = std::move(mutator_probs);
507 }
508}
509
510/*! \brief Returns true if the given target is one of the supported gpu targets. */
511inline bool IsGPUTarget(const std::string& target_name) {
512 static const std::unordered_set<std::string> gpu_targets{"cuda", "rocm", "vulkan", "metal"};
513 return gpu_targets.count(target_name);
514}
515
516/*!
517 * \brief Create an AutoInline schedule rule for the given target.
518 * \param target_name The name of the target ("llvm", "cuda", etc.)
519 * \return The AutoInline schedule rule for the given target.
520 */
521inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) {
522 Array<ScheduleRule> rules{nullptr};
523 if (target_name == "llvm") {
524 rules = ScheduleRule::DefaultLLVM();
525 } else if (target_name == "hexagon") {
526 rules = ScheduleRule::DefaultHexagon();
527 } else if (target_name == "c") {
528 rules = ScheduleRule::DefaultMicro();
529 } else if (IsGPUTarget(target_name)) {
530 rules = ScheduleRule::DefaultCUDA();
531 } else {
532 LOG(FATAL) << "ValueError: Unsupported target: " << target_name;
533 }
534 for (const ScheduleRule& rule : rules) {
535 if (rule->GetTypeKey() == "meta_schedule.AutoInline") {
536 return rule;
537 }
538 }
539 LOG(FATAL) << "ValueError: AutoInline rule is not found in the default rules for target: "
540 << target_name;
541 throw;
542}
543
544/*!
545 * \brief Summarize the run time of the given FloatImm array.
546 * \param arr The array of FloatImm.
547 * \return The summary of the values in the given array.
548 */
549inline double Sum(const Array<FloatImm>& arr) {
550 double sum = 0;
551 for (const FloatImm& f : arr) {
552 sum += f->value;
553 }
554 return sum;
555}
556
557} // namespace meta_schedule
558} // namespace tvm
559
560#endif // TVM_META_SCHEDULE_UTILS_H_
561