1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ |
21 | #define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ |
22 | |
23 | #include <tvm/meta_schedule/builder.h> |
24 | #include <tvm/meta_schedule/measure_candidate.h> |
25 | #include <tvm/meta_schedule/runner.h> |
26 | #include <tvm/meta_schedule/search_strategy.h> |
27 | #include <tvm/meta_schedule/tune_context.h> |
28 | #include <tvm/node/reflection.h> |
29 | #include <tvm/runtime/container/array.h> |
30 | #include <tvm/runtime/container/string.h> |
31 | #include <tvm/runtime/object.h> |
32 | #include <tvm/runtime/packed_func.h> |
33 | |
34 | namespace tvm { |
35 | namespace meta_schedule { |
36 | |
37 | class TaskScheduler; |
38 | |
39 | /*! \brief Rules to apply after measure results is available. */ |
40 | class MeasureCallbackNode : public runtime::Object { |
41 | public: |
42 | /*! \brief Virtual destructor. */ |
43 | virtual ~MeasureCallbackNode() = default; |
44 | |
45 | void VisitAttrs(tvm::AttrVisitor* v) {} |
46 | |
47 | /*! |
48 | * \brief Apply a measure callback rule with given arguments. |
49 | * \param task_scheduler The task scheduler. |
50 | * \param task_id The id of the task (tune context) to apply measure callbacks. |
51 | * \param measure_candidates The measure candidates. |
52 | * \param builder_results The builder results by building the measure candidates. |
53 | * \param runner_results The runner results by running the built measure candidates. |
54 | */ |
55 | virtual void Apply(const TaskScheduler& task_scheduler, // |
56 | int task_id, // |
57 | const Array<MeasureCandidate>& measure_candidates, // |
58 | const Array<BuilderResult>& builder_results, // |
59 | const Array<RunnerResult>& runner_results) = 0; |
60 | |
61 | static constexpr const char* _type_key = "meta_schedule.MeasureCallback" ; |
62 | TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); |
63 | }; |
64 | |
65 | /*! \brief The measure callback with customized methods on the python-side. */ |
66 | class PyMeasureCallbackNode : public MeasureCallbackNode { |
67 | public: |
68 | /*! |
69 | * \brief Apply a measure callback to the given schedule. |
70 | * \param task_scheduler The task scheduler. |
71 | * \param tasks The list of tune context to process. |
72 | * \param measure_candidates The measure candidates. |
73 | * \param builds The builder results by building the measure candidates. |
74 | * \param results The runner results by running the built measure candidates. |
75 | * \return Whether the measure callback was successfully applied. |
76 | */ |
77 | using FApply = |
78 | runtime::TypedPackedFunc<void(const TaskScheduler& task_scheduler, // |
79 | int task_id, // |
80 | const Array<MeasureCandidate>& measure_candidates, // |
81 | const Array<BuilderResult>& builds, // |
82 | const Array<RunnerResult>& results)>; |
83 | /*! |
84 | * \brief Get the measure callback function as string with name. |
85 | * \return The string of the measure callback function. |
86 | */ |
87 | using FAsString = runtime::TypedPackedFunc<String()>; |
88 | |
89 | /*! \brief The packed function to the `Apply` function. */ |
90 | FApply f_apply; |
91 | /*! \brief The packed function to the `AsString` function. */ |
92 | FAsString f_as_string; |
93 | |
94 | void VisitAttrs(tvm::AttrVisitor* v) { |
95 | // `f_apply` is not visited |
96 | // `f_as_string` is not visited |
97 | } |
98 | |
99 | void Apply(const TaskScheduler& task_scheduler, // |
100 | int task_id, // |
101 | const Array<MeasureCandidate>& measure_candidates, // |
102 | const Array<BuilderResult>& builds, // |
103 | const Array<RunnerResult>& results); |
104 | |
105 | static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback" ; |
106 | TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); |
107 | }; |
108 | |
109 | /*! |
110 | * \brief Managed reference to MeasureCallbackNode |
111 | * \sa MeasureCallbackNode |
112 | */ |
113 | class MeasureCallback : public runtime::ObjectRef { |
114 | public: |
115 | /*! |
116 | * \brief Create a measure callback that adds the measurement results into the database |
117 | * \return The measure callback created. |
118 | */ |
119 | TVM_DLL static MeasureCallback AddToDatabase(); |
120 | /*! |
121 | * \brief Create a measure callback that removes the build artifacts from the disk |
122 | * \return The measure callback created. |
123 | */ |
124 | TVM_DLL static MeasureCallback RemoveBuildArtifact(); |
125 | /*! |
126 | * \brief Create a measure callback that updates the cost model with measurement result. |
127 | * \return The measure callback created. |
128 | */ |
129 | TVM_DLL static MeasureCallback UpdateCostModel(); |
130 | /*! |
131 | * \brief Create a measure callback with customized methods on the python-side. |
132 | * \param f_apply The packed function of `Apply`. |
133 | * \param f_as_string The packed function of `AsString`. |
134 | * \return The measure callback created. |
135 | */ |
136 | TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, |
137 | PyMeasureCallbackNode::FAsString f_as_string); |
138 | /*! \brief The default list of measure callbacks. */ |
139 | TVM_DLL static Array<MeasureCallback, void> Default(); |
140 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); |
141 | }; |
142 | |
143 | } // namespace meta_schedule |
144 | } // namespace tvm |
145 | |
146 | #endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ |
147 | |