1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../module_equality.h" |
20 | #include "../utils.h" |
21 | |
22 | namespace tvm { |
23 | namespace meta_schedule { |
24 | |
25 | class MemoryDatabaseNode : public DatabaseNode { |
26 | public: |
27 | explicit MemoryDatabaseNode(String mod_eq_name = "structural" ) : DatabaseNode(mod_eq_name) {} |
28 | |
29 | Array<TuningRecord> records; |
30 | Array<Workload> workloads; |
31 | |
32 | void VisitAttrs(AttrVisitor* v) { |
33 | v->Visit("records" , &records); |
34 | v->Visit("workloads" , &workloads); |
35 | } |
36 | |
37 | static constexpr const char* _type_key = "meta_schedule.MemoryDatabase" ; |
38 | TVM_DECLARE_FINAL_OBJECT_INFO(MemoryDatabaseNode, DatabaseNode); |
39 | |
40 | public: |
41 | bool HasWorkload(const IRModule& mod) final { |
42 | for (const auto& workload : workloads) { |
43 | if (GetModuleEquality().Equal(workload->mod, mod)) { |
44 | return true; |
45 | } |
46 | } |
47 | return false; |
48 | } |
49 | |
50 | Workload CommitWorkload(const IRModule& mod) final { |
51 | for (const auto& workload : workloads) { |
52 | if (GetModuleEquality().Equal(workload->mod, mod)) { |
53 | return workload; |
54 | } |
55 | } |
56 | Workload workload(mod, GetModuleEquality().Hash(mod)); |
57 | workloads.push_back(workload); |
58 | return workload; |
59 | } |
60 | |
61 | void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); } |
62 | |
63 | Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final { |
64 | CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative" ; |
65 | if (top_k == 0) { |
66 | return {}; |
67 | } |
68 | std::vector<TuningRecord> results; |
69 | results.reserve(records.size()); |
70 | for (const TuningRecord& record : records) { |
71 | auto run_secs = record->run_secs; |
72 | if (!run_secs.defined() || run_secs.value().empty() || |
73 | std::all_of(run_secs.value().begin(), run_secs.value().end(), |
74 | // kMaxMeanTime(1e10) is used as a stub for undefined measurement times. |
75 | [](tvm::FloatImm v) { |
76 | return v.defined() && |
77 | v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime; |
78 | })) { |
79 | continue; |
80 | } |
81 | if (record->workload.same_as(workload) || |
82 | WorkloadEqual(GetModuleEquality())(record->workload, workload)) { |
83 | results.emplace_back(record); |
84 | } |
85 | } |
86 | std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); |
87 | if (results.size() > static_cast<size_t>(top_k)) { |
88 | return {results.begin(), results.begin() + top_k}; |
89 | } else { |
90 | if (results.size() < static_cast<size_t>(top_k)) { |
91 | LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " |
92 | "enough valid records in the database for this workload." ; |
93 | } |
94 | return results; |
95 | } |
96 | } |
97 | |
98 | Array<TuningRecord> GetAllTuningRecords() final { return records; } |
99 | |
100 | int64_t Size() final { return records.size(); } |
101 | }; |
102 | |
103 | Database Database::MemoryDatabase(String mod_eq_name) { |
104 | ObjectPtr<MemoryDatabaseNode> n = make_object<MemoryDatabaseNode>(mod_eq_name); |
105 | n->records.clear(); |
106 | n->workloads.clear(); |
107 | return Database(n); |
108 | } |
109 | |
110 | TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode); |
111 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase" ) |
112 | .set_body_typed(Database::MemoryDatabase); |
113 | |
114 | } // namespace meta_schedule |
115 | } // namespace tvm |
116 | |