1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/measure_record.cc |
22 | * \brief Json serialization format for dumping and loading tuning records. |
23 | */ |
24 | |
25 | #include <dmlc/json.h> |
26 | #include <tvm/auto_scheduler/loop_state.h> |
27 | #include <tvm/auto_scheduler/measure_record.h> |
28 | #include <tvm/auto_scheduler/transform_step.h> |
29 | #include <tvm/runtime/registry.h> |
30 | |
31 | #include <fstream> |
32 | #include <sstream> |
33 | #include <string> |
34 | #include <utility> |
35 | #include <vector> |
36 | |
37 | #include "utils.h" |
38 | |
39 | // Json serialization handler for MeasureInput, MeasureResult |
40 | // (and recursively for SearchTask, State, Step, ...) |
41 | namespace dmlc { |
42 | namespace json { |
43 | |
44 | template <> |
45 | struct Handler<::tvm::Array<::tvm::auto_scheduler::Stage>> { |
46 | inline static void Write(dmlc::JSONWriter* writer, |
47 | const ::tvm::Array<::tvm::auto_scheduler::Stage>& data) { |
48 | writer->BeginArray(false); |
49 | writer->EndArray(); |
50 | } |
51 | inline static void Read(dmlc::JSONReader* reader, |
52 | ::tvm::Array<::tvm::auto_scheduler::Stage>* data) { |
53 | bool s; |
54 | reader->BeginArray(); |
55 | s = reader->NextArrayItem(); |
56 | ICHECK(!s); |
57 | } |
58 | }; |
59 | |
60 | template <> |
61 | struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { |
62 | inline static void Write(dmlc::JSONWriter* writer, |
63 | const ::tvm::Array<::tvm::auto_scheduler::Step>& data) { |
64 | writer->BeginArray(false); |
65 | for (const auto& step : data) { |
66 | writer->WriteArraySeperator(); |
67 | writer->BeginArray(false); |
68 | step->WriteToRecord(writer); |
69 | writer->EndArray(); |
70 | } |
71 | writer->EndArray(); |
72 | } |
73 | |
74 | inline static void Read(dmlc::JSONReader* reader, |
75 | ::tvm::Array<::tvm::auto_scheduler::Step>* data) { |
76 | bool s; |
77 | reader->BeginArray(); |
78 | data->clear(); |
79 | while (reader->NextArrayItem()) { |
80 | reader->BeginArray(); |
81 | data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader)); |
82 | s = reader->NextArrayItem(); |
83 | ICHECK(!s); |
84 | } |
85 | } |
86 | }; |
87 | |
88 | template <> |
89 | struct Handler<::tvm::auto_scheduler::StateNode> { |
90 | inline static void Write(dmlc::JSONWriter* writer, const ::tvm::auto_scheduler::StateNode& data) { |
91 | writer->BeginArray(false); |
92 | writer->WriteArrayItem(data.stages); |
93 | writer->WriteArrayItem(data.transform_steps); |
94 | writer->EndArray(); |
95 | } |
96 | inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) { |
97 | bool s; |
98 | reader->BeginArray(); |
99 | s = reader->NextArrayItem(); |
100 | ICHECK(s); |
101 | reader->Read(&data->stages); |
102 | s = reader->NextArrayItem(); |
103 | ICHECK(s); |
104 | reader->Read(&data->transform_steps); |
105 | s = reader->NextArrayItem(); |
106 | ICHECK(!s); |
107 | } |
108 | }; |
109 | |
110 | template <> |
111 | struct Handler<::tvm::auto_scheduler::HardwareParamsNode> { |
112 | inline static void Write(dmlc::JSONWriter* writer, |
113 | const ::tvm::auto_scheduler::HardwareParamsNode& data) { |
114 | writer->BeginArray(false); |
115 | writer->WriteArrayItem(data.num_cores); |
116 | writer->WriteArrayItem(data.vector_unit_bytes); |
117 | writer->WriteArrayItem(data.cache_line_bytes); |
118 | writer->WriteArrayItem(data.max_shared_memory_per_block); |
119 | writer->WriteArrayItem(data.max_local_memory_per_block); |
120 | writer->WriteArrayItem(data.max_threads_per_block); |
121 | writer->WriteArrayItem(data.max_vthread_extent); |
122 | writer->WriteArrayItem(data.warp_size); |
123 | writer->EndArray(); |
124 | } |
125 | inline static void Read(dmlc::JSONReader* reader, |
126 | ::tvm::auto_scheduler::HardwareParamsNode* data) { |
127 | bool s; |
128 | reader->BeginArray(); |
129 | s = reader->NextArrayItem(); |
130 | CHECK(s); |
131 | reader->Read(&data->num_cores); |
132 | s = reader->NextArrayItem(); |
133 | CHECK(s); |
134 | reader->Read(&data->vector_unit_bytes); |
135 | s = reader->NextArrayItem(); |
136 | CHECK(s); |
137 | reader->Read(&data->cache_line_bytes); |
138 | s = reader->NextArrayItem(); |
139 | CHECK(s); |
140 | reader->Read(&data->max_shared_memory_per_block); |
141 | s = reader->NextArrayItem(); |
142 | CHECK(s); |
143 | reader->Read(&data->max_local_memory_per_block); |
144 | s = reader->NextArrayItem(); |
145 | CHECK(s); |
146 | reader->Read(&data->max_threads_per_block); |
147 | s = reader->NextArrayItem(); |
148 | CHECK(s); |
149 | reader->Read(&data->max_vthread_extent); |
150 | s = reader->NextArrayItem(); |
151 | CHECK(s); |
152 | reader->Read(&data->warp_size); |
153 | s = reader->NextArrayItem(); |
154 | CHECK(!s); |
155 | } |
156 | }; |
157 | |
158 | template <> |
159 | struct Handler<::tvm::auto_scheduler::SearchTaskNode> { |
160 | inline static void Write(dmlc::JSONWriter* writer, |
161 | const ::tvm::auto_scheduler::SearchTaskNode& data) { |
162 | writer->BeginArray(false); |
163 | writer->WriteArrayItem(std::string(data.workload_key)); |
164 | writer->WriteArrayItem(data.target->str()); |
165 | writer->WriteArrayItem(*data.hardware_params.get()); |
166 | ::tvm::Target target = data.target; |
167 | ::tvm::Target target_host = data.target_host; |
168 | ::tvm::CheckAndUpdateHostConsistency(&target, &target_host); |
169 | if (target_host.defined()) { |
170 | writer->WriteArrayItem(target_host->str()); |
171 | } else { |
172 | writer->WriteArrayItem(std::string("" )); |
173 | } |
174 | writer->WriteArrayItem(static_cast<int>(data.layout_rewrite_option)); |
175 | writer->WriteArraySeperator(); |
176 | writer->BeginArray(false); |
177 | for (const auto& i : data.task_input_names) { |
178 | writer->WriteArrayItem(std::string(i)); |
179 | } |
180 | writer->EndArray(); |
181 | writer->EndArray(); |
182 | } |
183 | inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { |
184 | bool s; |
185 | std::string str_value; |
186 | int int_value; |
187 | auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>(); |
188 | reader->BeginArray(); |
189 | s = reader->NextArrayItem(); |
190 | ICHECK(s); |
191 | reader->Read(&str_value); |
192 | data->workload_key = std::move(str_value); |
193 | s = reader->NextArrayItem(); |
194 | ICHECK(s); |
195 | reader->Read(&str_value); |
196 | data->target = ::tvm::Target(str_value); |
197 | s = reader->NextArrayItem(); |
198 | if (s) { |
199 | reader->Read(hardware_params_node.get()); |
200 | s = reader->NextArrayItem(); |
201 | data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node); |
202 | if (s) { |
203 | reader->Read(&str_value); |
204 | if (!str_value.empty()) { |
205 | data->target_host = ::tvm::Target(str_value); |
206 | ::tvm::CheckAndUpdateHostConsistency(&data->target, &data->target_host); |
207 | } |
208 | s = reader->NextArrayItem(); |
209 | ICHECK(s); |
210 | reader->Read(&int_value); |
211 | data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); |
212 | s = reader->NextArrayItem(); |
213 | if (s) { |
214 | reader->BeginArray(); |
215 | s = reader->NextArrayItem(); |
216 | while (s) { |
217 | reader->Read(&str_value); |
218 | data->task_input_names.push_back(str_value); |
219 | s = reader->NextArrayItem(); |
220 | } |
221 | // Process the end of array |
222 | s = reader->NextArrayItem(); |
223 | } |
224 | ICHECK(!s); |
225 | } |
226 | } |
227 | } |
228 | }; |
229 | |
230 | template <> |
231 | struct Handler<::tvm::auto_scheduler::MeasureInputNode> { |
232 | inline static void Write(dmlc::JSONWriter* writer, |
233 | const ::tvm::auto_scheduler::MeasureInputNode& data) { |
234 | writer->BeginArray(false); |
235 | writer->WriteArrayItem(*data.task.operator->()); |
236 | writer->WriteArrayItem(*data.state.operator->()); |
237 | writer->EndArray(); |
238 | } |
239 | inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureInputNode* data) { |
240 | auto task_node = ::tvm::make_object<::tvm::auto_scheduler::SearchTaskNode>(); |
241 | auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>(); |
242 | state_node->concrete = true; |
243 | |
244 | bool s; |
245 | reader->BeginArray(); |
246 | s = reader->NextArrayItem(); |
247 | ICHECK(s); |
248 | reader->Read(task_node.get()); |
249 | s = reader->NextArrayItem(); |
250 | ICHECK(s); |
251 | reader->Read(state_node.get()); |
252 | s = reader->NextArrayItem(); |
253 | ICHECK(!s); |
254 | |
255 | data->task = ::tvm::auto_scheduler::SearchTask(task_node); |
256 | data->state = ::tvm::auto_scheduler::State(state_node); |
257 | } |
258 | }; |
259 | |
260 | template <> |
261 | struct Handler<::tvm::auto_scheduler::MeasureResultNode> { |
262 | inline static void Write(dmlc::JSONWriter* writer, |
263 | const ::tvm::auto_scheduler::MeasureResultNode& data) { |
264 | writer->BeginArray(false); |
265 | writer->WriteArraySeperator(); |
266 | writer->BeginArray(false); |
267 | for (const auto& x : data.costs) { |
268 | auto pf = x.as<::tvm::tir::FloatImmNode>(); |
269 | ICHECK(pf != nullptr) << "Cost can only contain float values" ; |
270 | writer->WriteArrayItem(pf->value); |
271 | } |
272 | writer->EndArray(); |
273 | writer->WriteArrayItem(data.error_no); |
274 | writer->WriteArrayItem(data.all_cost); |
275 | writer->WriteArrayItem(static_cast<int>((data.timestamp))); |
276 | writer->EndArray(); |
277 | } |
278 | inline static void Read(dmlc::JSONReader* reader, |
279 | ::tvm::auto_scheduler::MeasureResultNode* data) { |
280 | std::vector<double> double_list; |
281 | bool s; |
282 | reader->BeginArray(); |
283 | s = reader->NextArrayItem(); |
284 | ICHECK(s); |
285 | reader->Read(&double_list); |
286 | data->costs.clear(); |
287 | for (const auto& i : double_list) { |
288 | data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); |
289 | } |
290 | s = reader->NextArrayItem(); |
291 | ICHECK(s); |
292 | reader->Read(&data->error_no); |
293 | s = reader->NextArrayItem(); |
294 | ICHECK(s); |
295 | reader->Read(&data->all_cost); |
296 | s = reader->NextArrayItem(); |
297 | ICHECK(s); |
298 | reader->Read(&data->timestamp); |
299 | s = reader->NextArrayItem(); |
300 | ICHECK(!s); |
301 | } |
302 | }; |
303 | |
304 | } // namespace json |
305 | } // namespace dmlc |
306 | |
307 | namespace tvm { |
308 | namespace auto_scheduler { |
309 | |
310 | TVM_REGISTER_OBJECT_TYPE(RecordToFileNode); |
311 | TVM_REGISTER_OBJECT_TYPE(RecordReaderNode); |
312 | |
313 | RecordToFile::RecordToFile(String filename) { |
314 | auto node = make_object<RecordToFileNode>(); |
315 | node->filename = std::move(filename); |
316 | data_ = std::move(node); |
317 | } |
318 | |
319 | void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs, |
320 | const Array<MeasureResult>& results, const std::string log_version) { |
321 | dmlc::JSONWriter writer(os); |
322 | for (size_t i = 0; i < inputs.size(); ++i) { |
323 | writer.BeginObject(false); |
324 | writer.WriteObjectKeyValue("i" , *inputs[i].operator->()); |
325 | writer.WriteObjectKeyValue("r" , *results[i].operator->()); |
326 | writer.WriteObjectKeyValue("v" , log_version); |
327 | writer.EndObject(); |
328 | *os << "\n" ; |
329 | } |
330 | } |
331 | |
332 | void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, |
333 | std::string* log_version) { |
334 | std::istringstream ss(str); |
335 | dmlc::JSONReader reader(&ss); |
336 | std::string key; |
337 | |
338 | reader.BeginObject(); |
339 | while (reader.NextObjectItem(&key)) { |
340 | if (key == "i" ) { |
341 | reader.Read(inp); |
342 | } else if (key == "r" ) { |
343 | reader.Read(res); |
344 | } else if (key == "v" ) { |
345 | reader.Read(log_version); |
346 | } else { |
347 | LOG(FATAL) << "Invalid key in json log: " << key; |
348 | } |
349 | } |
350 | } |
351 | |
352 | void RecordToFileNode::Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs, |
353 | const Array<MeasureResult>& results) { |
354 | std::ofstream ofs(filename, std::ofstream::app); |
355 | WriteMeasureRecords(&ofs, inputs, results); |
356 | } |
357 | |
358 | RecordReader::RecordReader(String filename) { |
359 | auto node = make_object<RecordReaderNode>(); |
360 | node->filename = filename; |
361 | node->infile.open(filename, std::ifstream::in); |
362 | data_ = std::move(node); |
363 | } |
364 | |
365 | RecordReaderNode::~RecordReaderNode() { infile.close(); } |
366 | |
367 | bool RecordReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { |
368 | std::string log_version; |
369 | |
370 | while (std::getline(infile, cur_line_)) { |
371 | if (cur_line_[0] == '#' || cur_line_[0] == ' ') { |
372 | // skip comment lines begin with '#' or ' ' |
373 | continue; |
374 | } |
375 | ReadMeasureRecord(cur_line_, inp, res, &log_version); |
376 | return true; |
377 | } |
378 | |
379 | return false; |
380 | } |
381 | |
382 | std::pair<Array<MeasureInput>, Array<MeasureResult>> RecordReaderNode::ReadLines(int max_size, |
383 | int skip_size) { |
384 | auto inp = make_object<MeasureInputNode>(); |
385 | auto res = make_object<MeasureResultNode>(); |
386 | Array<MeasureInput> inputs; |
387 | Array<MeasureResult> results; |
388 | |
389 | while (ReadNext(inp.get(), res.get())) { |
390 | if (skip_size > 0) { |
391 | skip_size--; |
392 | continue; |
393 | } |
394 | |
395 | inputs.push_back(inp->copy()); |
396 | results.push_back(res->copy()); |
397 | |
398 | if (max_size > 0 && static_cast<int>(inputs.size()) >= max_size) { |
399 | break; |
400 | } |
401 | } |
402 | |
403 | return std::make_pair(inputs, results); |
404 | } |
405 | |
406 | TVM_REGISTER_GLOBAL("auto_scheduler.RecordToFile" ).set_body_typed([](const String& filename) { |
407 | return RecordToFile(filename); |
408 | }); |
409 | |
410 | TVM_REGISTER_GLOBAL("auto_scheduler.RecordReader" ).set_body_typed([](const String& filename) { |
411 | return RecordReader(filename); |
412 | }); |
413 | |
414 | TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadLines" ) |
415 | .set_body_typed([](RecordReader reader, int size, int skip_size) { |
416 | const auto& res = reader->ReadLines(size, skip_size); |
417 | return Array<ObjectRef>{res.first, res.second}; |
418 | }); |
419 | |
420 | TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext" ).set_body_typed([](RecordReader reader) { |
421 | auto inp = make_object<MeasureInputNode>(); |
422 | auto res = make_object<MeasureResultNode>(); |
423 | if (reader->ReadNext(inp.get(), res.get())) { |
424 | return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)}; |
425 | } else { |
426 | return Array<ObjectRef>(); |
427 | } |
428 | }); |
429 | |
430 | TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord" ).set_body_typed([](const std::string& str) { |
431 | auto inp = make_object<MeasureInputNode>(); |
432 | auto res = make_object<MeasureResultNode>(); |
433 | std::string log_version; |
434 | ReadMeasureRecord(str, inp.get(), res.get(), &log_version); |
435 | return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)}; |
436 | }); |
437 | |
438 | TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords" ) |
439 | .set_body_typed([](MeasureInput inp, MeasureResult res) { |
440 | auto inps = Array<MeasureInput>({inp}); |
441 | auto ress = Array<MeasureResult>({res}); |
442 | std::ostringstream ss; |
443 | WriteMeasureRecords(&ss, inps, ress); |
444 | return String(ss.str()); |
445 | }); |
446 | |
447 | TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords" ) |
448 | .set_body_typed([](String filename, Array<MeasureInput> in, Array<MeasureResult> res) { |
449 | std::ofstream ofs(filename, std::ofstream::app); |
450 | WriteMeasureRecords(&ofs, in, res); |
451 | }); |
452 | |
453 | TVM_REGISTER_GLOBAL("auto_scheduler.SerializeMeasureInput" ) |
454 | .set_body_typed([](const MeasureInput& input) { |
455 | std::ostringstream os; |
456 | dmlc::JSONWriter writer(&os); |
457 | writer.Write(*input.get()); |
458 | return os.str(); |
459 | }); |
460 | |
461 | TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeMeasureInput" ).set_body_typed([](String json) { |
462 | std::istringstream ss(json); |
463 | dmlc::JSONReader reader(&ss); |
464 | auto inp = make_object<MeasureInputNode>(); |
465 | reader.Read(inp.get()); |
466 | return ObjectRef(inp); |
467 | }); |
468 | |
469 | TVM_REGISTER_GLOBAL("auto_scheduler.SerializeSearchTask" ) |
470 | .set_body_typed([](const SearchTask& search_task) { |
471 | std::ostringstream os; |
472 | dmlc::JSONWriter writer(&os); |
473 | writer.Write(*search_task.get()); |
474 | return os.str(); |
475 | }); |
476 | |
477 | TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeSearchTask" ).set_body_typed([](String json) { |
478 | std::istringstream ss(json); |
479 | dmlc::JSONReader reader(&ss); |
480 | auto search_task = make_object<SearchTaskNode>(); |
481 | reader.Read(search_task.get()); |
482 | return ObjectRef(search_task); |
483 | }); |
484 | |
485 | } // namespace auto_scheduler |
486 | } // namespace tvm |
487 | |