1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/utils.cc
22 * \brief Common utilities
23 */
24
25#include "utils.h"
26
27#include <algorithm>
28
29namespace tvm {
30namespace auto_scheduler {
31
32Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) {
33 const auto& stage = s->stages[stage_id];
34 const auto& pop = s->stages[stage_id]->op.as<te::ComputeOpNode>();
35 ICHECK(pop != nullptr);
36 const std::set<std::string>& no_split_at_inner_name_set =
37 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
38 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
39 : std::set<std::string>();
40 size_t reduce_count = 0;
41 for (const auto axis : pop->reduce_axis) {
42 if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
43 reduce_count++;
44 }
45 }
46
47 Array<Integer> spatial_split_step_ids;
48 for (int i = s->transform_steps.size() - 1; i >= 0; --i) {
49 if (IsStageNumberChangingStep(s->transform_steps[i])) {
50 if (stage_id > s->transform_steps[i]->stage_id) {
51 stage_id--;
52 }
53 } else if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
54 if (stage_id == ps->stage_id) {
55 // Assume SplitStep on reduction axes are always after SplitStep on spatial axes.
56 if (reduce_count) {
57 reduce_count--;
58 } else {
59 spatial_split_step_ids.push_back(i);
60 }
61 }
62 }
63 }
64
65 return spatial_split_step_ids;
66}
67
68std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
69 const State& state, int stage_id) {
70 int target_stage_id = GetSingleConsumerId(task, state, stage_id);
71 if (target_stage_id < 0) {
72 return {};
73 }
74 const Stage& target_stage = state->stages[target_stage_id];
75
76 std::vector<std::pair<int, int>> candidates;
77 bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
78 bool target_is_tiled = IsTiled(target_stage);
79
80 bool visited_reduce = false;
81 // Enumerate compute_at location at target_stage
82 // TODO(merrymercy): More analysis here to make smarter choices
83 for (size_t i = 0; i < target_stage->iters.size(); ++i) {
84 const Iterator& target_iter = target_stage->iters[i];
85 if (target_iter->iter_kind == IteratorKind::kReduction) {
86 visited_reduce = true;
87 if (!target_is_tiled) { // Do not go into reduce iter
88 break;
89 }
90 } else if (target_iter->iter_kind == IteratorKind::kSpatial) {
91 if (visited_reduce) { // Do not go into inner tile
92 break;
93 }
94 }
95
96 if (target_iter->annotation == IteratorAnnotation::kUnroll) {
97 // Do not go into the unroll region of const tensor indices
98 break;
99 }
100
101 if (GetExtent(target_iter) == 1) {
102 // Skip iterators with length of 1
103 continue;
104 }
105 if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
106 StrEndsWith(target_iter->name, ".0")) {
107 // Skip the first level iterators if target stage compute_at another stage
108 // In this case, the lengths of first level iterators are always one
109 continue;
110 }
111 candidates.emplace_back(target_stage_id, i);
112
113 if (state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
114 break;
115 }
116 }
117
118 // if the target_stage is already compute_at another stage X, try also compute_at X
119 // We call stage X as `target_target_stage`
120 if (target_compute_at_other) {
121 int target_target_stage_id;
122 target_target_stage_id = state->attach_map->stage_to_attach_iter.at(target_stage_id).first;
123 const Stage& target_target_stage = state->stages[target_target_stage_id];
124
125 for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
126 const Iterator& target_target_iter = target_target_stage->iters[i];
127 if (target_target_iter->iter_kind == IteratorKind::kReduction ||
128 state->attach_map->iter_to_attached_stages.count(
129 std::make_pair(target_target_stage_id, i))) {
130 break;
131 }
132
133 if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
134 // Do not go into the unroll region of const tensor indices
135 break;
136 }
137
138 if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1
139 continue;
140 }
141
142 candidates.emplace_back(target_target_stage_id, i);
143 }
144 }
145
146 return candidates;
147}
148
149State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
150 std::vector<int>* spatial_split_step_ids) {
151 // Temporal object to be used if the input pointer is nullptr
152 std::vector<int> temp_split_step_ids;
153 if (spatial_split_step_ids == nullptr) {
154 spatial_split_step_ids = &temp_split_step_ids;
155 }
156 spatial_split_step_ids->clear();
157
158 std::vector<std::vector<Iterator>> space_levels;
159 std::vector<std::vector<Iterator>> reduce_levels;
160 std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
161
162 size_t n_space =
163 std::count(format.begin(), format.end(), 's') + std::count(format.begin(), format.end(), 'S');
164 size_t n_reduce =
165 std::count(format.begin(), format.end(), 'r') + std::count(format.begin(), format.end(), 'R');
166 if (n_space + n_reduce != format.size()) {
167 LOG(FATAL) << "Invalid multi-level tiling format: " << format;
168 }
169 space_levels.resize(n_space);
170 reduce_levels.resize(n_reduce);
171
172 State tmp_s = state;
173 const Stage& stage = state->stages[stage_id];
174 const std::set<std::string>& no_split_at_inner_name_set =
175 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
176 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
177 : std::set<std::string>();
178
179 auto sr_levels = [&](int size, const Iterator& iter, std::vector<std::vector<Iterator>>& levels) {
180 ICHECK_GE(size, 1);
181 if (size == 1) {
182 levels[0].push_back(iter);
183 } else {
184 Array<Iterator> split_res =
185 tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt));
186 for (int i = 0; i < size; i++) {
187 levels[i].push_back(split_res[i]);
188 }
189 if (iter->iter_kind == IteratorKind::kSpatial) {
190 spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
191 }
192 }
193 };
194
195 for (const auto& iter : state->stages[stage_id]->iters) {
196 if (!no_split_at_inner_name_set.count(iter->name)) {
197 if (iter->iter_kind == IteratorKind::kSpatial) {
198 sr_levels(n_space, iter, space_levels);
199 } else if (iter->iter_kind == IteratorKind::kReduction) {
200 sr_levels(n_reduce, iter, reduce_levels);
201 } else {
202 LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
203 }
204 } else {
205 if (iter->iter_kind == IteratorKind::kSpatial) {
206 space_inner.push_back(iter);
207 } else if (iter->iter_kind == IteratorKind::kReduction) {
208 reduce_inner.push_back(iter);
209 } else {
210 LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
211 }
212 }
213 }
214
215 auto fill_levels = [&](std::vector<Iterator>& levels_iter, std::vector<Iterator>& fill) {
216 if (!fill.empty()) {
217 levels_iter.insert(levels_iter.begin(), std::make_move_iterator(fill.begin()),
218 std::make_move_iterator(fill.end()));
219 }
220 };
221 if (!space_levels.empty()) {
222 fill_levels(space_levels.front(), space_outer);
223 fill_levels(space_levels.back(), space_inner);
224 }
225 if (!reduce_levels.empty()) {
226 fill_levels(reduce_levels.front(), reduce_outer);
227 fill_levels(reduce_levels.back(), reduce_inner);
228 }
229
230 Array<Iterator> order;
231 int space_ct = 0, reduce_ct = 0;
232 for (const auto c : format) {
233 if (c == 's' || c == 'S') {
234 order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
235 std::make_move_iterator(space_levels[space_ct].end()));
236 space_ct++;
237 } else if (c == 'r' || c == 'R') {
238 order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
239 std::make_move_iterator(reduce_levels[reduce_ct].end()));
240 reduce_ct++;
241 } else {
242 LOG(FATAL) << "Invalid multi level tiling format: " << format;
243 }
244 }
245
246 tmp_s.reorder(stage_id, order);
247 return tmp_s;
248}
249
250State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
251 int n_split) {
252 if (n_split < 1 || n_split > 3) {
253 LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3";
254 }
255 // Apply up to three-level tiling structure: space_L0, space_L1, space_L2
256 std::vector<Iterator> space_0, space_1, space_2, space_3, tmp_order;
257 Array<Iterator> split_res;
258
259 auto pop = state->stages[stage_id]->op.as<te::ComputeOpNode>();
260 ICHECK(pop != nullptr);
261 const Stage& stage = state->stages[stage_id];
262 const std::set<std::string>& no_split_at_inner_name_set =
263 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
264 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
265 : std::set<std::string>();
266 int no_split_at_inner_name_in_stage_cnt = 0;
267 for (const auto& iter : state->stages[stage_id]->iters) {
268 no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name);
269 }
270
271 ICHECK_EQ(state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt,
272 split_step_ids.size());
273
274 State tmp_s = state;
275 int ct = 0;
276 for (const auto& iter : state->stages[stage_id]->iters) {
277 if (iter->iter_kind == IteratorKind::kSpatial) {
278 // For spatial iterator, split it into multi iterators
279 if (!no_split_at_inner_name_set.count(iter->name)) {
280 IteratorAnnotation ann_type = iter->annotation;
281 split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], n_split);
282 // Restore annotation. Move unroll and vectorize to inner, move parallel
283 // to outer
284 switch (ann_type) {
285 case IteratorAnnotation::kUnroll:
286 split_res.Set(n_split, tmp_s.unroll(stage_id, split_res[n_split]));
287 break;
288 case IteratorAnnotation::kVectorize:
289 split_res.Set(n_split, tmp_s.vectorize(stage_id, split_res[n_split]));
290 break;
291 case IteratorAnnotation::kParallel:
292 split_res.Set(0, tmp_s.parallel(stage_id, split_res[0]));
293 break;
294 default:
295 break;
296 }
297
298 space_0.push_back(split_res[0]);
299 space_1.push_back(split_res[1]);
300 if (n_split >= 2) {
301 space_2.push_back(split_res[2]);
302 if (n_split == 3) {
303 space_3.push_back(split_res[3]);
304 }
305 }
306 ct++;
307 } else {
308 if (no_split_at_inner_name_set.count(iter->name)) {
309 if (n_split == 1) {
310 space_1.push_back(iter);
311 } else if (n_split == 2) {
312 space_2.push_back(iter);
313 } else {
314 ICHECK_EQ(n_split, 3);
315 space_3.push_back(iter);
316 }
317 }
318 }
319 } else {
320 LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
321 }
322 }
323
324 if (n_split == 3) {
325 ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3);
326 } else if (n_split == 2) {
327 ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2);
328 } else {
329 ConcatenateMove(&tmp_order, &space_0, &space_1);
330 }
331 tmp_s.reorder(stage_id, tmp_order);
332 return tmp_s;
333}
334
335// Return whether a state has nested parallel, which is invalid on CPUs
336bool HasNestedParallel(const State& state) {
337 std::function<void(int stage_id, size_t*)> count_parallel_ct;
338
339 count_parallel_ct = [&state, &count_parallel_ct](int stage_id, size_t* parallel_ct) {
340 const Stage& stage = state->stages[stage_id];
341
342 if (stage->compute_at == ComputeAtKind::kInlined) {
343 return;
344 }
345
346 for (size_t i = 0; i < stage->iters.size(); ++i) {
347 if (stage->iters[i]->annotation == IteratorAnnotation::kParallel) {
348 (*parallel_ct)++;
349 }
350
351 IterKey iter_key(stage_id, i);
352 auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
353 if (pair != state->attach_map->iter_to_attached_stages.end()) {
354 for (const auto& attach_stage_id : pair->second) {
355 count_parallel_ct(attach_stage_id, parallel_ct);
356 }
357 }
358 }
359 };
360
361 for (size_t stage_id = 0; stage_id < state->stages.size(); ++stage_id) {
362 size_t parallel_ct = 0;
363
364 if (state->stages[stage_id]->compute_at == ComputeAtKind::kRoot) {
365 count_parallel_ct(stage_id, &parallel_ct);
366 if (parallel_ct >= 2) {
367 return true;
368 }
369 }
370 }
371
372 return false;
373}
374
375void PruneInvalidState(const SearchTask& task, Array<State>* states) {
376 size_t pt = 0;
377 for (size_t i = 0; i < states->size(); ++i) {
378 if (!(*states)[i].defined()) {
379 continue;
380 }
381 if (!IsGPUTask(task) && HasNestedParallel((*states)[i])) {
382 continue;
383 }
384
385 if (i != pt) {
386 states->Set(pt, (*states)[i]);
387 }
388 pt++;
389 }
390
391 if (pt == 0) {
392 LOG(FATAL) << "Internal error: All states are invalid.";
393 } else {
394 states->resize(pt);
395 }
396}
397
398/********** SplitFactorizationMemo **********/
399const Array<Array<Integer>>& SplitFactorizationMemo::GetFactorizationSchemes(
400 int extent, int n_lengths, int max_innermost_factor) {
401 QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor);
402 const auto& it = memory_.find(key);
403 if (it != memory_.end()) {
404 return it->second;
405 }
406
407 tmp_stack_ = Array<Integer>(n_lengths, Integer());
408 results_ = &memory_[key];
409 n_lengths_ = n_lengths;
410
411 DfsEnumerate(0, extent, max_innermost_factor);
412
413 return *results_;
414}
415
416void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_length, int max_innermost_factor) {
417 if (now == n_lengths_) {
418 if (tmp_stack_.back().as<IntImmNode>()->value <= max_innermost_factor) {
419 results_->push_back(tmp_stack_);
420 }
421 } else {
422 for (const auto& f : GetFactors(remaining_length)) {
423 tmp_stack_.Set(now, Integer(f));
424 DfsEnumerate(now + 1, remaining_length / f, max_innermost_factor);
425 }
426 }
427}
428
429const std::vector<int>& SplitFactorizationMemo::GetFactors(int n) {
430 auto it = factor_memory_.find(n);
431 if (it != factor_memory_.end()) {
432 return it->second;
433 }
434
435 std::vector<int>& res = factor_memory_[n];
436 int step = n % 2 == 0 ? 1 : 2;
437 for (size_t i = 1; i < static_cast<size_t>(std::sqrt(n)) + 1; i += step) {
438 if (n % i == 0) {
439 res.push_back(i);
440 if (n / i != i) {
441 res.push_back(n / i);
442 }
443 }
444 }
445 std::sort(res.begin(), res.end());
446 return res;
447}
448
449/********** Utils interface API for ffi **********/
450
451TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsGetConsumers")
452 .set_body_typed([](const SearchTask& task, const State& state, int stage_id) {
453 const std::set<int>& consumers = GetConsumers(task, state, stage_id);
454 tvm::Map<IntImm, IntImm> ret;
455 for (const auto& i : consumers) {
456 ret.Set(Integer(i), Integer(i));
457 }
458 return ret;
459 });
460
461TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsElementwiseMatch")
462 .set_body_typed([](const SearchTask& task, const State& state, int stage_id,
463 int target_stage_id) {
464 return ElementwiseMatch(task, state, stage_id, target_stage_id);
465 });
466
467TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled")
468 .set_body_typed([](const Stage& stage) { return IsTiled(stage); });
469
470TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheReadStage")
471 .set_body_typed([](const State& s, int stage_id) { return HasCacheReadStage(s, stage_id); });
472
473TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheWriteStage")
474 .set_body_typed([](const State& s, int stage_id) { return HasCacheWriteStage(s, stage_id); });
475
476TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasRfactorStage")
477 .set_body_typed([](const State& s, int stage_id) { return HasRfactorStage(s, stage_id); });
478
479TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCrossThreadReduction")
480 .set_body_typed([](const State& s, int stage_id) {
481 return HasCrossThreadReduction(s, stage_id);
482 });
483
484} // namespace auto_scheduler
485} // namespace tvm
486