1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/sketch_policy_rules.cc |
22 | * \brief Rules for generating the sketches, sampling the initial population, and mutating the |
23 | * population in SketchPolicy. |
24 | */ |
25 | |
26 | #include "sketch_policy_rules.h" |
27 | |
28 | #include <set> |
29 | #include <string> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | #include "sketch_policy.h" |
34 | |
35 | namespace tvm { |
36 | namespace auto_scheduler { |
37 | |
38 | static std::vector<int> auto_unroll_configs_cpu = {0, 16, 64, 512}; |
39 | static std::vector<int> auto_unroll_configs_gpu = {0, 16, 64, 512, 1024}; |
40 | |
41 | /********** Sketch Generation Rule **********/ |
42 | /********** RuleSkipStage **********/ |
43 | |
44 | SketchGenerationRule::ConditionKind RuleSkipStage::MeetCondition(const SketchPolicyNode& policy, |
45 | const State& state, |
46 | int stage_id) const { |
47 | // This rule should be the last rule, always return true to decrease the stage index count |
48 | return ConditionKind::kApply; |
49 | } |
50 | |
51 | std::vector<std::pair<State, int>> RuleSkipStage::Apply(const SketchPolicyNode& policy, |
52 | const State& state, int stage_id) const { |
53 | return {std::make_pair(state, stage_id - 1)}; |
54 | } |
55 | |
56 | /********** RuleAlwaysInline **********/ |
57 | inline bool ShouldAlwaysBeInlined(const SketchPolicyNode& policy, const State& state, |
58 | int stage_id) { |
59 | const SearchTask& task = policy.search_task; |
60 | const Stage& stage = state->stages[stage_id]; |
61 | |
62 | // Check the inline limitation of TE |
63 | if (stage->op_type == StageKind::kPlaceholder || IsOutputOp(task, state, stage_id) || |
64 | HasReduceIter(stage)) { |
65 | return false; |
66 | } |
67 | |
68 | if (IsGPUTask(task)) { // Greedily inline all inlinable ops on gpu |
69 | return true; |
70 | } else { |
71 | // Only always-inline strict-inlinable ops on cpu. |
72 | // The computation location of other ops will be tuned by InitChangeComputeLocation |
73 | // and MutateComputeLocation. |
74 | return IsStrictlyInlineable(task, state, stage_id); |
75 | } |
76 | } |
77 | |
78 | SketchGenerationRule::ConditionKind RuleAlwaysInline::MeetCondition(const SketchPolicyNode& policy, |
79 | const State& state, |
80 | int stage_id) const { |
81 | return ShouldAlwaysBeInlined(policy, state, stage_id) ? ConditionKind::kApplyAndSkipRest |
82 | : ConditionKind::kSkip; |
83 | } |
84 | |
85 | std::vector<std::pair<State, int>> RuleAlwaysInline::Apply(const SketchPolicyNode& policy, |
86 | const State& state, int stage_id) const { |
87 | State tmp_s = state; |
88 | tmp_s.compute_inline(stage_id); |
89 | return {std::make_pair(std::move(tmp_s), stage_id - 1)}; |
90 | } |
91 | |
92 | /********** RuleMultiLevelTiling **********/ |
93 | |
94 | SketchGenerationRule::ConditionKind RuleMultiLevelTiling::MeetCondition( |
95 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
96 | return NeedsMultilevelTiling(policy.search_task, state, stage_id) |
97 | ? ConditionKind::kApplyAndSkipRest |
98 | : ConditionKind::kSkip; |
99 | } |
100 | |
101 | std::vector<std::pair<State, int>> RuleMultiLevelTiling::Apply(const SketchPolicyNode& policy, |
102 | const State& state, |
103 | int stage_id) const { |
104 | const std::string& multi_level_tiling_structure = |
105 | IsGPUTask(policy.search_task) |
106 | ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure) |
107 | : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure); |
108 | State tmp_s = DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure); |
109 | return {std::make_pair(std::move(tmp_s), stage_id - 1)}; |
110 | } |
111 | |
112 | /********** RuleMultiLevelTilingWithFusion **********/ |
113 | |
114 | SketchGenerationRule::ConditionKind RuleMultiLevelTilingWithFusion::MeetCondition( |
115 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
116 | if (NeedsMultilevelTiling(policy.search_task, state, stage_id) && |
117 | HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) { |
118 | // Always do fusion for stage with cache_write or is in GPU policy |
119 | return HasCacheWriteStage(state, stage_id) || IsGPUTask(policy.search_task) |
120 | ? ConditionKind::kApplyAndSkipRest |
121 | : ConditionKind::kApply; |
122 | } |
123 | return ConditionKind::kSkip; |
124 | } |
125 | |
126 | std::vector<std::pair<State, int>> RuleMultiLevelTilingWithFusion::Apply( |
127 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
128 | int target_stage_id; |
129 | ICHECK( |
130 | HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id)); |
131 | const std::string& multi_level_tiling_structure = |
132 | IsGPUTask(policy.search_task) |
133 | ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure) |
134 | : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure); |
135 | std::vector<int> spatial_split_step_ids; |
136 | State base_state = |
137 | DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids); |
138 | |
139 | std::vector<std::pair<State, int>> ret; |
140 | std::vector<int> follow_tiling_levels = |
141 | IsGPUTask(policy.search_task) ? std::vector<int>{3} : std::vector<int>{1, 2}; |
142 | for (int level : follow_tiling_levels) { |
143 | if (tolower(multi_level_tiling_structure[level - 1]) != 's') { |
144 | continue; |
145 | } |
146 | State tmp_s = base_state; |
147 | tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level); |
148 | const Iterator& target_iter = |
149 | tmp_s->stages[target_stage_id]->iters[level * spatial_split_step_ids.size() - 1]; |
150 | tmp_s.compute_at(stage_id, target_stage_id, target_iter); |
151 | ret.emplace_back(std::move(tmp_s), stage_id - 1); |
152 | } |
153 | |
154 | return ret; |
155 | } |
156 | |
157 | /********** RuleAddCacheRead **********/ |
158 | |
159 | SketchGenerationRule::ConditionKind RuleAddCacheRead::MeetCondition(const SketchPolicyNode& policy, |
160 | const State& state, |
161 | int stage_id) const { |
162 | const SearchTask& task = policy.search_task; |
163 | |
164 | // Don't cache_read a stage if it has multiple consumers |
165 | const std::set<int>& consumers = GetConsumers(task, state, stage_id); |
166 | |
167 | if (consumers.size() == 0) return ConditionKind::kSkip; |
168 | // Don't cache_read a stage if its consumer does not need multi-level tiling |
169 | int target_stage_id = *consumers.begin(); |
170 | if (!NeedsMultilevelTiling(task, state, target_stage_id)) { |
171 | return ConditionKind::kSkip; |
172 | } |
173 | |
174 | // Don't cache_read a stage if its consumer does cross-thread reduction |
175 | if (HasCrossThreadReduction(state, target_stage_id)) { |
176 | return ConditionKind::kSkip; |
177 | } |
178 | |
179 | // Only direct producers can be cache read |
180 | const std::set<int>& producers = GetDirectProducers(task, state, target_stage_id); |
181 | if (producers.find(stage_id) == producers.end()) { |
182 | return ConditionKind::kSkip; |
183 | } |
184 | |
185 | return ConditionKind::kApplyAndSkipRest; |
186 | } |
187 | |
188 | std::vector<std::pair<State, int>> RuleAddCacheRead::Apply(const SketchPolicyNode& policy, |
189 | const State& state, int stage_id) const { |
190 | const SearchTask& task = policy.search_task; |
191 | const std::set<int>& consumers = GetConsumers(task, state, stage_id); |
192 | State tmp_s = state; |
193 | |
194 | int target_stage_id_offset = 0; |
195 | for (int orig_target_stage_id : consumers) { |
196 | int target_stage_id = orig_target_stage_id + target_stage_id_offset; |
197 | |
198 | // Cache read add shared memory |
199 | int added_stage_id = tmp_s.cache_read(stage_id, "shared" , {target_stage_id}, task->compute_dag); |
200 | target_stage_id_offset++; |
201 | target_stage_id++; |
202 | |
203 | const auto& share_read_pos = |
204 | GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]); |
205 | tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos); |
206 | } |
207 | |
208 | return {std::make_pair(tmp_s, stage_id)}; |
209 | } |
210 | |
211 | /********** RuleAddCacheWrite **********/ |
212 | |
213 | SketchGenerationRule::ConditionKind RuleAddCacheWrite::MeetCondition(const SketchPolicyNode& policy, |
214 | const State& state, |
215 | int stage_id) const { |
216 | // Add cache write if a stage needs multi-level tiling, but does not have a element-wise |
217 | // matched consumer |
218 | if (NeedsMultilevelTiling(policy.search_task, state, stage_id) && |
219 | !HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) { |
220 | // An apply and skip rule will be handled in RuleMultiLevelTilingWithFusion |
221 | return IsGPUTask(policy.search_task) ? ConditionKind::kApplyAndSkipRest : ConditionKind::kApply; |
222 | } |
223 | |
224 | return ConditionKind::kSkip; |
225 | } |
226 | |
227 | std::vector<std::pair<State, int>> RuleAddCacheWrite::Apply(const SketchPolicyNode& policy, |
228 | const State& state, |
229 | int stage_id) const { |
230 | State tmp_s = state; |
231 | tmp_s.cache_write(stage_id, "local" , policy.search_task->compute_dag); |
232 | return {std::make_pair(std::move(tmp_s), stage_id)}; |
233 | } |
234 | |
235 | /********** RuleAddRfactor **********/ |
236 | |
237 | SketchGenerationRule::ConditionKind RuleAddRfactor::MeetCondition(const SketchPolicyNode& policy, |
238 | const State& state, |
239 | int stage_id) const { |
240 | return (NeedsRfactor(policy.search_task, state, stage_id) && !HasCacheWriteStage(state, stage_id)) |
241 | ? ConditionKind::kApply |
242 | : ConditionKind::kSkip; |
243 | } |
244 | |
245 | std::vector<std::pair<State, int>> RuleAddRfactor::Apply(const SketchPolicyNode& policy, |
246 | const State& state, int stage_id) const { |
247 | // Fuse all reduction iters |
248 | Array<Iterator> space_iters, reduce_iters; |
249 | Iterator fused_reduce_iter; |
250 | State base_state = |
251 | FuseAllReductionIterators(state, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters); |
252 | |
253 | // TODO(merrymercy): We can do more analysis here to generate less and more efficient sketches. |
254 | // In some cases, we only need rfactor for more parallel |
255 | // In some cases, we only need rfactor for vectorization. |
256 | // Now we will generate two versions and let the search figure out the bette one. |
257 | |
258 | // Split reduction iters |
259 | const auto& split_res = base_state.split(stage_id, fused_reduce_iter, {Integer(1)}); |
260 | int factor_axis_id = static_cast<int>(space_iters.size()); |
261 | std::vector<std::pair<State, int>> ret; |
262 | for (const auto& split_iter : split_res) { |
263 | State tmp_s = base_state; |
264 | int rstage_id = |
265 | tmp_s.rfactor(stage_id, split_iter, factor_axis_id, policy.search_task->compute_dag); |
266 | |
267 | // reorder the space iterator to innermost for vectorization |
268 | if (split_iter == split_res[1]) { |
269 | Array<Iterator> new_order; |
270 | for (size_t i = 0; i < tmp_s->stages[rstage_id]->iters.size(); ++i) { |
271 | if (i != space_iters.size()) { |
272 | new_order.push_back(tmp_s->stages[rstage_id]->iters[i]); |
273 | } |
274 | } |
275 | new_order.push_back(tmp_s->stages[rstage_id]->iters[space_iters.size()]); |
276 | tmp_s.reorder(rstage_id, new_order); |
277 | } |
278 | |
279 | ret.emplace_back(std::move(tmp_s), rstage_id - 1); |
280 | } |
281 | |
282 | return ret; |
283 | } |
284 | |
285 | /********** RuleSimplifyComputeWithConstTensor **********/ |
286 | |
287 | SketchGenerationRule::ConditionKind RuleSimplifyComputeWithConstTensor::MeetCondition( |
288 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
289 | return state->stages[stage_id]->op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices) |
290 | ? ConditionKind::kApplyAndSkipRest |
291 | : ConditionKind::kSkip; |
292 | } |
293 | |
294 | std::vector<std::pair<State, int>> RuleSimplifyComputeWithConstTensor::Apply( |
295 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
296 | std::set<std::string> const_tensor_indices = GetIterNameSetParam( |
297 | state->stages[stage_id]->op->attrs, SearchPolicyKey::simplify_const_tensor_indices); |
298 | |
299 | State tmp_s = state; |
300 | Array<Array<Iterator>> tiled_outer_iters; |
301 | Array<Iterator> unrolled_inner_iters; |
302 | |
303 | // Currently set to 2 |
304 | size_t tile_level = 2; |
305 | |
306 | for (const auto& iter : state->stages[stage_id]->iters) { |
307 | if (const_tensor_indices.count(iter->name)) { |
308 | // unroll indices of const tensors |
309 | unrolled_inner_iters.push_back(tmp_s.unroll(stage_id, iter)); |
310 | } else { |
311 | // tile other space indices |
312 | ICHECK(iter->iter_kind == IteratorKind::kSpatial); |
313 | tiled_outer_iters.push_back( |
314 | tmp_s.split(stage_id, iter, Array<Optional<Integer>>(tile_level - 1, NullOpt))); |
315 | } |
316 | } |
317 | |
318 | // reorder them |
319 | Array<Iterator> new_order; |
320 | for (size_t i = 0; i < tile_level; ++i) { |
321 | for (size_t j = 0; j < tiled_outer_iters.size(); ++j) { |
322 | new_order.push_back(tiled_outer_iters[j][i]); |
323 | } |
324 | } |
325 | new_order.insert(new_order.end(), unrolled_inner_iters.begin(), unrolled_inner_iters.end()); |
326 | tmp_s.reorder(stage_id, new_order); |
327 | |
328 | return {std::make_pair(tmp_s, stage_id - 1)}; |
329 | } |
330 | |
331 | /********** RuleCrossThreadReduction **********/ |
332 | |
333 | SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition( |
334 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
335 | ICHECK(IsGPUTask(policy.search_task)); |
336 | |
337 | // If it is an intermediate state created by RuleAddCacheWrite, |
338 | // we just skip it. |
339 | if (HasCacheWriteStage(state, stage_id)) { |
340 | return ConditionKind::kSkip; |
341 | } |
342 | |
343 | const auto& op = state->stages[stage_id]->op; |
344 | if (op->IsInstance<te::ComputeOpNode>()) { |
345 | // Compute the product of lengths of all space iters and all reduce iters |
346 | auto [cum_space_len, cum_reduce_len] = |
347 | GetCumulativeSpaceAndReductionLength(state->stages[stage_id]); |
348 | |
349 | if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) { |
350 | // Avoid rfactor if we have enough parallelism on space iters |
351 | if (cum_space_len > policy.search_task->hardware_params->max_threads_per_block) { |
352 | return ConditionKind::kSkip; |
353 | } |
354 | |
355 | return cum_space_len < cum_reduce_len ? ConditionKind::kApply : ConditionKind::kSkip; |
356 | } else if (cum_reduce_len > 1) { |
357 | // Try rfactor for other reduction operators |
358 | return cum_reduce_len > policy.search_task->hardware_params->warp_size ? ConditionKind::kApply |
359 | : ConditionKind::kSkip; |
360 | } |
361 | } |
362 | |
363 | return ConditionKind::kSkip; |
364 | } |
365 | |
366 | std::vector<std::pair<State, int>> RuleCrossThreadReduction::Apply(const SketchPolicyNode& policy, |
367 | const State& state, |
368 | int stage_id) const { |
369 | const SearchTask& task = policy.search_task; |
370 | State tmp_s = state; |
371 | |
372 | // fuse all reduction iters |
373 | Array<Iterator> space_iters, reduce_iters; |
374 | Iterator fused_reduce_iter; |
375 | tmp_s = |
376 | FuseAllReductionIterators(tmp_s, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters); |
377 | |
378 | // Check the opportunity for kernel fusion |
379 | bool fusible = false; |
380 | int target_stage_id = GetSingleConsumerId(policy.search_task, tmp_s, stage_id); |
381 | int num_common_outer = -1; |
382 | if (target_stage_id >= 0) { |
383 | num_common_outer = |
384 | GetNumCommonOuterIterator(policy.search_task, tmp_s, stage_id, target_stage_id); |
385 | if (num_common_outer > 0 && |
386 | !NeedsMultilevelTiling(policy.search_task, state, target_stage_id)) { |
387 | fusible = true; |
388 | } |
389 | } |
390 | |
391 | if (fusible) { |
392 | const Stage& target_stage = state->stages[target_stage_id]; |
393 | std::vector<int> split_step_ids; |
394 | |
395 | GetSplitStepIds(tmp_s, target_stage_id, &split_step_ids); |
396 | |
397 | if (split_step_ids.size() == 0) { |
398 | // If the target stage does not have split step, |
399 | // it must be a simple stage without reduce iters. |
400 | // We then should do a split for it. |
401 | ICHECK(!HasReduceIter(target_stage)); |
402 | const auto& split_res = tmp_s.split(target_stage_id, target_stage->iters.back(), |
403 | {Integer(task->hardware_params->warp_size)}); |
404 | tmp_s.bind(target_stage_id, split_res[1], IteratorAnnotation::kThreadX); |
405 | split_step_ids.push_back(tmp_s->transform_steps.size() - 2); |
406 | } |
407 | |
408 | ICHECK_EQ(split_step_ids.size(), 1); |
409 | |
410 | const Iterator& target_iter = tmp_s->stages[target_stage_id]->iters[num_common_outer - 1]; |
411 | const auto& split_res = tmp_s.follow_split(stage_id, fused_reduce_iter, split_step_ids[0], 1); |
412 | tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX); |
413 | tmp_s.compute_at(stage_id, target_stage_id, target_iter); |
414 | } else { |
415 | const auto& split_res = |
416 | tmp_s.split(stage_id, fused_reduce_iter, {Integer(task->hardware_params->warp_size)}); |
417 | tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX); |
418 | } |
419 | |
420 | return {std::make_pair(std::move(tmp_s), stage_id - 1)}; |
421 | } |
422 | |
423 | /********** RuleSpecialComputeLocationGPU **********/ |
424 | |
425 | SketchGenerationRule::ConditionKind RuleSpecialComputeLocationGPU::MeetCondition( |
426 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
427 | if (GetProducers(policy.search_task, state, stage_id).empty()) { |
428 | return ConditionKind::kSkip; |
429 | } |
430 | |
431 | if (!ShouldAlwaysBeInlined(policy, state, stage_id)) { |
432 | return ConditionKind::kSkip; |
433 | } |
434 | |
435 | const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id); |
436 | if (consumers.size() == 1 && state->stages[*consumers.begin()]->op->attrs.count( |
437 | SearchPolicyKey::simplify_const_tensor_indices)) { |
438 | return ConditionKind::kApplyAndSkipRest; |
439 | } |
440 | |
441 | return ConditionKind::kSkip; |
442 | } |
443 | |
444 | std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply( |
445 | const SketchPolicyNode& policy, const State& state, int stage_id) const { |
446 | State tmp_s = state; |
447 | const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id); |
448 | ICHECK_EQ(consumers.size(), 1); |
449 | |
450 | // Get the last outer space iterator that is not unrolled. |
451 | const Stage& target_stage = state->stages[*consumers.begin()]; |
452 | for (size_t i = 0; i < target_stage->iters.size(); ++i) { |
453 | if (target_stage->iters[i]->annotation == IteratorAnnotation::kUnroll) { |
454 | ICHECK_GT(i, 0); |
455 | |
456 | tmp_s.compute_at(stage_id, *consumers.begin(), target_stage->iters[i - 1]); |
457 | break; |
458 | } |
459 | } |
460 | |
461 | return {std::make_pair(std::move(tmp_s), stage_id - 1)}; |
462 | } |
463 | |
464 | /********** RuleCustomSketch **********/ |
465 | |
466 | SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy, |
467 | const State& state, |
468 | int stage_id) const { |
469 | auto ret = meet_condition_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id); |
470 | if (ret.type_code() == 0) { |
471 | return ConditionKind(static_cast<int>(ret)); |
472 | } else { |
473 | LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest" ; |
474 | return ConditionKind::kApplyAndSkipRest; |
475 | } |
476 | } |
477 | |
478 | std::vector<std::pair<State, int>> RuleCustomSketch::Apply(const SketchPolicyNode& policy, |
479 | const State& state, int stage_id) const { |
480 | Array<Array<ObjectRef>> apply_ret = |
481 | apply_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id); |
482 | std::vector<std::pair<State, int>> ret; |
483 | for (const auto& item : apply_ret) { |
484 | CHECK_EQ(item.size(), 2); |
485 | auto next = item[1].as<IntImmNode>(); |
486 | ret.emplace_back(Downcast<State>(item[0]), next->value); |
487 | } |
488 | return ret; |
489 | } |
490 | |
491 | /********** Init Population **********/ |
492 | |
493 | PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state, |
494 | std::mt19937* rand_gen) const { |
495 | SplitFactorizationMemo split_memo; |
496 | int max_innermost_split_factor = |
497 | GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); |
498 | |
499 | StateNode* pstate = state->CopyOnWrite(); |
500 | // Scan the transformation history and randomly fill tiles size for all SplitStep |
501 | for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { |
502 | if (auto ps = (*state)->transform_steps[step_id].as<SplitStepNode>()) { |
503 | bool all_defined = true; |
504 | for (const auto& len : ps->lengths) { |
505 | if (!len) { |
506 | all_defined = false; |
507 | break; |
508 | } |
509 | } |
510 | if (all_defined) { |
511 | continue; |
512 | } |
513 | |
514 | ICHECK(ps->extent); |
515 | int extent = GetIntImm(ps->extent.value()); |
516 | const auto& candidate_lens = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(), |
517 | max_innermost_split_factor); |
518 | ICHECK(!candidate_lens.empty()); |
519 | const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()]; |
520 | |
521 | pstate->transform_steps.Set( |
522 | step_id, |
523 | SplitStep(ps->stage_id, ps->iter_id, ps->extent, |
524 | Array<Optional<Integer>>(candidate_lengths.begin(), candidate_lengths.end()), |
525 | ps->inner_to_outer)); |
526 | } |
527 | } |
528 | pstate->concrete = true; |
529 | |
530 | return ResultKind::kValid; |
531 | } |
532 | |
533 | PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply( |
534 | SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { |
535 | if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { |
536 | return ResultKind::kValid; |
537 | } |
538 | |
539 | for (int stage_id = static_cast<int>((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { |
540 | const Stage& stage = (*state)->stages[stage_id]; |
541 | // Skip the inlined stages and placeholders |
542 | if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) { |
543 | continue; |
544 | } |
545 | // Skip the tiled stages |
546 | if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { |
547 | continue; |
548 | } |
549 | |
550 | std::vector<std::pair<int, int>> candidates = |
551 | GetComputeLocationCandidates(policy->search_task, *state, stage_id); |
552 | |
553 | int choice = (*rand_gen)() % (candidates.size() + 2); |
554 | |
555 | if (choice == 0) { |
556 | if (!HasReduceIter(stage)) { |
557 | const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter; |
558 | if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) { |
559 | state->compute_inline(stage_id); |
560 | } |
561 | } |
562 | } else if (choice == 1) { |
563 | state->compute_root(stage_id); |
564 | } else { |
565 | choice = choice - 2; |
566 | const Stage& stage = (*state)->stages[candidates[choice].first]; |
567 | state->compute_at(stage_id, candidates[choice].first, |
568 | stage->iters[candidates[choice].second]); |
569 | } |
570 | } |
571 | |
572 | try { |
573 | *state = policy->search_task->compute_dag.InferBound(*state); |
574 | } catch (std::exception& e) { |
575 | return ResultKind::kInvalid; |
576 | } |
577 | return ResultKind::kValid; |
578 | } |
579 | |
580 | PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state, |
581 | std::mt19937* rand_gen) const { |
582 | std::function<void(const SketchPolicyNode&, State*, int stage_id, int iter_offset)> |
583 | annotate_parallel; |
584 | annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state, |
585 | int stage_id, int iter_offset) { |
586 | const Stage& stage = (*state)->stages[stage_id]; |
587 | |
588 | Array<Iterator> to_fuse; |
589 | int64_t parallel_degree = 1; |
590 | |
591 | // Try to fuse and parallel the outermost n iterators |
592 | // Stop if we meet reduce iterator or we have enough parallel degree |
593 | size_t iter_id = iter_offset; |
594 | for (; iter_id < stage->iters.size(); ++iter_id) { |
595 | const Iterator& it = stage->iters[iter_id]; |
596 | if (it->iter_kind == IteratorKind::kReduction || |
597 | it->annotation != IteratorAnnotation::kNone) { |
598 | break; |
599 | } |
600 | to_fuse.push_back(it); |
601 | parallel_degree *= GetExtent(it); |
602 | |
603 | if (parallel_degree > policy.search_task->hardware_params->num_cores * 16) { |
604 | break; |
605 | } |
606 | |
607 | if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) { |
608 | break; |
609 | } |
610 | } |
611 | |
612 | if (parallel_degree == 1) { |
613 | auto res = |
614 | (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); |
615 | if (res != (*state)->attach_map->iter_to_attached_stages.end()) { |
616 | for (int attached_stage_id : res->second) { |
617 | annotate_parallel(policy, state, attached_stage_id, 0); |
618 | } |
619 | annotate_parallel(policy, state, stage_id, iter_id + 1); |
620 | } |
621 | } |
622 | |
623 | if (!to_fuse.empty()) { |
624 | if (to_fuse.size() == 1) { |
625 | state->parallel(stage_id, to_fuse[0]); |
626 | } else { |
627 | Iterator fused_iter = state->fuse(stage_id, to_fuse); |
628 | state->parallel(stage_id, fused_iter); |
629 | } |
630 | } |
631 | }; |
632 | |
633 | for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { |
634 | const Stage& stage = (*state)->stages[stage_id]; |
635 | if (stage->compute_at != ComputeAtKind::kRoot || stage->op_type == StageKind::kPlaceholder) { |
636 | continue; |
637 | } |
638 | |
639 | annotate_parallel(*policy, state, stage_id, 0); |
640 | } |
641 | |
642 | return ResultKind::kValid; |
643 | } |
644 | |
645 | PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state, |
646 | std::mt19937* rand_gen) const { |
647 | std::vector<int>& auto_unroll_configs = |
648 | IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; |
649 | for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { |
650 | const Stage& stage = (*state)->stages[stage_id]; |
651 | // Skip the inlined stage and placeholder stage |
652 | if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) { |
653 | continue; |
654 | } |
655 | |
656 | // Handle always_unroll_inner attr |
657 | if (stage->op->attrs.count(SearchPolicyKey::always_unroll_inner)) { |
658 | const auto& to_unroll_name_set = |
659 | GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::always_unroll_inner); |
660 | |
661 | // Unroll the space iterators and reduce iterators listed in the attrs in the innermost |
662 | // tile |
663 | std::set<std::string> visited_names; |
664 | for (int n = static_cast<int>(stage->iters.size()) - 1; n >= 0; n--) { |
665 | const Iterator& it = stage->iters[n]; |
666 | |
667 | // If we meet two iterators that come from a same original iterator, |
668 | // then we are out of the innermost tile |
669 | size_t size_before = visited_names.size(); |
670 | ExtractOriginalIterators(it->name, &visited_names); |
671 | if (size_before == visited_names.size()) { |
672 | break; |
673 | } |
674 | |
675 | std::set<std::string> name; |
676 | ExtractOriginalIterators(it->name, &name); |
677 | if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) { |
678 | if (it->annotation == IteratorAnnotation::kNone) { |
679 | state->unroll(stage_id, it); |
680 | } |
681 | } |
682 | } |
683 | } |
684 | |
685 | if (HasReduceIter(stage)) { |
686 | // Use auto unroll for multi level tiled stage |
687 | int value = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]; |
688 | state->pragma(stage_id, (*state)->stages[stage_id]->iters[0], |
689 | std::string("auto_unroll_max_step" ) + "$" + std::to_string(value)); |
690 | } |
691 | } |
692 | |
693 | return ResultKind::kValid; |
694 | } |
695 | |
696 | PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, |
697 | State* state, |
698 | std::mt19937* rand_gen) const { |
699 | for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { |
700 | const Stage& stage = (*state)->stages[stage_id]; |
701 | // Skip the inlined stage and placeholder stage |
702 | if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) { |
703 | continue; |
704 | } |
705 | |
706 | // Try to fuse and vectorize the space iterators in the inner most tile |
707 | int64_t cum_length_prod = 1; |
708 | |
709 | int num_fusible = 0; |
710 | while (num_fusible < static_cast<int>(stage->iters.size())) { |
711 | int iter_id = static_cast<int>(stage->iters.size()) - 1 - num_fusible; |
712 | // Stop if this iterator has been a compute at attach point |
713 | if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) { |
714 | break; |
715 | } |
716 | |
717 | const Iterator& it = stage->iters[iter_id]; |
718 | // Stop if we meet a reduce iterator or annotated iterator |
719 | if (it->iter_kind == IteratorKind::kReduction || |
720 | it->annotation != IteratorAnnotation::kNone) { |
721 | break; |
722 | } |
723 | |
724 | // Stop if the memory access is not continuous (vectorizable) |
725 | // Note: The check is too hard, so we use heuristic here |
726 | if (IsTiled(stage) && num_fusible != 0) { |
727 | // If the stage is tiled, then the memory access must not be continuous |
728 | // for the innermost two iterators |
729 | break; |
730 | } |
731 | |
732 | cum_length_prod *= GetExtent(it); |
733 | if (cum_length_prod > GetIntParam(policy->params, SketchParamKey::max_vectorize_size)) { |
734 | break; |
735 | } |
736 | |
737 | num_fusible++; |
738 | } |
739 | |
740 | if (num_fusible > 1) { |
741 | // Select a random range to fuse |
742 | num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); |
743 | } |
744 | |
745 | if (num_fusible == 1) { |
746 | state->vectorize(stage_id, stage->iters.back()); |
747 | } else if (num_fusible > 1) { |
748 | Array<Iterator> to_fuse(stage->iters.end() + (-num_fusible), stage->iters.end()); |
749 | state->vectorize(stage_id, state->fuse(stage_id, to_fuse)); |
750 | } |
751 | } |
752 | |
753 | return ResultKind::kValid; |
754 | } |
755 | |
756 | PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state, |
757 | std::mt19937* rand_gen) const { |
758 | // Collect all stages that are roots of stages that perform multi-level tiling. |
759 | std::set<int> multi_level_tiling_root_set; |
760 | for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { |
761 | if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { |
762 | const Stage& stage = (*state)->stages[stage_id]; |
763 | if (stage->compute_at == ComputeAtKind::kInlined) { |
764 | continue; |
765 | } else if (stage->compute_at != ComputeAtKind::kIter) { |
766 | // This stage is not multi-level tiled, |
767 | // so it must be produced by RuleCrossThreadReduction. |
768 | ICHECK(HasCrossThreadReduction(*state, stage_id)); |
769 | } else { |
770 | const auto res = (*state)->attach_map->stage_to_attach_iter.find(stage_id); |
771 | ICHECK(res != (*state)->attach_map->stage_to_attach_iter.end()); |
772 | multi_level_tiling_root_set.insert(res->second.first); |
773 | } |
774 | } |
775 | } |
776 | |
777 | *state = policy->search_task->compute_dag.InferBound(*state); |
778 | |
779 | for (int stage_id = (*state)->stages.size() - 1; stage_id >= 0; --stage_id) { |
780 | const Stage& stage = (*state)->stages[stage_id]; |
781 | |
782 | if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) { |
783 | continue; |
784 | } |
785 | |
786 | // Deal with the cross-thread reduction generated by RuleCrossThreadReduction |
787 | if (HasCrossThreadReduction(*state, stage_id)) { |
788 | if (stage->compute_at != ComputeAtKind::kRoot) { |
789 | continue; |
790 | } |
791 | |
792 | Iterator fused_it; |
793 | *state = std::move(FuseAllOuterSpaceIterators(*state, stage_id, &fused_it)); |
794 | state->bind(stage_id, fused_it, IteratorAnnotation::kBlockX); |
795 | continue; |
796 | } |
797 | |
798 | // Skip if this stage has already been annotaed with threadIdx.x |
799 | if (HasAnnotatedIter(stage, IteratorAnnotation::kThreadX)) { |
800 | continue; |
801 | } |
802 | |
803 | if (stage->compute_at == ComputeAtKind::kRoot) { |
804 | // This stage has not been tiled, but in GPU schedule, we must tile the root stage |
805 | // to do thread binding |
806 | if (!multi_level_tiling_root_set.count(stage_id)) { |
807 | Iterator fused_it; |
808 | *state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it); |
809 | |
810 | if (GetExtent(fused_it) <= policy->search_task->hardware_params->warp_size) { |
811 | state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX); |
812 | } else { |
813 | // Set threadIdx.x = default_warp_size by default. |
814 | // The later EvolutionarySearch will try more possibility |
815 | const auto& split_its = state->split( |
816 | stage_id, fused_it, {Integer(policy->search_task->hardware_params->warp_size)}); |
817 | state->bind(stage_id, split_its[0], IteratorAnnotation::kBlockX); |
818 | state->bind(stage_id, split_its[1], IteratorAnnotation::kThreadX); |
819 | } |
820 | continue; |
821 | } |
822 | |
823 | // Otherwise, this is a tiled root stage, we assume it should be tiled with 3 space level |
824 | // in the outer iterators. |
825 | // The remaining part deals with the thread binding for multi-level tiled stages |
826 | auto pop = stage->op.as<te::ComputeOpNode>(); |
827 | std::vector<Iterator> to_fuse; |
828 | int total_space_extent = 1; |
829 | for (const auto& i : pop->root_iter_vars()) { |
830 | ICHECK(i->dom.defined()); |
831 | const auto& pint = i->dom->extent.as<IntImmNode>(); |
832 | ICHECK(pint); |
833 | total_space_extent *= pint->value; |
834 | } |
835 | |
836 | bool check_min_thread_extent = true; |
837 | // If the total space extent is too small, disable the check of minimal thread extent |
838 | if (total_space_extent <= policy->search_task->hardware_params->warp_size * 2) { |
839 | check_min_thread_extent = false; |
840 | } |
841 | |
842 | // Fuse the outermost space tile as blockIdx |
843 | for (size_t i = 0; i < pop->axis.size(); i++) { |
844 | const auto& it = (*state)->stages[stage_id]->iters[i]; |
845 | // There may be some iterators that are marked with no split, stop if reaches next |
846 | // tiling level |
847 | if (!StrEndsWith(it->name, ".0" )) { |
848 | break; |
849 | } |
850 | to_fuse.push_back(it); |
851 | } |
852 | const auto& blockidx_it = state->fuse(stage_id, to_fuse); |
853 | state->bind(stage_id, blockidx_it, IteratorAnnotation::kBlockX); |
854 | |
855 | // Fuse the second outermost space tile as vthread |
856 | to_fuse.clear(); |
857 | for (size_t i = 1; i < pop->axis.size() + 1; i++) { |
858 | const auto& it = (*state)->stages[stage_id]->iters[i]; |
859 | // There may be some iterators that are marked with no split, stop if reaches next |
860 | // tiling level |
861 | if (!StrEndsWith(it->name, ".1" )) { |
862 | break; |
863 | } |
864 | to_fuse.push_back((*state)->stages[stage_id]->iters[i]); |
865 | } |
866 | const auto& vthread_it = state->fuse(stage_id, to_fuse); |
867 | if (GetExtent(vthread_it) > policy->search_task->hardware_params->max_vthread_extent) { |
868 | return ResultKind::kInvalid; |
869 | } |
870 | state->bind(stage_id, vthread_it, IteratorAnnotation::kVThread); |
871 | |
872 | // Fuse the third outermost space tile as threadIdx |
873 | to_fuse.clear(); |
874 | for (size_t i = 2; i < pop->axis.size() + 2; i++) { |
875 | const auto& it = (*state)->stages[stage_id]->iters[i]; |
876 | // There may be some iterators that are marked with no split, stop if reaches next |
877 | // tiling level |
878 | if (!StrEndsWith(it->name, ".2" )) { |
879 | break; |
880 | } |
881 | to_fuse.push_back((*state)->stages[stage_id]->iters[i]); |
882 | } |
883 | const auto& threadidx_it = state->fuse(stage_id, to_fuse); |
884 | if (check_min_thread_extent && |
885 | GetExtent(threadidx_it) < policy->search_task->hardware_params->warp_size) { |
886 | return ResultKind::kInvalid; |
887 | } |
888 | state->bind(stage_id, threadidx_it, IteratorAnnotation::kThreadX); |
889 | } else if (stage->compute_at == ComputeAtKind::kIter && |
890 | StrEndsWith(stage->op->name, ".shared" )) { |
891 | // Do cooperative fetching for the cache read stage. |
892 | // Get spatial_split_step_ids from the root stage |
893 | const auto& it = (*state)->attach_map->stage_to_attach_iter.find(stage_id); |
894 | ICHECK(it != (*state)->attach_map->stage_to_attach_iter.end()); |
895 | Array<Integer> spatial_split_step_ids = GetSpatialSplitStepIds(*state, it->second.first); |
896 | |
897 | // Fuse all iterators to do cooperative fetching |
898 | Iterator fused = state->fuse(stage_id, (*state)->stages[stage_id]->iters); |
899 | // Split out an extra iterator for vectorization |
900 | // The later EvolutionarySearch will try more possibility |
901 | const auto& iters0 = state->split(stage_id, fused, {Integer(1)}); |
902 | state->vectorize(stage_id, iters0[1]); |
903 | // Follow split to keep a same thread extent with the root stage |
904 | const auto& iters1 = |
905 | state->follow_fused_split(stage_id, iters0[0], spatial_split_step_ids, 1, true); |
906 | state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX); |
907 | } |
908 | } |
909 | return ResultKind::kValid; |
910 | } |
911 | |
912 | PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state, |
913 | std::mt19937* rand_gen) const { |
914 | int max_innermost_split_factor = |
915 | GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); |
916 | |
917 | // Extract all SplitStep |
918 | std::vector<size_t> split_step_ids; |
919 | for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { |
920 | if (auto ps = (*state)->transform_steps[i].as<SplitStepNode>()) { |
921 | if (!ps->extent.defined() || !ps->extent.value()->IsInstance<IntImmNode>()) { |
922 | continue; |
923 | } |
924 | auto innermost_factor = ps->lengths.back().value_or(max_innermost_split_factor + 1); |
925 | if (GetIntImm(innermost_factor) <= max_innermost_split_factor) { |
926 | split_step_ids.push_back(i); |
927 | } |
928 | } |
929 | } |
930 | if (split_step_ids.empty()) { |
931 | // No tile size could be mutated. |
932 | return ResultKind::kInvalid; |
933 | } |
934 | |
935 | // Select a SplitStep with extent larger than one to mutate. |
936 | int retry_ct = 0; |
937 | int64_t extent = 1; |
938 | int step_id; |
939 | const SplitStepNode* ps; |
940 | |
941 | do { |
942 | step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()]; |
943 | ps = (*state)->transform_steps[step_id].as<SplitStepNode>(); |
944 | ICHECK(ps != nullptr); |
945 | extent = GetIntImm(ps->extent.value()); |
946 | retry_ct += 1; |
947 | } while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && (extent == 1 || extent == 0)); |
948 | |
949 | if (extent <= 1) { |
950 | // Cannot find a step with extent larger than one. |
951 | return ResultKind::kInvalid; |
952 | } |
953 | |
954 | // Fetch the current tile sizes. |
955 | std::vector<int> lengths(ps->lengths.size() + 1, 1); |
956 | for (int i = 0; i < static_cast<int>(ps->lengths.size()); ++i) { |
957 | lengths[i + 1] = GetIntImm(ps->lengths[i].value()); |
958 | } |
959 | lengths[0] = extent / ElementProduct(lengths); |
960 | |
961 | // Random permute the tile size order. |
962 | std::vector<int> random_perm; |
963 | RandomPermutation(lengths.size(), &random_perm, rand_gen); |
964 | |
965 | // Try to divide a factor from one tile size and multiple it to another. |
966 | for (size_t i = 0; i < random_perm.size(); ++i) { |
967 | size_t src_idx = random_perm[i]; |
968 | int length = lengths[src_idx]; |
969 | if (length <= 1) { |
970 | continue; |
971 | } |
972 | |
973 | // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx] |
974 | size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; |
975 | const std::vector<int>& factors = policy->split_memo.GetFactors(length); |
976 | ICHECK_GE(factors.size(), 1); |
977 | |
978 | int divide_factor; |
979 | if (dst_idx == lengths.size() - 1) { |
980 | // Maintain the restriction of hardware_params.max_innermost_split_factor. |
981 | int max_factor_index = static_cast<int>(factors.size()) - 1; |
982 | for (; max_factor_index >= 1; max_factor_index--) { |
983 | if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { |
984 | break; |
985 | } |
986 | } |
987 | if (max_factor_index == 0) { |
988 | // Failed on this dst_idx, try next one. |
989 | continue; |
990 | } |
991 | divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)]; |
992 | } else { |
993 | divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)]; |
994 | } |
995 | |
996 | // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]. |
997 | Array<Integer> new_lengths; |
998 | for (size_t j = 1; j < lengths.size(); ++j) { |
999 | if (j == src_idx) { |
1000 | new_lengths.push_back(Integer(lengths[j] / divide_factor)); |
1001 | } else if (j == dst_idx) { |
1002 | new_lengths.push_back(Integer(lengths[j] * divide_factor)); |
1003 | } else { |
1004 | new_lengths.push_back(Integer(lengths[j])); |
1005 | } |
1006 | } |
1007 | |
1008 | ICHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor); |
1009 | |
1010 | StateNode* pstate = state->CopyOnWrite(); |
1011 | pstate->transform_steps.Set( |
1012 | step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent, |
1013 | Array<Optional<Integer>>(new_lengths.begin(), new_lengths.end()), |
1014 | ps->inner_to_outer)); |
1015 | return ResultKind::kValid; |
1016 | } |
1017 | return ResultKind::kInvalid; |
1018 | } |
1019 | |
1020 | PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state, |
1021 | std::mt19937* rand_gen) const { |
1022 | // Extract all auto_unroll_max_step pragma steps. |
1023 | std::vector<int> pragma_steps; |
1024 | for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { |
1025 | if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) { |
1026 | if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step" )) { |
1027 | pragma_steps.push_back(i); |
1028 | } |
1029 | } |
1030 | } |
1031 | if (pragma_steps.empty()) { |
1032 | return ResultKind::kInvalid; |
1033 | } |
1034 | |
1035 | std::vector<int>& auto_unroll_configs = |
1036 | IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; |
1037 | |
1038 | // Randomly pick up an auto unroll pragma step |
1039 | auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()]; |
1040 | auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>(); |
1041 | ICHECK(ps); |
1042 | |
1043 | // Mutate its value to a random candidates |
1044 | int val = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]; |
1045 | StateNode* pstate = state->CopyOnWrite(); |
1046 | pstate->transform_steps.Set( |
1047 | step_id, PragmaStep(ps->stage_id, ps->iter_id, |
1048 | std::string("auto_unroll_max_step" ) + "$" + std::to_string(val))); |
1049 | Stage new_stage = pstate->stages[ps->stage_id]; |
1050 | new_stage.CopyOnWrite()->attrs.auto_unroll_max_step = val; |
1051 | pstate->stages.Set(ps->stage_id, new_stage); |
1052 | return ResultKind::kValid; |
1053 | } |
1054 | |
1055 | PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, |
1056 | State* state, |
1057 | std::mt19937* rand_gen) const { |
1058 | if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { |
1059 | return ResultKind::kInvalid; |
1060 | } |
1061 | |
1062 | // Extract all compute_at steps. |
1063 | std::vector<int> compute_at_steps; |
1064 | for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { |
1065 | if (auto ps = (*state)->transform_steps[s].as<ComputeAtStepNode>()) { |
1066 | int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id; |
1067 | |
1068 | if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) { |
1069 | continue; |
1070 | } |
1071 | |
1072 | if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + stage_inc)) { |
1073 | continue; |
1074 | } |
1075 | compute_at_steps.push_back(s); |
1076 | } |
1077 | } |
1078 | if (compute_at_steps.empty()) { |
1079 | return ResultKind::kInvalid; |
1080 | } |
1081 | |
1082 | // Randomly pick one step |
1083 | size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()]; |
1084 | auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>(); |
1085 | int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id; |
1086 | ICHECK(ps != nullptr); |
1087 | |
1088 | // Randomly pick a new computation location |
1089 | std::vector<std::pair<int, int>> candidates = |
1090 | GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + stage_inc); |
1091 | if (candidates.empty()) { |
1092 | return ResultKind::kInvalid; |
1093 | } |
1094 | int choice = (*rand_gen)() % (candidates.size()); |
1095 | int new_compute_at_stage_id = candidates[choice].first; |
1096 | int new_compute_at_iter_id = candidates[choice].second; |
1097 | |
1098 | // Replay a new state. |
1099 | State tmp_s = policy->search_task->compute_dag->init_state; |
1100 | for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { |
1101 | if (s == step_id) { |
1102 | tmp_s.CopyOnWrite()->transform_steps.push_back( |
1103 | ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, new_compute_at_iter_id)); |
1104 | } else { |
1105 | tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]); |
1106 | } |
1107 | try { |
1108 | StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag); |
1109 | } catch (Error& e) { |
1110 | return ResultKind::kInvalid; |
1111 | } |
1112 | } |
1113 | |
1114 | *state = tmp_s; |
1115 | return ResultKind::kValid; |
1116 | } |
1117 | |
1118 | PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state, |
1119 | std::mt19937* rand_gen) const { |
1120 | // This mutation rule only focuses on a case that parallel was added to |
1121 | // the outermost loop and the loop is generated by fusing other loops. |
1122 | // In short, we mutate the fusion step before the parallel step. |
1123 | |
1124 | // Extract all parallel steps. |
1125 | std::vector<int> parallel_steps; |
1126 | for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { |
1127 | auto ps = (*state)->transform_steps[s].as<AnnotationStepNode>(); |
1128 | if (!ps || ps->annotation != IteratorAnnotation::kParallel) { |
1129 | continue; |
1130 | } |
1131 | |
1132 | // Skip non-outermost loop or the parallel step without fusion beforehand. |
1133 | if (ps->iter_id != 0 || s == 0 || !(*state)->transform_steps[s - 1].as<FuseStepNode>()) { |
1134 | continue; |
1135 | } |
1136 | auto fuse_step = (*state)->transform_steps[s - 1].as<FuseStepNode>(); |
1137 | if (fuse_step->fused_ids[0] != 0) { |
1138 | continue; |
1139 | } |
1140 | |
1141 | parallel_steps.push_back(s); |
1142 | } |
1143 | if (parallel_steps.empty()) { |
1144 | return ResultKind::kInvalid; |
1145 | } |
1146 | |
1147 | // Randomly pick one parallel step. |
1148 | size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()]; |
1149 | |
1150 | // Replay a new state until the picked fuse step. |
1151 | State tmp_s = policy->search_task->compute_dag->init_state; |
1152 | for (size_t s = 0; s < step_id - 1; ++s) { |
1153 | const auto& step = (*state)->transform_steps[s]; |
1154 | tmp_s.CopyOnWrite()->transform_steps.push_back(step); |
1155 | StepApplyToState(step, &tmp_s, policy->search_task->compute_dag); |
1156 | } |
1157 | |
1158 | // Compute all possible fusion granularities |
1159 | auto fuse_step = (*state)->transform_steps[step_id - 1].as<FuseStepNode>(); |
1160 | int stage_id = fuse_step->stage_id; |
1161 | const Stage& stage = tmp_s->stages[stage_id]; |
1162 | size_t max_fusable_iter_id; |
1163 | for (max_fusable_iter_id = 0; max_fusable_iter_id < stage->iters.size(); ++max_fusable_iter_id) { |
1164 | const Iterator& it = stage->iters[max_fusable_iter_id]; |
1165 | if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) { |
1166 | break; |
1167 | } |
1168 | |
1169 | if (tmp_s->attach_map->iter_to_attached_stages.count( |
1170 | std::make_pair(stage_id, max_fusable_iter_id))) { |
1171 | break; |
1172 | } |
1173 | } |
1174 | |
1175 | if (max_fusable_iter_id == 0) { |
1176 | return ResultKind::kInvalid; |
1177 | } |
1178 | |
1179 | // Randomly pick one granularity |
1180 | int fuse_to_iter_id = (*rand_gen)() % max_fusable_iter_id + 1; |
1181 | Array<Integer> fused_ids; |
1182 | for (int i = 0; i < fuse_to_iter_id; ++i) { |
1183 | fused_ids.push_back(i); |
1184 | } |
1185 | int iter_offset = fuse_step->fused_ids.back()->value - fused_ids.back()->value; |
1186 | if (iter_offset == 0) { |
1187 | return ResultKind::kInvalid; |
1188 | } |
1189 | |
1190 | // Replay the mutated fused and annotation step. |
1191 | auto new_fuse_step = FuseStep(stage_id, fused_ids); |
1192 | tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); |
1193 | StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag); |
1194 | tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]); |
1195 | StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag); |
1196 | |
1197 | // Replay the rest steps. |
1198 | for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) { |
1199 | auto step = (*state)->transform_steps[s]; |
1200 | if (step->stage_id == stage_id) { |
1201 | // Since we changed the loop structure, iter ID in later steps to the same stage |
1202 | // has to be adjusted. |
1203 | if (auto ps = step.as<AnnotationStepNode>()) { |
1204 | if (ps->iter_id == 0) { |
1205 | step = AnnotationStep(ps->stage_id, 0, ps->annotation); |
1206 | } else { |
1207 | ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); |
1208 | step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); |
1209 | } |
1210 | } else if (auto ps = step.as<PragmaStepNode>()) { |
1211 | if (ps->iter_id == 0) { |
1212 | step = PragmaStep(ps->stage_id, 0, ps->pragma_type); |
1213 | } else { |
1214 | ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); |
1215 | step = PragmaStep(ps->stage_id, ps->iter_id + iter_offset, ps->pragma_type); |
1216 | } |
1217 | } else { |
1218 | return ResultKind::kInvalid; |
1219 | } |
1220 | } |
1221 | if (IsStageNumberChangingStep(step)) { |
1222 | // For these steps, we have to update stage_id because these steps will make stage_id |
1223 | // out-dated. But here we just simply give up this mutation for simplicity. |
1224 | // This is not an issue because this will never happend in normal cases where all these steps |
1225 | // are before parallel steps. |
1226 | return ResultKind::kInvalid; |
1227 | } |
1228 | tmp_s.CopyOnWrite()->transform_steps.push_back(step); |
1229 | try { |
1230 | StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag); |
1231 | } catch (Error& e) { |
1232 | return ResultKind::kInvalid; |
1233 | } |
1234 | } |
1235 | |
1236 | *state = tmp_s; |
1237 | return ResultKind::kValid; |
1238 | } |
1239 | |
1240 | } // namespace auto_scheduler |
1241 | } // namespace tvm |
1242 | |