1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../module_equality.h"
20#include "../utils.h"
21
22namespace tvm {
23namespace meta_schedule {
24
25/******** Workload ********/
26
27Workload::Workload(IRModule mod) {
28 ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>();
29 n->mod = mod;
30 n->shash = ModuleEquality::Create("structural")->Hash(mod);
31 data_ = std::move(n);
32}
33
34Workload::Workload(IRModule mod, Workload::THashCode shash) {
35 ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>();
36 n->mod = mod;
37 n->shash = shash;
38 data_ = std::move(n);
39}
40
41ObjectRef WorkloadNode::AsJSON() const {
42 // Convert `this->mod` to JSON
43 std::string json_mod = tvm::SaveJSON(this->mod);
44 // Dump the JSON string to base64
45 std::string b64_mod = Base64Encode(json_mod);
46 // Output
47 return Array<ObjectRef>{SHash2Str(this->shash), String(b64_mod)};
48}
49
50Workload Workload::FromJSON(const ObjectRef& json_obj) {
51 IRModule mod{nullptr};
52 THashCode shash = 0;
53 try {
54 const ArrayNode* json_array = json_obj.as<ArrayNode>();
55 CHECK(json_array && json_array->size() == 2);
56 // Load json[0] => shash
57 String str_shash = Downcast<String>(json_array->at(0));
58 // Load json[1] => mod
59 {
60 String b64_mod = Downcast<String>(json_array->at(1));
61 std::string json_mod = Base64Decode(b64_mod);
62 mod = Downcast<IRModule>(LoadJSON(json_mod));
63 std::stringstream(str_shash) >> shash;
64 }
65 } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
66 LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
67 << "\nThe error is: " << e.what();
68 }
69 return Workload(mod, shash);
70}
71
72/******** TuningRecord ********/
73
74TuningRecord::TuningRecord(tir::Trace trace, Workload workload, Optional<Array<FloatImm>> run_secs,
75 Optional<Target> target, Optional<Array<ArgInfo>> args_info) {
76 ObjectPtr<TuningRecordNode> n = make_object<TuningRecordNode>();
77 n->trace = trace;
78 n->workload = workload;
79 n->run_secs = run_secs;
80 n->target = target;
81 n->args_info = args_info;
82 this->data_ = n;
83}
84
85bool WorkloadEqual::operator()(const Workload& a, const Workload& b) const {
86 return a->shash == b->shash && mod_eq_.Equal(a->mod, b->mod);
87}
88
89MeasureCandidate TuningRecordNode::AsMeasureCandidate() const {
90 tir::Schedule sch =
91 tir::Schedule::Traced(workload->mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail);
92 trace->ApplyToSchedule(sch, false, nullptr);
93 return MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true));
94}
95
96ObjectRef TuningRecordNode::AsJSON() const {
97 Optional<Array<ObjectRef>> json_args_info{nullptr};
98 Optional<ObjectRef> json_target{nullptr};
99 if (args_info.defined()) {
100 Array<ObjectRef> info;
101 info.reserve(args_info.value().size());
102 for (const ArgInfo& arg_info : args_info.value()) {
103 info.push_back(arg_info->AsJSON());
104 }
105 json_args_info = info;
106 }
107 if (target.defined()) {
108 json_target = target.value()->Export();
109 }
110 return Array<ObjectRef>{trace->AsJSON(false), //
111 run_secs, //
112 json_target, //
113 json_args_info};
114}
115
116TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) {
117 tir::Trace trace{nullptr};
118 Optional<Array<FloatImm>> run_secs{nullptr};
119 Optional<Target> target{nullptr};
120 Optional<Array<ArgInfo>> args_info{nullptr};
121 try {
122 const ArrayNode* json_array = json_obj.as<ArrayNode>();
123 CHECK(json_array && json_array->size() == 4);
124 // Load json[1] => run_secs
125 if (json_array->at(1).defined()) {
126 run_secs = AsFloatArray(json_array->at(1));
127 }
128 // Load json[2] => target
129 if (json_array->at(2).defined()) {
130 target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2)));
131 }
132 // Load json[3] => args_info
133 if (json_array->at(3).defined()) {
134 const ArrayNode* json_args_info = json_array->at(3).as<ArrayNode>();
135 Array<ArgInfo> info;
136 info.reserve(json_args_info->size());
137 for (const ObjectRef& json_arg_info : *json_args_info) {
138 info.push_back(ArgInfo::FromJSON(json_arg_info));
139 }
140 args_info = info;
141 }
142 // Load json[0] => trace
143 {
144 const ObjectRef& json_trace = json_array->at(0);
145 tir::Schedule sch =
146 tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0,
147 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
148 tir::Trace::ApplyJSONToSchedule(json_trace, sch);
149 trace = sch->trace().value();
150 }
151 } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
152 LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
153 << "\nThe error is: " << e.what();
154 }
155 return TuningRecord(trace, workload, run_secs, target, args_info);
156}
157
158/******** Database ********/
159DatabaseNode::DatabaseNode(String mod_eq_name) { mod_eq_ = ModuleEquality::Create(mod_eq_name); }
160DatabaseNode::~DatabaseNode() = default;
161
162Optional<TuningRecord> DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target,
163 const String& workload_name) {
164 if (!this->HasWorkload(mod)) {
165 return NullOpt;
166 }
167 Array<TuningRecord> records = this->GetTopK(this->CommitWorkload(mod), 1);
168 if (records.empty()) {
169 return NullOpt;
170 }
171 ICHECK_EQ(records.size(), 1);
172 return records[0];
173}
174
175Optional<tir::Schedule> DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target,
176 const String& workload_name) {
177 if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target, workload_name)) {
178 TuningRecord record = opt_record.value();
179 tir::Schedule sch =
180 tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
181 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
182 record->trace->ApplyToSchedule(sch, false);
183 return sch;
184 } else {
185 return NullOpt;
186 }
187}
188
189Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target,
190 const String& workload_name) {
191 if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target, workload_name)) {
192 return opt_sch.value()->mod();
193 } else {
194 return NullOpt;
195 }
196}
197
198std::vector<Database>* ThreadLocalDatabases() {
199 static thread_local std::vector<Database> tls;
200 return &tls;
201}
202
203void Database::EnterWithScope() { ThreadLocalDatabases()->push_back(*this); }
204
205void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); }
206
207Optional<Database> Database::Current() {
208 std::vector<Database>* tls = ThreadLocalDatabases();
209 if (tls->empty()) {
210 return NullOpt;
211 } else {
212 return tls->back();
213 }
214}
215
216/******** PyDatabase ********/
217PyDatabaseNode::PyDatabaseNode(String mod_eq_name) : DatabaseNode(mod_eq_name) {}
218
219Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
220 PyDatabaseNode::FCommitWorkload f_commit_workload,
221 PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
222 PyDatabaseNode::FGetTopK f_get_top_k,
223 PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
224 PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
225 PyDatabaseNode::FQuerySchedule f_query_schedule,
226 PyDatabaseNode::FQueryIRModule f_query_ir_module,
227 PyDatabaseNode::FSize f_size, String mod_eq_name) {
228 ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>(mod_eq_name);
229 n->f_has_workload = f_has_workload;
230 n->f_commit_workload = f_commit_workload;
231 n->f_commit_tuning_record = f_commit_tuning_record;
232 n->f_get_top_k = f_get_top_k;
233 n->f_get_all_tuning_records = f_get_all_tuning_records;
234 n->f_query_tuning_record = f_query_tuning_record;
235 n->f_query_schedule = f_query_schedule;
236 n->f_query_ir_module = f_query_ir_module;
237 n->f_size = f_size;
238 return Database(n);
239}
240
241/******** FFI ********/
242
243TVM_REGISTER_NODE_TYPE(WorkloadNode);
244TVM_REGISTER_NODE_TYPE(TuningRecordNode);
245TVM_REGISTER_OBJECT_TYPE(DatabaseNode);
246TVM_REGISTER_NODE_TYPE(PyDatabaseNode);
247TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) {
248 return Workload(mod);
249});
250TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON")
251 .set_body_method<Workload>(&WorkloadNode::AsJSON);
252TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON);
253TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
254 .set_body_typed([](tir::Trace trace, Workload workload, Optional<Array<FloatImm>> run_secs,
255 Optional<Target> target, Optional<Array<ArgInfo>> args_info) {
256 return TuningRecord(trace, workload, run_secs, target, args_info);
257 });
258TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate")
259 .set_body_method<TuningRecord>(&TuningRecordNode::AsMeasureCandidate);
260TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
261 .set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
262TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
263TVM_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope")
264 .set_body_method(&Database::EnterWithScope);
265TVM_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope")
266 .set_body_method(&Database::ExitWithScope);
267TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current);
268TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload")
269 .set_body_method<Database>(&DatabaseNode::HasWorkload);
270TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload")
271 .set_body_method<Database>(&DatabaseNode::CommitWorkload);
272TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord")
273 .set_body_method<Database>(&DatabaseNode::CommitTuningRecord);
274TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK")
275 .set_body_method<Database>(&DatabaseNode::GetTopK);
276TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords")
277 .set_body_method<Database>(&DatabaseNode::GetAllTuningRecords);
278TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method<Database>(&DatabaseNode::Size);
279TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord")
280 .set_body_method<Database>(&DatabaseNode::QueryTuningRecord);
281TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule")
282 .set_body_method<Database>(&DatabaseNode::QuerySchedule);
283TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule")
284 .set_body_method<Database>(&DatabaseNode::QueryIRModule);
285TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase);
286
287} // namespace meta_schedule
288} // namespace tvm
289