1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../module_equality.h"
20#include "../utils.h"
21
22namespace tvm {
23namespace meta_schedule {
24
25class MemoryDatabaseNode : public DatabaseNode {
26 public:
27 explicit MemoryDatabaseNode(String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {}
28
29 Array<TuningRecord> records;
30 Array<Workload> workloads;
31
32 void VisitAttrs(AttrVisitor* v) {
33 v->Visit("records", &records);
34 v->Visit("workloads", &workloads);
35 }
36
37 static constexpr const char* _type_key = "meta_schedule.MemoryDatabase";
38 TVM_DECLARE_FINAL_OBJECT_INFO(MemoryDatabaseNode, DatabaseNode);
39
40 public:
41 bool HasWorkload(const IRModule& mod) final {
42 for (const auto& workload : workloads) {
43 if (GetModuleEquality().Equal(workload->mod, mod)) {
44 return true;
45 }
46 }
47 return false;
48 }
49
50 Workload CommitWorkload(const IRModule& mod) final {
51 for (const auto& workload : workloads) {
52 if (GetModuleEquality().Equal(workload->mod, mod)) {
53 return workload;
54 }
55 }
56 Workload workload(mod, GetModuleEquality().Hash(mod));
57 workloads.push_back(workload);
58 return workload;
59 }
60
61 void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); }
62
63 Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
64 CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative";
65 if (top_k == 0) {
66 return {};
67 }
68 std::vector<TuningRecord> results;
69 results.reserve(records.size());
70 for (const TuningRecord& record : records) {
71 auto run_secs = record->run_secs;
72 if (!run_secs.defined() || run_secs.value().empty() ||
73 std::all_of(run_secs.value().begin(), run_secs.value().end(),
74 // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
75 [](tvm::FloatImm v) {
76 return v.defined() &&
77 v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
78 })) {
79 continue;
80 }
81 if (record->workload.same_as(workload) ||
82 WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
83 results.emplace_back(record);
84 }
85 }
86 std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs());
87 if (results.size() > static_cast<size_t>(top_k)) {
88 return {results.begin(), results.begin() + top_k};
89 } else {
90 if (results.size() < static_cast<size_t>(top_k)) {
91 LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
92 "enough valid records in the database for this workload.";
93 }
94 return results;
95 }
96 }
97
98 Array<TuningRecord> GetAllTuningRecords() final { return records; }
99
100 int64_t Size() final { return records.size(); }
101};
102
103Database Database::MemoryDatabase(String mod_eq_name) {
104 ObjectPtr<MemoryDatabaseNode> n = make_object<MemoryDatabaseNode>(mod_eq_name);
105 n->records.clear();
106 n->workloads.clear();
107 return Database(n);
108}
109
110TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode);
111TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase")
112 .set_body_typed(Database::MemoryDatabase);
113
114} // namespace meta_schedule
115} // namespace tvm
116