1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/auto_scheduler/measure_record.h
22 * \brief Json serialization format for dumping and loading measurement records.
23 */
24
25#ifndef TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_
26#define TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_
27
28#include <tvm/auto_scheduler/measure.h>
29
30#include <fstream>
31#include <string>
32#include <utility>
33
34namespace tvm {
35namespace auto_scheduler {
36
37const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*)
38
39/*! \brief Callback for logging the input and results of measurements to file */
40class RecordToFileNode : public MeasureCallbackNode {
41 public:
42 /*! \brief The name of output file. */
43 String filename;
44
45 void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
46 const Array<MeasureResult>& results) final;
47
48 static constexpr const char* _type_key = "auto_scheduler.RecordToFile";
49 TVM_DECLARE_FINAL_OBJECT_INFO(RecordToFileNode, MeasureCallbackNode);
50};
51
52/*!
53 * \brief Managed reference to RecordToFileNode.
54 * \sa RecordToFileNode
55 */
56class RecordToFile : public MeasureCallback {
57 public:
58 /*!
59 * \brief The constructor.
60 * \param filename The name of output file
61 */
62 explicit RecordToFile(String filename);
63
64 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordToFile, MeasureCallback, RecordToFileNode);
65};
66
67/*! \brief Log reader to load step logs from a file.*/
68class RecordReaderNode : public Object {
69 public:
70 /*! \brief The name of input file. */
71 String filename;
72 /*! \brief The reading file stream. */
73 std::ifstream infile;
74
75 ~RecordReaderNode();
76
77 /*!
78 * \brief Read next line in the log file.
79 * \param inp A pointer to a MeasureInputNode, this is used as output.
80 * \param res A pointer to a MeasureResultNode, this is used as output.
81 * \return Whether the read is successful. */
82 bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res);
83
84 /*!
85 * \brief Read multiple lines from the log file.
86 * \param max_size The maximum number of lines. -1 means read all lines.
87 * \param skip_size Skip the first n lines.
88 * \return The MeasureInputs and MeasureResults loaded from the log file.
89 */
90 std::pair<Array<MeasureInput>, Array<MeasureResult>> ReadLines(int max_size = -1,
91 int skip_size = 0);
92
93 static constexpr const char* _type_key = "auto_scheduler.RecordReader";
94 TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object);
95
96 private:
97 /*! \brief A string storing the current line. */
98 std::string cur_line_;
99};
100
101/*!
102 * \brief Managed reference to RecordReaderNode.
103 * \sa RecordReaderNode
104 */
105class RecordReader : public ObjectRef {
106 public:
107 /*!
108 * \brief The constructor.
109 * \param filename The name of input file
110 */
111 explicit RecordReader(String filename);
112
113 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordReader, ObjectRef, RecordReaderNode);
114};
115
116/*!
117 * \brief Append measure records to an output stream.
118 * \param os A pointer to a output stream.
119 * \param inputs The MeasureInputs to be written.
120 * \param results The MeasureResults to be written.
121 * \param log_version The log version for the given record.
122 */
123void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
124 const Array<MeasureResult>& results,
125 const std::string log_version = AUTO_SCHEDULER_LOG_VERSION);
126
127/*!
128 * \brief Read one measure record from a string.
129 * \param str The record string to be parsed.
130 * \param inp A pointer to a MeasureInputNode used to store the return value.
131 * \param res A pointer to a MeasureResultNode used to store the return value.
132 * \param log_version A pointer to a string used to store the log version.
133 */
134void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res,
135 std::string* log_version);
136
137} // namespace auto_scheduler
138} // namespace tvm
139
140#endif // TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_
141