1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | /*! \brief A search strategy that generates measure candidates using trace and random decisions. */ |
25 | class ReplayTraceNode : public SearchStrategyNode { |
26 | public: |
27 | /*! \brief The state of the search strategy. */ |
28 | struct State { |
29 | /*! \brief The search strategy itself */ |
30 | ReplayTraceNode* self; |
31 | /*! \brief The design spaces. */ |
32 | Array<tir::Trace> design_spaces; |
33 | /*! \brief The number of total trials. */ |
34 | int max_trials; |
35 | /*! \brief The number of trials per iteration. */ |
36 | int num_trials_per_iter; |
37 | /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ |
38 | int st; |
39 | /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ |
40 | int ed; |
41 | |
42 | /*! \brief The module to be tuned. */ |
43 | Array<IRModule> per_thread_mod_{nullptr}; |
44 | |
45 | explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces, int max_trials, |
46 | int num_trials_per_iter) |
47 | : self(self), |
48 | design_spaces(design_spaces), |
49 | max_trials(max_trials), |
50 | num_trials_per_iter(num_trials_per_iter), |
51 | st(0), |
52 | ed(num_trials_per_iter) { |
53 | IRModule mod = self->mod_.value(); |
54 | this->per_thread_mod_.reserve(self->num_threads_); |
55 | for (int i = 0; i < self->num_threads_; i++) { |
56 | this->per_thread_mod_.push_back(DeepCopyIRModule(mod)); |
57 | } |
58 | } |
59 | |
60 | inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates(); |
61 | inline void NotifyRunnerResults(const Array<RunnerResult>& results); |
62 | }; |
63 | |
64 | /*! \brief The max number of failures during trace replaying. */ |
65 | int max_fail_count; |
66 | |
67 | /*! \brief The random state. -1 means using random number. */ |
68 | TRandState rand_state_ = -1; |
69 | /*! \brief The IRModule to be scheduled from TuneContext. */ |
70 | Optional<IRModule> mod_ = NullOpt; |
71 | /*! \brief The number of threads to be used. */ |
72 | int num_threads_ = -1; |
73 | /*! \brief The postprocessors. */ |
74 | Array<Postproc> postprocs_ = {}; |
75 | /*! \brief The state of the search strategy. */ |
76 | std::unique_ptr<State> state_ = nullptr; |
77 | |
78 | void VisitAttrs(tvm::AttrVisitor* v) { |
79 | v->Visit("max_fail_count" , &max_fail_count); |
80 | // `rand_state_` is not visited |
81 | // `mod_` is not visited |
82 | // `num_threads_` is not visited |
83 | // `postprocs_` is not visited |
84 | // `state_` is not visited |
85 | } |
86 | |
87 | static constexpr const char* _type_key = "meta_schedule.ReplayTrace" ; |
88 | TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); |
89 | |
90 | void InitializeWithTuneContext(const TuneContext& ctx) final { |
91 | CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined" ; |
92 | CHECK(ctx->space_generator.defined()) |
93 | << "ValueError: TuneContext.space_generator is not defined" ; |
94 | if (!ctx->space_generator.value()->postprocs.defined()) { |
95 | TVM_PY_LOG(WARNING, ctx->logger) |
96 | << "`postprocs` is not defined in " << ctx->space_generator.value() |
97 | << ". Please explicitly set `postprocs` to an empty list if you don't want to " |
98 | "apply any post-processing." ; |
99 | } |
100 | this->rand_state_ = ForkSeed(&ctx->rand_state); |
101 | this->mod_ = ctx->mod; |
102 | this->num_threads_ = ctx->num_threads; |
103 | this->postprocs_ = ctx->space_generator.value()->postprocs.value_or({}); |
104 | this->state_.reset(); |
105 | } |
106 | |
107 | void PreTuning(int max_trials, int num_trials_per_iter, const Array<tir::Schedule>& design_spaces, |
108 | const Optional<Database>& database, const Optional<CostModel>& cost_model) final { |
109 | ICHECK(!design_spaces.empty()); |
110 | CHECK(this->state_ == nullptr) |
111 | << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`." ; |
112 | Array<tir::Trace> design_space_traces; |
113 | design_space_traces.reserve(design_spaces.size()); |
114 | for (const tir::Schedule& space : design_spaces) { |
115 | design_space_traces.push_back(space->trace().value()->Simplified(true)); |
116 | } |
117 | this->state_ = |
118 | std::make_unique<State>(this, design_space_traces, max_trials, num_trials_per_iter); |
119 | } |
120 | |
121 | void PostTuning() final { |
122 | ICHECK(this->state_ != nullptr); |
123 | this->state_.reset(); |
124 | } |
125 | |
126 | Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final { |
127 | ICHECK(this->state_ != nullptr); |
128 | return this->state_->GenerateMeasureCandidates(); |
129 | } |
130 | |
131 | void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates, |
132 | const Array<RunnerResult>& results) final { |
133 | ICHECK(this->state_ != nullptr); |
134 | this->state_->NotifyRunnerResults(results); |
135 | } |
136 | |
137 | SearchStrategy Clone() const final { |
138 | ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>(); |
139 | n->max_fail_count = this->max_fail_count; |
140 | n->rand_state_ = this->rand_state_; |
141 | n->state_ = nullptr; // cleared the state |
142 | return SearchStrategy(n); |
143 | } |
144 | }; |
145 | |
146 | inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasureCandidates() { |
147 | if (st >= max_trials) { |
148 | return NullOpt; |
149 | } |
150 | ed = std::min(ed, max_trials); |
151 | ICHECK_LT(st, ed); |
152 | std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); |
153 | Array<MeasureCandidate> per_task_result(ed - st, MeasureCandidate{nullptr}); |
154 | ThreadedTraceApply pp(self->postprocs_); |
155 | auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id, |
156 | int task_id) -> void { |
157 | TRandState& rand_state = per_thread_rand_state[thread_id]; |
158 | IRModule mod = this->per_thread_mod_[thread_id]; |
159 | |
160 | for (int fail_count = 0; fail_count < self->max_fail_count; fail_count++) { |
161 | int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); |
162 | tir::Trace trace = design_spaces[design_space_index]; |
163 | tir::Trace new_trace = tir::Trace(trace->insts, {}); |
164 | if (Optional<tir::Schedule> opt_sch = pp.Apply(mod, new_trace, &rand_state)) { |
165 | tir::Schedule sch = opt_sch.value(); |
166 | Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); |
167 | per_task_result.Set(task_id, MeasureCandidate(sch, args_info)); |
168 | break; |
169 | } |
170 | } |
171 | }; |
172 | support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); |
173 | Array<MeasureCandidate> filtered; |
174 | filtered.reserve(ed - st); |
175 | for (MeasureCandidate result : per_task_result) |
176 | if (result.defined()) { |
177 | filtered.push_back(result); |
178 | } |
179 | return filtered; |
180 | } |
181 | |
182 | inline void ReplayTraceNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) { |
183 | st += num_trials_per_iter; |
184 | ed += num_trials_per_iter; |
185 | } |
186 | |
187 | SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { |
188 | ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>(); |
189 | n->max_fail_count = max_fail_count; |
190 | return SearchStrategy(n); |
191 | } |
192 | |
193 | TVM_REGISTER_NODE_TYPE(ReplayTraceNode); |
194 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace" ) |
195 | .set_body_typed(SearchStrategy::ReplayTrace); |
196 | |
197 | } // namespace meta_schedule |
198 | } // namespace tvm |
199 | |