1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info) {
25 ObjectPtr<MeasureCandidateNode> n = make_object<MeasureCandidateNode>();
26 n->sch = sch;
27 n->args_info = args_info;
28 data_ = std::move(n);
29}
30
31void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) {
32 ICHECK(f_initialize_with_tune_context != nullptr)
33 << "PySearchStrategy's InitializeWithTuneContext method not implemented!";
34 f_initialize_with_tune_context(context);
35}
36
37void PySearchStrategyNode::PreTuning(int max_trials, int num_trials_per_iter,
38 const Array<tir::Schedule>& design_spaces,
39 const Optional<Database>& database,
40 const Optional<CostModel>& cost_model) {
41 ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
42 f_pre_tuning(max_trials, num_trials_per_iter, design_spaces, database, cost_model);
43}
44
45void PySearchStrategyNode::PostTuning() {
46 ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!";
47 f_post_tuning();
48}
49
50Optional<Array<MeasureCandidate>> PySearchStrategyNode::GenerateMeasureCandidates() {
51 ICHECK(f_generate_measure_candidates != nullptr)
52 << "PySearchStrategy's GenerateMeasureCandidates method not implemented!";
53 return f_generate_measure_candidates();
54}
55
56void PySearchStrategyNode::NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
57 const Array<RunnerResult>& results) {
58 ICHECK(f_notify_runner_results != nullptr)
59 << "PySearchStrategy's NotifyRunnerResults method not implemented!";
60 f_notify_runner_results(measure_candidates, results);
61}
62
63SearchStrategy PySearchStrategyNode::Clone() const {
64 ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!";
65 return f_clone();
66}
67
68SearchStrategy SearchStrategy::PySearchStrategy(
69 PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
70 PySearchStrategyNode::FPreTuning f_pre_tuning, //
71 PySearchStrategyNode::FPostTuning f_post_tuning, //
72 PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, //
73 PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results, //
74 PySearchStrategyNode::FClone f_clone) {
75 ObjectPtr<PySearchStrategyNode> n = make_object<PySearchStrategyNode>();
76 n->f_initialize_with_tune_context = f_initialize_with_tune_context;
77 n->f_pre_tuning = f_pre_tuning;
78 n->f_post_tuning = f_post_tuning;
79 n->f_generate_measure_candidates = f_generate_measure_candidates;
80 n->f_notify_runner_results = f_notify_runner_results;
81 n->f_clone = f_clone;
82 return SearchStrategy(n);
83}
84
85TVM_REGISTER_NODE_TYPE(MeasureCandidateNode);
86TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode);
87TVM_REGISTER_NODE_TYPE(PySearchStrategyNode);
88
89TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate")
90 .set_body_typed([](tir::Schedule sch, Array<ArgInfo> args_info) -> MeasureCandidate {
91 return MeasureCandidate(sch, args_info);
92 });
93TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy")
94 .set_body_typed(SearchStrategy::PySearchStrategy);
95TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext")
96 .set_body_method<SearchStrategy>(&SearchStrategyNode::InitializeWithTuneContext);
97TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning")
98 .set_body_method<SearchStrategy>(&SearchStrategyNode::PreTuning);
99TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning")
100 .set_body_method<SearchStrategy>(&SearchStrategyNode::PostTuning);
101TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates")
102 .set_body_method<SearchStrategy>(&SearchStrategyNode::GenerateMeasureCandidates);
103TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults")
104 .set_body_method<SearchStrategy>(&SearchStrategyNode::NotifyRunnerResults);
105TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone")
106 .set_body_method<SearchStrategy>(&SearchStrategyNode::Clone);
107
108} // namespace meta_schedule
109} // namespace tvm
110