1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | class ScheduleFnDatabaseNode : public DatabaseNode { |
25 | public: |
26 | explicit ScheduleFnDatabaseNode(String mod_eq_name = "structural" ) : DatabaseNode(mod_eq_name) {} |
27 | |
28 | runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn; |
29 | |
30 | void VisitAttrs(AttrVisitor* v) { |
31 | // `schedule_fn` is not visited. |
32 | } |
33 | |
34 | static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase" ; |
35 | TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode); |
36 | |
37 | public: |
38 | Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target, |
39 | const String& workload_name) final { |
40 | if (Optional<tir::Schedule> sch = this->QuerySchedule(mod, target, workload_name)) { |
41 | return TuningRecord(sch.value()->trace().value(), |
42 | /*workload=*/Workload(mod, 0), // |
43 | /*run_secs=*/NullOpt, // |
44 | /*target=*/target, // |
45 | /*arg_info=*/NullOpt); |
46 | } |
47 | return NullOpt; |
48 | } |
49 | |
50 | Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target, |
51 | const String& workload_name) final { |
52 | tir::Schedule sch = |
53 | tir::Schedule::Traced(WithAttr<IRModule>(mod, "task_name" , workload_name), |
54 | /*rand_state=*/-1, |
55 | /*debug_mode=*/0, |
56 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); |
57 | if (!schedule_fn(sch)) { |
58 | return NullOpt; |
59 | } |
60 | return sch; |
61 | } |
62 | |
63 | bool HasWorkload(const IRModule& mod) final { |
64 | LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.HasWorkload" ; |
65 | throw; |
66 | } |
67 | |
68 | Workload CommitWorkload(const IRModule& mod) final { |
69 | LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitWorkload" ; |
70 | throw; |
71 | } |
72 | |
73 | void CommitTuningRecord(const TuningRecord& record) final { |
74 | LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitTuningRecord" ; |
75 | throw; |
76 | } |
77 | |
78 | Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final { |
79 | LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK" ; |
80 | throw; |
81 | } |
82 | |
83 | Array<TuningRecord> GetAllTuningRecords() final { |
84 | LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetAllTuningRecords" ; |
85 | throw; |
86 | } |
87 | |
88 | int64_t Size() final { |
89 | LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.size" ; |
90 | throw; |
91 | } |
92 | }; |
93 | |
94 | Database Database::ScheduleFnDatabase(runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn, |
95 | String mod_eq_name) { |
96 | ObjectPtr<ScheduleFnDatabaseNode> n = make_object<ScheduleFnDatabaseNode>(mod_eq_name); |
97 | n->schedule_fn = std::move(schedule_fn); |
98 | return Database(n); |
99 | } |
100 | |
101 | TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode); |
102 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase" ) |
103 | .set_body_typed(Database::ScheduleFnDatabase); |
104 | |
105 | } // namespace meta_schedule |
106 | } // namespace tvm |
107 | |