1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief A search strategy that generates measure candidates using space generator. */
25class ReplayFuncNode : public SearchStrategyNode {
26 public:
27 /*! \brief The state of the search strategy. */
28 struct State {
29 /*! \brief The search strategy itself */
30 ReplayFuncNode* self;
31 /*! \brief The number of total trials. */
32 int max_trials;
33 /*! \brief The number of trials per iteration. */
34 int num_trials_per_iter;
35 /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
36 int st;
37 /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
38 int ed;
39
40 explicit State(ReplayFuncNode* self, int max_trials, int num_trials_per_iter)
41 : self(self),
42 max_trials(max_trials),
43 num_trials_per_iter(num_trials_per_iter),
44 st(0),
45 ed(num_trials_per_iter) {
46 CHECK(self->mod_.defined() && self->space_generator_.defined())
47 << "ValueError: The search strategy has not been initialized.";
48 }
49
50 inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
51 inline void NotifyRunnerResults(const Array<RunnerResult>& results);
52 };
53
54 /*! \brief The random state. -1 means using random number. */
55 TRandState rand_state_ = -1;
56 /*! \brief The IRModule to be scheduled from TuneContext. */
57 Optional<IRModule> mod_ = NullOpt;
58 /*! \brief The space generator from TuneContext. */
59 Optional<SpaceGenerator> space_generator_ = NullOpt;
60 /*! \brief The state of the search strategy. */
61 std::unique_ptr<State> state_ = nullptr;
62
63 void VisitAttrs(tvm::AttrVisitor* v) {}
64
65 static constexpr const char* _type_key = "meta_schedule.ReplayFunc";
66 TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);
67
68 void InitializeWithTuneContext(const TuneContext& ctx) final {
69 CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined";
70 CHECK(ctx->space_generator.defined())
71 << "ValueError: TuneContext.space_generator is not defined";
72 if (!ctx->space_generator.value()->postprocs.defined()) {
73 TVM_PY_LOG(WARNING, ctx->logger)
74 << "`postprocs` is not defined in " << ctx->space_generator.value()
75 << ". Please explicitly set `postprocs` to an empty list if you don't want to "
76 "apply any post-processing.";
77 }
78 this->rand_state_ = ForkSeed(&ctx->rand_state);
79 this->mod_ = ctx->mod;
80 this->space_generator_ = ctx->space_generator;
81 this->state_.reset();
82 }
83
84 void PreTuning(int max_trials, int num_trials_per_iter, const Array<tir::Schedule>& design_spaces,
85 const Optional<Database>& database, const Optional<CostModel>& cost_model) final {
86 CHECK(this->state_ == nullptr)
87 << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`.";
88 this->state_ = std::make_unique<State>(this, max_trials, num_trials_per_iter);
89 }
90
91 void PostTuning() final {
92 CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding "
93 "`PreTuning`, or `PostTuning` is already invoked.";
94 this->state_.reset();
95 }
96
97 Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
98 ICHECK(this->state_ != nullptr);
99 return this->state_->GenerateMeasureCandidates();
100 }
101
102 void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
103 const Array<RunnerResult>& results) final {
104 ICHECK(this->state_ != nullptr);
105 this->state_->NotifyRunnerResults(results);
106 }
107
108 SearchStrategy Clone() const final {
109 ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
110 n->rand_state_ = -1;
111 n->mod_ = NullOpt;
112 n->space_generator_ = NullOpt;
113 n->state_ = nullptr;
114 return SearchStrategy(n);
115 }
116};
117
118inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureCandidates() {
119 if (st >= max_trials) {
120 return NullOpt;
121 }
122 ed = std::min(ed, max_trials);
123 Array<MeasureCandidate> result;
124 IRModule mod = self->mod_.value();
125 Array<Postproc> postprocs = self->space_generator_.value()->postprocs.value_or({});
126 for (int i = st; i < ed; i++) {
127 for (;;) {
128 Array<tir::Schedule> schs = self->space_generator_.value()->GenerateDesignSpace(mod);
129 int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size());
130 tir::Schedule sch = schs[design_space_index];
131 sch->EnterPostproc();
132 bool failed = false;
133 for (const Postproc& proc : postprocs) {
134 if (!proc->Apply(sch)) {
135 failed = true;
136 break;
137 }
138 }
139 if (!failed) {
140 Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true);
141 result.push_back(MeasureCandidate(sch, args_info));
142 break;
143 }
144 }
145 }
146 return result;
147}
148
149inline void ReplayFuncNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) {
150 st += num_trials_per_iter;
151 ed += num_trials_per_iter;
152}
153
154SearchStrategy SearchStrategy::ReplayFunc() {
155 ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
156 return SearchStrategy(n);
157}
158
159TVM_REGISTER_NODE_TYPE(ReplayFuncNode);
160TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc")
161 .set_body_typed(SearchStrategy::ReplayFunc);
162
163} // namespace meta_schedule
164} // namespace tvm
165