1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/cost_model.h
22 * \brief Cost models that estimate the performance of programs
23 */
24
25#ifndef TVM_AUTO_SCHEDULER_COST_MODEL_H_
26#define TVM_AUTO_SCHEDULER_COST_MODEL_H_
27
28#include <tvm/auto_scheduler/compute_dag.h>
29#include <tvm/auto_scheduler/measure.h>
30#include <tvm/node/node.h>
31#include <tvm/runtime/packed_func.h>
32
33#include <vector>
34
35namespace tvm {
36namespace auto_scheduler {
37
38using runtime::PackedFunc;
39using runtime::TypedPackedFunc;
40
41/*! \brief The base class for cost model */
42class CostModelNode : public Object {
43 public:
44 /*!
45 * \brief Update the cost model according to new measurement results (training data).
46 * \param inputs The measure inputs
47 * \param results The measure results
48 */
49 virtual void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) = 0;
50
51 /*!
52 * \brief Predict the scores of states
53 * \param task The search task of states
54 * \param states The input states
55 * \param scores The predicted scores for all states
56 */
57 virtual void Predict(const SearchTask& task, const Array<State>& states,
58 std::vector<float>* scores) = 0;
59
60 /*!
61 * \brief Predict the scores of all stages in states. This is the breakdown version of `Predict`
62 * \param task The search task
63 * \param states The input states
64 * \param state_scores The predicted scores for all states
65 * \param stage_scores The predicted scores for all stages in all stages
66 */
67 virtual void PredictStages(const SearchTask& task, const Array<State>& states,
68 std::vector<float>* state_scores,
69 std::vector<std::vector<float>>* stage_scores) {
70 LOG(FATAL) << "Not implemented";
71 }
72
73 /*!
74 * \brief Default virtual destructor
75 */
76 virtual ~CostModelNode() {}
77
78 static constexpr const char* _type_key = "auto_scheduler.CostModel";
79 TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object);
80};
81
82/*!
83 * \brief Managed reference to CostModelNode.
84 * \sa CostModelNode
85 */
86class CostModel : public ObjectRef {
87 public:
88 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode);
89};
90
91/*! \brief The cost model returning random value for all predictions */
92class RandomModelNode : public CostModelNode {
93 public:
94 /*! \brief Pointer to a random number generator function */
95 const TypedPackedFunc<void(size_t, void*)>* random_number_func;
96
97 void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
98
99 void Predict(const SearchTask& task, const Array<State>& states,
100 std::vector<float>* scores) final;
101
102 static constexpr const char* _type_key = "auto_scheduler.RandomModel";
103 TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode);
104};
105
106/*!
107 * \brief Managed reference to RandomModelNode.
108 * \sa RandomModelNode
109 */
110class RandomModel : public CostModel {
111 public:
112 RandomModel();
113 explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : CostModel(n) {}
114
115 RandomModelNode* operator->() const { return static_cast<RandomModelNode*>(data_.get()); }
116
117 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel);
118 using ContainerType = RandomModelNode;
119};
120
121/*! \brief A wrapper for cost model defined by python code
122 * This class will call functions defined in the python */
123class PythonBasedModelNode : public CostModelNode {
124 public:
125 /*! \brief Pointer to the update function in python */
126 PackedFunc update_func;
127 /*! \brief Pointer to the predict function in python */
128 PackedFunc predict_func;
129 /*! \brief Pointer to the predict function in python */
130 PackedFunc predict_stage_func;
131
132 void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
133
134 void Predict(const SearchTask& task, const Array<State>& states,
135 std::vector<float>* scores) final;
136
137 void PredictStages(const SearchTask& task, const Array<State>& states,
138 std::vector<float>* state_scores,
139 std::vector<std::vector<float>>* stage_scores) final;
140
141 static constexpr const char* _type_key = "auto_scheduler.PythonBasedModel";
142 TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode);
143};
144
145/*!
146 * \brief Managed reference to PythonBasedModelNode.
147 * \sa PythonBasedModelNode
148 */
149class PythonBasedModel : public CostModel {
150 public:
151 /*!
152 * \brief The constructor.
153 * \param update_func The pointer to the update function defined in python
154 * \param predict_func The pointer to the prediction function defined in python
155 * \param predict_stage_func The pointer to the prediction function defined in python
156 */
157 PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func);
158
159 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, PythonBasedModelNode);
160};
161
162} // namespace auto_scheduler
163} // namespace tvm
164
165#endif // TVM_AUTO_SCHEDULER_COST_MODEL_H_
166