1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/auto_scheduler/measure_record.h |
22 | * \brief Json serialization format for dumping and loading measurement records. |
23 | */ |
24 | |
25 | #ifndef TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ |
26 | #define TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ |
27 | |
28 | #include <tvm/auto_scheduler/measure.h> |
29 | |
30 | #include <fstream> |
31 | #include <string> |
32 | #include <utility> |
33 | |
34 | namespace tvm { |
35 | namespace auto_scheduler { |
36 | |
37 | const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6" ; // NOLINT(*) |
38 | |
39 | /*! \brief Callback for logging the input and results of measurements to file */ |
40 | class RecordToFileNode : public MeasureCallbackNode { |
41 | public: |
42 | /*! \brief The name of output file. */ |
43 | String filename; |
44 | |
45 | void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs, |
46 | const Array<MeasureResult>& results) final; |
47 | |
48 | static constexpr const char* _type_key = "auto_scheduler.RecordToFile" ; |
49 | TVM_DECLARE_FINAL_OBJECT_INFO(RecordToFileNode, MeasureCallbackNode); |
50 | }; |
51 | |
52 | /*! |
53 | * \brief Managed reference to RecordToFileNode. |
54 | * \sa RecordToFileNode |
55 | */ |
56 | class RecordToFile : public MeasureCallback { |
57 | public: |
58 | /*! |
59 | * \brief The constructor. |
60 | * \param filename The name of output file |
61 | */ |
62 | explicit RecordToFile(String filename); |
63 | |
64 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordToFile, MeasureCallback, RecordToFileNode); |
65 | }; |
66 | |
67 | /*! \brief Log reader to load step logs from a file.*/ |
68 | class RecordReaderNode : public Object { |
69 | public: |
70 | /*! \brief The name of input file. */ |
71 | String filename; |
72 | /*! \brief The reading file stream. */ |
73 | std::ifstream infile; |
74 | |
75 | ~RecordReaderNode(); |
76 | |
77 | /*! |
78 | * \brief Read next line in the log file. |
79 | * \param inp A pointer to a MeasureInputNode, this is used as output. |
80 | * \param res A pointer to a MeasureResultNode, this is used as output. |
81 | * \return Whether the read is successful. */ |
82 | bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res); |
83 | |
84 | /*! |
85 | * \brief Read multiple lines from the log file. |
86 | * \param max_size The maximum number of lines. -1 means read all lines. |
87 | * \param skip_size Skip the first n lines. |
88 | * \return The MeasureInputs and MeasureResults loaded from the log file. |
89 | */ |
90 | std::pair<Array<MeasureInput>, Array<MeasureResult>> ReadLines(int max_size = -1, |
91 | int skip_size = 0); |
92 | |
93 | static constexpr const char* _type_key = "auto_scheduler.RecordReader" ; |
94 | TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object); |
95 | |
96 | private: |
97 | /*! \brief A string storing the current line. */ |
98 | std::string cur_line_; |
99 | }; |
100 | |
101 | /*! |
102 | * \brief Managed reference to RecordReaderNode. |
103 | * \sa RecordReaderNode |
104 | */ |
105 | class RecordReader : public ObjectRef { |
106 | public: |
107 | /*! |
108 | * \brief The constructor. |
109 | * \param filename The name of input file |
110 | */ |
111 | explicit RecordReader(String filename); |
112 | |
113 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordReader, ObjectRef, RecordReaderNode); |
114 | }; |
115 | |
116 | /*! |
117 | * \brief Append measure records to an output stream. |
118 | * \param os A pointer to a output stream. |
119 | * \param inputs The MeasureInputs to be written. |
120 | * \param results The MeasureResults to be written. |
121 | * \param log_version The log version for the given record. |
122 | */ |
123 | void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs, |
124 | const Array<MeasureResult>& results, |
125 | const std::string log_version = AUTO_SCHEDULER_LOG_VERSION); |
126 | |
127 | /*! |
128 | * \brief Read one measure record from a string. |
129 | * \param str The record string to be parsed. |
130 | * \param inp A pointer to a MeasureInputNode used to store the return value. |
131 | * \param res A pointer to a MeasureResultNode used to store the return value. |
132 | * \param log_version A pointer to a string used to store the log version. |
133 | */ |
134 | void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, |
135 | std::string* log_version); |
136 | |
137 | } // namespace auto_scheduler |
138 | } // namespace tvm |
139 | |
140 | #endif // TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ |
141 | |