1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
21#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
22
23#include <tvm/meta_schedule/builder.h>
24#include <tvm/meta_schedule/measure_candidate.h>
25#include <tvm/meta_schedule/runner.h>
26#include <tvm/meta_schedule/search_strategy.h>
27#include <tvm/meta_schedule/tune_context.h>
28#include <tvm/node/reflection.h>
29#include <tvm/runtime/container/array.h>
30#include <tvm/runtime/container/string.h>
31#include <tvm/runtime/object.h>
32#include <tvm/runtime/packed_func.h>
33
34namespace tvm {
35namespace meta_schedule {
36
37class TaskScheduler;
38
39/*! \brief Rules to apply after measure results is available. */
40class MeasureCallbackNode : public runtime::Object {
41 public:
42 /*! \brief Virtual destructor. */
43 virtual ~MeasureCallbackNode() = default;
44
45 void VisitAttrs(tvm::AttrVisitor* v) {}
46
47 /*!
48 * \brief Apply a measure callback rule with given arguments.
49 * \param task_scheduler The task scheduler.
50 * \param task_id The id of the task (tune context) to apply measure callbacks.
51 * \param measure_candidates The measure candidates.
52 * \param builder_results The builder results by building the measure candidates.
53 * \param runner_results The runner results by running the built measure candidates.
54 */
55 virtual void Apply(const TaskScheduler& task_scheduler, //
56 int task_id, //
57 const Array<MeasureCandidate>& measure_candidates, //
58 const Array<BuilderResult>& builder_results, //
59 const Array<RunnerResult>& runner_results) = 0;
60
61 static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
62 TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
63};
64
65/*! \brief The measure callback with customized methods on the python-side. */
66class PyMeasureCallbackNode : public MeasureCallbackNode {
67 public:
68 /*!
69 * \brief Apply a measure callback to the given schedule.
70 * \param task_scheduler The task scheduler.
71 * \param tasks The list of tune context to process.
72 * \param measure_candidates The measure candidates.
73 * \param builds The builder results by building the measure candidates.
74 * \param results The runner results by running the built measure candidates.
75 * \return Whether the measure callback was successfully applied.
76 */
77 using FApply =
78 runtime::TypedPackedFunc<void(const TaskScheduler& task_scheduler, //
79 int task_id, //
80 const Array<MeasureCandidate>& measure_candidates, //
81 const Array<BuilderResult>& builds, //
82 const Array<RunnerResult>& results)>;
83 /*!
84 * \brief Get the measure callback function as string with name.
85 * \return The string of the measure callback function.
86 */
87 using FAsString = runtime::TypedPackedFunc<String()>;
88
89 /*! \brief The packed function to the `Apply` function. */
90 FApply f_apply;
91 /*! \brief The packed function to the `AsString` function. */
92 FAsString f_as_string;
93
94 void VisitAttrs(tvm::AttrVisitor* v) {
95 // `f_apply` is not visited
96 // `f_as_string` is not visited
97 }
98
99 void Apply(const TaskScheduler& task_scheduler, //
100 int task_id, //
101 const Array<MeasureCandidate>& measure_candidates, //
102 const Array<BuilderResult>& builds, //
103 const Array<RunnerResult>& results);
104
105 static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
106 TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
107};
108
109/*!
110 * \brief Managed reference to MeasureCallbackNode
111 * \sa MeasureCallbackNode
112 */
113class MeasureCallback : public runtime::ObjectRef {
114 public:
115 /*!
116 * \brief Create a measure callback that adds the measurement results into the database
117 * \return The measure callback created.
118 */
119 TVM_DLL static MeasureCallback AddToDatabase();
120 /*!
121 * \brief Create a measure callback that removes the build artifacts from the disk
122 * \return The measure callback created.
123 */
124 TVM_DLL static MeasureCallback RemoveBuildArtifact();
125 /*!
126 * \brief Create a measure callback that updates the cost model with measurement result.
127 * \return The measure callback created.
128 */
129 TVM_DLL static MeasureCallback UpdateCostModel();
130 /*!
131 * \brief Create a measure callback with customized methods on the python-side.
132 * \param f_apply The packed function of `Apply`.
133 * \param f_as_string The packed function of `AsString`.
134 * \return The measure callback created.
135 */
136 TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,
137 PyMeasureCallbackNode::FAsString f_as_string);
138 /*! \brief The default list of measure callbacks. */
139 TVM_DLL static Array<MeasureCallback, void> Default();
140 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
141};
142
143} // namespace meta_schedule
144} // namespace tvm
145
146#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
147