1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <algorithm> |
20 | #include <unordered_map> |
21 | |
22 | #include "../utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | /*! |
28 | * \brief Check if the instruction is annotation with `meta_schedule_parallel` |
29 | * \param inst The instruction to be checked |
30 | * \return Whether the instruction is annotation with `meta_schedule_parallel` |
31 | */ |
32 | bool IsAnnotateWithParallel(const Instruction& inst) { |
33 | static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate" ); |
34 | if (!inst->kind.same_as(inst_annotate)) { |
35 | return false; |
36 | } |
37 | ICHECK_EQ(inst->attrs.size(), 1); |
38 | String ann_key = Downcast<String>(inst->attrs[0]); |
39 | return ann_key == attr::meta_schedule_parallel; |
40 | } |
41 | |
42 | /*! |
43 | * \brief Replace the annotation value |
44 | * \param inst The instruction to be replaced |
45 | * \param ann_val The new annotation value |
46 | * \return The replaced instruction |
47 | */ |
48 | Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { |
49 | ICHECK_EQ(inst->inputs.size(), 2); |
50 | return Instruction(/*kind=*/inst->kind, // |
51 | /*inputs=*/{inst->inputs[0], Integer(ann_val)}, // |
52 | /*attrs=*/inst->attrs, |
53 | /*outputs=*/inst->outputs); |
54 | } |
55 | |
56 | /*! |
57 | * \brief Get the output of the instruction Get-Block |
58 | * \param inst The instruction to be checked |
59 | * \return The output of the instruction Get-Block |
60 | */ |
61 | const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { |
62 | static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock" ); |
63 | if (!inst->kind.same_as(inst_get_block)) { |
64 | return nullptr; |
65 | } |
66 | ICHECK_EQ(inst->outputs.size(), 1); |
67 | const BlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], BlockRVNode); |
68 | return block; |
69 | } |
70 | |
71 | /*! |
72 | * \brief Analyze the parallel structure |
73 | * \param self The schedule state |
74 | * \param block_name The name of the root block |
75 | * \param func_name The name of the PrimFunc |
76 | * \param limit The uplimit of the parallelism |
77 | * \return The parallel structure |
78 | */ |
79 | std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self, |
80 | const String& block_name, const String& func_name, |
81 | int64_t limit) { |
82 | Array<StmtSRef> block_srefs = |
83 | tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); |
84 | ICHECK_EQ(block_srefs.size(), 1); |
85 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]); |
86 | ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block)); |
87 | std::vector<std::vector<int64_t>> results; |
88 | results.reserve(info.realizes.size()); |
89 | for (const BlockRealize& realize : info.realizes) { |
90 | // Step 1. Extract static loop extents for spatial loops |
91 | std::vector<int64_t> loop_extents; |
92 | const ForNode* loop = nullptr; |
93 | for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent; |
94 | (loop = loop_sref->StmtAs<ForNode>()) != nullptr; // |
95 | loop_sref = loop_sref->parent) { |
96 | int64_t loop_extent = -1; |
97 | if (const auto* ext = GetLoopIntExtent(loop)) { |
98 | if (!info.non_spatial_vars.count(loop->loop_var.get())) { |
99 | loop_extent = *ext; |
100 | } |
101 | } |
102 | if (loop_extent != -1) { |
103 | loop_extents.push_back(loop_extent); |
104 | } else { |
105 | loop_extents.clear(); |
106 | } |
107 | } |
108 | // Step 2. Take the prefix product of loop extents |
109 | if (!loop_extents.empty()) { |
110 | results.emplace_back(); |
111 | std::vector<int64_t>& result = results.back(); |
112 | result.reserve(loop_extents.size()); |
113 | int64_t prod_extent = 1; |
114 | for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) { |
115 | result.push_back(prod_extent *= *it); |
116 | if (prod_extent >= limit) { |
117 | break; |
118 | } |
119 | } |
120 | } |
121 | } |
122 | return results; |
123 | } |
124 | |
125 | /*! |
126 | * \brief Get the number of parallelizable loops for each subtree |
127 | * \param loop_extent_prods The parallel structure for each subtree |
128 | * \param limit The uplimit of the parallelism |
129 | * \return The number of parallelizable loops for each subtree |
130 | */ |
131 | std::vector<int> GetNumFusedLoops(const std::vector<std::vector<int64_t>>& loop_extent_prods, |
132 | int64_t limit) { |
133 | std::vector<int> results; |
134 | results.reserve(loop_extent_prods.size()); |
135 | for (const std::vector<int64_t>& prods : loop_extent_prods) { |
136 | int n = prods.size(); |
137 | int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin(); |
138 | if (i > 0 && prods[i - 1] == limit) { |
139 | --i; |
140 | } |
141 | if (i != n) { |
142 | ++i; |
143 | } |
144 | results.push_back(i); |
145 | } |
146 | return results; |
147 | } |
148 | |
149 | } // namespace tir |
150 | } // namespace tvm |
151 | |
152 | namespace tvm { |
153 | namespace meta_schedule { |
154 | |
155 | using tir::Instruction; |
156 | using tir::Trace; |
157 | |
158 | /*! \brief Create a Mutator that mutates the parallel extent */ |
159 | class MutateParallelNode : public MutatorNode { |
160 | public: |
161 | /*! |
162 | * \brief The maximum number of jobs to be launched per CPU core. |
163 | * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. |
164 | * Use -1 to disable parallelism. |
165 | */ |
166 | int64_t max_jobs_per_core; |
167 | /*! \brief The number of cores in CPU. */ |
168 | int max_parallel_extent_; |
169 | /*! \brief JSON representation of the workload */ |
170 | std::string json_mod_; |
171 | |
172 | void VisitAttrs(tvm::AttrVisitor* v) { |
173 | v->Visit("max_jobs_per_core" , &max_jobs_per_core); |
174 | // `max_parallel_extent_` is not visited. |
175 | // `json_mod` is not visited. |
176 | } |
177 | |
178 | static constexpr const char* _type_key = "meta_schedule.MutateParallel" ; |
179 | TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); |
180 | |
181 | public: |
182 | struct Candidate; |
183 | // Inherit from `MutatorNode` |
184 | void InitializeWithTuneContext(const TuneContext& context) final { |
185 | Target target = context->target.value(); |
186 | this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core; |
187 | this->json_mod_ = SaveJSON(context->mod.value()); |
188 | } |
189 | // Inherit from `MutatorNode` |
190 | Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final; |
191 | // Inherit from `MutatorNode` |
192 | Mutator Clone() const final { |
193 | ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>(*this); |
194 | return Mutator(n); |
195 | } |
196 | }; |
197 | |
198 | /*! \brief The candidate to be mutated */ |
199 | struct MutateParallelNode::Candidate { |
200 | /*! \brief The annotation instruction */ |
201 | Instruction inst; |
202 | /*! \brief The current parallel extent */ |
203 | int64_t parallel_extent; |
204 | /*! \brief The name of the root block */ |
205 | String block_name; |
206 | /*! \brief The name of the PrimFunc */ |
207 | String func_name; |
208 | }; |
209 | |
210 | /*! |
211 | * \brief Get an instruction that annotates the maximum parallel extent |
212 | * \param trace The trace to be mutated |
213 | * \param rand_state The random state |
214 | * \param candidate The candidate to be mutated |
215 | * \return Whether a decision is found |
216 | */ |
217 | bool FindParallelDecision(const Trace& trace, TRandState* rand_state, |
218 | MutateParallelNode::Candidate* candidate) { |
219 | using tir::BlockRVNode; |
220 | using tir::InstructionNode; |
221 | std::unordered_map<const BlockRVNode*, const InstructionNode*> get_block_insts; |
222 | std::vector<const InstructionNode*> ann_insts; |
223 | get_block_insts.reserve(trace->insts.size()); |
224 | ann_insts.reserve(trace->insts.size()); |
225 | for (const Instruction& inst : trace->insts) { |
226 | if (tir::IsAnnotateWithParallel(inst)) { |
227 | ann_insts.push_back(inst.get()); |
228 | } |
229 | if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) { |
230 | get_block_insts[block_rv] = inst.get(); |
231 | } |
232 | } |
233 | int n_ann_insts = ann_insts.size(); |
234 | if (n_ann_insts == 0) { |
235 | return false; |
236 | } |
237 | const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; |
238 | ICHECK_EQ(ann_inst->inputs.size(), 2); |
239 | const InstructionNode* get_block_inst = |
240 | get_block_insts.at(Downcast<tir::BlockRV>(ann_inst->inputs[0]).get()); |
241 | ICHECK_EQ(get_block_inst->attrs.size(), 2); |
242 | candidate->inst = GetRef<Instruction>(ann_inst); |
243 | candidate->parallel_extent = Downcast<IntImm>(ann_inst->inputs[1])->value; |
244 | candidate->block_name = Downcast<String>(get_block_inst->attrs[0]); |
245 | candidate->func_name = Downcast<String>(get_block_inst->attrs[1]); |
246 | return true; |
247 | } |
248 | |
249 | Optional<Trace> MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { |
250 | // Step 1. Find a parallel decision. |
251 | Candidate candidate; |
252 | if (!FindParallelDecision(trace, rand_state, &candidate)) { |
253 | return NullOpt; |
254 | } |
255 | // Step 2. Replay the instructions to recover loop extents |
256 | tir::Schedule sch = tir::Schedule::Traced( // |
257 | /*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), // |
258 | /*rand_state=*/ForkSeed(rand_state), // |
259 | /*debug_mode=*/0, |
260 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); |
261 | trace->ApplyToSchedule(sch, /*remove_postproc=*/true); |
262 | // Step 3. Find all possible parallel plans |
263 | std::vector<std::vector<int64_t>> loop_extent_prods = tir::AnalyzeParallel( |
264 | sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_); |
265 | std::unordered_map<int64_t, std::vector<int>> limit2plan; |
266 | std::map<std::vector<int>, int64_t> plan2limit; |
267 | for (const std::vector<int64_t>& prods : loop_extent_prods) { |
268 | for (int64_t limit : prods) { |
269 | if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) { |
270 | std::vector<int> plan = tir::GetNumFusedLoops(loop_extent_prods, limit); |
271 | limit2plan[limit] = plan; |
272 | plan2limit[plan] = limit; |
273 | } |
274 | } |
275 | } |
276 | // Step 4. Remove the original plan and remove it |
277 | std::vector<int> original_plan = |
278 | tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); |
279 | auto it = plan2limit.find(original_plan); |
280 | if (it != plan2limit.end()) { |
281 | plan2limit.erase(it); |
282 | } |
283 | // Step 5. Pick a new plan |
284 | int n_plans = plan2limit.size(); |
285 | if (n_plans == 0) { |
286 | return NullOpt; |
287 | } |
288 | it = plan2limit.begin(); |
289 | for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { |
290 | ++it; |
291 | } |
292 | int64_t limit = it->second; |
293 | // Step 6. Assemble a new trace |
294 | Array<Instruction> insts; |
295 | insts.reserve(trace->insts.size()); |
296 | for (const Instruction& inst : trace->insts) { |
297 | if (inst.same_as(candidate.inst)) { |
298 | insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit)); |
299 | } else if (inst->kind->IsPostproc()) { |
300 | break; |
301 | } else { |
302 | insts.push_back(inst); |
303 | } |
304 | } |
305 | return Trace(insts, trace->decisions); |
306 | } |
307 | |
308 | Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { |
309 | ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>(); |
310 | n->max_jobs_per_core = max_jobs_per_core; |
311 | return Mutator(n); |
312 | } |
313 | |
314 | TVM_REGISTER_NODE_TYPE(MutateParallelNode); |
315 | TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel" ).set_body_typed(Mutator::MutateParallel); |
316 | |
317 | } // namespace meta_schedule |
318 | } // namespace tvm |
319 | |