1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/******** Constructors ********/
25
26BuilderInput::BuilderInput(IRModule mod, Target target,
27 Optional<Map<String, runtime::NDArray>> params) {
28 ObjectPtr<BuilderInputNode> n = make_object<BuilderInputNode>();
29 n->mod = std::move(mod);
30 n->target = std::move(target);
31 n->params = std::move(params);
32 data_ = std::move(n);
33}
34
35BuilderResult::BuilderResult(Optional<String> artifact_path, Optional<String> error_msg) {
36 ObjectPtr<BuilderResultNode> n = make_object<BuilderResultNode>();
37 n->artifact_path = std::move(artifact_path);
38 n->error_msg = std::move(error_msg);
39 data_ = std::move(n);
40}
41
42Builder Builder::PyBuilder(BuilderNode::FBuild f_build) {
43 ObjectPtr<PyBuilderNode> n = make_object<PyBuilderNode>();
44 n->f_build = std::move(f_build);
45 return Builder(std::move(n));
46}
47
48/******** FFI ********/
49
50TVM_REGISTER_NODE_TYPE(BuilderInputNode);
51TVM_REGISTER_NODE_TYPE(BuilderResultNode);
52TVM_REGISTER_OBJECT_TYPE(BuilderNode);
53TVM_REGISTER_NODE_TYPE(PyBuilderNode);
54
55TVM_REGISTER_GLOBAL("meta_schedule.BuilderInput")
56 .set_body_typed([](IRModule mod, Target target,
57 Optional<Map<String, runtime::NDArray>> params) -> BuilderInput {
58 return BuilderInput(mod, target, params);
59 });
60
61TVM_REGISTER_GLOBAL("meta_schedule.BuilderResult")
62 .set_body_typed([](Optional<String> artifact_path,
63 Optional<String> error_msg) -> BuilderResult {
64 return BuilderResult(artifact_path, error_msg);
65 });
66
67TVM_REGISTER_GLOBAL("meta_schedule.BuilderBuild").set_body_method<Builder>(&BuilderNode::Build);
68
69TVM_REGISTER_GLOBAL("meta_schedule.BuilderPyBuilder").set_body_typed(Builder::PyBuilder);
70
71} // namespace meta_schedule
72} // namespace tvm
73