1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/utils.cc |
22 | * \brief Common utilities |
23 | */ |
24 | |
25 | #include "utils.h" |
26 | |
27 | #include <algorithm> |
28 | |
29 | namespace tvm { |
30 | namespace auto_scheduler { |
31 | |
32 | Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) { |
33 | const auto& stage = s->stages[stage_id]; |
34 | const auto& pop = s->stages[stage_id]->op.as<te::ComputeOpNode>(); |
35 | ICHECK(pop != nullptr); |
36 | const std::set<std::string>& no_split_at_inner_name_set = |
37 | stage->op->attrs.count(SearchPolicyKey::no_split_at_inner) |
38 | ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) |
39 | : std::set<std::string>(); |
40 | size_t reduce_count = 0; |
41 | for (const auto axis : pop->reduce_axis) { |
42 | if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { |
43 | reduce_count++; |
44 | } |
45 | } |
46 | |
47 | Array<Integer> spatial_split_step_ids; |
48 | for (int i = s->transform_steps.size() - 1; i >= 0; --i) { |
49 | if (IsStageNumberChangingStep(s->transform_steps[i])) { |
50 | if (stage_id > s->transform_steps[i]->stage_id) { |
51 | stage_id--; |
52 | } |
53 | } else if (auto ps = s->transform_steps[i].as<SplitStepNode>()) { |
54 | if (stage_id == ps->stage_id) { |
55 | // Assume SplitStep on reduction axes are always after SplitStep on spatial axes. |
56 | if (reduce_count) { |
57 | reduce_count--; |
58 | } else { |
59 | spatial_split_step_ids.push_back(i); |
60 | } |
61 | } |
62 | } |
63 | } |
64 | |
65 | return spatial_split_step_ids; |
66 | } |
67 | |
68 | std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task, |
69 | const State& state, int stage_id) { |
70 | int target_stage_id = GetSingleConsumerId(task, state, stage_id); |
71 | if (target_stage_id < 0) { |
72 | return {}; |
73 | } |
74 | const Stage& target_stage = state->stages[target_stage_id]; |
75 | |
76 | std::vector<std::pair<int, int>> candidates; |
77 | bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter; |
78 | bool target_is_tiled = IsTiled(target_stage); |
79 | |
80 | bool visited_reduce = false; |
81 | // Enumerate compute_at location at target_stage |
82 | // TODO(merrymercy): More analysis here to make smarter choices |
83 | for (size_t i = 0; i < target_stage->iters.size(); ++i) { |
84 | const Iterator& target_iter = target_stage->iters[i]; |
85 | if (target_iter->iter_kind == IteratorKind::kReduction) { |
86 | visited_reduce = true; |
87 | if (!target_is_tiled) { // Do not go into reduce iter |
88 | break; |
89 | } |
90 | } else if (target_iter->iter_kind == IteratorKind::kSpatial) { |
91 | if (visited_reduce) { // Do not go into inner tile |
92 | break; |
93 | } |
94 | } |
95 | |
96 | if (target_iter->annotation == IteratorAnnotation::kUnroll) { |
97 | // Do not go into the unroll region of const tensor indices |
98 | break; |
99 | } |
100 | |
101 | if (GetExtent(target_iter) == 1) { |
102 | // Skip iterators with length of 1 |
103 | continue; |
104 | } |
105 | if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial && |
106 | StrEndsWith(target_iter->name, ".0" )) { |
107 | // Skip the first level iterators if target stage compute_at another stage |
108 | // In this case, the lengths of first level iterators are always one |
109 | continue; |
110 | } |
111 | candidates.emplace_back(target_stage_id, i); |
112 | |
113 | if (state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) { |
114 | break; |
115 | } |
116 | } |
117 | |
118 | // if the target_stage is already compute_at another stage X, try also compute_at X |
119 | // We call stage X as `target_target_stage` |
120 | if (target_compute_at_other) { |
121 | int target_target_stage_id; |
122 | target_target_stage_id = state->attach_map->stage_to_attach_iter.at(target_stage_id).first; |
123 | const Stage& target_target_stage = state->stages[target_target_stage_id]; |
124 | |
125 | for (size_t i = 0; i < target_target_stage->iters.size(); ++i) { |
126 | const Iterator& target_target_iter = target_target_stage->iters[i]; |
127 | if (target_target_iter->iter_kind == IteratorKind::kReduction || |
128 | state->attach_map->iter_to_attached_stages.count( |
129 | std::make_pair(target_target_stage_id, i))) { |
130 | break; |
131 | } |
132 | |
133 | if (target_target_iter->annotation == IteratorAnnotation::kUnroll) { |
134 | // Do not go into the unroll region of const tensor indices |
135 | break; |
136 | } |
137 | |
138 | if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 |
139 | continue; |
140 | } |
141 | |
142 | candidates.emplace_back(target_target_stage_id, i); |
143 | } |
144 | } |
145 | |
146 | return candidates; |
147 | } |
148 | |
149 | State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, |
150 | std::vector<int>* spatial_split_step_ids) { |
151 | // Temporal object to be used if the input pointer is nullptr |
152 | std::vector<int> temp_split_step_ids; |
153 | if (spatial_split_step_ids == nullptr) { |
154 | spatial_split_step_ids = &temp_split_step_ids; |
155 | } |
156 | spatial_split_step_ids->clear(); |
157 | |
158 | std::vector<std::vector<Iterator>> space_levels; |
159 | std::vector<std::vector<Iterator>> reduce_levels; |
160 | std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner; |
161 | |
162 | size_t n_space = |
163 | std::count(format.begin(), format.end(), 's') + std::count(format.begin(), format.end(), 'S'); |
164 | size_t n_reduce = |
165 | std::count(format.begin(), format.end(), 'r') + std::count(format.begin(), format.end(), 'R'); |
166 | if (n_space + n_reduce != format.size()) { |
167 | LOG(FATAL) << "Invalid multi-level tiling format: " << format; |
168 | } |
169 | space_levels.resize(n_space); |
170 | reduce_levels.resize(n_reduce); |
171 | |
172 | State tmp_s = state; |
173 | const Stage& stage = state->stages[stage_id]; |
174 | const std::set<std::string>& no_split_at_inner_name_set = |
175 | stage->op->attrs.count(SearchPolicyKey::no_split_at_inner) |
176 | ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) |
177 | : std::set<std::string>(); |
178 | |
179 | auto sr_levels = [&](int size, const Iterator& iter, std::vector<std::vector<Iterator>>& levels) { |
180 | ICHECK_GE(size, 1); |
181 | if (size == 1) { |
182 | levels[0].push_back(iter); |
183 | } else { |
184 | Array<Iterator> split_res = |
185 | tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt)); |
186 | for (int i = 0; i < size; i++) { |
187 | levels[i].push_back(split_res[i]); |
188 | } |
189 | if (iter->iter_kind == IteratorKind::kSpatial) { |
190 | spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); |
191 | } |
192 | } |
193 | }; |
194 | |
195 | for (const auto& iter : state->stages[stage_id]->iters) { |
196 | if (!no_split_at_inner_name_set.count(iter->name)) { |
197 | if (iter->iter_kind == IteratorKind::kSpatial) { |
198 | sr_levels(n_space, iter, space_levels); |
199 | } else if (iter->iter_kind == IteratorKind::kReduction) { |
200 | sr_levels(n_reduce, iter, reduce_levels); |
201 | } else { |
202 | LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind); |
203 | } |
204 | } else { |
205 | if (iter->iter_kind == IteratorKind::kSpatial) { |
206 | space_inner.push_back(iter); |
207 | } else if (iter->iter_kind == IteratorKind::kReduction) { |
208 | reduce_inner.push_back(iter); |
209 | } else { |
210 | LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind); |
211 | } |
212 | } |
213 | } |
214 | |
215 | auto fill_levels = [&](std::vector<Iterator>& levels_iter, std::vector<Iterator>& fill) { |
216 | if (!fill.empty()) { |
217 | levels_iter.insert(levels_iter.begin(), std::make_move_iterator(fill.begin()), |
218 | std::make_move_iterator(fill.end())); |
219 | } |
220 | }; |
221 | if (!space_levels.empty()) { |
222 | fill_levels(space_levels.front(), space_outer); |
223 | fill_levels(space_levels.back(), space_inner); |
224 | } |
225 | if (!reduce_levels.empty()) { |
226 | fill_levels(reduce_levels.front(), reduce_outer); |
227 | fill_levels(reduce_levels.back(), reduce_inner); |
228 | } |
229 | |
230 | Array<Iterator> order; |
231 | int space_ct = 0, reduce_ct = 0; |
232 | for (const auto c : format) { |
233 | if (c == 's' || c == 'S') { |
234 | order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()), |
235 | std::make_move_iterator(space_levels[space_ct].end())); |
236 | space_ct++; |
237 | } else if (c == 'r' || c == 'R') { |
238 | order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()), |
239 | std::make_move_iterator(reduce_levels[reduce_ct].end())); |
240 | reduce_ct++; |
241 | } else { |
242 | LOG(FATAL) << "Invalid multi level tiling format: " << format; |
243 | } |
244 | } |
245 | |
246 | tmp_s.reorder(stage_id, order); |
247 | return tmp_s; |
248 | } |
249 | |
250 | State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids, |
251 | int n_split) { |
252 | if (n_split < 1 || n_split > 3) { |
253 | LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3" ; |
254 | } |
255 | // Apply up to three-level tiling structure: space_L0, space_L1, space_L2 |
256 | std::vector<Iterator> space_0, space_1, space_2, space_3, tmp_order; |
257 | Array<Iterator> split_res; |
258 | |
259 | auto pop = state->stages[stage_id]->op.as<te::ComputeOpNode>(); |
260 | ICHECK(pop != nullptr); |
261 | const Stage& stage = state->stages[stage_id]; |
262 | const std::set<std::string>& no_split_at_inner_name_set = |
263 | stage->op->attrs.count(SearchPolicyKey::no_split_at_inner) |
264 | ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) |
265 | : std::set<std::string>(); |
266 | int no_split_at_inner_name_in_stage_cnt = 0; |
267 | for (const auto& iter : state->stages[stage_id]->iters) { |
268 | no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name); |
269 | } |
270 | |
271 | ICHECK_EQ(state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt, |
272 | split_step_ids.size()); |
273 | |
274 | State tmp_s = state; |
275 | int ct = 0; |
276 | for (const auto& iter : state->stages[stage_id]->iters) { |
277 | if (iter->iter_kind == IteratorKind::kSpatial) { |
278 | // For spatial iterator, split it into multi iterators |
279 | if (!no_split_at_inner_name_set.count(iter->name)) { |
280 | IteratorAnnotation ann_type = iter->annotation; |
281 | split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], n_split); |
282 | // Restore annotation. Move unroll and vectorize to inner, move parallel |
283 | // to outer |
284 | switch (ann_type) { |
285 | case IteratorAnnotation::kUnroll: |
286 | split_res.Set(n_split, tmp_s.unroll(stage_id, split_res[n_split])); |
287 | break; |
288 | case IteratorAnnotation::kVectorize: |
289 | split_res.Set(n_split, tmp_s.vectorize(stage_id, split_res[n_split])); |
290 | break; |
291 | case IteratorAnnotation::kParallel: |
292 | split_res.Set(0, tmp_s.parallel(stage_id, split_res[0])); |
293 | break; |
294 | default: |
295 | break; |
296 | } |
297 | |
298 | space_0.push_back(split_res[0]); |
299 | space_1.push_back(split_res[1]); |
300 | if (n_split >= 2) { |
301 | space_2.push_back(split_res[2]); |
302 | if (n_split == 3) { |
303 | space_3.push_back(split_res[3]); |
304 | } |
305 | } |
306 | ct++; |
307 | } else { |
308 | if (no_split_at_inner_name_set.count(iter->name)) { |
309 | if (n_split == 1) { |
310 | space_1.push_back(iter); |
311 | } else if (n_split == 2) { |
312 | space_2.push_back(iter); |
313 | } else { |
314 | ICHECK_EQ(n_split, 3); |
315 | space_3.push_back(iter); |
316 | } |
317 | } |
318 | } |
319 | } else { |
320 | LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind); |
321 | } |
322 | } |
323 | |
324 | if (n_split == 3) { |
325 | ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); |
326 | } else if (n_split == 2) { |
327 | ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2); |
328 | } else { |
329 | ConcatenateMove(&tmp_order, &space_0, &space_1); |
330 | } |
331 | tmp_s.reorder(stage_id, tmp_order); |
332 | return tmp_s; |
333 | } |
334 | |
335 | // Return whether a state has nested parallel, which is invalid on CPUs |
336 | bool HasNestedParallel(const State& state) { |
337 | std::function<void(int stage_id, size_t*)> count_parallel_ct; |
338 | |
339 | count_parallel_ct = [&state, &count_parallel_ct](int stage_id, size_t* parallel_ct) { |
340 | const Stage& stage = state->stages[stage_id]; |
341 | |
342 | if (stage->compute_at == ComputeAtKind::kInlined) { |
343 | return; |
344 | } |
345 | |
346 | for (size_t i = 0; i < stage->iters.size(); ++i) { |
347 | if (stage->iters[i]->annotation == IteratorAnnotation::kParallel) { |
348 | (*parallel_ct)++; |
349 | } |
350 | |
351 | IterKey iter_key(stage_id, i); |
352 | auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); |
353 | if (pair != state->attach_map->iter_to_attached_stages.end()) { |
354 | for (const auto& attach_stage_id : pair->second) { |
355 | count_parallel_ct(attach_stage_id, parallel_ct); |
356 | } |
357 | } |
358 | } |
359 | }; |
360 | |
361 | for (size_t stage_id = 0; stage_id < state->stages.size(); ++stage_id) { |
362 | size_t parallel_ct = 0; |
363 | |
364 | if (state->stages[stage_id]->compute_at == ComputeAtKind::kRoot) { |
365 | count_parallel_ct(stage_id, ¶llel_ct); |
366 | if (parallel_ct >= 2) { |
367 | return true; |
368 | } |
369 | } |
370 | } |
371 | |
372 | return false; |
373 | } |
374 | |
375 | void PruneInvalidState(const SearchTask& task, Array<State>* states) { |
376 | size_t pt = 0; |
377 | for (size_t i = 0; i < states->size(); ++i) { |
378 | if (!(*states)[i].defined()) { |
379 | continue; |
380 | } |
381 | if (!IsGPUTask(task) && HasNestedParallel((*states)[i])) { |
382 | continue; |
383 | } |
384 | |
385 | if (i != pt) { |
386 | states->Set(pt, (*states)[i]); |
387 | } |
388 | pt++; |
389 | } |
390 | |
391 | if (pt == 0) { |
392 | LOG(FATAL) << "Internal error: All states are invalid." ; |
393 | } else { |
394 | states->resize(pt); |
395 | } |
396 | } |
397 | |
398 | /********** SplitFactorizationMemo **********/ |
399 | const Array<Array<Integer>>& SplitFactorizationMemo::GetFactorizationSchemes( |
400 | int extent, int n_lengths, int max_innermost_factor) { |
401 | QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); |
402 | const auto& it = memory_.find(key); |
403 | if (it != memory_.end()) { |
404 | return it->second; |
405 | } |
406 | |
407 | tmp_stack_ = Array<Integer>(n_lengths, Integer()); |
408 | results_ = &memory_[key]; |
409 | n_lengths_ = n_lengths; |
410 | |
411 | DfsEnumerate(0, extent, max_innermost_factor); |
412 | |
413 | return *results_; |
414 | } |
415 | |
416 | void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_length, int max_innermost_factor) { |
417 | if (now == n_lengths_) { |
418 | if (tmp_stack_.back().as<IntImmNode>()->value <= max_innermost_factor) { |
419 | results_->push_back(tmp_stack_); |
420 | } |
421 | } else { |
422 | for (const auto& f : GetFactors(remaining_length)) { |
423 | tmp_stack_.Set(now, Integer(f)); |
424 | DfsEnumerate(now + 1, remaining_length / f, max_innermost_factor); |
425 | } |
426 | } |
427 | } |
428 | |
429 | const std::vector<int>& SplitFactorizationMemo::GetFactors(int n) { |
430 | auto it = factor_memory_.find(n); |
431 | if (it != factor_memory_.end()) { |
432 | return it->second; |
433 | } |
434 | |
435 | std::vector<int>& res = factor_memory_[n]; |
436 | int step = n % 2 == 0 ? 1 : 2; |
437 | for (size_t i = 1; i < static_cast<size_t>(std::sqrt(n)) + 1; i += step) { |
438 | if (n % i == 0) { |
439 | res.push_back(i); |
440 | if (n / i != i) { |
441 | res.push_back(n / i); |
442 | } |
443 | } |
444 | } |
445 | std::sort(res.begin(), res.end()); |
446 | return res; |
447 | } |
448 | |
449 | /********** Utils interface API for ffi **********/ |
450 | |
451 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsGetConsumers" ) |
452 | .set_body_typed([](const SearchTask& task, const State& state, int stage_id) { |
453 | const std::set<int>& consumers = GetConsumers(task, state, stage_id); |
454 | tvm::Map<IntImm, IntImm> ret; |
455 | for (const auto& i : consumers) { |
456 | ret.Set(Integer(i), Integer(i)); |
457 | } |
458 | return ret; |
459 | }); |
460 | |
461 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsElementwiseMatch" ) |
462 | .set_body_typed([](const SearchTask& task, const State& state, int stage_id, |
463 | int target_stage_id) { |
464 | return ElementwiseMatch(task, state, stage_id, target_stage_id); |
465 | }); |
466 | |
467 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled" ) |
468 | .set_body_typed([](const Stage& stage) { return IsTiled(stage); }); |
469 | |
470 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheReadStage" ) |
471 | .set_body_typed([](const State& s, int stage_id) { return HasCacheReadStage(s, stage_id); }); |
472 | |
473 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheWriteStage" ) |
474 | .set_body_typed([](const State& s, int stage_id) { return HasCacheWriteStage(s, stage_id); }); |
475 | |
476 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasRfactorStage" ) |
477 | .set_body_typed([](const State& s, int stage_id) { return HasRfactorStage(s, stage_id); }); |
478 | |
479 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCrossThreadReduction" ) |
480 | .set_body_typed([](const State& s, int stage_id) { |
481 | return HasCrossThreadReduction(s, stage_id); |
482 | }); |
483 | |
484 | } // namespace auto_scheduler |
485 | } // namespace tvm |
486 | |