1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | /*! \brief A search strategy that generates measure candidates using space generator. */ |
25 | class ReplayFuncNode : public SearchStrategyNode { |
26 | public: |
27 | /*! \brief The state of the search strategy. */ |
28 | struct State { |
29 | /*! \brief The search strategy itself */ |
30 | ReplayFuncNode* self; |
31 | /*! \brief The number of total trials. */ |
32 | int max_trials; |
33 | /*! \brief The number of trials per iteration. */ |
34 | int num_trials_per_iter; |
35 | /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ |
36 | int st; |
37 | /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ |
38 | int ed; |
39 | |
40 | explicit State(ReplayFuncNode* self, int max_trials, int num_trials_per_iter) |
41 | : self(self), |
42 | max_trials(max_trials), |
43 | num_trials_per_iter(num_trials_per_iter), |
44 | st(0), |
45 | ed(num_trials_per_iter) { |
46 | CHECK(self->mod_.defined() && self->space_generator_.defined()) |
47 | << "ValueError: The search strategy has not been initialized." ; |
48 | } |
49 | |
50 | inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates(); |
51 | inline void NotifyRunnerResults(const Array<RunnerResult>& results); |
52 | }; |
53 | |
54 | /*! \brief The random state. -1 means using random number. */ |
55 | TRandState rand_state_ = -1; |
56 | /*! \brief The IRModule to be scheduled from TuneContext. */ |
57 | Optional<IRModule> mod_ = NullOpt; |
58 | /*! \brief The space generator from TuneContext. */ |
59 | Optional<SpaceGenerator> space_generator_ = NullOpt; |
60 | /*! \brief The state of the search strategy. */ |
61 | std::unique_ptr<State> state_ = nullptr; |
62 | |
63 | void VisitAttrs(tvm::AttrVisitor* v) {} |
64 | |
65 | static constexpr const char* _type_key = "meta_schedule.ReplayFunc" ; |
66 | TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); |
67 | |
68 | void InitializeWithTuneContext(const TuneContext& ctx) final { |
69 | CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined" ; |
70 | CHECK(ctx->space_generator.defined()) |
71 | << "ValueError: TuneContext.space_generator is not defined" ; |
72 | if (!ctx->space_generator.value()->postprocs.defined()) { |
73 | TVM_PY_LOG(WARNING, ctx->logger) |
74 | << "`postprocs` is not defined in " << ctx->space_generator.value() |
75 | << ". Please explicitly set `postprocs` to an empty list if you don't want to " |
76 | "apply any post-processing." ; |
77 | } |
78 | this->rand_state_ = ForkSeed(&ctx->rand_state); |
79 | this->mod_ = ctx->mod; |
80 | this->space_generator_ = ctx->space_generator; |
81 | this->state_.reset(); |
82 | } |
83 | |
84 | void PreTuning(int max_trials, int num_trials_per_iter, const Array<tir::Schedule>& design_spaces, |
85 | const Optional<Database>& database, const Optional<CostModel>& cost_model) final { |
86 | CHECK(this->state_ == nullptr) |
87 | << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`." ; |
88 | this->state_ = std::make_unique<State>(this, max_trials, num_trials_per_iter); |
89 | } |
90 | |
91 | void PostTuning() final { |
92 | CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding " |
93 | "`PreTuning`, or `PostTuning` is already invoked." ; |
94 | this->state_.reset(); |
95 | } |
96 | |
97 | Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final { |
98 | ICHECK(this->state_ != nullptr); |
99 | return this->state_->GenerateMeasureCandidates(); |
100 | } |
101 | |
102 | void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates, |
103 | const Array<RunnerResult>& results) final { |
104 | ICHECK(this->state_ != nullptr); |
105 | this->state_->NotifyRunnerResults(results); |
106 | } |
107 | |
108 | SearchStrategy Clone() const final { |
109 | ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>(); |
110 | n->rand_state_ = -1; |
111 | n->mod_ = NullOpt; |
112 | n->space_generator_ = NullOpt; |
113 | n->state_ = nullptr; |
114 | return SearchStrategy(n); |
115 | } |
116 | }; |
117 | |
118 | inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureCandidates() { |
119 | if (st >= max_trials) { |
120 | return NullOpt; |
121 | } |
122 | ed = std::min(ed, max_trials); |
123 | Array<MeasureCandidate> result; |
124 | IRModule mod = self->mod_.value(); |
125 | Array<Postproc> postprocs = self->space_generator_.value()->postprocs.value_or({}); |
126 | for (int i = st; i < ed; i++) { |
127 | for (;;) { |
128 | Array<tir::Schedule> schs = self->space_generator_.value()->GenerateDesignSpace(mod); |
129 | int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); |
130 | tir::Schedule sch = schs[design_space_index]; |
131 | sch->EnterPostproc(); |
132 | bool failed = false; |
133 | for (const Postproc& proc : postprocs) { |
134 | if (!proc->Apply(sch)) { |
135 | failed = true; |
136 | break; |
137 | } |
138 | } |
139 | if (!failed) { |
140 | Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); |
141 | result.push_back(MeasureCandidate(sch, args_info)); |
142 | break; |
143 | } |
144 | } |
145 | } |
146 | return result; |
147 | } |
148 | |
149 | inline void ReplayFuncNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) { |
150 | st += num_trials_per_iter; |
151 | ed += num_trials_per_iter; |
152 | } |
153 | |
154 | SearchStrategy SearchStrategy::ReplayFunc() { |
155 | ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>(); |
156 | return SearchStrategy(n); |
157 | } |
158 | |
159 | TVM_REGISTER_NODE_TYPE(ReplayFuncNode); |
160 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc" ) |
161 | .set_body_typed(SearchStrategy::ReplayFunc); |
162 | |
163 | } // namespace meta_schedule |
164 | } // namespace tvm |
165 | |