1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24class ScheduleFnDatabaseNode : public DatabaseNode {
25 public:
26 explicit ScheduleFnDatabaseNode(String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {}
27
28 runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn;
29
30 void VisitAttrs(AttrVisitor* v) {
31 // `schedule_fn` is not visited.
32 }
33
34 static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase";
35 TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode);
36
37 public:
38 Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
39 const String& workload_name) final {
40 if (Optional<tir::Schedule> sch = this->QuerySchedule(mod, target, workload_name)) {
41 return TuningRecord(sch.value()->trace().value(),
42 /*workload=*/Workload(mod, 0), //
43 /*run_secs=*/NullOpt, //
44 /*target=*/target, //
45 /*arg_info=*/NullOpt);
46 }
47 return NullOpt;
48 }
49
50 Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
51 const String& workload_name) final {
52 tir::Schedule sch =
53 tir::Schedule::Traced(WithAttr<IRModule>(mod, "task_name", workload_name),
54 /*rand_state=*/-1,
55 /*debug_mode=*/0,
56 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
57 if (!schedule_fn(sch)) {
58 return NullOpt;
59 }
60 return sch;
61 }
62
63 bool HasWorkload(const IRModule& mod) final {
64 LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.HasWorkload";
65 throw;
66 }
67
68 Workload CommitWorkload(const IRModule& mod) final {
69 LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitWorkload";
70 throw;
71 }
72
73 void CommitTuningRecord(const TuningRecord& record) final {
74 LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitTuningRecord";
75 throw;
76 }
77
78 Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
79 LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK";
80 throw;
81 }
82
83 Array<TuningRecord> GetAllTuningRecords() final {
84 LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetAllTuningRecords";
85 throw;
86 }
87
88 int64_t Size() final {
89 LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.size";
90 throw;
91 }
92};
93
94Database Database::ScheduleFnDatabase(runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn,
95 String mod_eq_name) {
96 ObjectPtr<ScheduleFnDatabaseNode> n = make_object<ScheduleFnDatabaseNode>(mod_eq_name);
97 n->schedule_fn = std::move(schedule_fn);
98 return Database(n);
99}
100
101TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode);
102TVM_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase")
103 .set_body_typed(Database::ScheduleFnDatabase);
104
105} // namespace meta_schedule
106} // namespace tvm
107