1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_META_SCHEDULE_COST_MODEL_H_ |
21 | #define TVM_META_SCHEDULE_COST_MODEL_H_ |
22 | |
23 | #include <tvm/meta_schedule/arg_info.h> |
24 | #include <tvm/meta_schedule/measure_candidate.h> |
25 | #include <tvm/meta_schedule/runner.h> |
26 | #include <tvm/node/reflection.h> |
27 | #include <tvm/runtime/container/array.h> |
28 | #include <tvm/runtime/container/string.h> |
29 | #include <tvm/runtime/object.h> |
30 | #include <tvm/runtime/packed_func.h> |
31 | #include <tvm/tir/schedule/schedule.h> |
32 | |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace meta_schedule { |
37 | |
38 | class TuneContext; |
39 | |
40 | /*! \brief Cost model. */ |
41 | class CostModelNode : public runtime::Object { |
42 | public: |
43 | /*! \brief Virtual destructor. */ |
44 | virtual ~CostModelNode() = default; |
45 | |
46 | void VisitAttrs(tvm::AttrVisitor* v) {} |
47 | |
48 | /*! |
49 | * \brief Load the cost model from given file location. |
50 | * \param path The file path. |
51 | */ |
52 | virtual void Load(const String& path) = 0; |
53 | |
54 | /*! |
55 | * \brief Save the cost model to given file location. |
56 | * \param path The file path. |
57 | */ |
58 | virtual void Save(const String& path) = 0; |
59 | |
60 | /*! |
61 | * \brief Update the cost model given running results. |
62 | * \param context The tuning context. |
63 | * \param candidates The measure candidates. |
64 | * \param results The running results of the measure candidates. |
65 | */ |
66 | virtual void Update(const TuneContext& context, const Array<MeasureCandidate>& candidates, |
67 | const Array<RunnerResult>& results) = 0; |
68 | |
69 | /*! |
70 | * \brief Predict the normalized score (the larger the better) of given measure candidates. |
71 | * \param context The tuning context. |
72 | * \param candidates The measure candidates. |
73 | * \return The predicted normalized score. |
74 | */ |
75 | virtual std::vector<double> Predict(const TuneContext& context, |
76 | const Array<MeasureCandidate>& candidates) = 0; |
77 | |
78 | static constexpr const char* _type_key = "meta_schedule.CostModel" ; |
79 | TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); |
80 | }; |
81 | |
82 | /*! \brief The cost model with customized methods on the python-side. */ |
83 | class PyCostModelNode : public CostModelNode { |
84 | public: |
85 | /*! |
86 | * \brief Load the cost model from given file location. |
87 | * \param path The file path. |
88 | */ |
89 | using FLoad = runtime::TypedPackedFunc<void(String)>; |
90 | /*! |
91 | * \brief Save the cost model to given file location. |
92 | * \param path The file path. |
93 | */ |
94 | using FSave = runtime::TypedPackedFunc<void(String)>; |
95 | /*! |
96 | * \brief Update the cost model given running results. |
97 | * \param context The tuning context. |
98 | * \param candidates The measure candidates. |
99 | * \param results The running results of the measure candidates. |
100 | * \return Whether cost model was updated successfully. |
101 | */ |
102 | using FUpdate = runtime::TypedPackedFunc<void(const TuneContext&, const Array<MeasureCandidate>&, |
103 | const Array<RunnerResult>&)>; |
104 | /*! |
105 | * \brief Predict the running results of given measure candidates. |
106 | * \param context The tuning context. |
107 | * \param candidates The measure candidates. |
108 | * \param p_addr The address to save the estimated running results. |
109 | */ |
110 | using FPredict = runtime::TypedPackedFunc<void(const TuneContext&, const Array<MeasureCandidate>&, |
111 | void* p_addr)>; |
112 | /*! |
113 | * \brief Get the cost model as string with name. |
114 | * \return The string representation of the cost model. |
115 | */ |
116 | using FAsString = runtime::TypedPackedFunc<String()>; |
117 | |
118 | /*! \brief The packed function to the `Load` function. */ |
119 | FLoad f_load; |
120 | /*! \brief The packed function to the `Save` function. */ |
121 | FSave f_save; |
122 | /*! \brief The packed function to the `Update` function. */ |
123 | FUpdate f_update; |
124 | /*! \brief The packed function to the `Predict` function. */ |
125 | FPredict f_predict; |
126 | /*! \brief The packed function to the `AsString` function. */ |
127 | FAsString f_as_string; |
128 | |
129 | void VisitAttrs(tvm::AttrVisitor* v) { |
130 | // `f_load` is not visited |
131 | // `f_save` is not visited |
132 | // `f_update` is not visited |
133 | // `f_predict` is not visited |
134 | // `f_as_string` is not visited |
135 | } |
136 | |
137 | void Load(const String& path); |
138 | void Save(const String& path); |
139 | void Update(const TuneContext& context, const Array<MeasureCandidate>& candidates, |
140 | const Array<RunnerResult>& results); |
141 | std::vector<double> Predict(const TuneContext& context, |
142 | const Array<MeasureCandidate>& candidates); |
143 | |
144 | static constexpr const char* _type_key = "meta_schedule.PyCostModel" ; |
145 | TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); |
146 | }; |
147 | |
148 | /*! |
149 | * \brief Managed reference to CostModelNode |
150 | * \sa CostModelNode |
151 | */ |
152 | class CostModel : public runtime::ObjectRef { |
153 | public: |
154 | /*! |
155 | * \brief Create a feature extractor with customized methods on the python-side. |
156 | * \param f_load The packed function of `Load`. |
157 | * \param f_save The packed function of `Save`. |
158 | * \param f_update The packed function of `Update`. |
159 | * \param f_predict The packed function of `Predict`. |
160 | * \param f_as_string The packed function of `AsString`. |
161 | * \return The feature extractor created. |
162 | */ |
163 | TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, // |
164 | PyCostModelNode::FSave f_save, // |
165 | PyCostModelNode::FUpdate f_update, // |
166 | PyCostModelNode::FPredict f_predict, // |
167 | PyCostModelNode::FAsString f_as_string); |
168 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode); |
169 | }; |
170 | |
171 | } // namespace meta_schedule |
172 | } // namespace tvm |
173 | |
174 | #endif // TVM_META_SCHEDULE_COST_MODEL_H_ |
175 | |