1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_META_SCHEDULE_RUNNER_H_ |
20 | #define TVM_META_SCHEDULE_RUNNER_H_ |
21 | |
22 | #include <tvm/ir/expr.h> |
23 | #include <tvm/meta_schedule/arg_info.h> |
24 | #include <tvm/node/reflection.h> |
25 | #include <tvm/runtime/container/array.h> |
26 | #include <tvm/runtime/container/optional.h> |
27 | #include <tvm/runtime/container/string.h> |
28 | #include <tvm/runtime/object.h> |
29 | #include <tvm/runtime/packed_func.h> |
30 | |
31 | namespace tvm { |
32 | namespace meta_schedule { |
33 | |
34 | /*! \brief Runner's input containing path of artifact, type of device and argument info. */ |
35 | class RunnerInputNode : public runtime::Object { |
36 | public: |
37 | /*! \brief The path to the built artifact. */ |
38 | String artifact_path; |
39 | /*! \brief The type of device. */ |
40 | String device_type; |
41 | /*! \brief The argument information. */ |
42 | Array<ArgInfo> args_info; |
43 | |
44 | void VisitAttrs(tvm::AttrVisitor* v) { |
45 | v->Visit("artifact_path" , &artifact_path); |
46 | v->Visit("device_type" , &device_type); |
47 | v->Visit("args_info" , &args_info); |
48 | } |
49 | |
50 | static constexpr const char* _type_key = "meta_schedule.RunnerInput" ; |
51 | TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object); |
52 | }; |
53 | |
54 | /*! |
55 | * \brief Managed reference to RunnerInputNode |
56 | * \sa RunnerInputNode |
57 | */ |
58 | class RunnerInput : public runtime::ObjectRef { |
59 | public: |
60 | /*! |
61 | * \brief Constructor of RunnerInput |
62 | * \param artifact_path The path to the built artifact. |
63 | * \param device_type The type of device. |
64 | * \param args_info The argument information. |
65 | */ |
66 | TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array<ArgInfo> args_info); |
67 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode); |
68 | }; |
69 | |
70 | /*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ |
71 | class RunnerResultNode : public runtime::Object { |
72 | public: |
73 | /*! \brief The run time in seconds.*/ |
74 | Optional<Array<FloatImm>> run_secs; |
75 | /*! \brief The error message, if any. */ |
76 | Optional<String> error_msg; |
77 | |
78 | void VisitAttrs(tvm::AttrVisitor* v) { |
79 | v->Visit("run_secs" , &run_secs); |
80 | v->Visit("error_msg" , &error_msg); |
81 | } |
82 | |
83 | static constexpr const char* _type_key = "meta_schedule.RunnerResult" ; |
84 | TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); |
85 | }; |
86 | |
87 | /*! |
88 | * \brief Managed reference to RunnerResultNode |
89 | * \sa RunnerResultNode |
90 | */ |
91 | class RunnerResult : public runtime::ObjectRef { |
92 | public: |
93 | /*! |
94 | * \brief Constructor |
95 | * \brief The run time in seconds. |
96 | * \brief The error message, if any. |
97 | */ |
98 | TVM_DLL explicit RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String> error_msg); |
99 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); |
100 | }; |
101 | |
102 | /*! |
103 | * \brief A class to asynchronously fetch runner's output. |
104 | * \note The API design is consistent with python's concurrent.futures.Future: |
105 | * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future |
106 | */ |
107 | class RunnerFutureNode : public runtime::Object { |
108 | public: |
109 | /*! |
110 | * \brief The function type to check whether the runner has finished. |
111 | * \return Whether the runner's output is ready. |
112 | */ |
113 | using FDone = runtime::TypedPackedFunc<bool()>; |
114 | /*! |
115 | * \brief The function type to fetch runner output if it is ready. |
116 | * \return The runner's output. |
117 | */ |
118 | using FResult = runtime::TypedPackedFunc<RunnerResult()>; |
119 | |
120 | /*! \brief The packed function to check whether the runner has finished. */ |
121 | FDone f_done; |
122 | /*! \brief The packed function to fetch runner output if it is ready. */ |
123 | FResult f_result; |
124 | |
125 | void VisitAttrs(tvm::AttrVisitor* v) { |
126 | // `f_done` is not visited |
127 | // `f_result` is not visited |
128 | } |
129 | |
130 | /*! |
131 | * \brief Check whether the runner has finished. |
132 | * \return A boolean indicating whether the runner has finished. |
133 | */ |
134 | bool Done() const { |
135 | ICHECK(f_done != nullptr) << "PyRunnerFuture's Done method not implemented!" ; |
136 | return f_done(); |
137 | } |
138 | /*! |
139 | * \brief Fetch the runner's output if it is ready. |
140 | * \return The runner's output. |
141 | */ |
142 | RunnerResult Result() const { |
143 | ICHECK(f_result != nullptr) << "PyRunnerFuture's Result method not implemented!" ; |
144 | return f_result(); |
145 | } |
146 | |
147 | static constexpr const char* _type_key = "meta_schedule.RunnerFuture" ; |
148 | TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object); |
149 | }; |
150 | |
151 | /*! |
152 | * \brief Managed reference to RunnerFutureNode |
153 | * \sa RunnerFutureNode |
154 | */ |
155 | class RunnerFuture : public runtime::ObjectRef { |
156 | public: |
157 | using FDone = RunnerFutureNode::FDone; |
158 | using FResult = RunnerFutureNode::FResult; |
159 | |
160 | /*! |
161 | * \brief Constructor of RunnerFuture |
162 | * \param f_done The packed function to check whether the runner has finished. |
163 | * \param f_result The packed function to fetch runner output if it is ready. |
164 | */ |
165 | TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result); |
166 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef, |
167 | RunnerFutureNode); |
168 | }; |
169 | |
170 | /*! \brief The abstract runner interface. */ |
171 | class RunnerNode : public runtime::Object { |
172 | public: |
173 | /*! |
174 | * \brief The function type to run the built artifacts and get runner futures. |
175 | * \param input The runner's inputs. |
176 | * \return The runner futures. |
177 | * \sa RunnerFuture |
178 | */ |
179 | using FRun = runtime::TypedPackedFunc<Array<RunnerFuture>(Array<RunnerInput>)>; |
180 | |
181 | /*! \brief Default destructor */ |
182 | virtual ~RunnerNode() = default; |
183 | |
184 | /*! |
185 | * \brief Run the built artifact and get runner futures. |
186 | * \param runner_inputs The runner's inputs. |
187 | * \return The runner futures. |
188 | */ |
189 | virtual Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) = 0; |
190 | |
191 | static constexpr const char* _type_key = "meta_schedule.Runner" ; |
192 | TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object); |
193 | }; |
194 | |
195 | /*! |
196 | * \brief Managed reference to RunnerNode |
197 | * \sa RunnerNode |
198 | */ |
199 | class Runner : public runtime::ObjectRef { |
200 | public: |
201 | using FRun = RunnerNode::FRun; |
202 | |
203 | /*! |
204 | * \brief Create a runner with customized build method on the python-side. |
205 | * \param f_run The packed function to run the built artifacts and get runner futures. |
206 | * \return The runner created. |
207 | */ |
208 | TVM_DLL static Runner PyRunner(FRun f_run); |
209 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode); |
210 | }; |
211 | |
212 | /*! \brief An abstract runner with customized build method on the python-side. */ |
213 | class PyRunnerNode : public RunnerNode { |
214 | public: |
215 | /*! \brief The packed function to run the built artifacts and get runner futures. */ |
216 | FRun f_run; |
217 | |
218 | void VisitAttrs(tvm::AttrVisitor* v) { |
219 | // `f_run` is not visited |
220 | } |
221 | |
222 | Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final { |
223 | ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!" ; |
224 | return f_run(runner_inputs); |
225 | } |
226 | |
227 | static constexpr const char* _type_key = "meta_schedule.PyRunner" ; |
228 | TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); |
229 | }; |
230 | |
231 | } // namespace meta_schedule |
232 | } // namespace tvm |
233 | |
234 | #endif // TVM_META_SCHEDULE_RUNNER_H_ |
235 | |