1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/cost_model.h |
22 | * \brief Cost models that estimate the performance of programs |
23 | */ |
24 | |
25 | #ifndef TVM_AUTO_SCHEDULER_COST_MODEL_H_ |
26 | #define TVM_AUTO_SCHEDULER_COST_MODEL_H_ |
27 | |
28 | #include <tvm/auto_scheduler/compute_dag.h> |
29 | #include <tvm/auto_scheduler/measure.h> |
30 | #include <tvm/node/node.h> |
31 | #include <tvm/runtime/packed_func.h> |
32 | |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace auto_scheduler { |
37 | |
38 | using runtime::PackedFunc; |
39 | using runtime::TypedPackedFunc; |
40 | |
41 | /*! \brief The base class for cost model */ |
42 | class CostModelNode : public Object { |
43 | public: |
44 | /*! |
45 | * \brief Update the cost model according to new measurement results (training data). |
46 | * \param inputs The measure inputs |
47 | * \param results The measure results |
48 | */ |
49 | virtual void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) = 0; |
50 | |
51 | /*! |
52 | * \brief Predict the scores of states |
53 | * \param task The search task of states |
54 | * \param states The input states |
55 | * \param scores The predicted scores for all states |
56 | */ |
57 | virtual void Predict(const SearchTask& task, const Array<State>& states, |
58 | std::vector<float>* scores) = 0; |
59 | |
60 | /*! |
61 | * \brief Predict the scores of all stages in states. This is the breakdown version of `Predict` |
62 | * \param task The search task |
63 | * \param states The input states |
64 | * \param state_scores The predicted scores for all states |
65 | * \param stage_scores The predicted scores for all stages in all stages |
66 | */ |
67 | virtual void PredictStages(const SearchTask& task, const Array<State>& states, |
68 | std::vector<float>* state_scores, |
69 | std::vector<std::vector<float>>* stage_scores) { |
70 | LOG(FATAL) << "Not implemented" ; |
71 | } |
72 | |
73 | /*! |
74 | * \brief Default virtual destructor |
75 | */ |
76 | virtual ~CostModelNode() {} |
77 | |
78 | static constexpr const char* _type_key = "auto_scheduler.CostModel" ; |
79 | TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); |
80 | }; |
81 | |
82 | /*! |
83 | * \brief Managed reference to CostModelNode. |
84 | * \sa CostModelNode |
85 | */ |
86 | class CostModel : public ObjectRef { |
87 | public: |
88 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode); |
89 | }; |
90 | |
91 | /*! \brief The cost model returning random value for all predictions */ |
92 | class RandomModelNode : public CostModelNode { |
93 | public: |
94 | /*! \brief Pointer to a random number generator function */ |
95 | const TypedPackedFunc<void(size_t, void*)>* random_number_func; |
96 | |
97 | void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final; |
98 | |
99 | void Predict(const SearchTask& task, const Array<State>& states, |
100 | std::vector<float>* scores) final; |
101 | |
102 | static constexpr const char* _type_key = "auto_scheduler.RandomModel" ; |
103 | TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); |
104 | }; |
105 | |
106 | /*! |
107 | * \brief Managed reference to RandomModelNode. |
108 | * \sa RandomModelNode |
109 | */ |
110 | class RandomModel : public CostModel { |
111 | public: |
112 | RandomModel(); |
113 | explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : CostModel(n) {} |
114 | |
115 | RandomModelNode* operator->() const { return static_cast<RandomModelNode*>(data_.get()); } |
116 | |
117 | TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel); |
118 | using ContainerType = RandomModelNode; |
119 | }; |
120 | |
121 | /*! \brief A wrapper for cost model defined by python code |
122 | * This class will call functions defined in the python */ |
123 | class PythonBasedModelNode : public CostModelNode { |
124 | public: |
125 | /*! \brief Pointer to the update function in python */ |
126 | PackedFunc update_func; |
127 | /*! \brief Pointer to the predict function in python */ |
128 | PackedFunc predict_func; |
129 | /*! \brief Pointer to the predict function in python */ |
130 | PackedFunc predict_stage_func; |
131 | |
132 | void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final; |
133 | |
134 | void Predict(const SearchTask& task, const Array<State>& states, |
135 | std::vector<float>* scores) final; |
136 | |
137 | void PredictStages(const SearchTask& task, const Array<State>& states, |
138 | std::vector<float>* state_scores, |
139 | std::vector<std::vector<float>>* stage_scores) final; |
140 | |
141 | static constexpr const char* _type_key = "auto_scheduler.PythonBasedModel" ; |
142 | TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode); |
143 | }; |
144 | |
145 | /*! |
146 | * \brief Managed reference to PythonBasedModelNode. |
147 | * \sa PythonBasedModelNode |
148 | */ |
149 | class PythonBasedModel : public CostModel { |
150 | public: |
151 | /*! |
152 | * \brief The constructor. |
153 | * \param update_func The pointer to the update function defined in python |
154 | * \param predict_func The pointer to the prediction function defined in python |
155 | * \param predict_stage_func The pointer to the prediction function defined in python |
156 | */ |
157 | PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func); |
158 | |
159 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, PythonBasedModelNode); |
160 | }; |
161 | |
162 | } // namespace auto_scheduler |
163 | } // namespace tvm |
164 | |
165 | #endif // TVM_AUTO_SCHEDULER_COST_MODEL_H_ |
166 | |