1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <algorithm>
20#include <unordered_map>
21
22#include "../utils.h"
23
24namespace tvm {
25namespace tir {
26
27/*!
28 * \brief Check if the instruction is annotation with `meta_schedule_parallel`
29 * \param inst The instruction to be checked
30 * \return Whether the instruction is annotation with `meta_schedule_parallel`
31 */
32bool IsAnnotateWithParallel(const Instruction& inst) {
33 static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
34 if (!inst->kind.same_as(inst_annotate)) {
35 return false;
36 }
37 ICHECK_EQ(inst->attrs.size(), 1);
38 String ann_key = Downcast<String>(inst->attrs[0]);
39 return ann_key == attr::meta_schedule_parallel;
40}
41
42/*!
43 * \brief Replace the annotation value
44 * \param inst The instruction to be replaced
45 * \param ann_val The new annotation value
46 * \return The replaced instruction
47 */
48Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) {
49 ICHECK_EQ(inst->inputs.size(), 2);
50 return Instruction(/*kind=*/inst->kind, //
51 /*inputs=*/{inst->inputs[0], Integer(ann_val)}, //
52 /*attrs=*/inst->attrs,
53 /*outputs=*/inst->outputs);
54}
55
56/*!
57 * \brief Get the output of the instruction Get-Block
58 * \param inst The instruction to be checked
59 * \return The output of the instruction Get-Block
60 */
61const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
62 static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock");
63 if (!inst->kind.same_as(inst_get_block)) {
64 return nullptr;
65 }
66 ICHECK_EQ(inst->outputs.size(), 1);
67 const BlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], BlockRVNode);
68 return block;
69}
70
71/*!
72 * \brief Analyze the parallel structure
73 * \param self The schedule state
74 * \param block_name The name of the root block
75 * \param func_name The name of the PrimFunc
76 * \param limit The uplimit of the parallelism
77 * \return The parallel structure
78 */
79std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
80 const String& block_name, const String& func_name,
81 int64_t limit) {
82 Array<StmtSRef> block_srefs =
83 tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name));
84 ICHECK_EQ(block_srefs.size(), 1);
85 const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]);
86 ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
87 std::vector<std::vector<int64_t>> results;
88 results.reserve(info.realizes.size());
89 for (const BlockRealize& realize : info.realizes) {
90 // Step 1. Extract static loop extents for spatial loops
91 std::vector<int64_t> loop_extents;
92 const ForNode* loop = nullptr;
93 for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent;
94 (loop = loop_sref->StmtAs<ForNode>()) != nullptr; //
95 loop_sref = loop_sref->parent) {
96 int64_t loop_extent = -1;
97 if (const auto* ext = GetLoopIntExtent(loop)) {
98 if (!info.non_spatial_vars.count(loop->loop_var.get())) {
99 loop_extent = *ext;
100 }
101 }
102 if (loop_extent != -1) {
103 loop_extents.push_back(loop_extent);
104 } else {
105 loop_extents.clear();
106 }
107 }
108 // Step 2. Take the prefix product of loop extents
109 if (!loop_extents.empty()) {
110 results.emplace_back();
111 std::vector<int64_t>& result = results.back();
112 result.reserve(loop_extents.size());
113 int64_t prod_extent = 1;
114 for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) {
115 result.push_back(prod_extent *= *it);
116 if (prod_extent >= limit) {
117 break;
118 }
119 }
120 }
121 }
122 return results;
123}
124
125/*!
126 * \brief Get the number of parallelizable loops for each subtree
127 * \param loop_extent_prods The parallel structure for each subtree
128 * \param limit The uplimit of the parallelism
129 * \return The number of parallelizable loops for each subtree
130 */
131std::vector<int> GetNumFusedLoops(const std::vector<std::vector<int64_t>>& loop_extent_prods,
132 int64_t limit) {
133 std::vector<int> results;
134 results.reserve(loop_extent_prods.size());
135 for (const std::vector<int64_t>& prods : loop_extent_prods) {
136 int n = prods.size();
137 int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin();
138 if (i > 0 && prods[i - 1] == limit) {
139 --i;
140 }
141 if (i != n) {
142 ++i;
143 }
144 results.push_back(i);
145 }
146 return results;
147}
148
149} // namespace tir
150} // namespace tvm
151
152namespace tvm {
153namespace meta_schedule {
154
155using tir::Instruction;
156using tir::Trace;
157
158/*! \brief Create a Mutator that mutates the parallel extent */
159class MutateParallelNode : public MutatorNode {
160 public:
161 /*!
162 * \brief The maximum number of jobs to be launched per CPU core.
163 * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`.
164 * Use -1 to disable parallelism.
165 */
166 int64_t max_jobs_per_core;
167 /*! \brief The number of cores in CPU. */
168 int max_parallel_extent_;
169 /*! \brief JSON representation of the workload */
170 std::string json_mod_;
171
172 void VisitAttrs(tvm::AttrVisitor* v) {
173 v->Visit("max_jobs_per_core", &max_jobs_per_core);
174 // `max_parallel_extent_` is not visited.
175 // `json_mod` is not visited.
176 }
177
178 static constexpr const char* _type_key = "meta_schedule.MutateParallel";
179 TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode);
180
181 public:
182 struct Candidate;
183 // Inherit from `MutatorNode`
184 void InitializeWithTuneContext(const TuneContext& context) final {
185 Target target = context->target.value();
186 this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core;
187 this->json_mod_ = SaveJSON(context->mod.value());
188 }
189 // Inherit from `MutatorNode`
190 Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
191 // Inherit from `MutatorNode`
192 Mutator Clone() const final {
193 ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>(*this);
194 return Mutator(n);
195 }
196};
197
198/*! \brief The candidate to be mutated */
199struct MutateParallelNode::Candidate {
200 /*! \brief The annotation instruction */
201 Instruction inst;
202 /*! \brief The current parallel extent */
203 int64_t parallel_extent;
204 /*! \brief The name of the root block */
205 String block_name;
206 /*! \brief The name of the PrimFunc */
207 String func_name;
208};
209
210/*!
211 * \brief Get an instruction that annotates the maximum parallel extent
212 * \param trace The trace to be mutated
213 * \param rand_state The random state
214 * \param candidate The candidate to be mutated
215 * \return Whether a decision is found
216 */
217bool FindParallelDecision(const Trace& trace, TRandState* rand_state,
218 MutateParallelNode::Candidate* candidate) {
219 using tir::BlockRVNode;
220 using tir::InstructionNode;
221 std::unordered_map<const BlockRVNode*, const InstructionNode*> get_block_insts;
222 std::vector<const InstructionNode*> ann_insts;
223 get_block_insts.reserve(trace->insts.size());
224 ann_insts.reserve(trace->insts.size());
225 for (const Instruction& inst : trace->insts) {
226 if (tir::IsAnnotateWithParallel(inst)) {
227 ann_insts.push_back(inst.get());
228 }
229 if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) {
230 get_block_insts[block_rv] = inst.get();
231 }
232 }
233 int n_ann_insts = ann_insts.size();
234 if (n_ann_insts == 0) {
235 return false;
236 }
237 const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
238 ICHECK_EQ(ann_inst->inputs.size(), 2);
239 const InstructionNode* get_block_inst =
240 get_block_insts.at(Downcast<tir::BlockRV>(ann_inst->inputs[0]).get());
241 ICHECK_EQ(get_block_inst->attrs.size(), 2);
242 candidate->inst = GetRef<Instruction>(ann_inst);
243 candidate->parallel_extent = Downcast<IntImm>(ann_inst->inputs[1])->value;
244 candidate->block_name = Downcast<String>(get_block_inst->attrs[0]);
245 candidate->func_name = Downcast<String>(get_block_inst->attrs[1]);
246 return true;
247}
248
249Optional<Trace> MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) {
250 // Step 1. Find a parallel decision.
251 Candidate candidate;
252 if (!FindParallelDecision(trace, rand_state, &candidate)) {
253 return NullOpt;
254 }
255 // Step 2. Replay the instructions to recover loop extents
256 tir::Schedule sch = tir::Schedule::Traced( //
257 /*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
258 /*rand_state=*/ForkSeed(rand_state), //
259 /*debug_mode=*/0,
260 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
261 trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
262 // Step 3. Find all possible parallel plans
263 std::vector<std::vector<int64_t>> loop_extent_prods = tir::AnalyzeParallel(
264 sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_);
265 std::unordered_map<int64_t, std::vector<int>> limit2plan;
266 std::map<std::vector<int>, int64_t> plan2limit;
267 for (const std::vector<int64_t>& prods : loop_extent_prods) {
268 for (int64_t limit : prods) {
269 if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) {
270 std::vector<int> plan = tir::GetNumFusedLoops(loop_extent_prods, limit);
271 limit2plan[limit] = plan;
272 plan2limit[plan] = limit;
273 }
274 }
275 }
276 // Step 4. Remove the original plan and remove it
277 std::vector<int> original_plan =
278 tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent);
279 auto it = plan2limit.find(original_plan);
280 if (it != plan2limit.end()) {
281 plan2limit.erase(it);
282 }
283 // Step 5. Pick a new plan
284 int n_plans = plan2limit.size();
285 if (n_plans == 0) {
286 return NullOpt;
287 }
288 it = plan2limit.begin();
289 for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) {
290 ++it;
291 }
292 int64_t limit = it->second;
293 // Step 6. Assemble a new trace
294 Array<Instruction> insts;
295 insts.reserve(trace->insts.size());
296 for (const Instruction& inst : trace->insts) {
297 if (inst.same_as(candidate.inst)) {
298 insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit));
299 } else if (inst->kind->IsPostproc()) {
300 break;
301 } else {
302 insts.push_back(inst);
303 }
304 }
305 return Trace(insts, trace->decisions);
306}
307
308Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) {
309 ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>();
310 n->max_jobs_per_core = max_jobs_per_core;
311 return Mutator(n);
312}
313
314TVM_REGISTER_NODE_TYPE(MutateParallelNode);
315TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel);
316
317} // namespace meta_schedule
318} // namespace tvm
319