1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <set> |
20 | #include <thread> |
21 | #include <unordered_map> |
22 | |
23 | #include "../module_equality.h" |
24 | #include "../utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace meta_schedule { |
28 | |
29 | /*! |
30 | * \brief Read lines from a json file. |
31 | * \param path The path to the json file. |
32 | * \param num_lines The number of threads used to concurrently parse the lines. |
33 | * \param allow_missing Whether to create new file when the given path is not found. |
34 | * \return An array containing lines read from the json file. |
35 | */ |
36 | std::vector<ObjectRef> JSONFileReadLines(const String& path, int num_threads, bool allow_missing) { |
37 | std::ifstream is(path); |
38 | if (is.good()) { |
39 | std::vector<String> json_strs; |
40 | for (std::string str; std::getline(is, str);) { |
41 | json_strs.push_back(str); |
42 | } |
43 | int n = json_strs.size(); |
44 | std::vector<ObjectRef> json_objs; |
45 | json_objs.resize(n); |
46 | support::parallel_for_dynamic(0, n, num_threads, [&](int thread_id, int task_id) { |
47 | json_objs[task_id] = JSONLoads(json_strs[task_id]); |
48 | }); |
49 | return json_objs; |
50 | } |
51 | CHECK(allow_missing) << "ValueError: File doesn't exist: " << path; |
52 | std::ofstream os(path); |
53 | CHECK(os.good()) << "ValueError: Cannot create new file: " << path; |
54 | return {}; |
55 | } |
56 | |
57 | /*! |
58 | * \brief Append a line to a json file. |
59 | * \param path The path to the json file. |
60 | * \param line The line to append. |
61 | */ |
62 | void JSONFileAppendLine(const String& path, const std::string& line) { |
63 | std::ofstream os(path, std::ofstream::app); |
64 | CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; |
65 | os << line << std::endl; |
66 | } |
67 | |
68 | /*! \brief The default database implementation, which mimics two database tables with two files. */ |
69 | class JSONDatabaseNode : public DatabaseNode { |
70 | public: |
71 | explicit JSONDatabaseNode(String mod_eq_name = "structural" ) |
72 | : DatabaseNode(mod_eq_name), |
73 | workloads2idx_(/*bucket_count*/ 0, WorkloadHash(), WorkloadEqual(GetModuleEquality())) {} |
74 | |
75 | /*! \brief The path to the workload table */ |
76 | String path_workload; |
77 | /*! \brief The path to the tuning record table */ |
78 | String path_tuning_record; |
79 | /*! \brief All the workloads in the database */ |
80 | std::unordered_map<Workload, int, WorkloadHash, WorkloadEqual> workloads2idx_; |
81 | /*! \brief All the tuning records in the database */ |
82 | std::multiset<TuningRecord, SortTuningRecordByMeanRunSecs> tuning_records_; |
83 | |
84 | void VisitAttrs(tvm::AttrVisitor* v) { |
85 | v->Visit("path_workload" , &path_workload); |
86 | v->Visit("path_tuning_record" , &path_tuning_record); |
87 | // `workloads2idx_` is not visited |
88 | // `tuning_records_` is not visited |
89 | } |
90 | |
91 | static constexpr const char* _type_key = "meta_schedule.JSONDatabase" ; |
92 | TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); |
93 | |
94 | public: |
95 | bool HasWorkload(const IRModule& mod) { |
96 | return workloads2idx_.find(Workload(mod, GetModuleEquality().Hash(mod))) != |
97 | workloads2idx_.end(); |
98 | } |
99 | |
100 | Workload CommitWorkload(const IRModule& mod) { |
101 | // Try to insert `mod` into `workloads_` |
102 | auto [it, inserted] = |
103 | this->workloads2idx_.emplace(Workload(mod, GetModuleEquality().Hash(mod)), -1); |
104 | Workload workload = it->first; |
105 | // If `mod` is new in `workloads2idx_`, append it to the workload file |
106 | if (inserted) { |
107 | it->second = static_cast<int>(this->workloads2idx_.size()) - 1; |
108 | JSONFileAppendLine(this->path_workload, JSONDumps(workload->AsJSON())); |
109 | } |
110 | return it->first; |
111 | } |
112 | |
113 | void CommitTuningRecord(const TuningRecord& record) { |
114 | this->tuning_records_.insert(record); |
115 | JSONFileAppendLine(this->path_tuning_record, |
116 | JSONDumps(Array<ObjectRef>{ |
117 | /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)), |
118 | /*tuning_record=*/record->AsJSON() // |
119 | })); |
120 | } |
121 | |
122 | Array<TuningRecord> GetTopK(const Workload& workload, int top_k) { |
123 | CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative" ; |
124 | if (top_k == 0) { |
125 | return {}; |
126 | } |
127 | Array<TuningRecord> results; |
128 | results.reserve(top_k); |
129 | for (const TuningRecord& record : this->tuning_records_) { |
130 | auto run_secs = record->run_secs; |
131 | if (!run_secs.defined() || run_secs.value().empty() || |
132 | std::all_of(run_secs.value().begin(), run_secs.value().end(), |
133 | // kMaxMeanTime(1e10) is used as a stub for undefined measurement times. |
134 | [](tvm::FloatImm v) { |
135 | return v.defined() && |
136 | v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime; |
137 | })) { |
138 | continue; |
139 | } |
140 | if (record->workload.same_as(workload) || |
141 | WorkloadEqual(GetModuleEquality())(record->workload, workload)) { |
142 | results.push_back(record); |
143 | if (results.size() == static_cast<size_t>(top_k)) { |
144 | break; |
145 | } |
146 | } |
147 | } |
148 | if (results.size() < static_cast<size_t>(top_k)) { |
149 | LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " |
150 | "enough valid records in the database for this workload." ; |
151 | } |
152 | return results; |
153 | } |
154 | |
155 | Array<TuningRecord> GetAllTuningRecords() { |
156 | Array<TuningRecord> results; |
157 | results.reserve(Size()); |
158 | for (const TuningRecord& record : this->tuning_records_) { |
159 | results.push_back(record); |
160 | } |
161 | return results; |
162 | } |
163 | |
164 | int64_t Size() { return tuning_records_.size(); } |
165 | }; |
166 | |
167 | Database Database::JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing, |
168 | String mod_eq_name) { |
169 | int num_threads = std::thread::hardware_concurrency(); |
170 | ObjectPtr<JSONDatabaseNode> n = make_object<JSONDatabaseNode>(mod_eq_name); |
171 | // Load `n->workloads2idx_` from `path_workload` |
172 | std::vector<Workload> workloads; |
173 | { |
174 | std::vector<ObjectRef> json_objs = JSONFileReadLines(path_workload, num_threads, allow_missing); |
175 | int n_objs = json_objs.size(); |
176 | n->workloads2idx_.reserve(n_objs); |
177 | workloads.reserve(n_objs); |
178 | for (int i = 0; i < n_objs; ++i) { |
179 | Workload workload = Workload::FromJSON(json_objs[i]); |
180 | auto recalc_hash = n->GetModuleEquality().Hash(workload->mod); |
181 | CHECK_EQ(recalc_hash, workload->shash) |
182 | << "ValueError: Module hash changed. Given: " << workload->shash |
183 | << "; Recalculated: " << recalc_hash; |
184 | n->workloads2idx_.emplace(workload, i); |
185 | workloads.push_back(workload); |
186 | } |
187 | } |
188 | // Load `n->tuning_records_` from `path_tuning_record` |
189 | { |
190 | std::vector<ObjectRef> json_objs = |
191 | JSONFileReadLines(path_tuning_record, num_threads, allow_missing); |
192 | std::vector<TuningRecord> records; |
193 | records.resize(json_objs.size(), TuningRecord{nullptr}); |
194 | support::parallel_for_dynamic( |
195 | 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { |
196 | const ObjectRef& json_obj = json_objs[task_id]; |
197 | Workload workload{nullptr}; |
198 | try { |
199 | const ArrayNode* arr = json_obj.as<ArrayNode>(); |
200 | ICHECK_EQ(arr->size(), 2); |
201 | workload = workloads[Downcast<Integer>(arr->at(0)).IntValue()]; |
202 | records[task_id] = TuningRecord::FromJSON(arr->at(1), workload); |
203 | } catch (std::runtime_error& e) { |
204 | LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) |
205 | << " of file " << path_tuning_record << ". The workload is:\n" |
206 | << (workload.defined() ? workload->mod->Script() : "(null)" ) |
207 | << "\nThe JSONObject of TuningRecord is:\n" |
208 | << json_obj << "\nThe error message is:\n" |
209 | << e.what(); |
210 | } |
211 | }); |
212 | for (const TuningRecord& record : records) { |
213 | n->tuning_records_.insert(record); |
214 | } |
215 | } |
216 | n->path_workload = path_workload; |
217 | n->path_tuning_record = path_tuning_record; |
218 | return Database(n); |
219 | } |
220 | |
221 | TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); |
222 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase" ).set_body_typed(Database::JSONDatabase); |
223 | |
224 | } // namespace meta_schedule |
225 | } // namespace tvm |
226 | |