1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_DATABASE_H_
20#define TVM_META_SCHEDULE_DATABASE_H_
21
22#include <tvm/ir/expr.h>
23#include <tvm/ir/module.h>
24#include <tvm/meta_schedule/arg_info.h>
25#include <tvm/node/reflection.h>
26#include <tvm/runtime/container/array.h>
27#include <tvm/runtime/container/string.h>
28#include <tvm/runtime/object.h>
29#include <tvm/runtime/packed_func.h>
30#include <tvm/target/target.h>
31#include <tvm/tir/schedule/schedule.h>
32#include <tvm/tir/schedule/trace.h>
33
34#include <memory>
35
36namespace tvm {
37namespace meta_schedule {
38
39class ModuleEquality;
40
41/*! \brief A workload, i.e. an IRModule and its structural hash. */
42class WorkloadNode : public runtime::Object {
43 public:
44 /*! \brief The type of structural hash */
45 using THashCode = size_t;
46 /*! \brief The workload's IRModule. */
47 IRModule mod;
48 /*! \brief The workload's structural hash. */
49 THashCode shash;
50
51 void VisitAttrs(tvm::AttrVisitor* v) {
52 v->Visit("mod", &mod);
53 // `shash` is not visited because TVM FFI doesn't support uint64_t
54 }
55
56 static constexpr const char* _type_key = "meta_schedule.Workload";
57 TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object);
58
59 /*!
60 * \brief Export the workload to a JSON string.
61 * \return An array containing the structural hash and the base64 json string.
62 */
63 ObjectRef AsJSON() const;
64};
65
66/*!
67 * \brief Managed reference to WorkloadNode.
68 * \sa WorkloadNode
69 */
70class Workload : public runtime::ObjectRef {
71 public:
72 using THashCode = WorkloadNode::THashCode;
73 /*!
74 * \brief Constructor of Workload.
75 * \param mod The workload's IRModule.
76 */
77 TVM_DLL explicit Workload(IRModule mod);
78 /*!
79 * \brief Constructor of Workload.
80 * \param mod The workload's IRModule.
81 * \param shash The workload's structural hash.
82 */
83 TVM_DLL explicit Workload(IRModule mod, THashCode shash);
84 /*!
85 * \brief Create a workload from a json object.
86 * \param json_obj The json object.
87 * \return The created workload.
88 */
89 TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
90
91 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Workload, runtime::ObjectRef, WorkloadNode);
92};
93
94/*! \brief The hash method for Workload */
95struct WorkloadHash {
96 size_t operator()(const Workload& a) const { return a->shash; }
97};
98
99/*! \brief The equality check for Workload */
100struct WorkloadEqual {
101 explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
102
103 bool operator()(const Workload& a, const Workload& b) const;
104
105 private:
106 /*! \brief The module equality testing and hashing method */
107 const ModuleEquality& mod_eq_;
108};
109
110/*! \brief The class of measure candidates. */
111class MeasureCandidate;
112
113/*! \brief The class of tuning records. */
114class TuningRecordNode : public runtime::Object {
115 public:
116 /*! \brief The trace tuned. */
117 tir::Trace trace;
118 /*! \brief The workload. */
119 Workload workload{nullptr};
120 /*! \brief The profiling result in seconds. */
121 Optional<Array<FloatImm>> run_secs;
122 /*! \brief The target for tuning. */
123 Optional<Target> target;
124 /*! \brief The argument information. */
125 Optional<Array<ArgInfo>> args_info;
126
127 void VisitAttrs(tvm::AttrVisitor* v) {
128 v->Visit("trace", &trace);
129 v->Visit("workload", &workload);
130 v->Visit("run_secs", &run_secs);
131 v->Visit("target", &target);
132 v->Visit("args_info", &args_info);
133 }
134
135 static constexpr const char* _type_key = "meta_schedule.TuningRecord";
136 TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object);
137
138 /*! \brief Construct the measure candidate given the initial IR module and trace
139 * stored in the tuning record. */
140 MeasureCandidate AsMeasureCandidate() const;
141 /*!
142 * \brief Export the tuning record to a JSON string.
143 * \return An array containing the trace, running secs, serialized target, and
144 * argument information.
145 */
146 ObjectRef AsJSON() const;
147};
148
149/*!
150 * \brief The managed reference of TuningRecordNode.
151 * \sa TuningRecordNode
152 */
153class TuningRecord : public runtime::ObjectRef {
154 public:
155 /*!
156 \brief Constructor of a tuning record.
157 \param trace The trace of the tuning record.
158 \param workload The workload of the tuning record.
159 \param run_secs The running time of the tuning record.
160 \param target The target of the tuning record.
161 \param args_info The argument information of the tuning record.
162 */
163 TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
164 Optional<Array<FloatImm>> run_secs, Optional<Target> target,
165 Optional<Array<ArgInfo>> args_info);
166 /*!
167 * \brief Create a tuning record from a json object.
168 * \param json_obj The json object.
169 * \param workload The workload.
170 * \return The tuning record created.
171 */
172 TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
173 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode);
174};
175
176/* \brief The abstract interface of database. */
177class DatabaseNode : public runtime::Object {
178 public:
179 /*!
180 * \brief Constructor
181 * \param mod_eq_name A string to specify the module equality testing and hashing method.
182 * It must be one of the followings:
183 * - "structural": Use StructuralEqual/Hash
184 * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
185 * equality testing and hashing.
186 * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
187 * given module. The "ignore-ndarray" varint is used for the extracted blocks
188 * or in case no anchor block is found.
189 * For the definition of the anchor block, see tvm/tir/analysis.h.
190 */
191 explicit DatabaseNode(String mod_eq_name = "structural");
192
193 /*! \brief Default destructor */
194 virtual ~DatabaseNode();
195 /*!
196 * \brief Check if the database has the given workload.
197 * \param mod The IRModule to be searched for.
198 * \return Whether the database has the given workload.
199 */
200 virtual bool HasWorkload(const IRModule& mod) = 0;
201 /*!
202 * \brief Look up or add workload to the database if missing.
203 * \param mod The IRModule to be searched for or added.
204 * \return The workload corresponding to the given IRModule.
205 */
206 virtual Workload CommitWorkload(const IRModule& mod) = 0;
207 /*!
208 * \brief Add a tuning record to the database.
209 * \param record The tuning record to be added.
210 */
211 virtual void CommitTuningRecord(const TuningRecord& record) = 0;
212 /*!
213 * \brief Get the top K tuning records of given workload from the database.
214 * \param workload The workload to be searched for.
215 * \param top_k The number of top records to be returned.
216 * \return An array of top K tuning records for the given workload.
217 */
218 virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
219 /*!
220 * \brief Get all tuning records from the database.
221 * \return An Array of all the tuning records in the database.
222 */
223 virtual Array<TuningRecord> GetAllTuningRecords() = 0;
224 /*!
225 * \brief Get the size of the database.
226 * \return The size of the database.
227 */
228 virtual int64_t Size() = 0;
229 /*!
230 * \brief Query the best record of the given workload from the database.
231 * \param mod The IRModule to be searched for.
232 * \param target The target to be searched for.
233 * \param workload_name The name of the workload to be searched for.
234 * \return The best record of the given workload; NullOpt if not found.
235 */
236 virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
237 const String& workload_name);
238 /*!
239 * \brief Query the best schedule of the given workload from the database.
240 * \param mod The IRModule to be searched for.
241 * \param target The target to be searched for.
242 * \param workload_name The name of the workload to be searched for.
243 * \return The schedule in the best schedule of the given workload; NullOpt if not found.
244 */
245 virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
246 const String& workload_name);
247 /*!
248 * \brief Query the best IRModule of the given workload from the database.
249 * \param mod The IRModule to be searched for.
250 * \param target The target to be searched for.
251 * \param workload_name The name of the workload to be searched for.
252 * \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
253 */
254 virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
255 const String& workload_name);
256
257 /*! \brief Return a reference to the owned module equality method instance. */
258 const ModuleEquality& GetModuleEquality() const {
259 ICHECK(mod_eq_);
260 return *mod_eq_;
261 }
262
263 static constexpr const char* _type_key = "meta_schedule.Database";
264 TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
265
266 private:
267 /*! \brief The module equality testing and hashing method */
268 std::unique_ptr<ModuleEquality> mod_eq_;
269};
270
271/*! \brief The database with customized methods on the python-side. */
272class PyDatabaseNode : public DatabaseNode {
273 public:
274 /*!
275 * \brief Constructor
276 * \param mod_eq_name A string to specify the module equality testing and hashing method.
277 * It must be one of the followings:
278 * - "structural": Use StructuralEqual/Hash
279 * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
280 * equality testing and hashing.
281 * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a
282 * given module. The "ignore-ndarray" varint is used for the extracted blocks
283 * or in case no anchor block is found.
284 * For the definition of the anchor block, see tvm/tir/analysis.h.
285 */
286 explicit PyDatabaseNode(String mod_eq_name = "structural");
287
288 /*!
289 * \brief The function type of `HasWorkload` method.
290 * \param mod The IRModule to be searched for.
291 * \return Whether the database has the given workload.
292 */
293 using FHasWorkload = runtime::TypedPackedFunc<bool(const IRModule&)>;
294 /*!
295 * \brief The function type of `CommitWorkload` method.
296 * \param mod The IRModule to be searched for or added.
297 * \return The workload corresponding to the given IRModule.
298 */
299 using FCommitWorkload = runtime::TypedPackedFunc<Workload(const IRModule&)>;
300 /*!
301 * \brief The function type of `CommitTuningRecord` method.
302 * \param record The tuning record to be added.
303 */
304 using FCommitTuningRecord = runtime::TypedPackedFunc<void(const TuningRecord&)>;
305 /*!
306 * \brief The function type of `GetTopK` method.
307 * \param workload The workload to be searched for.
308 * \param top_k The number of top records to be returned.
309 * \return An array of top K tuning records for the given workload.
310 */
311 using FGetTopK = runtime::TypedPackedFunc<Array<TuningRecord>(const Workload&, int)>;
312 /*!
313 * \brief The function type of `GetAllTuningRecords` method.
314 * \return An Array of all the tuning records in the database.
315 */
316 using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
317 /*!
318 * \brief The function type of `QueryTuningRecord` method.
319 * \param mod The IRModule to be searched for.
320 * \param target The target to be searched for.
321 * \param workload_name The name of the workload to be searched for.
322 * \return The best record of the given workload; NullOpt if not found.
323 */
324 using FQueryTuningRecord = runtime::TypedPackedFunc<Optional<TuningRecord>(
325 const IRModule&, const Target&, const String&)>;
326 /*!
327 * \brief The function type of `QuerySchedule` method.
328 * \param mod The IRModule to be searched for.
329 * \param target The target to be searched for.
330 * \param workload_name The name of the workload to be searched for.
331 * \return The schedule in the best schedule of the given workload; NullOpt if not found.
332 */
333 using FQuerySchedule = runtime::TypedPackedFunc<Optional<tir::Schedule>(
334 const IRModule&, const Target&, const String&)>;
335 /*!
336 * \brief The function type of `QueryIRModule` method.
337 * \param mod The IRModule to be searched for.
338 * \param target The target to be searched for.
339 * \param workload_name The name of the workload to be searched for.
340 * \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
341 */
342 using FQueryIRModule =
343 runtime::TypedPackedFunc<Optional<IRModule>(const IRModule&, const Target&, const String&)>;
344 /*!
345 * \brief The function type of `Size` method.
346 * \return The size of the database.
347 */
348 using FSize = runtime::TypedPackedFunc<int64_t()>;
349
350 /*! \brief The packed function to the `HasWorkload` function. */
351 FHasWorkload f_has_workload;
352 /*! \brief The packed function to the `CommitWorkload` function. */
353 FCommitWorkload f_commit_workload;
354 /*! \brief The packed function to the `CommitTuningRecord` function. */
355 FCommitTuningRecord f_commit_tuning_record;
356 /*! \brief The packed function to the `GetTopK` function. */
357 FGetTopK f_get_top_k;
358 /*! \brief The packed function to the `GetAllTuningRecords` function. */
359 FGetAllTuningRecords f_get_all_tuning_records;
360 /*! \brief The packed function to the `QueryTuningRecord` function. */
361 FQueryTuningRecord f_query_tuning_record;
362 /*! \brief The packed function to the `QuerySchedule` function. */
363 FQuerySchedule f_query_schedule;
364 /*! \brief The packed function to the `QueryIRModule` function. */
365 FQueryIRModule f_query_ir_module;
366 /*! \brief The packed function to the `Size` function. */
367 FSize f_size;
368
369 void VisitAttrs(tvm::AttrVisitor* v) {
370 // PackedFuncs are all not visited, because the reflection system doesn't take care of them,
371 // so it cannot be accessible on the python side. If there is such need from the future,
372 // we can then add corresponding accessor methods to help access on python.
373 // `f_has_workload` is not visited
374 // `f_commit_workload` is not visited
375 // `f_commit_tuning_record` is not visited
376 // `f_get_top_k` is not visited
377 // `f_get_all_tuning_records` is not visited
378 // `f_query_tuning_record` is not visited
379 // `f_query_schedule` is not visited
380 // `f_query_ir_module` is not visited
381 // `f_size` is not visited
382 }
383
384 bool HasWorkload(const IRModule& mod) final {
385 ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
386 return f_has_workload(mod);
387 }
388
389 Workload CommitWorkload(const IRModule& mod) final {
390 ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
391 return f_commit_workload(mod);
392 }
393
394 void CommitTuningRecord(const TuningRecord& record) final {
395 ICHECK(f_commit_tuning_record != nullptr)
396 << "PyDatabase's CommitTuningRecord method not implemented!";
397 f_commit_tuning_record(record);
398 }
399
400 Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
401 ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
402 return f_get_top_k(workload, top_k);
403 }
404
405 Array<TuningRecord> GetAllTuningRecords() final {
406 ICHECK(f_get_all_tuning_records != nullptr)
407 << "PyDatabase's GetAllTuningRecords method not implemented!";
408 return f_get_all_tuning_records();
409 }
410
411 Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
412 const String& workload_name) final {
413 if (f_query_tuning_record == nullptr) {
414 return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
415 } else {
416 return f_query_tuning_record(mod, target, workload_name);
417 }
418 }
419
420 Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
421 const String& workload_name) final {
422 if (f_query_schedule == nullptr) {
423 return DatabaseNode::QuerySchedule(mod, target, workload_name);
424 } else {
425 return f_query_schedule(mod, target, workload_name);
426 }
427 }
428
429 Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
430 const String& workload_name) final {
431 if (f_query_ir_module == nullptr) {
432 return DatabaseNode::QueryIRModule(mod, target, workload_name);
433 } else {
434 return f_query_ir_module(mod, target, workload_name);
435 }
436 }
437
438 int64_t Size() final {
439 ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
440 return f_size();
441 }
442
443 static constexpr const char* _type_key = "meta_schedule.PyDatabase";
444 TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);
445};
446
447/*!
448 * \brief Managed reference to DatabaseNode.
449 * \sa DatabaseNode
450 */
451class Database : public runtime::ObjectRef {
452 public:
453 /*!
454 * \brief An in-memory database.
455 * \param mod_eq_name A string to specify the module equality testing and hashing method.
456 */
457 TVM_DLL static Database MemoryDatabase(String mod_eq_name = "structural");
458 /*!
459 * \brief A database for injecting handcrafted schedule functions.
460 * \param schedule_fn The function to do scheduling, which takes a TIR schedule,
461 * and returns a boolean indicating if the schedule is successful.
462 * \param mod_eq_name A string to specify the module equality testing and hashing method.
463 */
464 TVM_DLL static Database ScheduleFnDatabase(
465 runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn, String mod_eq_name = "structural");
466 /*!
467 * \brief Create a default database that uses JSON file for tuning records.
468 * \param path_workload The path to the workload table.
469 * \param path_tuning_record The path to the database table.
470 * \param allow_missing Whether to create new file when the given path is not found.
471 * \param mod_eq_name A string to specify the module equality testing and hashing method.
472 */
473 TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
474 bool allow_missing, String mod_eq_name = "structural");
475 /*!
476 * \brief A database composed of multiple databases, allowing users to guide IR rewriting using
477 * combined knowledge of those databases. To each query, it returns the best record among all the
478 * databases given.
479 * \param databases The list of databases to be combined.
480 * \return The combined database.
481 */
482 TVM_DLL static Database UnionDatabase(Array<Database, void> databases);
483 /*!
484 * \brief A database composed of multiple databases, allowing users to guide IR rewriting using
485 * combined knowledge of those databases. To each query, it returns the record from the first
486 * database that responds to the query.
487 * \param databases The database to be subsetted.
488 * \return The subsetted database.
489 */
490 TVM_DLL static Database OrderedUnionDatabase(Array<Database, void> databases);
491 /*!
492 * \brief Create a database with customized methods on the python-side.
493 * \param f_has_workload The packed function of `HasWorkload`.
494 * \param f_commit_workload The packed function of `CommitWorkload`.
495 * \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
496 * \param f_get_top_k The packed function of `GetTopK`.
497 * \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`.
498 * \param f_query_tuning_record The packed function of `QueryTuningRecord`.
499 * \param f_query_schedule The packed function of `QuerySchedule`.
500 * \param f_query_ir_module The packed function of `QueryIRModule`.
501 * \param f_size The packed function of `Size`.
502 * \param mod_eq_name A string to specify the module equality testing and hashing method.
503 * \return The created database.
504 */
505 TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
506 PyDatabaseNode::FCommitWorkload f_commit_workload,
507 PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
508 PyDatabaseNode::FGetTopK f_get_top_k,
509 PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
510 PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
511 PyDatabaseNode::FQuerySchedule f_query_schedule,
512 PyDatabaseNode::FQueryIRModule f_query_ir_module,
513 PyDatabaseNode::FSize f_size,
514 String mod_eq_name = "structural");
515 /*! \return The current Database in the scope. */
516 static Optional<Database> Current();
517 /*! \brief Entering the scope of the context manager */
518 void EnterWithScope();
519 /*! \brief Exiting the scope of the context manager */
520 void ExitWithScope();
521
522 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
523};
524
525} // namespace meta_schedule
526} // namespace tvm
527
528#endif // TVM_META_SCHEDULE_DATABASE_H_
529