1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../module_equality.h" |
20 | #include "../utils.h" |
21 | |
22 | namespace tvm { |
23 | namespace meta_schedule { |
24 | |
25 | /******** Workload ********/ |
26 | |
27 | Workload::Workload(IRModule mod) { |
28 | ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>(); |
29 | n->mod = mod; |
30 | n->shash = ModuleEquality::Create("structural" )->Hash(mod); |
31 | data_ = std::move(n); |
32 | } |
33 | |
34 | Workload::Workload(IRModule mod, Workload::THashCode shash) { |
35 | ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>(); |
36 | n->mod = mod; |
37 | n->shash = shash; |
38 | data_ = std::move(n); |
39 | } |
40 | |
41 | ObjectRef WorkloadNode::AsJSON() const { |
42 | // Convert `this->mod` to JSON |
43 | std::string json_mod = tvm::SaveJSON(this->mod); |
44 | // Dump the JSON string to base64 |
45 | std::string b64_mod = Base64Encode(json_mod); |
46 | // Output |
47 | return Array<ObjectRef>{SHash2Str(this->shash), String(b64_mod)}; |
48 | } |
49 | |
50 | Workload Workload::FromJSON(const ObjectRef& json_obj) { |
51 | IRModule mod{nullptr}; |
52 | THashCode shash = 0; |
53 | try { |
54 | const ArrayNode* json_array = json_obj.as<ArrayNode>(); |
55 | CHECK(json_array && json_array->size() == 2); |
56 | // Load json[0] => shash |
57 | String str_shash = Downcast<String>(json_array->at(0)); |
58 | // Load json[1] => mod |
59 | { |
60 | String b64_mod = Downcast<String>(json_array->at(1)); |
61 | std::string json_mod = Base64Decode(b64_mod); |
62 | mod = Downcast<IRModule>(LoadJSON(json_mod)); |
63 | std::stringstream(str_shash) >> shash; |
64 | } |
65 | } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error |
66 | LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj |
67 | << "\nThe error is: " << e.what(); |
68 | } |
69 | return Workload(mod, shash); |
70 | } |
71 | |
72 | /******** TuningRecord ********/ |
73 | |
74 | TuningRecord::TuningRecord(tir::Trace trace, Workload workload, Optional<Array<FloatImm>> run_secs, |
75 | Optional<Target> target, Optional<Array<ArgInfo>> args_info) { |
76 | ObjectPtr<TuningRecordNode> n = make_object<TuningRecordNode>(); |
77 | n->trace = trace; |
78 | n->workload = workload; |
79 | n->run_secs = run_secs; |
80 | n->target = target; |
81 | n->args_info = args_info; |
82 | this->data_ = n; |
83 | } |
84 | |
85 | bool WorkloadEqual::operator()(const Workload& a, const Workload& b) const { |
86 | return a->shash == b->shash && mod_eq_.Equal(a->mod, b->mod); |
87 | } |
88 | |
89 | MeasureCandidate TuningRecordNode::AsMeasureCandidate() const { |
90 | tir::Schedule sch = |
91 | tir::Schedule::Traced(workload->mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail); |
92 | trace->ApplyToSchedule(sch, false, nullptr); |
93 | return MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true)); |
94 | } |
95 | |
96 | ObjectRef TuningRecordNode::AsJSON() const { |
97 | Optional<Array<ObjectRef>> json_args_info{nullptr}; |
98 | Optional<ObjectRef> json_target{nullptr}; |
99 | if (args_info.defined()) { |
100 | Array<ObjectRef> info; |
101 | info.reserve(args_info.value().size()); |
102 | for (const ArgInfo& arg_info : args_info.value()) { |
103 | info.push_back(arg_info->AsJSON()); |
104 | } |
105 | json_args_info = info; |
106 | } |
107 | if (target.defined()) { |
108 | json_target = target.value()->Export(); |
109 | } |
110 | return Array<ObjectRef>{trace->AsJSON(false), // |
111 | run_secs, // |
112 | json_target, // |
113 | json_args_info}; |
114 | } |
115 | |
116 | TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { |
117 | tir::Trace trace{nullptr}; |
118 | Optional<Array<FloatImm>> run_secs{nullptr}; |
119 | Optional<Target> target{nullptr}; |
120 | Optional<Array<ArgInfo>> args_info{nullptr}; |
121 | try { |
122 | const ArrayNode* json_array = json_obj.as<ArrayNode>(); |
123 | CHECK(json_array && json_array->size() == 4); |
124 | // Load json[1] => run_secs |
125 | if (json_array->at(1).defined()) { |
126 | run_secs = AsFloatArray(json_array->at(1)); |
127 | } |
128 | // Load json[2] => target |
129 | if (json_array->at(2).defined()) { |
130 | target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2))); |
131 | } |
132 | // Load json[3] => args_info |
133 | if (json_array->at(3).defined()) { |
134 | const ArrayNode* json_args_info = json_array->at(3).as<ArrayNode>(); |
135 | Array<ArgInfo> info; |
136 | info.reserve(json_args_info->size()); |
137 | for (const ObjectRef& json_arg_info : *json_args_info) { |
138 | info.push_back(ArgInfo::FromJSON(json_arg_info)); |
139 | } |
140 | args_info = info; |
141 | } |
142 | // Load json[0] => trace |
143 | { |
144 | const ObjectRef& json_trace = json_array->at(0); |
145 | tir::Schedule sch = |
146 | tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0, |
147 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); |
148 | tir::Trace::ApplyJSONToSchedule(json_trace, sch); |
149 | trace = sch->trace().value(); |
150 | } |
151 | } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error |
152 | LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj |
153 | << "\nThe error is: " << e.what(); |
154 | } |
155 | return TuningRecord(trace, workload, run_secs, target, args_info); |
156 | } |
157 | |
158 | /******** Database ********/ |
159 | DatabaseNode::DatabaseNode(String mod_eq_name) { mod_eq_ = ModuleEquality::Create(mod_eq_name); } |
160 | DatabaseNode::~DatabaseNode() = default; |
161 | |
162 | Optional<TuningRecord> DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target, |
163 | const String& workload_name) { |
164 | if (!this->HasWorkload(mod)) { |
165 | return NullOpt; |
166 | } |
167 | Array<TuningRecord> records = this->GetTopK(this->CommitWorkload(mod), 1); |
168 | if (records.empty()) { |
169 | return NullOpt; |
170 | } |
171 | ICHECK_EQ(records.size(), 1); |
172 | return records[0]; |
173 | } |
174 | |
175 | Optional<tir::Schedule> DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target, |
176 | const String& workload_name) { |
177 | if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target, workload_name)) { |
178 | TuningRecord record = opt_record.value(); |
179 | tir::Schedule sch = |
180 | tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, |
181 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); |
182 | record->trace->ApplyToSchedule(sch, false); |
183 | return sch; |
184 | } else { |
185 | return NullOpt; |
186 | } |
187 | } |
188 | |
189 | Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target, |
190 | const String& workload_name) { |
191 | if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target, workload_name)) { |
192 | return opt_sch.value()->mod(); |
193 | } else { |
194 | return NullOpt; |
195 | } |
196 | } |
197 | |
198 | std::vector<Database>* ThreadLocalDatabases() { |
199 | static thread_local std::vector<Database> tls; |
200 | return &tls; |
201 | } |
202 | |
203 | void Database::EnterWithScope() { ThreadLocalDatabases()->push_back(*this); } |
204 | |
205 | void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); } |
206 | |
207 | Optional<Database> Database::Current() { |
208 | std::vector<Database>* tls = ThreadLocalDatabases(); |
209 | if (tls->empty()) { |
210 | return NullOpt; |
211 | } else { |
212 | return tls->back(); |
213 | } |
214 | } |
215 | |
216 | /******** PyDatabase ********/ |
217 | PyDatabaseNode::PyDatabaseNode(String mod_eq_name) : DatabaseNode(mod_eq_name) {} |
218 | |
219 | Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, |
220 | PyDatabaseNode::FCommitWorkload f_commit_workload, |
221 | PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, |
222 | PyDatabaseNode::FGetTopK f_get_top_k, |
223 | PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, |
224 | PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, |
225 | PyDatabaseNode::FQuerySchedule f_query_schedule, |
226 | PyDatabaseNode::FQueryIRModule f_query_ir_module, |
227 | PyDatabaseNode::FSize f_size, String mod_eq_name) { |
228 | ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>(mod_eq_name); |
229 | n->f_has_workload = f_has_workload; |
230 | n->f_commit_workload = f_commit_workload; |
231 | n->f_commit_tuning_record = f_commit_tuning_record; |
232 | n->f_get_top_k = f_get_top_k; |
233 | n->f_get_all_tuning_records = f_get_all_tuning_records; |
234 | n->f_query_tuning_record = f_query_tuning_record; |
235 | n->f_query_schedule = f_query_schedule; |
236 | n->f_query_ir_module = f_query_ir_module; |
237 | n->f_size = f_size; |
238 | return Database(n); |
239 | } |
240 | |
241 | /******** FFI ********/ |
242 | |
243 | TVM_REGISTER_NODE_TYPE(WorkloadNode); |
244 | TVM_REGISTER_NODE_TYPE(TuningRecordNode); |
245 | TVM_REGISTER_OBJECT_TYPE(DatabaseNode); |
246 | TVM_REGISTER_NODE_TYPE(PyDatabaseNode); |
247 | TVM_REGISTER_GLOBAL("meta_schedule.Workload" ).set_body_typed([](IRModule mod) { |
248 | return Workload(mod); |
249 | }); |
250 | TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON" ) |
251 | .set_body_method<Workload>(&WorkloadNode::AsJSON); |
252 | TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON" ).set_body_typed(&Workload::FromJSON); |
253 | TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord" ) |
254 | .set_body_typed([](tir::Trace trace, Workload workload, Optional<Array<FloatImm>> run_secs, |
255 | Optional<Target> target, Optional<Array<ArgInfo>> args_info) { |
256 | return TuningRecord(trace, workload, run_secs, target, args_info); |
257 | }); |
258 | TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate" ) |
259 | .set_body_method<TuningRecord>(&TuningRecordNode::AsMeasureCandidate); |
260 | TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON" ) |
261 | .set_body_method<TuningRecord>(&TuningRecordNode::AsJSON); |
262 | TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON" ).set_body_typed(TuningRecord::FromJSON); |
263 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope" ) |
264 | .set_body_method(&Database::EnterWithScope); |
265 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope" ) |
266 | .set_body_method(&Database::ExitWithScope); |
267 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent" ).set_body_typed(Database::Current); |
268 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload" ) |
269 | .set_body_method<Database>(&DatabaseNode::HasWorkload); |
270 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload" ) |
271 | .set_body_method<Database>(&DatabaseNode::CommitWorkload); |
272 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord" ) |
273 | .set_body_method<Database>(&DatabaseNode::CommitTuningRecord); |
274 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK" ) |
275 | .set_body_method<Database>(&DatabaseNode::GetTopK); |
276 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords" ) |
277 | .set_body_method<Database>(&DatabaseNode::GetAllTuningRecords); |
278 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize" ).set_body_method<Database>(&DatabaseNode::Size); |
279 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord" ) |
280 | .set_body_method<Database>(&DatabaseNode::QueryTuningRecord); |
281 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule" ) |
282 | .set_body_method<Database>(&DatabaseNode::QuerySchedule); |
283 | TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule" ) |
284 | .set_body_method<Database>(&DatabaseNode::QueryIRModule); |
285 | TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase" ).set_body_typed(Database::PyDatabase); |
286 | |
287 | } // namespace meta_schedule |
288 | } // namespace tvm |
289 | |