1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_RUNNER_H_
20#define TVM_META_SCHEDULE_RUNNER_H_
21
22#include <tvm/ir/expr.h>
23#include <tvm/meta_schedule/arg_info.h>
24#include <tvm/node/reflection.h>
25#include <tvm/runtime/container/array.h>
26#include <tvm/runtime/container/optional.h>
27#include <tvm/runtime/container/string.h>
28#include <tvm/runtime/object.h>
29#include <tvm/runtime/packed_func.h>
30
31namespace tvm {
32namespace meta_schedule {
33
34/*! \brief Runner's input containing path of artifact, type of device and argument info. */
35class RunnerInputNode : public runtime::Object {
36 public:
37 /*! \brief The path to the built artifact. */
38 String artifact_path;
39 /*! \brief The type of device. */
40 String device_type;
41 /*! \brief The argument information. */
42 Array<ArgInfo> args_info;
43
44 void VisitAttrs(tvm::AttrVisitor* v) {
45 v->Visit("artifact_path", &artifact_path);
46 v->Visit("device_type", &device_type);
47 v->Visit("args_info", &args_info);
48 }
49
50 static constexpr const char* _type_key = "meta_schedule.RunnerInput";
51 TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object);
52};
53
54/*!
55 * \brief Managed reference to RunnerInputNode
56 * \sa RunnerInputNode
57 */
58class RunnerInput : public runtime::ObjectRef {
59 public:
60 /*!
61 * \brief Constructor of RunnerInput
62 * \param artifact_path The path to the built artifact.
63 * \param device_type The type of device.
64 * \param args_info The argument information.
65 */
66 TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array<ArgInfo> args_info);
67 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode);
68};
69
70/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
71class RunnerResultNode : public runtime::Object {
72 public:
73 /*! \brief The run time in seconds.*/
74 Optional<Array<FloatImm>> run_secs;
75 /*! \brief The error message, if any. */
76 Optional<String> error_msg;
77
78 void VisitAttrs(tvm::AttrVisitor* v) {
79 v->Visit("run_secs", &run_secs);
80 v->Visit("error_msg", &error_msg);
81 }
82
83 static constexpr const char* _type_key = "meta_schedule.RunnerResult";
84 TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object);
85};
86
87/*!
88 * \brief Managed reference to RunnerResultNode
89 * \sa RunnerResultNode
90 */
91class RunnerResult : public runtime::ObjectRef {
92 public:
93 /*!
94 * \brief Constructor
95 * \brief The run time in seconds.
96 * \brief The error message, if any.
97 */
98 TVM_DLL explicit RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String> error_msg);
99 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode);
100};
101
102/*!
103 * \brief A class to asynchronously fetch runner's output.
104 * \note The API design is consistent with python's concurrent.futures.Future:
105 * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
106 */
107class RunnerFutureNode : public runtime::Object {
108 public:
109 /*!
110 * \brief The function type to check whether the runner has finished.
111 * \return Whether the runner's output is ready.
112 */
113 using FDone = runtime::TypedPackedFunc<bool()>;
114 /*!
115 * \brief The function type to fetch runner output if it is ready.
116 * \return The runner's output.
117 */
118 using FResult = runtime::TypedPackedFunc<RunnerResult()>;
119
120 /*! \brief The packed function to check whether the runner has finished. */
121 FDone f_done;
122 /*! \brief The packed function to fetch runner output if it is ready. */
123 FResult f_result;
124
125 void VisitAttrs(tvm::AttrVisitor* v) {
126 // `f_done` is not visited
127 // `f_result` is not visited
128 }
129
130 /*!
131 * \brief Check whether the runner has finished.
132 * \return A boolean indicating whether the runner has finished.
133 */
134 bool Done() const {
135 ICHECK(f_done != nullptr) << "PyRunnerFuture's Done method not implemented!";
136 return f_done();
137 }
138 /*!
139 * \brief Fetch the runner's output if it is ready.
140 * \return The runner's output.
141 */
142 RunnerResult Result() const {
143 ICHECK(f_result != nullptr) << "PyRunnerFuture's Result method not implemented!";
144 return f_result();
145 }
146
147 static constexpr const char* _type_key = "meta_schedule.RunnerFuture";
148 TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object);
149};
150
151/*!
152 * \brief Managed reference to RunnerFutureNode
153 * \sa RunnerFutureNode
154 */
155class RunnerFuture : public runtime::ObjectRef {
156 public:
157 using FDone = RunnerFutureNode::FDone;
158 using FResult = RunnerFutureNode::FResult;
159
160 /*!
161 * \brief Constructor of RunnerFuture
162 * \param f_done The packed function to check whether the runner has finished.
163 * \param f_result The packed function to fetch runner output if it is ready.
164 */
165 TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result);
166 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef,
167 RunnerFutureNode);
168};
169
170/*! \brief The abstract runner interface. */
171class RunnerNode : public runtime::Object {
172 public:
173 /*!
174 * \brief The function type to run the built artifacts and get runner futures.
175 * \param input The runner's inputs.
176 * \return The runner futures.
177 * \sa RunnerFuture
178 */
179 using FRun = runtime::TypedPackedFunc<Array<RunnerFuture>(Array<RunnerInput>)>;
180
181 /*! \brief Default destructor */
182 virtual ~RunnerNode() = default;
183
184 /*!
185 * \brief Run the built artifact and get runner futures.
186 * \param runner_inputs The runner's inputs.
187 * \return The runner futures.
188 */
189 virtual Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) = 0;
190
191 static constexpr const char* _type_key = "meta_schedule.Runner";
192 TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object);
193};
194
195/*!
196 * \brief Managed reference to RunnerNode
197 * \sa RunnerNode
198 */
199class Runner : public runtime::ObjectRef {
200 public:
201 using FRun = RunnerNode::FRun;
202
203 /*!
204 * \brief Create a runner with customized build method on the python-side.
205 * \param f_run The packed function to run the built artifacts and get runner futures.
206 * \return The runner created.
207 */
208 TVM_DLL static Runner PyRunner(FRun f_run);
209 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode);
210};
211
212/*! \brief An abstract runner with customized build method on the python-side. */
213class PyRunnerNode : public RunnerNode {
214 public:
215 /*! \brief The packed function to run the built artifacts and get runner futures. */
216 FRun f_run;
217
218 void VisitAttrs(tvm::AttrVisitor* v) {
219 // `f_run` is not visited
220 }
221
222 Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final {
223 ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!";
224 return f_run(runner_inputs);
225 }
226
227 static constexpr const char* _type_key = "meta_schedule.PyRunner";
228 TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode);
229};
230
231} // namespace meta_schedule
232} // namespace tvm
233
234#endif // TVM_META_SCHEDULE_RUNNER_H_
235