1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_BUILDER_H_
20#define TVM_META_SCHEDULE_BUILDER_H_
21
22#include <tvm/ir/module.h>
23#include <tvm/node/reflection.h>
24#include <tvm/runtime/container/array.h>
25#include <tvm/runtime/container/map.h>
26#include <tvm/runtime/container/optional.h>
27#include <tvm/runtime/container/string.h>
28#include <tvm/runtime/ndarray.h>
29#include <tvm/runtime/object.h>
30#include <tvm/runtime/packed_func.h>
31#include <tvm/target/target.h>
32
33namespace tvm {
34namespace meta_schedule {
35
36/*! \brief The builder's input, containing an IRModule and the target. */
37class BuilderInputNode : public runtime::Object {
38 public:
39 /*! \brief The IRModule to be built. */
40 IRModule mod;
41 /*! \brief The target to be built for. */
42 Target target;
43 /*! \brief Parameters for Relay build module. */
44 Optional<Map<String, runtime::NDArray>> params;
45
46 void VisitAttrs(tvm::AttrVisitor* v) {
47 v->Visit("mod", &mod);
48 v->Visit("target", &target);
49 v->Visit("params", &params);
50 }
51
52 static constexpr const char* _type_key = "meta_schedule.BuilderInput";
53 TVM_DECLARE_FINAL_OBJECT_INFO(BuilderInputNode, runtime::Object);
54};
55
56/*!
57 * \brief Managed reference to BuilderInputNode
58 * \sa BuilderInputNode
59 */
60class BuilderInput : public runtime::ObjectRef {
61 public:
62 /*!
63 * \brief Constructor of BuilderInput.
64 * \param mod The IRModule to be built.
65 * \param target The target to be built for.
66 * \param params Parameters for Relay build module.
67 */
68 TVM_DLL explicit BuilderInput(IRModule mod, Target target,
69 Optional<Map<String, runtime::NDArray>> params = NullOpt);
70 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode);
71};
72
73/*! \brief The builder's output, containing the artifact path or error message if any. */
74class BuilderResultNode : public runtime::Object {
75 public:
76 /*! \brief The path to the built artifact. */
77 Optional<String> artifact_path;
78 /*! \brief The error message if any. */
79 Optional<String> error_msg;
80
81 void VisitAttrs(tvm::AttrVisitor* v) {
82 v->Visit("artifact_path", &artifact_path);
83 v->Visit("error_msg", &error_msg);
84 }
85
86 static constexpr const char* _type_key = "meta_schedule.BuilderResult";
87 TVM_DECLARE_FINAL_OBJECT_INFO(BuilderResultNode, runtime::Object);
88};
89
90/*!
91 * \brief Managed reference to BuilderResultNode
92 * \sa BuilderResultNode
93 */
94class BuilderResult : public runtime::ObjectRef {
95 public:
96 /*!
97 * \brief Constructor of BuilderResult.
98 * \param artifact_path The path to the built artifact.
99 * \param error_msg The error message if any.
100 */
101 TVM_DLL explicit BuilderResult(Optional<String> artifact_path, Optional<String> error_msg);
102 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderResult, runtime::ObjectRef, BuilderResultNode);
103};
104
105/*! \brief The abstract builder interface. */
106class BuilderNode : public runtime::Object {
107 public:
108 /*! \brief Default destructor */
109 virtual ~BuilderNode() = default;
110 /*!
111 * \brief Generate the build results from build inputs.
112 * \param build_inputs The inputs to be built.
113 * \return The build results.
114 */
115 virtual Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) = 0;
116 /*!
117 * \brief The function type of `Build` method.
118 * \param build_inputs The inputs to be built.
119 * \return The build results.
120 */
121 using FBuild = runtime::TypedPackedFunc<Array<BuilderResult>(const Array<BuilderInput>&)>;
122
123 static constexpr const char* _type_key = "meta_schedule.Builder";
124 TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object);
125};
126
127/*!
128 * \brief Managed reference to BuilderNode
129 * \sa BuilderNode
130 */
131class Builder : public runtime::ObjectRef {
132 public:
133 /*!
134 * \brief Create a builder with customized build method on the python-side.
135 * \param f_build The packed function to the `Build` function..
136 * \return The Builder created.
137 */
138 static Builder PyBuilder(BuilderNode::FBuild f_build);
139 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Builder, runtime::ObjectRef, BuilderNode);
140};
141
142/*! \brief An abstract builder with customized build method on the python-side. */
143class PyBuilderNode : public BuilderNode {
144 public:
145 /*! \brief The packed function to the `Build` function. */
146 FBuild f_build;
147
148 void VisitAttrs(tvm::AttrVisitor* v) {
149 // `f_build` is not visited
150 }
151
152 Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) final {
153 ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!";
154 return f_build(build_inputs);
155 }
156
157 static constexpr const char* _type_key = "meta_schedule.PyBuilder";
158 TVM_DECLARE_FINAL_OBJECT_INFO(PyBuilderNode, BuilderNode);
159};
160
161} // namespace meta_schedule
162} // namespace tvm
163
164#endif // TVM_META_SCHEDULE_BUILDER_H_
165