1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_META_SCHEDULE_BUILDER_H_ |
20 | #define TVM_META_SCHEDULE_BUILDER_H_ |
21 | |
22 | #include <tvm/ir/module.h> |
23 | #include <tvm/node/reflection.h> |
24 | #include <tvm/runtime/container/array.h> |
25 | #include <tvm/runtime/container/map.h> |
26 | #include <tvm/runtime/container/optional.h> |
27 | #include <tvm/runtime/container/string.h> |
28 | #include <tvm/runtime/ndarray.h> |
29 | #include <tvm/runtime/object.h> |
30 | #include <tvm/runtime/packed_func.h> |
31 | #include <tvm/target/target.h> |
32 | |
33 | namespace tvm { |
34 | namespace meta_schedule { |
35 | |
36 | /*! \brief The builder's input, containing an IRModule and the target. */ |
37 | class BuilderInputNode : public runtime::Object { |
38 | public: |
39 | /*! \brief The IRModule to be built. */ |
40 | IRModule mod; |
41 | /*! \brief The target to be built for. */ |
42 | Target target; |
43 | /*! \brief Parameters for Relay build module. */ |
44 | Optional<Map<String, runtime::NDArray>> params; |
45 | |
46 | void VisitAttrs(tvm::AttrVisitor* v) { |
47 | v->Visit("mod" , &mod); |
48 | v->Visit("target" , &target); |
49 | v->Visit("params" , ¶ms); |
50 | } |
51 | |
52 | static constexpr const char* _type_key = "meta_schedule.BuilderInput" ; |
53 | TVM_DECLARE_FINAL_OBJECT_INFO(BuilderInputNode, runtime::Object); |
54 | }; |
55 | |
56 | /*! |
57 | * \brief Managed reference to BuilderInputNode |
58 | * \sa BuilderInputNode |
59 | */ |
60 | class BuilderInput : public runtime::ObjectRef { |
61 | public: |
62 | /*! |
63 | * \brief Constructor of BuilderInput. |
64 | * \param mod The IRModule to be built. |
65 | * \param target The target to be built for. |
66 | * \param params Parameters for Relay build module. |
67 | */ |
68 | TVM_DLL explicit BuilderInput(IRModule mod, Target target, |
69 | Optional<Map<String, runtime::NDArray>> params = NullOpt); |
70 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); |
71 | }; |
72 | |
73 | /*! \brief The builder's output, containing the artifact path or error message if any. */ |
74 | class BuilderResultNode : public runtime::Object { |
75 | public: |
76 | /*! \brief The path to the built artifact. */ |
77 | Optional<String> artifact_path; |
78 | /*! \brief The error message if any. */ |
79 | Optional<String> error_msg; |
80 | |
81 | void VisitAttrs(tvm::AttrVisitor* v) { |
82 | v->Visit("artifact_path" , &artifact_path); |
83 | v->Visit("error_msg" , &error_msg); |
84 | } |
85 | |
86 | static constexpr const char* _type_key = "meta_schedule.BuilderResult" ; |
87 | TVM_DECLARE_FINAL_OBJECT_INFO(BuilderResultNode, runtime::Object); |
88 | }; |
89 | |
90 | /*! |
91 | * \brief Managed reference to BuilderResultNode |
92 | * \sa BuilderResultNode |
93 | */ |
94 | class BuilderResult : public runtime::ObjectRef { |
95 | public: |
96 | /*! |
97 | * \brief Constructor of BuilderResult. |
98 | * \param artifact_path The path to the built artifact. |
99 | * \param error_msg The error message if any. |
100 | */ |
101 | TVM_DLL explicit BuilderResult(Optional<String> artifact_path, Optional<String> error_msg); |
102 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderResult, runtime::ObjectRef, BuilderResultNode); |
103 | }; |
104 | |
105 | /*! \brief The abstract builder interface. */ |
106 | class BuilderNode : public runtime::Object { |
107 | public: |
108 | /*! \brief Default destructor */ |
109 | virtual ~BuilderNode() = default; |
110 | /*! |
111 | * \brief Generate the build results from build inputs. |
112 | * \param build_inputs The inputs to be built. |
113 | * \return The build results. |
114 | */ |
115 | virtual Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) = 0; |
116 | /*! |
117 | * \brief The function type of `Build` method. |
118 | * \param build_inputs The inputs to be built. |
119 | * \return The build results. |
120 | */ |
121 | using FBuild = runtime::TypedPackedFunc<Array<BuilderResult>(const Array<BuilderInput>&)>; |
122 | |
123 | static constexpr const char* _type_key = "meta_schedule.Builder" ; |
124 | TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object); |
125 | }; |
126 | |
127 | /*! |
128 | * \brief Managed reference to BuilderNode |
129 | * \sa BuilderNode |
130 | */ |
131 | class Builder : public runtime::ObjectRef { |
132 | public: |
133 | /*! |
134 | * \brief Create a builder with customized build method on the python-side. |
135 | * \param f_build The packed function to the `Build` function.. |
136 | * \return The Builder created. |
137 | */ |
138 | static Builder PyBuilder(BuilderNode::FBuild f_build); |
139 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Builder, runtime::ObjectRef, BuilderNode); |
140 | }; |
141 | |
142 | /*! \brief An abstract builder with customized build method on the python-side. */ |
143 | class PyBuilderNode : public BuilderNode { |
144 | public: |
145 | /*! \brief The packed function to the `Build` function. */ |
146 | FBuild f_build; |
147 | |
148 | void VisitAttrs(tvm::AttrVisitor* v) { |
149 | // `f_build` is not visited |
150 | } |
151 | |
152 | Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) final { |
153 | ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!" ; |
154 | return f_build(build_inputs); |
155 | } |
156 | |
157 | static constexpr const char* _type_key = "meta_schedule.PyBuilder" ; |
158 | TVM_DECLARE_FINAL_OBJECT_INFO(PyBuilderNode, BuilderNode); |
159 | }; |
160 | |
161 | } // namespace meta_schedule |
162 | } // namespace tvm |
163 | |
164 | #endif // TVM_META_SCHEDULE_BUILDER_H_ |
165 | |