1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/sketch_policy_rules.cc
22 * \brief Rules for generating the sketches, sampling the initial population, and mutating the
23 * population in SketchPolicy.
24 */
25
26#include "sketch_policy_rules.h"
27
28#include <set>
29#include <string>
30#include <utility>
31#include <vector>
32
33#include "sketch_policy.h"
34
35namespace tvm {
36namespace auto_scheduler {
37
38static std::vector<int> auto_unroll_configs_cpu = {0, 16, 64, 512};
39static std::vector<int> auto_unroll_configs_gpu = {0, 16, 64, 512, 1024};
40
41/********** Sketch Generation Rule **********/
42/********** RuleSkipStage **********/
43
44SketchGenerationRule::ConditionKind RuleSkipStage::MeetCondition(const SketchPolicyNode& policy,
45 const State& state,
46 int stage_id) const {
47 // This rule should be the last rule, always return true to decrease the stage index count
48 return ConditionKind::kApply;
49}
50
51std::vector<std::pair<State, int>> RuleSkipStage::Apply(const SketchPolicyNode& policy,
52 const State& state, int stage_id) const {
53 return {std::make_pair(state, stage_id - 1)};
54}
55
56/********** RuleAlwaysInline **********/
57inline bool ShouldAlwaysBeInlined(const SketchPolicyNode& policy, const State& state,
58 int stage_id) {
59 const SearchTask& task = policy.search_task;
60 const Stage& stage = state->stages[stage_id];
61
62 // Check the inline limitation of TE
63 if (stage->op_type == StageKind::kPlaceholder || IsOutputOp(task, state, stage_id) ||
64 HasReduceIter(stage)) {
65 return false;
66 }
67
68 if (IsGPUTask(task)) { // Greedily inline all inlinable ops on gpu
69 return true;
70 } else {
71 // Only always-inline strict-inlinable ops on cpu.
72 // The computation location of other ops will be tuned by InitChangeComputeLocation
73 // and MutateComputeLocation.
74 return IsStrictlyInlineable(task, state, stage_id);
75 }
76}
77
78SketchGenerationRule::ConditionKind RuleAlwaysInline::MeetCondition(const SketchPolicyNode& policy,
79 const State& state,
80 int stage_id) const {
81 return ShouldAlwaysBeInlined(policy, state, stage_id) ? ConditionKind::kApplyAndSkipRest
82 : ConditionKind::kSkip;
83}
84
85std::vector<std::pair<State, int>> RuleAlwaysInline::Apply(const SketchPolicyNode& policy,
86 const State& state, int stage_id) const {
87 State tmp_s = state;
88 tmp_s.compute_inline(stage_id);
89 return {std::make_pair(std::move(tmp_s), stage_id - 1)};
90}
91
92/********** RuleMultiLevelTiling **********/
93
94SketchGenerationRule::ConditionKind RuleMultiLevelTiling::MeetCondition(
95 const SketchPolicyNode& policy, const State& state, int stage_id) const {
96 return NeedsMultilevelTiling(policy.search_task, state, stage_id)
97 ? ConditionKind::kApplyAndSkipRest
98 : ConditionKind::kSkip;
99}
100
101std::vector<std::pair<State, int>> RuleMultiLevelTiling::Apply(const SketchPolicyNode& policy,
102 const State& state,
103 int stage_id) const {
104 const std::string& multi_level_tiling_structure =
105 IsGPUTask(policy.search_task)
106 ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
107 : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
108 State tmp_s = DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure);
109 return {std::make_pair(std::move(tmp_s), stage_id - 1)};
110}
111
112/********** RuleMultiLevelTilingWithFusion **********/
113
114SketchGenerationRule::ConditionKind RuleMultiLevelTilingWithFusion::MeetCondition(
115 const SketchPolicyNode& policy, const State& state, int stage_id) const {
116 if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
117 HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
118 // Always do fusion for stage with cache_write or is in GPU policy
119 return HasCacheWriteStage(state, stage_id) || IsGPUTask(policy.search_task)
120 ? ConditionKind::kApplyAndSkipRest
121 : ConditionKind::kApply;
122 }
123 return ConditionKind::kSkip;
124}
125
126std::vector<std::pair<State, int>> RuleMultiLevelTilingWithFusion::Apply(
127 const SketchPolicyNode& policy, const State& state, int stage_id) const {
128 int target_stage_id;
129 ICHECK(
130 HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id));
131 const std::string& multi_level_tiling_structure =
132 IsGPUTask(policy.search_task)
133 ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
134 : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
135 std::vector<int> spatial_split_step_ids;
136 State base_state =
137 DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids);
138
139 std::vector<std::pair<State, int>> ret;
140 std::vector<int> follow_tiling_levels =
141 IsGPUTask(policy.search_task) ? std::vector<int>{3} : std::vector<int>{1, 2};
142 for (int level : follow_tiling_levels) {
143 if (tolower(multi_level_tiling_structure[level - 1]) != 's') {
144 continue;
145 }
146 State tmp_s = base_state;
147 tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level);
148 const Iterator& target_iter =
149 tmp_s->stages[target_stage_id]->iters[level * spatial_split_step_ids.size() - 1];
150 tmp_s.compute_at(stage_id, target_stage_id, target_iter);
151 ret.emplace_back(std::move(tmp_s), stage_id - 1);
152 }
153
154 return ret;
155}
156
157/********** RuleAddCacheRead **********/
158
159SketchGenerationRule::ConditionKind RuleAddCacheRead::MeetCondition(const SketchPolicyNode& policy,
160 const State& state,
161 int stage_id) const {
162 const SearchTask& task = policy.search_task;
163
164 // Don't cache_read a stage if it has multiple consumers
165 const std::set<int>& consumers = GetConsumers(task, state, stage_id);
166
167 if (consumers.size() == 0) return ConditionKind::kSkip;
168 // Don't cache_read a stage if its consumer does not need multi-level tiling
169 int target_stage_id = *consumers.begin();
170 if (!NeedsMultilevelTiling(task, state, target_stage_id)) {
171 return ConditionKind::kSkip;
172 }
173
174 // Don't cache_read a stage if its consumer does cross-thread reduction
175 if (HasCrossThreadReduction(state, target_stage_id)) {
176 return ConditionKind::kSkip;
177 }
178
179 // Only direct producers can be cache read
180 const std::set<int>& producers = GetDirectProducers(task, state, target_stage_id);
181 if (producers.find(stage_id) == producers.end()) {
182 return ConditionKind::kSkip;
183 }
184
185 return ConditionKind::kApplyAndSkipRest;
186}
187
188std::vector<std::pair<State, int>> RuleAddCacheRead::Apply(const SketchPolicyNode& policy,
189 const State& state, int stage_id) const {
190 const SearchTask& task = policy.search_task;
191 const std::set<int>& consumers = GetConsumers(task, state, stage_id);
192 State tmp_s = state;
193
194 int target_stage_id_offset = 0;
195 for (int orig_target_stage_id : consumers) {
196 int target_stage_id = orig_target_stage_id + target_stage_id_offset;
197
198 // Cache read add shared memory
199 int added_stage_id = tmp_s.cache_read(stage_id, "shared", {target_stage_id}, task->compute_dag);
200 target_stage_id_offset++;
201 target_stage_id++;
202
203 const auto& share_read_pos =
204 GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
205 tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
206 }
207
208 return {std::make_pair(tmp_s, stage_id)};
209}
210
211/********** RuleAddCacheWrite **********/
212
213SketchGenerationRule::ConditionKind RuleAddCacheWrite::MeetCondition(const SketchPolicyNode& policy,
214 const State& state,
215 int stage_id) const {
216 // Add cache write if a stage needs multi-level tiling, but does not have a element-wise
217 // matched consumer
218 if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
219 !HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
220 // An apply and skip rule will be handled in RuleMultiLevelTilingWithFusion
221 return IsGPUTask(policy.search_task) ? ConditionKind::kApplyAndSkipRest : ConditionKind::kApply;
222 }
223
224 return ConditionKind::kSkip;
225}
226
227std::vector<std::pair<State, int>> RuleAddCacheWrite::Apply(const SketchPolicyNode& policy,
228 const State& state,
229 int stage_id) const {
230 State tmp_s = state;
231 tmp_s.cache_write(stage_id, "local", policy.search_task->compute_dag);
232 return {std::make_pair(std::move(tmp_s), stage_id)};
233}
234
235/********** RuleAddRfactor **********/
236
237SketchGenerationRule::ConditionKind RuleAddRfactor::MeetCondition(const SketchPolicyNode& policy,
238 const State& state,
239 int stage_id) const {
240 return (NeedsRfactor(policy.search_task, state, stage_id) && !HasCacheWriteStage(state, stage_id))
241 ? ConditionKind::kApply
242 : ConditionKind::kSkip;
243}
244
245std::vector<std::pair<State, int>> RuleAddRfactor::Apply(const SketchPolicyNode& policy,
246 const State& state, int stage_id) const {
247 // Fuse all reduction iters
248 Array<Iterator> space_iters, reduce_iters;
249 Iterator fused_reduce_iter;
250 State base_state =
251 FuseAllReductionIterators(state, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters);
252
253 // TODO(merrymercy): We can do more analysis here to generate less and more efficient sketches.
254 // In some cases, we only need rfactor for more parallel
255 // In some cases, we only need rfactor for vectorization.
256 // Now we will generate two versions and let the search figure out the bette one.
257
258 // Split reduction iters
259 const auto& split_res = base_state.split(stage_id, fused_reduce_iter, {Integer(1)});
260 int factor_axis_id = static_cast<int>(space_iters.size());
261 std::vector<std::pair<State, int>> ret;
262 for (const auto& split_iter : split_res) {
263 State tmp_s = base_state;
264 int rstage_id =
265 tmp_s.rfactor(stage_id, split_iter, factor_axis_id, policy.search_task->compute_dag);
266
267 // reorder the space iterator to innermost for vectorization
268 if (split_iter == split_res[1]) {
269 Array<Iterator> new_order;
270 for (size_t i = 0; i < tmp_s->stages[rstage_id]->iters.size(); ++i) {
271 if (i != space_iters.size()) {
272 new_order.push_back(tmp_s->stages[rstage_id]->iters[i]);
273 }
274 }
275 new_order.push_back(tmp_s->stages[rstage_id]->iters[space_iters.size()]);
276 tmp_s.reorder(rstage_id, new_order);
277 }
278
279 ret.emplace_back(std::move(tmp_s), rstage_id - 1);
280 }
281
282 return ret;
283}
284
285/********** RuleSimplifyComputeWithConstTensor **********/
286
287SketchGenerationRule::ConditionKind RuleSimplifyComputeWithConstTensor::MeetCondition(
288 const SketchPolicyNode& policy, const State& state, int stage_id) const {
289 return state->stages[stage_id]->op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)
290 ? ConditionKind::kApplyAndSkipRest
291 : ConditionKind::kSkip;
292}
293
294std::vector<std::pair<State, int>> RuleSimplifyComputeWithConstTensor::Apply(
295 const SketchPolicyNode& policy, const State& state, int stage_id) const {
296 std::set<std::string> const_tensor_indices = GetIterNameSetParam(
297 state->stages[stage_id]->op->attrs, SearchPolicyKey::simplify_const_tensor_indices);
298
299 State tmp_s = state;
300 Array<Array<Iterator>> tiled_outer_iters;
301 Array<Iterator> unrolled_inner_iters;
302
303 // Currently set to 2
304 size_t tile_level = 2;
305
306 for (const auto& iter : state->stages[stage_id]->iters) {
307 if (const_tensor_indices.count(iter->name)) {
308 // unroll indices of const tensors
309 unrolled_inner_iters.push_back(tmp_s.unroll(stage_id, iter));
310 } else {
311 // tile other space indices
312 ICHECK(iter->iter_kind == IteratorKind::kSpatial);
313 tiled_outer_iters.push_back(
314 tmp_s.split(stage_id, iter, Array<Optional<Integer>>(tile_level - 1, NullOpt)));
315 }
316 }
317
318 // reorder them
319 Array<Iterator> new_order;
320 for (size_t i = 0; i < tile_level; ++i) {
321 for (size_t j = 0; j < tiled_outer_iters.size(); ++j) {
322 new_order.push_back(tiled_outer_iters[j][i]);
323 }
324 }
325 new_order.insert(new_order.end(), unrolled_inner_iters.begin(), unrolled_inner_iters.end());
326 tmp_s.reorder(stage_id, new_order);
327
328 return {std::make_pair(tmp_s, stage_id - 1)};
329}
330
331/********** RuleCrossThreadReduction **********/
332
333SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
334 const SketchPolicyNode& policy, const State& state, int stage_id) const {
335 ICHECK(IsGPUTask(policy.search_task));
336
337 // If it is an intermediate state created by RuleAddCacheWrite,
338 // we just skip it.
339 if (HasCacheWriteStage(state, stage_id)) {
340 return ConditionKind::kSkip;
341 }
342
343 const auto& op = state->stages[stage_id]->op;
344 if (op->IsInstance<te::ComputeOpNode>()) {
345 // Compute the product of lengths of all space iters and all reduce iters
346 auto [cum_space_len, cum_reduce_len] =
347 GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
348
349 if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
350 // Avoid rfactor if we have enough parallelism on space iters
351 if (cum_space_len > policy.search_task->hardware_params->max_threads_per_block) {
352 return ConditionKind::kSkip;
353 }
354
355 return cum_space_len < cum_reduce_len ? ConditionKind::kApply : ConditionKind::kSkip;
356 } else if (cum_reduce_len > 1) {
357 // Try rfactor for other reduction operators
358 return cum_reduce_len > policy.search_task->hardware_params->warp_size ? ConditionKind::kApply
359 : ConditionKind::kSkip;
360 }
361 }
362
363 return ConditionKind::kSkip;
364}
365
366std::vector<std::pair<State, int>> RuleCrossThreadReduction::Apply(const SketchPolicyNode& policy,
367 const State& state,
368 int stage_id) const {
369 const SearchTask& task = policy.search_task;
370 State tmp_s = state;
371
372 // fuse all reduction iters
373 Array<Iterator> space_iters, reduce_iters;
374 Iterator fused_reduce_iter;
375 tmp_s =
376 FuseAllReductionIterators(tmp_s, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters);
377
378 // Check the opportunity for kernel fusion
379 bool fusible = false;
380 int target_stage_id = GetSingleConsumerId(policy.search_task, tmp_s, stage_id);
381 int num_common_outer = -1;
382 if (target_stage_id >= 0) {
383 num_common_outer =
384 GetNumCommonOuterIterator(policy.search_task, tmp_s, stage_id, target_stage_id);
385 if (num_common_outer > 0 &&
386 !NeedsMultilevelTiling(policy.search_task, state, target_stage_id)) {
387 fusible = true;
388 }
389 }
390
391 if (fusible) {
392 const Stage& target_stage = state->stages[target_stage_id];
393 std::vector<int> split_step_ids;
394
395 GetSplitStepIds(tmp_s, target_stage_id, &split_step_ids);
396
397 if (split_step_ids.size() == 0) {
398 // If the target stage does not have split step,
399 // it must be a simple stage without reduce iters.
400 // We then should do a split for it.
401 ICHECK(!HasReduceIter(target_stage));
402 const auto& split_res = tmp_s.split(target_stage_id, target_stage->iters.back(),
403 {Integer(task->hardware_params->warp_size)});
404 tmp_s.bind(target_stage_id, split_res[1], IteratorAnnotation::kThreadX);
405 split_step_ids.push_back(tmp_s->transform_steps.size() - 2);
406 }
407
408 ICHECK_EQ(split_step_ids.size(), 1);
409
410 const Iterator& target_iter = tmp_s->stages[target_stage_id]->iters[num_common_outer - 1];
411 const auto& split_res = tmp_s.follow_split(stage_id, fused_reduce_iter, split_step_ids[0], 1);
412 tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
413 tmp_s.compute_at(stage_id, target_stage_id, target_iter);
414 } else {
415 const auto& split_res =
416 tmp_s.split(stage_id, fused_reduce_iter, {Integer(task->hardware_params->warp_size)});
417 tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
418 }
419
420 return {std::make_pair(std::move(tmp_s), stage_id - 1)};
421}
422
423/********** RuleSpecialComputeLocationGPU **********/
424
425SketchGenerationRule::ConditionKind RuleSpecialComputeLocationGPU::MeetCondition(
426 const SketchPolicyNode& policy, const State& state, int stage_id) const {
427 if (GetProducers(policy.search_task, state, stage_id).empty()) {
428 return ConditionKind::kSkip;
429 }
430
431 if (!ShouldAlwaysBeInlined(policy, state, stage_id)) {
432 return ConditionKind::kSkip;
433 }
434
435 const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
436 if (consumers.size() == 1 && state->stages[*consumers.begin()]->op->attrs.count(
437 SearchPolicyKey::simplify_const_tensor_indices)) {
438 return ConditionKind::kApplyAndSkipRest;
439 }
440
441 return ConditionKind::kSkip;
442}
443
444std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
445 const SketchPolicyNode& policy, const State& state, int stage_id) const {
446 State tmp_s = state;
447 const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
448 ICHECK_EQ(consumers.size(), 1);
449
450 // Get the last outer space iterator that is not unrolled.
451 const Stage& target_stage = state->stages[*consumers.begin()];
452 for (size_t i = 0; i < target_stage->iters.size(); ++i) {
453 if (target_stage->iters[i]->annotation == IteratorAnnotation::kUnroll) {
454 ICHECK_GT(i, 0);
455
456 tmp_s.compute_at(stage_id, *consumers.begin(), target_stage->iters[i - 1]);
457 break;
458 }
459 }
460
461 return {std::make_pair(std::move(tmp_s), stage_id - 1)};
462}
463
464/********** RuleCustomSketch **********/
465
466SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy,
467 const State& state,
468 int stage_id) const {
469 auto ret = meet_condition_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
470 if (ret.type_code() == 0) {
471 return ConditionKind(static_cast<int>(ret));
472 } else {
473 LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest";
474 return ConditionKind::kApplyAndSkipRest;
475 }
476}
477
478std::vector<std::pair<State, int>> RuleCustomSketch::Apply(const SketchPolicyNode& policy,
479 const State& state, int stage_id) const {
480 Array<Array<ObjectRef>> apply_ret =
481 apply_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
482 std::vector<std::pair<State, int>> ret;
483 for (const auto& item : apply_ret) {
484 CHECK_EQ(item.size(), 2);
485 auto next = item[1].as<IntImmNode>();
486 ret.emplace_back(Downcast<State>(item[0]), next->value);
487 }
488 return ret;
489}
490
491/********** Init Population **********/
492
493PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
494 std::mt19937* rand_gen) const {
495 SplitFactorizationMemo split_memo;
496 int max_innermost_split_factor =
497 GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
498
499 StateNode* pstate = state->CopyOnWrite();
500 // Scan the transformation history and randomly fill tiles size for all SplitStep
501 for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
502 if (auto ps = (*state)->transform_steps[step_id].as<SplitStepNode>()) {
503 bool all_defined = true;
504 for (const auto& len : ps->lengths) {
505 if (!len) {
506 all_defined = false;
507 break;
508 }
509 }
510 if (all_defined) {
511 continue;
512 }
513
514 ICHECK(ps->extent);
515 int extent = GetIntImm(ps->extent.value());
516 const auto& candidate_lens = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(),
517 max_innermost_split_factor);
518 ICHECK(!candidate_lens.empty());
519 const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];
520
521 pstate->transform_steps.Set(
522 step_id,
523 SplitStep(ps->stage_id, ps->iter_id, ps->extent,
524 Array<Optional<Integer>>(candidate_lengths.begin(), candidate_lengths.end()),
525 ps->inner_to_outer));
526 }
527 }
528 pstate->concrete = true;
529
530 return ResultKind::kValid;
531}
532
533PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(
534 SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const {
535 if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
536 return ResultKind::kValid;
537 }
538
539 for (int stage_id = static_cast<int>((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) {
540 const Stage& stage = (*state)->stages[stage_id];
541 // Skip the inlined stages and placeholders
542 if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) {
543 continue;
544 }
545 // Skip the tiled stages
546 if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
547 continue;
548 }
549
550 std::vector<std::pair<int, int>> candidates =
551 GetComputeLocationCandidates(policy->search_task, *state, stage_id);
552
553 int choice = (*rand_gen)() % (candidates.size() + 2);
554
555 if (choice == 0) {
556 if (!HasReduceIter(stage)) {
557 const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter;
558 if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) {
559 state->compute_inline(stage_id);
560 }
561 }
562 } else if (choice == 1) {
563 state->compute_root(stage_id);
564 } else {
565 choice = choice - 2;
566 const Stage& stage = (*state)->stages[candidates[choice].first];
567 state->compute_at(stage_id, candidates[choice].first,
568 stage->iters[candidates[choice].second]);
569 }
570 }
571
572 try {
573 *state = policy->search_task->compute_dag.InferBound(*state);
574 } catch (std::exception& e) {
575 return ResultKind::kInvalid;
576 }
577 return ResultKind::kValid;
578}
579
580PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state,
581 std::mt19937* rand_gen) const {
582 std::function<void(const SketchPolicyNode&, State*, int stage_id, int iter_offset)>
583 annotate_parallel;
584 annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state,
585 int stage_id, int iter_offset) {
586 const Stage& stage = (*state)->stages[stage_id];
587
588 Array<Iterator> to_fuse;
589 int64_t parallel_degree = 1;
590
591 // Try to fuse and parallel the outermost n iterators
592 // Stop if we meet reduce iterator or we have enough parallel degree
593 size_t iter_id = iter_offset;
594 for (; iter_id < stage->iters.size(); ++iter_id) {
595 const Iterator& it = stage->iters[iter_id];
596 if (it->iter_kind == IteratorKind::kReduction ||
597 it->annotation != IteratorAnnotation::kNone) {
598 break;
599 }
600 to_fuse.push_back(it);
601 parallel_degree *= GetExtent(it);
602
603 if (parallel_degree > policy.search_task->hardware_params->num_cores * 16) {
604 break;
605 }
606
607 if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) {
608 break;
609 }
610 }
611
612 if (parallel_degree == 1) {
613 auto res =
614 (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
615 if (res != (*state)->attach_map->iter_to_attached_stages.end()) {
616 for (int attached_stage_id : res->second) {
617 annotate_parallel(policy, state, attached_stage_id, 0);
618 }
619 annotate_parallel(policy, state, stage_id, iter_id + 1);
620 }
621 }
622
623 if (!to_fuse.empty()) {
624 if (to_fuse.size() == 1) {
625 state->parallel(stage_id, to_fuse[0]);
626 } else {
627 Iterator fused_iter = state->fuse(stage_id, to_fuse);
628 state->parallel(stage_id, fused_iter);
629 }
630 }
631 };
632
633 for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
634 const Stage& stage = (*state)->stages[stage_id];
635 if (stage->compute_at != ComputeAtKind::kRoot || stage->op_type == StageKind::kPlaceholder) {
636 continue;
637 }
638
639 annotate_parallel(*policy, state, stage_id, 0);
640 }
641
642 return ResultKind::kValid;
643}
644
645PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state,
646 std::mt19937* rand_gen) const {
647 std::vector<int>& auto_unroll_configs =
648 IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;
649 for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
650 const Stage& stage = (*state)->stages[stage_id];
651 // Skip the inlined stage and placeholder stage
652 if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
653 continue;
654 }
655
656 // Handle always_unroll_inner attr
657 if (stage->op->attrs.count(SearchPolicyKey::always_unroll_inner)) {
658 const auto& to_unroll_name_set =
659 GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::always_unroll_inner);
660
661 // Unroll the space iterators and reduce iterators listed in the attrs in the innermost
662 // tile
663 std::set<std::string> visited_names;
664 for (int n = static_cast<int>(stage->iters.size()) - 1; n >= 0; n--) {
665 const Iterator& it = stage->iters[n];
666
667 // If we meet two iterators that come from a same original iterator,
668 // then we are out of the innermost tile
669 size_t size_before = visited_names.size();
670 ExtractOriginalIterators(it->name, &visited_names);
671 if (size_before == visited_names.size()) {
672 break;
673 }
674
675 std::set<std::string> name;
676 ExtractOriginalIterators(it->name, &name);
677 if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) {
678 if (it->annotation == IteratorAnnotation::kNone) {
679 state->unroll(stage_id, it);
680 }
681 }
682 }
683 }
684
685 if (HasReduceIter(stage)) {
686 // Use auto unroll for multi level tiled stage
687 int value = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()];
688 state->pragma(stage_id, (*state)->stages[stage_id]->iters[0],
689 std::string("auto_unroll_max_step") + "$" + std::to_string(value));
690 }
691 }
692
693 return ResultKind::kValid;
694}
695
696PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy,
697 State* state,
698 std::mt19937* rand_gen) const {
699 for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
700 const Stage& stage = (*state)->stages[stage_id];
701 // Skip the inlined stage and placeholder stage
702 if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
703 continue;
704 }
705
706 // Try to fuse and vectorize the space iterators in the inner most tile
707 int64_t cum_length_prod = 1;
708
709 int num_fusible = 0;
710 while (num_fusible < static_cast<int>(stage->iters.size())) {
711 int iter_id = static_cast<int>(stage->iters.size()) - 1 - num_fusible;
712 // Stop if this iterator has been a compute at attach point
713 if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) {
714 break;
715 }
716
717 const Iterator& it = stage->iters[iter_id];
718 // Stop if we meet a reduce iterator or annotated iterator
719 if (it->iter_kind == IteratorKind::kReduction ||
720 it->annotation != IteratorAnnotation::kNone) {
721 break;
722 }
723
724 // Stop if the memory access is not continuous (vectorizable)
725 // Note: The check is too hard, so we use heuristic here
726 if (IsTiled(stage) && num_fusible != 0) {
727 // If the stage is tiled, then the memory access must not be continuous
728 // for the innermost two iterators
729 break;
730 }
731
732 cum_length_prod *= GetExtent(it);
733 if (cum_length_prod > GetIntParam(policy->params, SketchParamKey::max_vectorize_size)) {
734 break;
735 }
736
737 num_fusible++;
738 }
739
740 if (num_fusible > 1) {
741 // Select a random range to fuse
742 num_fusible = 1 + (*rand_gen)() % (num_fusible - 1);
743 }
744
745 if (num_fusible == 1) {
746 state->vectorize(stage_id, stage->iters.back());
747 } else if (num_fusible > 1) {
748 Array<Iterator> to_fuse(stage->iters.end() + (-num_fusible), stage->iters.end());
749 state->vectorize(stage_id, state->fuse(stage_id, to_fuse));
750 }
751 }
752
753 return ResultKind::kValid;
754}
755
756PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state,
757 std::mt19937* rand_gen) const {
758 // Collect all stages that are roots of stages that perform multi-level tiling.
759 std::set<int> multi_level_tiling_root_set;
760 for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
761 if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
762 const Stage& stage = (*state)->stages[stage_id];
763 if (stage->compute_at == ComputeAtKind::kInlined) {
764 continue;
765 } else if (stage->compute_at != ComputeAtKind::kIter) {
766 // This stage is not multi-level tiled,
767 // so it must be produced by RuleCrossThreadReduction.
768 ICHECK(HasCrossThreadReduction(*state, stage_id));
769 } else {
770 const auto res = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
771 ICHECK(res != (*state)->attach_map->stage_to_attach_iter.end());
772 multi_level_tiling_root_set.insert(res->second.first);
773 }
774 }
775 }
776
777 *state = policy->search_task->compute_dag.InferBound(*state);
778
779 for (int stage_id = (*state)->stages.size() - 1; stage_id >= 0; --stage_id) {
780 const Stage& stage = (*state)->stages[stage_id];
781
782 if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
783 continue;
784 }
785
786 // Deal with the cross-thread reduction generated by RuleCrossThreadReduction
787 if (HasCrossThreadReduction(*state, stage_id)) {
788 if (stage->compute_at != ComputeAtKind::kRoot) {
789 continue;
790 }
791
792 Iterator fused_it;
793 *state = std::move(FuseAllOuterSpaceIterators(*state, stage_id, &fused_it));
794 state->bind(stage_id, fused_it, IteratorAnnotation::kBlockX);
795 continue;
796 }
797
798 // Skip if this stage has already been annotaed with threadIdx.x
799 if (HasAnnotatedIter(stage, IteratorAnnotation::kThreadX)) {
800 continue;
801 }
802
803 if (stage->compute_at == ComputeAtKind::kRoot) {
804 // This stage has not been tiled, but in GPU schedule, we must tile the root stage
805 // to do thread binding
806 if (!multi_level_tiling_root_set.count(stage_id)) {
807 Iterator fused_it;
808 *state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it);
809
810 if (GetExtent(fused_it) <= policy->search_task->hardware_params->warp_size) {
811 state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX);
812 } else {
813 // Set threadIdx.x = default_warp_size by default.
814 // The later EvolutionarySearch will try more possibility
815 const auto& split_its = state->split(
816 stage_id, fused_it, {Integer(policy->search_task->hardware_params->warp_size)});
817 state->bind(stage_id, split_its[0], IteratorAnnotation::kBlockX);
818 state->bind(stage_id, split_its[1], IteratorAnnotation::kThreadX);
819 }
820 continue;
821 }
822
823 // Otherwise, this is a tiled root stage, we assume it should be tiled with 3 space level
824 // in the outer iterators.
825 // The remaining part deals with the thread binding for multi-level tiled stages
826 auto pop = stage->op.as<te::ComputeOpNode>();
827 std::vector<Iterator> to_fuse;
828 int total_space_extent = 1;
829 for (const auto& i : pop->root_iter_vars()) {
830 ICHECK(i->dom.defined());
831 const auto& pint = i->dom->extent.as<IntImmNode>();
832 ICHECK(pint);
833 total_space_extent *= pint->value;
834 }
835
836 bool check_min_thread_extent = true;
837 // If the total space extent is too small, disable the check of minimal thread extent
838 if (total_space_extent <= policy->search_task->hardware_params->warp_size * 2) {
839 check_min_thread_extent = false;
840 }
841
842 // Fuse the outermost space tile as blockIdx
843 for (size_t i = 0; i < pop->axis.size(); i++) {
844 const auto& it = (*state)->stages[stage_id]->iters[i];
845 // There may be some iterators that are marked with no split, stop if reaches next
846 // tiling level
847 if (!StrEndsWith(it->name, ".0")) {
848 break;
849 }
850 to_fuse.push_back(it);
851 }
852 const auto& blockidx_it = state->fuse(stage_id, to_fuse);
853 state->bind(stage_id, blockidx_it, IteratorAnnotation::kBlockX);
854
855 // Fuse the second outermost space tile as vthread
856 to_fuse.clear();
857 for (size_t i = 1; i < pop->axis.size() + 1; i++) {
858 const auto& it = (*state)->stages[stage_id]->iters[i];
859 // There may be some iterators that are marked with no split, stop if reaches next
860 // tiling level
861 if (!StrEndsWith(it->name, ".1")) {
862 break;
863 }
864 to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
865 }
866 const auto& vthread_it = state->fuse(stage_id, to_fuse);
867 if (GetExtent(vthread_it) > policy->search_task->hardware_params->max_vthread_extent) {
868 return ResultKind::kInvalid;
869 }
870 state->bind(stage_id, vthread_it, IteratorAnnotation::kVThread);
871
872 // Fuse the third outermost space tile as threadIdx
873 to_fuse.clear();
874 for (size_t i = 2; i < pop->axis.size() + 2; i++) {
875 const auto& it = (*state)->stages[stage_id]->iters[i];
876 // There may be some iterators that are marked with no split, stop if reaches next
877 // tiling level
878 if (!StrEndsWith(it->name, ".2")) {
879 break;
880 }
881 to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
882 }
883 const auto& threadidx_it = state->fuse(stage_id, to_fuse);
884 if (check_min_thread_extent &&
885 GetExtent(threadidx_it) < policy->search_task->hardware_params->warp_size) {
886 return ResultKind::kInvalid;
887 }
888 state->bind(stage_id, threadidx_it, IteratorAnnotation::kThreadX);
889 } else if (stage->compute_at == ComputeAtKind::kIter &&
890 StrEndsWith(stage->op->name, ".shared")) {
891 // Do cooperative fetching for the cache read stage.
892 // Get spatial_split_step_ids from the root stage
893 const auto& it = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
894 ICHECK(it != (*state)->attach_map->stage_to_attach_iter.end());
895 Array<Integer> spatial_split_step_ids = GetSpatialSplitStepIds(*state, it->second.first);
896
897 // Fuse all iterators to do cooperative fetching
898 Iterator fused = state->fuse(stage_id, (*state)->stages[stage_id]->iters);
899 // Split out an extra iterator for vectorization
900 // The later EvolutionarySearch will try more possibility
901 const auto& iters0 = state->split(stage_id, fused, {Integer(1)});
902 state->vectorize(stage_id, iters0[1]);
903 // Follow split to keep a same thread extent with the root stage
904 const auto& iters1 =
905 state->follow_fused_split(stage_id, iters0[0], spatial_split_step_ids, 1, true);
906 state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX);
907 }
908 }
909 return ResultKind::kValid;
910}
911
912PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state,
913 std::mt19937* rand_gen) const {
914 int max_innermost_split_factor =
915 GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
916
917 // Extract all SplitStep
918 std::vector<size_t> split_step_ids;
919 for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
920 if (auto ps = (*state)->transform_steps[i].as<SplitStepNode>()) {
921 if (!ps->extent.defined() || !ps->extent.value()->IsInstance<IntImmNode>()) {
922 continue;
923 }
924 auto innermost_factor = ps->lengths.back().value_or(max_innermost_split_factor + 1);
925 if (GetIntImm(innermost_factor) <= max_innermost_split_factor) {
926 split_step_ids.push_back(i);
927 }
928 }
929 }
930 if (split_step_ids.empty()) {
931 // No tile size could be mutated.
932 return ResultKind::kInvalid;
933 }
934
935 // Select a SplitStep with extent larger than one to mutate.
936 int retry_ct = 0;
937 int64_t extent = 1;
938 int step_id;
939 const SplitStepNode* ps;
940
941 do {
942 step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()];
943 ps = (*state)->transform_steps[step_id].as<SplitStepNode>();
944 ICHECK(ps != nullptr);
945 extent = GetIntImm(ps->extent.value());
946 retry_ct += 1;
947 } while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && (extent == 1 || extent == 0));
948
949 if (extent <= 1) {
950 // Cannot find a step with extent larger than one.
951 return ResultKind::kInvalid;
952 }
953
954 // Fetch the current tile sizes.
955 std::vector<int> lengths(ps->lengths.size() + 1, 1);
956 for (int i = 0; i < static_cast<int>(ps->lengths.size()); ++i) {
957 lengths[i + 1] = GetIntImm(ps->lengths[i].value());
958 }
959 lengths[0] = extent / ElementProduct(lengths);
960
961 // Random permute the tile size order.
962 std::vector<int> random_perm;
963 RandomPermutation(lengths.size(), &random_perm, rand_gen);
964
965 // Try to divide a factor from one tile size and multiple it to another.
966 for (size_t i = 0; i < random_perm.size(); ++i) {
967 size_t src_idx = random_perm[i];
968 int length = lengths[src_idx];
969 if (length <= 1) {
970 continue;
971 }
972
973 // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]
974 size_t dst_idx = random_perm[(i + 1) % random_perm.size()];
975 const std::vector<int>& factors = policy->split_memo.GetFactors(length);
976 ICHECK_GE(factors.size(), 1);
977
978 int divide_factor;
979 if (dst_idx == lengths.size() - 1) {
980 // Maintain the restriction of hardware_params.max_innermost_split_factor.
981 int max_factor_index = static_cast<int>(factors.size()) - 1;
982 for (; max_factor_index >= 1; max_factor_index--) {
983 if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) {
984 break;
985 }
986 }
987 if (max_factor_index == 0) {
988 // Failed on this dst_idx, try next one.
989 continue;
990 }
991 divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)];
992 } else {
993 divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)];
994 }
995
996 // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx].
997 Array<Integer> new_lengths;
998 for (size_t j = 1; j < lengths.size(); ++j) {
999 if (j == src_idx) {
1000 new_lengths.push_back(Integer(lengths[j] / divide_factor));
1001 } else if (j == dst_idx) {
1002 new_lengths.push_back(Integer(lengths[j] * divide_factor));
1003 } else {
1004 new_lengths.push_back(Integer(lengths[j]));
1005 }
1006 }
1007
1008 ICHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor);
1009
1010 StateNode* pstate = state->CopyOnWrite();
1011 pstate->transform_steps.Set(
1012 step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
1013 Array<Optional<Integer>>(new_lengths.begin(), new_lengths.end()),
1014 ps->inner_to_outer));
1015 return ResultKind::kValid;
1016 }
1017 return ResultKind::kInvalid;
1018}
1019
1020PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state,
1021 std::mt19937* rand_gen) const {
1022 // Extract all auto_unroll_max_step pragma steps.
1023 std::vector<int> pragma_steps;
1024 for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
1025 if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) {
1026 if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) {
1027 pragma_steps.push_back(i);
1028 }
1029 }
1030 }
1031 if (pragma_steps.empty()) {
1032 return ResultKind::kInvalid;
1033 }
1034
1035 std::vector<int>& auto_unroll_configs =
1036 IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;
1037
1038 // Randomly pick up an auto unroll pragma step
1039 auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()];
1040 auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
1041 ICHECK(ps);
1042
1043 // Mutate its value to a random candidates
1044 int val = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()];
1045 StateNode* pstate = state->CopyOnWrite();
1046 pstate->transform_steps.Set(
1047 step_id, PragmaStep(ps->stage_id, ps->iter_id,
1048 std::string("auto_unroll_max_step") + "$" + std::to_string(val)));
1049 Stage new_stage = pstate->stages[ps->stage_id];
1050 new_stage.CopyOnWrite()->attrs.auto_unroll_max_step = val;
1051 pstate->stages.Set(ps->stage_id, new_stage);
1052 return ResultKind::kValid;
1053}
1054
1055PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
1056 State* state,
1057 std::mt19937* rand_gen) const {
1058 if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
1059 return ResultKind::kInvalid;
1060 }
1061
1062 // Extract all compute_at steps.
1063 std::vector<int> compute_at_steps;
1064 for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
1065 if (auto ps = (*state)->transform_steps[s].as<ComputeAtStepNode>()) {
1066 int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id;
1067
1068 if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) {
1069 continue;
1070 }
1071
1072 if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + stage_inc)) {
1073 continue;
1074 }
1075 compute_at_steps.push_back(s);
1076 }
1077 }
1078 if (compute_at_steps.empty()) {
1079 return ResultKind::kInvalid;
1080 }
1081
1082 // Randomly pick one step
1083 size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()];
1084 auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>();
1085 int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id;
1086 ICHECK(ps != nullptr);
1087
1088 // Randomly pick a new computation location
1089 std::vector<std::pair<int, int>> candidates =
1090 GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + stage_inc);
1091 if (candidates.empty()) {
1092 return ResultKind::kInvalid;
1093 }
1094 int choice = (*rand_gen)() % (candidates.size());
1095 int new_compute_at_stage_id = candidates[choice].first;
1096 int new_compute_at_iter_id = candidates[choice].second;
1097
1098 // Replay a new state.
1099 State tmp_s = policy->search_task->compute_dag->init_state;
1100 for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
1101 if (s == step_id) {
1102 tmp_s.CopyOnWrite()->transform_steps.push_back(
1103 ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, new_compute_at_iter_id));
1104 } else {
1105 tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]);
1106 }
1107 try {
1108 StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag);
1109 } catch (Error& e) {
1110 return ResultKind::kInvalid;
1111 }
1112 }
1113
1114 *state = tmp_s;
1115 return ResultKind::kValid;
1116}
1117
1118PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state,
1119 std::mt19937* rand_gen) const {
1120 // This mutation rule only focuses on a case that parallel was added to
1121 // the outermost loop and the loop is generated by fusing other loops.
1122 // In short, we mutate the fusion step before the parallel step.
1123
1124 // Extract all parallel steps.
1125 std::vector<int> parallel_steps;
1126 for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
1127 auto ps = (*state)->transform_steps[s].as<AnnotationStepNode>();
1128 if (!ps || ps->annotation != IteratorAnnotation::kParallel) {
1129 continue;
1130 }
1131
1132 // Skip non-outermost loop or the parallel step without fusion beforehand.
1133 if (ps->iter_id != 0 || s == 0 || !(*state)->transform_steps[s - 1].as<FuseStepNode>()) {
1134 continue;
1135 }
1136 auto fuse_step = (*state)->transform_steps[s - 1].as<FuseStepNode>();
1137 if (fuse_step->fused_ids[0] != 0) {
1138 continue;
1139 }
1140
1141 parallel_steps.push_back(s);
1142 }
1143 if (parallel_steps.empty()) {
1144 return ResultKind::kInvalid;
1145 }
1146
1147 // Randomly pick one parallel step.
1148 size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()];
1149
1150 // Replay a new state until the picked fuse step.
1151 State tmp_s = policy->search_task->compute_dag->init_state;
1152 for (size_t s = 0; s < step_id - 1; ++s) {
1153 const auto& step = (*state)->transform_steps[s];
1154 tmp_s.CopyOnWrite()->transform_steps.push_back(step);
1155 StepApplyToState(step, &tmp_s, policy->search_task->compute_dag);
1156 }
1157
1158 // Compute all possible fusion granularities
1159 auto fuse_step = (*state)->transform_steps[step_id - 1].as<FuseStepNode>();
1160 int stage_id = fuse_step->stage_id;
1161 const Stage& stage = tmp_s->stages[stage_id];
1162 size_t max_fusable_iter_id;
1163 for (max_fusable_iter_id = 0; max_fusable_iter_id < stage->iters.size(); ++max_fusable_iter_id) {
1164 const Iterator& it = stage->iters[max_fusable_iter_id];
1165 if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
1166 break;
1167 }
1168
1169 if (tmp_s->attach_map->iter_to_attached_stages.count(
1170 std::make_pair(stage_id, max_fusable_iter_id))) {
1171 break;
1172 }
1173 }
1174
1175 if (max_fusable_iter_id == 0) {
1176 return ResultKind::kInvalid;
1177 }
1178
1179 // Randomly pick one granularity
1180 int fuse_to_iter_id = (*rand_gen)() % max_fusable_iter_id + 1;
1181 Array<Integer> fused_ids;
1182 for (int i = 0; i < fuse_to_iter_id; ++i) {
1183 fused_ids.push_back(i);
1184 }
1185 int iter_offset = fuse_step->fused_ids.back()->value - fused_ids.back()->value;
1186 if (iter_offset == 0) {
1187 return ResultKind::kInvalid;
1188 }
1189
1190 // Replay the mutated fused and annotation step.
1191 auto new_fuse_step = FuseStep(stage_id, fused_ids);
1192 tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step);
1193 StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag);
1194 tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]);
1195 StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag);
1196
1197 // Replay the rest steps.
1198 for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) {
1199 auto step = (*state)->transform_steps[s];
1200 if (step->stage_id == stage_id) {
1201 // Since we changed the loop structure, iter ID in later steps to the same stage
1202 // has to be adjusted.
1203 if (auto ps = step.as<AnnotationStepNode>()) {
1204 if (ps->iter_id == 0) {
1205 step = AnnotationStep(ps->stage_id, 0, ps->annotation);
1206 } else {
1207 ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size());
1208 step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation);
1209 }
1210 } else if (auto ps = step.as<PragmaStepNode>()) {
1211 if (ps->iter_id == 0) {
1212 step = PragmaStep(ps->stage_id, 0, ps->pragma_type);
1213 } else {
1214 ICHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size());
1215 step = PragmaStep(ps->stage_id, ps->iter_id + iter_offset, ps->pragma_type);
1216 }
1217 } else {
1218 return ResultKind::kInvalid;
1219 }
1220 }
1221 if (IsStageNumberChangingStep(step)) {
1222 // For these steps, we have to update stage_id because these steps will make stage_id
1223 // out-dated. But here we just simply give up this mutation for simplicity.
1224 // This is not an issue because this will never happend in normal cases where all these steps
1225 // are before parallel steps.
1226 return ResultKind::kInvalid;
1227 }
1228 tmp_s.CopyOnWrite()->transform_steps.push_back(step);
1229 try {
1230 StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag);
1231 } catch (Error& e) {
1232 return ResultKind::kInvalid;
1233 }
1234 }
1235
1236 *state = tmp_s;
1237 return ResultKind::kValid;
1238}
1239
1240} // namespace auto_scheduler
1241} // namespace tvm
1242