1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <set>
20#include <thread>
21#include <unordered_map>
22
23#include "../module_equality.h"
24#include "../utils.h"
25
26namespace tvm {
27namespace meta_schedule {
28
29/*!
30 * \brief Read lines from a json file.
31 * \param path The path to the json file.
32 * \param num_lines The number of threads used to concurrently parse the lines.
33 * \param allow_missing Whether to create new file when the given path is not found.
34 * \return An array containing lines read from the json file.
35 */
36std::vector<ObjectRef> JSONFileReadLines(const String& path, int num_threads, bool allow_missing) {
37 std::ifstream is(path);
38 if (is.good()) {
39 std::vector<String> json_strs;
40 for (std::string str; std::getline(is, str);) {
41 json_strs.push_back(str);
42 }
43 int n = json_strs.size();
44 std::vector<ObjectRef> json_objs;
45 json_objs.resize(n);
46 support::parallel_for_dynamic(0, n, num_threads, [&](int thread_id, int task_id) {
47 json_objs[task_id] = JSONLoads(json_strs[task_id]);
48 });
49 return json_objs;
50 }
51 CHECK(allow_missing) << "ValueError: File doesn't exist: " << path;
52 std::ofstream os(path);
53 CHECK(os.good()) << "ValueError: Cannot create new file: " << path;
54 return {};
55}
56
57/*!
58 * \brief Append a line to a json file.
59 * \param path The path to the json file.
60 * \param line The line to append.
61 */
62void JSONFileAppendLine(const String& path, const std::string& line) {
63 std::ofstream os(path, std::ofstream::app);
64 CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path;
65 os << line << std::endl;
66}
67
68/*! \brief The default database implementation, which mimics two database tables with two files. */
69class JSONDatabaseNode : public DatabaseNode {
70 public:
71 explicit JSONDatabaseNode(String mod_eq_name = "structural")
72 : DatabaseNode(mod_eq_name),
73 workloads2idx_(/*bucket_count*/ 0, WorkloadHash(), WorkloadEqual(GetModuleEquality())) {}
74
75 /*! \brief The path to the workload table */
76 String path_workload;
77 /*! \brief The path to the tuning record table */
78 String path_tuning_record;
79 /*! \brief All the workloads in the database */
80 std::unordered_map<Workload, int, WorkloadHash, WorkloadEqual> workloads2idx_;
81 /*! \brief All the tuning records in the database */
82 std::multiset<TuningRecord, SortTuningRecordByMeanRunSecs> tuning_records_;
83
84 void VisitAttrs(tvm::AttrVisitor* v) {
85 v->Visit("path_workload", &path_workload);
86 v->Visit("path_tuning_record", &path_tuning_record);
87 // `workloads2idx_` is not visited
88 // `tuning_records_` is not visited
89 }
90
91 static constexpr const char* _type_key = "meta_schedule.JSONDatabase";
92 TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode);
93
94 public:
95 bool HasWorkload(const IRModule& mod) {
96 return workloads2idx_.find(Workload(mod, GetModuleEquality().Hash(mod))) !=
97 workloads2idx_.end();
98 }
99
100 Workload CommitWorkload(const IRModule& mod) {
101 // Try to insert `mod` into `workloads_`
102 auto [it, inserted] =
103 this->workloads2idx_.emplace(Workload(mod, GetModuleEquality().Hash(mod)), -1);
104 Workload workload = it->first;
105 // If `mod` is new in `workloads2idx_`, append it to the workload file
106 if (inserted) {
107 it->second = static_cast<int>(this->workloads2idx_.size()) - 1;
108 JSONFileAppendLine(this->path_workload, JSONDumps(workload->AsJSON()));
109 }
110 return it->first;
111 }
112
113 void CommitTuningRecord(const TuningRecord& record) {
114 this->tuning_records_.insert(record);
115 JSONFileAppendLine(this->path_tuning_record,
116 JSONDumps(Array<ObjectRef>{
117 /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)),
118 /*tuning_record=*/record->AsJSON() //
119 }));
120 }
121
122 Array<TuningRecord> GetTopK(const Workload& workload, int top_k) {
123 CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative";
124 if (top_k == 0) {
125 return {};
126 }
127 Array<TuningRecord> results;
128 results.reserve(top_k);
129 for (const TuningRecord& record : this->tuning_records_) {
130 auto run_secs = record->run_secs;
131 if (!run_secs.defined() || run_secs.value().empty() ||
132 std::all_of(run_secs.value().begin(), run_secs.value().end(),
133 // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
134 [](tvm::FloatImm v) {
135 return v.defined() &&
136 v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
137 })) {
138 continue;
139 }
140 if (record->workload.same_as(workload) ||
141 WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
142 results.push_back(record);
143 if (results.size() == static_cast<size_t>(top_k)) {
144 break;
145 }
146 }
147 }
148 if (results.size() < static_cast<size_t>(top_k)) {
149 LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
150 "enough valid records in the database for this workload.";
151 }
152 return results;
153 }
154
155 Array<TuningRecord> GetAllTuningRecords() {
156 Array<TuningRecord> results;
157 results.reserve(Size());
158 for (const TuningRecord& record : this->tuning_records_) {
159 results.push_back(record);
160 }
161 return results;
162 }
163
164 int64_t Size() { return tuning_records_.size(); }
165};
166
167Database Database::JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing,
168 String mod_eq_name) {
169 int num_threads = std::thread::hardware_concurrency();
170 ObjectPtr<JSONDatabaseNode> n = make_object<JSONDatabaseNode>(mod_eq_name);
171 // Load `n->workloads2idx_` from `path_workload`
172 std::vector<Workload> workloads;
173 {
174 std::vector<ObjectRef> json_objs = JSONFileReadLines(path_workload, num_threads, allow_missing);
175 int n_objs = json_objs.size();
176 n->workloads2idx_.reserve(n_objs);
177 workloads.reserve(n_objs);
178 for (int i = 0; i < n_objs; ++i) {
179 Workload workload = Workload::FromJSON(json_objs[i]);
180 auto recalc_hash = n->GetModuleEquality().Hash(workload->mod);
181 CHECK_EQ(recalc_hash, workload->shash)
182 << "ValueError: Module hash changed. Given: " << workload->shash
183 << "; Recalculated: " << recalc_hash;
184 n->workloads2idx_.emplace(workload, i);
185 workloads.push_back(workload);
186 }
187 }
188 // Load `n->tuning_records_` from `path_tuning_record`
189 {
190 std::vector<ObjectRef> json_objs =
191 JSONFileReadLines(path_tuning_record, num_threads, allow_missing);
192 std::vector<TuningRecord> records;
193 records.resize(json_objs.size(), TuningRecord{nullptr});
194 support::parallel_for_dynamic(
195 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) {
196 const ObjectRef& json_obj = json_objs[task_id];
197 Workload workload{nullptr};
198 try {
199 const ArrayNode* arr = json_obj.as<ArrayNode>();
200 ICHECK_EQ(arr->size(), 2);
201 workload = workloads[Downcast<Integer>(arr->at(0)).IntValue()];
202 records[task_id] = TuningRecord::FromJSON(arr->at(1), workload);
203 } catch (std::runtime_error& e) {
204 LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1)
205 << " of file " << path_tuning_record << ". The workload is:\n"
206 << (workload.defined() ? workload->mod->Script() : "(null)")
207 << "\nThe JSONObject of TuningRecord is:\n"
208 << json_obj << "\nThe error message is:\n"
209 << e.what();
210 }
211 });
212 for (const TuningRecord& record : records) {
213 n->tuning_records_.insert(record);
214 }
215 }
216 n->path_workload = path_workload;
217 n->path_tuning_record = path_tuning_record;
218 return Database(n);
219}
220
221TVM_REGISTER_NODE_TYPE(JSONDatabaseNode);
222TVM_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase);
223
224} // namespace meta_schedule
225} // namespace tvm
226