1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief A search strategy that generates measure candidates using trace and random decisions. */
25class ReplayTraceNode : public SearchStrategyNode {
26 public:
27 /*! \brief The state of the search strategy. */
28 struct State {
29 /*! \brief The search strategy itself */
30 ReplayTraceNode* self;
31 /*! \brief The design spaces. */
32 Array<tir::Trace> design_spaces;
33 /*! \brief The number of total trials. */
34 int max_trials;
35 /*! \brief The number of trials per iteration. */
36 int num_trials_per_iter;
37 /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
38 int st;
39 /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
40 int ed;
41
42 /*! \brief The module to be tuned. */
43 Array<IRModule> per_thread_mod_{nullptr};
44
45 explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces, int max_trials,
46 int num_trials_per_iter)
47 : self(self),
48 design_spaces(design_spaces),
49 max_trials(max_trials),
50 num_trials_per_iter(num_trials_per_iter),
51 st(0),
52 ed(num_trials_per_iter) {
53 IRModule mod = self->mod_.value();
54 this->per_thread_mod_.reserve(self->num_threads_);
55 for (int i = 0; i < self->num_threads_; i++) {
56 this->per_thread_mod_.push_back(DeepCopyIRModule(mod));
57 }
58 }
59
60 inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
61 inline void NotifyRunnerResults(const Array<RunnerResult>& results);
62 };
63
64 /*! \brief The max number of failures during trace replaying. */
65 int max_fail_count;
66
67 /*! \brief The random state. -1 means using random number. */
68 TRandState rand_state_ = -1;
69 /*! \brief The IRModule to be scheduled from TuneContext. */
70 Optional<IRModule> mod_ = NullOpt;
71 /*! \brief The number of threads to be used. */
72 int num_threads_ = -1;
73 /*! \brief The postprocessors. */
74 Array<Postproc> postprocs_ = {};
75 /*! \brief The state of the search strategy. */
76 std::unique_ptr<State> state_ = nullptr;
77
78 void VisitAttrs(tvm::AttrVisitor* v) {
79 v->Visit("max_fail_count", &max_fail_count);
80 // `rand_state_` is not visited
81 // `mod_` is not visited
82 // `num_threads_` is not visited
83 // `postprocs_` is not visited
84 // `state_` is not visited
85 }
86
87 static constexpr const char* _type_key = "meta_schedule.ReplayTrace";
88 TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode);
89
90 void InitializeWithTuneContext(const TuneContext& ctx) final {
91 CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined";
92 CHECK(ctx->space_generator.defined())
93 << "ValueError: TuneContext.space_generator is not defined";
94 if (!ctx->space_generator.value()->postprocs.defined()) {
95 TVM_PY_LOG(WARNING, ctx->logger)
96 << "`postprocs` is not defined in " << ctx->space_generator.value()
97 << ". Please explicitly set `postprocs` to an empty list if you don't want to "
98 "apply any post-processing.";
99 }
100 this->rand_state_ = ForkSeed(&ctx->rand_state);
101 this->mod_ = ctx->mod;
102 this->num_threads_ = ctx->num_threads;
103 this->postprocs_ = ctx->space_generator.value()->postprocs.value_or({});
104 this->state_.reset();
105 }
106
107 void PreTuning(int max_trials, int num_trials_per_iter, const Array<tir::Schedule>& design_spaces,
108 const Optional<Database>& database, const Optional<CostModel>& cost_model) final {
109 ICHECK(!design_spaces.empty());
110 CHECK(this->state_ == nullptr)
111 << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`.";
112 Array<tir::Trace> design_space_traces;
113 design_space_traces.reserve(design_spaces.size());
114 for (const tir::Schedule& space : design_spaces) {
115 design_space_traces.push_back(space->trace().value()->Simplified(true));
116 }
117 this->state_ =
118 std::make_unique<State>(this, design_space_traces, max_trials, num_trials_per_iter);
119 }
120
121 void PostTuning() final {
122 ICHECK(this->state_ != nullptr);
123 this->state_.reset();
124 }
125
126 Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
127 ICHECK(this->state_ != nullptr);
128 return this->state_->GenerateMeasureCandidates();
129 }
130
131 void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
132 const Array<RunnerResult>& results) final {
133 ICHECK(this->state_ != nullptr);
134 this->state_->NotifyRunnerResults(results);
135 }
136
137 SearchStrategy Clone() const final {
138 ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>();
139 n->max_fail_count = this->max_fail_count;
140 n->rand_state_ = this->rand_state_;
141 n->state_ = nullptr; // cleared the state
142 return SearchStrategy(n);
143 }
144};
145
146inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasureCandidates() {
147 if (st >= max_trials) {
148 return NullOpt;
149 }
150 ed = std::min(ed, max_trials);
151 ICHECK_LT(st, ed);
152 std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_);
153 Array<MeasureCandidate> per_task_result(ed - st, MeasureCandidate{nullptr});
154 ThreadedTraceApply pp(self->postprocs_);
155 auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id,
156 int task_id) -> void {
157 TRandState& rand_state = per_thread_rand_state[thread_id];
158 IRModule mod = this->per_thread_mod_[thread_id];
159
160 for (int fail_count = 0; fail_count < self->max_fail_count; fail_count++) {
161 int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
162 tir::Trace trace = design_spaces[design_space_index];
163 tir::Trace new_trace = tir::Trace(trace->insts, {});
164 if (Optional<tir::Schedule> opt_sch = pp.Apply(mod, new_trace, &rand_state)) {
165 tir::Schedule sch = opt_sch.value();
166 Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true);
167 per_task_result.Set(task_id, MeasureCandidate(sch, args_info));
168 break;
169 }
170 }
171 };
172 support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker);
173 Array<MeasureCandidate> filtered;
174 filtered.reserve(ed - st);
175 for (MeasureCandidate result : per_task_result)
176 if (result.defined()) {
177 filtered.push_back(result);
178 }
179 return filtered;
180}
181
182inline void ReplayTraceNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) {
183 st += num_trials_per_iter;
184 ed += num_trials_per_iter;
185}
186
187SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) {
188 ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>();
189 n->max_fail_count = max_fail_count;
190 return SearchStrategy(n);
191}
192
193TVM_REGISTER_NODE_TYPE(ReplayTraceNode);
194TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace")
195 .set_body_typed(SearchStrategy::ReplayTrace);
196
197} // namespace meta_schedule
198} // namespace tvm
199