1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info) { |
25 | ObjectPtr<MeasureCandidateNode> n = make_object<MeasureCandidateNode>(); |
26 | n->sch = sch; |
27 | n->args_info = args_info; |
28 | data_ = std::move(n); |
29 | } |
30 | |
31 | void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) { |
32 | ICHECK(f_initialize_with_tune_context != nullptr) |
33 | << "PySearchStrategy's InitializeWithTuneContext method not implemented!" ; |
34 | f_initialize_with_tune_context(context); |
35 | } |
36 | |
37 | void PySearchStrategyNode::PreTuning(int max_trials, int num_trials_per_iter, |
38 | const Array<tir::Schedule>& design_spaces, |
39 | const Optional<Database>& database, |
40 | const Optional<CostModel>& cost_model) { |
41 | ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!" ; |
42 | f_pre_tuning(max_trials, num_trials_per_iter, design_spaces, database, cost_model); |
43 | } |
44 | |
45 | void PySearchStrategyNode::PostTuning() { |
46 | ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!" ; |
47 | f_post_tuning(); |
48 | } |
49 | |
50 | Optional<Array<MeasureCandidate>> PySearchStrategyNode::GenerateMeasureCandidates() { |
51 | ICHECK(f_generate_measure_candidates != nullptr) |
52 | << "PySearchStrategy's GenerateMeasureCandidates method not implemented!" ; |
53 | return f_generate_measure_candidates(); |
54 | } |
55 | |
56 | void PySearchStrategyNode::NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates, |
57 | const Array<RunnerResult>& results) { |
58 | ICHECK(f_notify_runner_results != nullptr) |
59 | << "PySearchStrategy's NotifyRunnerResults method not implemented!" ; |
60 | f_notify_runner_results(measure_candidates, results); |
61 | } |
62 | |
63 | SearchStrategy PySearchStrategyNode::Clone() const { |
64 | ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!" ; |
65 | return f_clone(); |
66 | } |
67 | |
68 | SearchStrategy SearchStrategy::PySearchStrategy( |
69 | PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // |
70 | PySearchStrategyNode::FPreTuning f_pre_tuning, // |
71 | PySearchStrategyNode::FPostTuning f_post_tuning, // |
72 | PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // |
73 | PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results, // |
74 | PySearchStrategyNode::FClone f_clone) { |
75 | ObjectPtr<PySearchStrategyNode> n = make_object<PySearchStrategyNode>(); |
76 | n->f_initialize_with_tune_context = f_initialize_with_tune_context; |
77 | n->f_pre_tuning = f_pre_tuning; |
78 | n->f_post_tuning = f_post_tuning; |
79 | n->f_generate_measure_candidates = f_generate_measure_candidates; |
80 | n->f_notify_runner_results = f_notify_runner_results; |
81 | n->f_clone = f_clone; |
82 | return SearchStrategy(n); |
83 | } |
84 | |
85 | TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); |
86 | TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); |
87 | TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); |
88 | |
89 | TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate" ) |
90 | .set_body_typed([](tir::Schedule sch, Array<ArgInfo> args_info) -> MeasureCandidate { |
91 | return MeasureCandidate(sch, args_info); |
92 | }); |
93 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy" ) |
94 | .set_body_typed(SearchStrategy::PySearchStrategy); |
95 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext" ) |
96 | .set_body_method<SearchStrategy>(&SearchStrategyNode::InitializeWithTuneContext); |
97 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning" ) |
98 | .set_body_method<SearchStrategy>(&SearchStrategyNode::PreTuning); |
99 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning" ) |
100 | .set_body_method<SearchStrategy>(&SearchStrategyNode::PostTuning); |
101 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates" ) |
102 | .set_body_method<SearchStrategy>(&SearchStrategyNode::GenerateMeasureCandidates); |
103 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults" ) |
104 | .set_body_method<SearchStrategy>(&SearchStrategyNode::NotifyRunnerResults); |
105 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone" ) |
106 | .set_body_method<SearchStrategy>(&SearchStrategyNode::Clone); |
107 | |
108 | } // namespace meta_schedule |
109 | } // namespace tvm |
110 | |