1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/measure_record.cc
22 * \brief Json serialization format for dumping and loading tuning records.
23 */
24
25#include <dmlc/json.h>
26#include <tvm/auto_scheduler/loop_state.h>
27#include <tvm/auto_scheduler/measure_record.h>
28#include <tvm/auto_scheduler/transform_step.h>
29#include <tvm/runtime/registry.h>
30
31#include <fstream>
32#include <sstream>
33#include <string>
34#include <utility>
35#include <vector>
36
37#include "utils.h"
38
39// Json serialization handler for MeasureInput, MeasureResult
40// (and recursively for SearchTask, State, Step, ...)
41namespace dmlc {
42namespace json {
43
44template <>
45struct Handler<::tvm::Array<::tvm::auto_scheduler::Stage>> {
46 inline static void Write(dmlc::JSONWriter* writer,
47 const ::tvm::Array<::tvm::auto_scheduler::Stage>& data) {
48 writer->BeginArray(false);
49 writer->EndArray();
50 }
51 inline static void Read(dmlc::JSONReader* reader,
52 ::tvm::Array<::tvm::auto_scheduler::Stage>* data) {
53 bool s;
54 reader->BeginArray();
55 s = reader->NextArrayItem();
56 ICHECK(!s);
57 }
58};
59
60template <>
61struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
62 inline static void Write(dmlc::JSONWriter* writer,
63 const ::tvm::Array<::tvm::auto_scheduler::Step>& data) {
64 writer->BeginArray(false);
65 for (const auto& step : data) {
66 writer->WriteArraySeperator();
67 writer->BeginArray(false);
68 step->WriteToRecord(writer);
69 writer->EndArray();
70 }
71 writer->EndArray();
72 }
73
74 inline static void Read(dmlc::JSONReader* reader,
75 ::tvm::Array<::tvm::auto_scheduler::Step>* data) {
76 bool s;
77 reader->BeginArray();
78 data->clear();
79 while (reader->NextArrayItem()) {
80 reader->BeginArray();
81 data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader));
82 s = reader->NextArrayItem();
83 ICHECK(!s);
84 }
85 }
86};
87
88template <>
89struct Handler<::tvm::auto_scheduler::StateNode> {
90 inline static void Write(dmlc::JSONWriter* writer, const ::tvm::auto_scheduler::StateNode& data) {
91 writer->BeginArray(false);
92 writer->WriteArrayItem(data.stages);
93 writer->WriteArrayItem(data.transform_steps);
94 writer->EndArray();
95 }
96 inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) {
97 bool s;
98 reader->BeginArray();
99 s = reader->NextArrayItem();
100 ICHECK(s);
101 reader->Read(&data->stages);
102 s = reader->NextArrayItem();
103 ICHECK(s);
104 reader->Read(&data->transform_steps);
105 s = reader->NextArrayItem();
106 ICHECK(!s);
107 }
108};
109
110template <>
111struct Handler<::tvm::auto_scheduler::HardwareParamsNode> {
112 inline static void Write(dmlc::JSONWriter* writer,
113 const ::tvm::auto_scheduler::HardwareParamsNode& data) {
114 writer->BeginArray(false);
115 writer->WriteArrayItem(data.num_cores);
116 writer->WriteArrayItem(data.vector_unit_bytes);
117 writer->WriteArrayItem(data.cache_line_bytes);
118 writer->WriteArrayItem(data.max_shared_memory_per_block);
119 writer->WriteArrayItem(data.max_local_memory_per_block);
120 writer->WriteArrayItem(data.max_threads_per_block);
121 writer->WriteArrayItem(data.max_vthread_extent);
122 writer->WriteArrayItem(data.warp_size);
123 writer->EndArray();
124 }
125 inline static void Read(dmlc::JSONReader* reader,
126 ::tvm::auto_scheduler::HardwareParamsNode* data) {
127 bool s;
128 reader->BeginArray();
129 s = reader->NextArrayItem();
130 CHECK(s);
131 reader->Read(&data->num_cores);
132 s = reader->NextArrayItem();
133 CHECK(s);
134 reader->Read(&data->vector_unit_bytes);
135 s = reader->NextArrayItem();
136 CHECK(s);
137 reader->Read(&data->cache_line_bytes);
138 s = reader->NextArrayItem();
139 CHECK(s);
140 reader->Read(&data->max_shared_memory_per_block);
141 s = reader->NextArrayItem();
142 CHECK(s);
143 reader->Read(&data->max_local_memory_per_block);
144 s = reader->NextArrayItem();
145 CHECK(s);
146 reader->Read(&data->max_threads_per_block);
147 s = reader->NextArrayItem();
148 CHECK(s);
149 reader->Read(&data->max_vthread_extent);
150 s = reader->NextArrayItem();
151 CHECK(s);
152 reader->Read(&data->warp_size);
153 s = reader->NextArrayItem();
154 CHECK(!s);
155 }
156};
157
158template <>
159struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
160 inline static void Write(dmlc::JSONWriter* writer,
161 const ::tvm::auto_scheduler::SearchTaskNode& data) {
162 writer->BeginArray(false);
163 writer->WriteArrayItem(std::string(data.workload_key));
164 writer->WriteArrayItem(data.target->str());
165 writer->WriteArrayItem(*data.hardware_params.get());
166 ::tvm::Target target = data.target;
167 ::tvm::Target target_host = data.target_host;
168 ::tvm::CheckAndUpdateHostConsistency(&target, &target_host);
169 if (target_host.defined()) {
170 writer->WriteArrayItem(target_host->str());
171 } else {
172 writer->WriteArrayItem(std::string(""));
173 }
174 writer->WriteArrayItem(static_cast<int>(data.layout_rewrite_option));
175 writer->WriteArraySeperator();
176 writer->BeginArray(false);
177 for (const auto& i : data.task_input_names) {
178 writer->WriteArrayItem(std::string(i));
179 }
180 writer->EndArray();
181 writer->EndArray();
182 }
183 inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
184 bool s;
185 std::string str_value;
186 int int_value;
187 auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>();
188 reader->BeginArray();
189 s = reader->NextArrayItem();
190 ICHECK(s);
191 reader->Read(&str_value);
192 data->workload_key = std::move(str_value);
193 s = reader->NextArrayItem();
194 ICHECK(s);
195 reader->Read(&str_value);
196 data->target = ::tvm::Target(str_value);
197 s = reader->NextArrayItem();
198 if (s) {
199 reader->Read(hardware_params_node.get());
200 s = reader->NextArrayItem();
201 data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node);
202 if (s) {
203 reader->Read(&str_value);
204 if (!str_value.empty()) {
205 data->target_host = ::tvm::Target(str_value);
206 ::tvm::CheckAndUpdateHostConsistency(&data->target, &data->target_host);
207 }
208 s = reader->NextArrayItem();
209 ICHECK(s);
210 reader->Read(&int_value);
211 data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value);
212 s = reader->NextArrayItem();
213 if (s) {
214 reader->BeginArray();
215 s = reader->NextArrayItem();
216 while (s) {
217 reader->Read(&str_value);
218 data->task_input_names.push_back(str_value);
219 s = reader->NextArrayItem();
220 }
221 // Process the end of array
222 s = reader->NextArrayItem();
223 }
224 ICHECK(!s);
225 }
226 }
227 }
228};
229
230template <>
231struct Handler<::tvm::auto_scheduler::MeasureInputNode> {
232 inline static void Write(dmlc::JSONWriter* writer,
233 const ::tvm::auto_scheduler::MeasureInputNode& data) {
234 writer->BeginArray(false);
235 writer->WriteArrayItem(*data.task.operator->());
236 writer->WriteArrayItem(*data.state.operator->());
237 writer->EndArray();
238 }
239 inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureInputNode* data) {
240 auto task_node = ::tvm::make_object<::tvm::auto_scheduler::SearchTaskNode>();
241 auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>();
242 state_node->concrete = true;
243
244 bool s;
245 reader->BeginArray();
246 s = reader->NextArrayItem();
247 ICHECK(s);
248 reader->Read(task_node.get());
249 s = reader->NextArrayItem();
250 ICHECK(s);
251 reader->Read(state_node.get());
252 s = reader->NextArrayItem();
253 ICHECK(!s);
254
255 data->task = ::tvm::auto_scheduler::SearchTask(task_node);
256 data->state = ::tvm::auto_scheduler::State(state_node);
257 }
258};
259
260template <>
261struct Handler<::tvm::auto_scheduler::MeasureResultNode> {
262 inline static void Write(dmlc::JSONWriter* writer,
263 const ::tvm::auto_scheduler::MeasureResultNode& data) {
264 writer->BeginArray(false);
265 writer->WriteArraySeperator();
266 writer->BeginArray(false);
267 for (const auto& x : data.costs) {
268 auto pf = x.as<::tvm::tir::FloatImmNode>();
269 ICHECK(pf != nullptr) << "Cost can only contain float values";
270 writer->WriteArrayItem(pf->value);
271 }
272 writer->EndArray();
273 writer->WriteArrayItem(data.error_no);
274 writer->WriteArrayItem(data.all_cost);
275 writer->WriteArrayItem(static_cast<int>((data.timestamp)));
276 writer->EndArray();
277 }
278 inline static void Read(dmlc::JSONReader* reader,
279 ::tvm::auto_scheduler::MeasureResultNode* data) {
280 std::vector<double> double_list;
281 bool s;
282 reader->BeginArray();
283 s = reader->NextArrayItem();
284 ICHECK(s);
285 reader->Read(&double_list);
286 data->costs.clear();
287 for (const auto& i : double_list) {
288 data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i));
289 }
290 s = reader->NextArrayItem();
291 ICHECK(s);
292 reader->Read(&data->error_no);
293 s = reader->NextArrayItem();
294 ICHECK(s);
295 reader->Read(&data->all_cost);
296 s = reader->NextArrayItem();
297 ICHECK(s);
298 reader->Read(&data->timestamp);
299 s = reader->NextArrayItem();
300 ICHECK(!s);
301 }
302};
303
304} // namespace json
305} // namespace dmlc
306
307namespace tvm {
308namespace auto_scheduler {
309
310TVM_REGISTER_OBJECT_TYPE(RecordToFileNode);
311TVM_REGISTER_OBJECT_TYPE(RecordReaderNode);
312
313RecordToFile::RecordToFile(String filename) {
314 auto node = make_object<RecordToFileNode>();
315 node->filename = std::move(filename);
316 data_ = std::move(node);
317}
318
319void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
320 const Array<MeasureResult>& results, const std::string log_version) {
321 dmlc::JSONWriter writer(os);
322 for (size_t i = 0; i < inputs.size(); ++i) {
323 writer.BeginObject(false);
324 writer.WriteObjectKeyValue("i", *inputs[i].operator->());
325 writer.WriteObjectKeyValue("r", *results[i].operator->());
326 writer.WriteObjectKeyValue("v", log_version);
327 writer.EndObject();
328 *os << "\n";
329 }
330}
331
332void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res,
333 std::string* log_version) {
334 std::istringstream ss(str);
335 dmlc::JSONReader reader(&ss);
336 std::string key;
337
338 reader.BeginObject();
339 while (reader.NextObjectItem(&key)) {
340 if (key == "i") {
341 reader.Read(inp);
342 } else if (key == "r") {
343 reader.Read(res);
344 } else if (key == "v") {
345 reader.Read(log_version);
346 } else {
347 LOG(FATAL) << "Invalid key in json log: " << key;
348 }
349 }
350}
351
352void RecordToFileNode::Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
353 const Array<MeasureResult>& results) {
354 std::ofstream ofs(filename, std::ofstream::app);
355 WriteMeasureRecords(&ofs, inputs, results);
356}
357
358RecordReader::RecordReader(String filename) {
359 auto node = make_object<RecordReaderNode>();
360 node->filename = filename;
361 node->infile.open(filename, std::ifstream::in);
362 data_ = std::move(node);
363}
364
365RecordReaderNode::~RecordReaderNode() { infile.close(); }
366
367bool RecordReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) {
368 std::string log_version;
369
370 while (std::getline(infile, cur_line_)) {
371 if (cur_line_[0] == '#' || cur_line_[0] == ' ') {
372 // skip comment lines begin with '#' or ' '
373 continue;
374 }
375 ReadMeasureRecord(cur_line_, inp, res, &log_version);
376 return true;
377 }
378
379 return false;
380}
381
382std::pair<Array<MeasureInput>, Array<MeasureResult>> RecordReaderNode::ReadLines(int max_size,
383 int skip_size) {
384 auto inp = make_object<MeasureInputNode>();
385 auto res = make_object<MeasureResultNode>();
386 Array<MeasureInput> inputs;
387 Array<MeasureResult> results;
388
389 while (ReadNext(inp.get(), res.get())) {
390 if (skip_size > 0) {
391 skip_size--;
392 continue;
393 }
394
395 inputs.push_back(inp->copy());
396 results.push_back(res->copy());
397
398 if (max_size > 0 && static_cast<int>(inputs.size()) >= max_size) {
399 break;
400 }
401 }
402
403 return std::make_pair(inputs, results);
404}
405
406TVM_REGISTER_GLOBAL("auto_scheduler.RecordToFile").set_body_typed([](const String& filename) {
407 return RecordToFile(filename);
408});
409
410TVM_REGISTER_GLOBAL("auto_scheduler.RecordReader").set_body_typed([](const String& filename) {
411 return RecordReader(filename);
412});
413
414TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadLines")
415 .set_body_typed([](RecordReader reader, int size, int skip_size) {
416 const auto& res = reader->ReadLines(size, skip_size);
417 return Array<ObjectRef>{res.first, res.second};
418 });
419
420TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext").set_body_typed([](RecordReader reader) {
421 auto inp = make_object<MeasureInputNode>();
422 auto res = make_object<MeasureResultNode>();
423 if (reader->ReadNext(inp.get(), res.get())) {
424 return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)};
425 } else {
426 return Array<ObjectRef>();
427 }
428});
429
430TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord").set_body_typed([](const std::string& str) {
431 auto inp = make_object<MeasureInputNode>();
432 auto res = make_object<MeasureResultNode>();
433 std::string log_version;
434 ReadMeasureRecord(str, inp.get(), res.get(), &log_version);
435 return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)};
436});
437
438TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords")
439 .set_body_typed([](MeasureInput inp, MeasureResult res) {
440 auto inps = Array<MeasureInput>({inp});
441 auto ress = Array<MeasureResult>({res});
442 std::ostringstream ss;
443 WriteMeasureRecords(&ss, inps, ress);
444 return String(ss.str());
445 });
446
447TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords")
448 .set_body_typed([](String filename, Array<MeasureInput> in, Array<MeasureResult> res) {
449 std::ofstream ofs(filename, std::ofstream::app);
450 WriteMeasureRecords(&ofs, in, res);
451 });
452
453TVM_REGISTER_GLOBAL("auto_scheduler.SerializeMeasureInput")
454 .set_body_typed([](const MeasureInput& input) {
455 std::ostringstream os;
456 dmlc::JSONWriter writer(&os);
457 writer.Write(*input.get());
458 return os.str();
459 });
460
461TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeMeasureInput").set_body_typed([](String json) {
462 std::istringstream ss(json);
463 dmlc::JSONReader reader(&ss);
464 auto inp = make_object<MeasureInputNode>();
465 reader.Read(inp.get());
466 return ObjectRef(inp);
467});
468
469TVM_REGISTER_GLOBAL("auto_scheduler.SerializeSearchTask")
470 .set_body_typed([](const SearchTask& search_task) {
471 std::ostringstream os;
472 dmlc::JSONWriter writer(&os);
473 writer.Write(*search_task.get());
474 return os.str();
475 });
476
477TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeSearchTask").set_body_typed([](String json) {
478 std::istringstream ss(json);
479 dmlc::JSONReader reader(&ss);
480 auto search_task = make_object<SearchTaskNode>();
481 reader.Read(search_task.get());
482 return ObjectRef(search_task);
483});
484
485} // namespace auto_scheduler
486} // namespace tvm
487